Skip to content

Commit fe00d6b

Browse files
authored
Send a request to config.json when downloading model from HF repository (#31)
## Description Currently, the HuggingFace download count is not incremented when downloading models. HF counts the model downloads by counting the requests sent to the config.json file in the repository root (a HEAD request is enough). This PR sends a HEAD request to the config json, when the model source is our organization. source: https://huggingface.co/docs/hub/models-download-stats ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 0ae6023 commit fe00d6b

File tree

2 files changed

+184
-122
lines changed

2 files changed

+184
-122
lines changed

android/src/main/java/com/rnexecutorch/Fetcher.kt

+161-122
Original file line numberDiff line numberDiff line change
@@ -5,104 +5,135 @@ import okhttp3.Call
55
import okhttp3.Callback
66
import okhttp3.OkHttpClient
77
import okhttp3.Request
8+
import okhttp3.RequestBody
89
import okhttp3.Response
910
import okio.IOException
1011
import java.io.File
1112
import java.io.FileOutputStream
1213
import java.net.URL
1314

1415
enum class ResourceType {
15-
TOKENIZER,
16-
MODEL
16+
TOKENIZER,
17+
MODEL,
1718
}
1819

1920
class Fetcher {
20-
companion object {
21-
private fun saveResponseToFile(
22-
response: Response,
23-
directory: File,
24-
fileName: String
25-
): File {
26-
val file = File(directory.path, fileName)
27-
file.outputStream().use { outputStream ->
28-
response.body?.byteStream()?.copyTo(outputStream)
29-
}
30-
return file
21+
companion object {
22+
private fun saveResponseToFile(
23+
response: Response,
24+
directory: File,
25+
fileName: String,
26+
): File {
27+
val file = File(directory.path, fileName)
28+
file.outputStream().use { outputStream ->
29+
response.body?.byteStream()?.copyTo(outputStream)
30+
}
31+
return file
32+
}
33+
34+
private fun hasValidExtension(fileName: String, resourceType: ResourceType): Boolean {
35+
return when (resourceType) {
36+
ResourceType.TOKENIZER -> {
37+
fileName.endsWith(".bin")
3138
}
3239

33-
private fun hasValidExtension(fileName: String, resourceType: ResourceType): Boolean {
34-
return when (resourceType) {
35-
ResourceType.TOKENIZER -> {
36-
fileName.endsWith(".bin")
37-
}
40+
ResourceType.MODEL -> {
41+
fileName.endsWith(".pte")
42+
}
43+
}
44+
}
3845

39-
ResourceType.MODEL -> {
40-
fileName.endsWith(".pte")
41-
}
42-
}
46+
private fun extractFileName(url: URL): String {
47+
if (url.path == "/assets/") {
48+
val pathSegments = url.toString().split('/')
49+
return pathSegments[pathSegments.size - 1].split("?")[0]
50+
} else if (url.protocol == "file") {
51+
val localPath = url.toString().split("://")[1]
52+
val file = File(localPath)
53+
if (file.exists()) {
54+
return localPath
4355
}
4456

45-
private fun extractFileName(url: URL): String {
46-
if (url.path == "/assets/") {
47-
val pathSegments = url.toString().split('/')
48-
return pathSegments[pathSegments.size - 1].split("?")[0]
49-
} else if (url.protocol == "file") {
50-
val localPath = url.toString().split("://")[1]
51-
val file = File(localPath)
52-
if (file.exists()) {
53-
return localPath
57+
throw Exception("file_not_found")
58+
} else {
59+
return url.path.substringAfterLast('/')
60+
}
61+
}
62+
63+
private fun fetchModel(
64+
file: File,
65+
validFile: File,
66+
client: OkHttpClient,
67+
url: URL,
68+
onComplete: (String?, Exception?) -> Unit,
69+
listener: ProgressResponseBody.ProgressListener? = null,
70+
) {
71+
val request = Request.Builder().url(url).build()
72+
client.newCall(request).enqueue(object : Callback {
73+
override fun onFailure(call: Call, e: IOException) {
74+
onComplete(null, e)
75+
}
76+
77+
override fun onResponse(call: Call, response: Response) {
78+
if (!response.isSuccessful) {
79+
onComplete(null, Exception("download_error"))
80+
}
81+
82+
response.body?.let { body ->
83+
val progressBody = listener?.let { ProgressResponseBody(body, it) }
84+
val inputStream = progressBody?.source()?.inputStream()
85+
inputStream?.use { input ->
86+
FileOutputStream(file).use { output ->
87+
val buffer = ByteArray(2048)
88+
var bytesRead: Int
89+
while (input.read(buffer).also { bytesRead = it } != -1) {
90+
output.write(buffer, 0, bytesRead)
5491
}
92+
}
93+
}
5594

56-
throw Exception("file_not_found")
95+
if (file.renameTo(validFile)) {
96+
onComplete(validFile.absolutePath, null)
5797
} else {
58-
return url.path.substringAfterLast('/')
98+
onComplete(null, Exception("Failed to move the file to the valid location"))
5999
}
100+
}
60101
}
102+
})
103+
}
61104

62-
private fun fetchModel(file: File, validFile: File, client: OkHttpClient, url: URL, onComplete: (String?, Exception?) -> Unit,
63-
listener: ProgressResponseBody.ProgressListener? = null){
64-
val request = Request.Builder().url(url).build()
65-
client.newCall(request).enqueue(object : Callback {
66-
override fun onFailure(call: Call, e: IOException) {
67-
onComplete(null, e)
68-
}
105+
private fun isUrlPointingToHfRepo(url: URL): Boolean {
106+
val expectedHost = "huggingface.co"
107+
val expectedPathPrefix = "/software-mansion/"
108+
if (url.host != expectedHost) {
109+
return false
110+
}
111+
return url.path.startsWith(expectedPathPrefix)
112+
}
69113

70-
override fun onResponse(call: Call, response: Response) {
71-
if (!response.isSuccessful) {
72-
onComplete(null, Exception("download_error"))
73-
}
74-
75-
response.body?.let { body ->
76-
val progressBody = listener?.let { ProgressResponseBody(body, it) }
77-
val inputStream = progressBody?.source()?.inputStream()
78-
inputStream?.use { input ->
79-
FileOutputStream(file).use { output ->
80-
val buffer = ByteArray(2048)
81-
var bytesRead: Int
82-
while (input.read(buffer).also { bytesRead = it } != -1) {
83-
output.write(buffer, 0, bytesRead)
84-
}
85-
}
86-
}
87-
88-
if (file.renameTo(validFile)) {
89-
onComplete(validFile.absolutePath, null)
90-
} else {
91-
onComplete(null, Exception("Failed to move the file to the valid location"))
92-
}
93-
}
94-
}
95-
})
96-
}
114+
private fun resolveConfigUrlFromModelUrl(modelUrl: URL): URL {
115+
// Create a new URL using the base URL and append the desired path
116+
val baseUrl = modelUrl.protocol + "://" + modelUrl.host + modelUrl.path.substringBefore("resolve/")
117+
return URL(baseUrl + "resolve/main/config.json")
118+
}
97119

98-
fun downloadResource(
99-
context: Context,
100-
client: OkHttpClient,
101-
url: URL,
102-
resourceType: ResourceType,
103-
onComplete: (String?, Exception?) -> Unit,
104-
listener: ProgressResponseBody.ProgressListener? = null
105-
) {
120+
private fun sendRequestToUrl(url: URL, method: String, body: RequestBody?, client: OkHttpClient): Response {
121+
val request = Request.Builder()
122+
.url(url)
123+
.method(method, body)
124+
.build()
125+
val response = client.newCall(request).execute()
126+
return response
127+
}
128+
129+
fun downloadResource(
130+
context: Context,
131+
client: OkHttpClient,
132+
url: URL,
133+
resourceType: ResourceType,
134+
onComplete: (String?, Exception?) -> Unit,
135+
listener: ProgressResponseBody.ProgressListener? = null,
136+
) {
106137
/*
107138
Fetching model and tokenizer file
108139
1. Extract file name from provided URL
@@ -115,57 +146,65 @@ class Fetcher {
115146
6. If the file does not exist, and is a tokenizer, fetch the file
116147
7. If the file is a model, fetch the file with ProgressResponseBody
117148
*/
118-
val fileName: String
119-
120-
try {
121-
fileName = extractFileName(url)
122-
} catch (e: Exception) {
123-
onComplete(null, e)
124-
return
125-
}
126-
127-
if(fileName.contains("/")){
128-
onComplete(fileName, null)
129-
return
130-
}
131-
132-
if (!hasValidExtension(fileName, resourceType)) {
133-
onComplete(null, Exception("invalid_resource_extension"))
134-
return
135-
}
136-
137-
var tempFile = File(context.filesDir, fileName)
138-
if(tempFile.exists()){
139-
tempFile.delete()
140-
}
149+
val fileName: String
150+
151+
try {
152+
fileName = extractFileName(url)
153+
} catch (e: Exception) {
154+
onComplete(null, e)
155+
return
156+
}
157+
158+
if (fileName.contains("/")) {
159+
onComplete(fileName, null)
160+
return
161+
}
162+
163+
if (!hasValidExtension(fileName, resourceType)) {
164+
onComplete(null, Exception("invalid_resource_extension"))
165+
return
166+
}
167+
168+
var tempFile = File(context.filesDir, fileName)
169+
if (tempFile.exists()) {
170+
tempFile.delete()
171+
}
172+
173+
val modelsDirectory = File(context.filesDir, "models").apply {
174+
if (!exists()) {
175+
mkdirs()
176+
}
177+
}
141178

142-
val modelsDirectory = File(context.filesDir, "models").apply {
143-
if (!exists()) {
144-
mkdirs()
145-
}
146-
}
179+
var validFile = File(modelsDirectory, fileName)
180+
if (validFile.exists()) {
181+
onComplete(validFile.absolutePath, null)
182+
return
183+
}
147184

148-
var validFile = File(modelsDirectory, fileName)
149-
if (validFile.exists()) {
150-
onComplete(validFile.absolutePath, null)
151-
return
152-
}
185+
if (resourceType == ResourceType.TOKENIZER) {
186+
val request = Request.Builder().url(url).build()
187+
val response = client.newCall(request).execute()
153188

154-
if (resourceType == ResourceType.TOKENIZER) {
155-
val request = Request.Builder().url(url).build()
156-
val response = client.newCall(request).execute()
189+
if (!response.isSuccessful) {
190+
onComplete(null, Exception("download_error"))
191+
return
192+
}
157193

158-
if (!response.isSuccessful) {
159-
onComplete(null, Exception("download_error"))
160-
return
161-
}
194+
validFile = saveResponseToFile(response, modelsDirectory, fileName)
195+
onComplete(validFile.absolutePath, null)
196+
return
197+
}
162198

163-
validFile = saveResponseToFile(response, modelsDirectory, fileName)
164-
onComplete(validFile.absolutePath, null)
165-
return
166-
}
199+
// If the url is a Software Mansion HuggingFace repo, we want to send a HEAD
200+
// request to the config.json file, this increments HF download counter
201+
// https://huggingface.co/docs/hub/models-download-stats
202+
if (isUrlPointingToHfRepo(url)) {
203+
val configUrl = resolveConfigUrlFromModelUrl(url)
204+
sendRequestToUrl(configUrl, "HEAD", null, client)
205+
}
167206

168-
fetchModel(tempFile, validFile, client, url, onComplete, listener)
169-
}
207+
fetchModel(tempFile, validFile, client, url, onComplete, listener)
170208
}
171-
}
209+
}
210+
}

ios/utils/LargeFileFetcher.mm

+23
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ - (void)URLSession:(NSURLSession *)session downloadTask:(NSURLSessionDownloadTas
5454
}
5555
}
5656

57+
- (void)sendHeadRequestToURL:(NSURL *)url {
58+
NSMutableURLRequest *request = [NSMutableURLRequest requestWithURL:url];
59+
[request setHTTPMethod:@"HEAD"];
60+
NSURLSessionDataTask *dataTask = [_session dataTaskWithRequest:request];
61+
[dataTask resume];
62+
}
63+
5764
- (void)startDownloadingFileFromURL:(NSURL *)url {
5865
//Check if file is a valid url, if not check if it's path to local file
5966
if (![Fetcher isValidURL:url]) {
@@ -77,6 +84,22 @@ - (void)startDownloadingFileFromURL:(NSURL *)url {
7784
[self executeCompletionWithSuccess:filePath];
7885
return;
7986
}
87+
88+
// If the url is a Software Mansion HuggingFace repo, we want to send a HEAD
89+
// request to the config.json file, this increments HF download counter
90+
// https://huggingface.co/docs/hub/models-download-stats
91+
NSString *huggingFaceOrgNSString = @"https://huggingface.co/software-mansion/";
92+
NSString *modelURLNSString = [url absoluteString];
93+
94+
if ([modelURLNSString hasPrefix:huggingFaceOrgNSString]) {
95+
NSRange resolveRange = [modelURLNSString rangeOfString:@"resolve"];
96+
if (resolveRange.location != NSNotFound) {
97+
NSString *configURLNSString = [modelURLNSString substringToIndex:resolveRange.location + resolveRange.length];
98+
configURLNSString = [configURLNSString stringByAppendingString:@"/main/config.json"];
99+
NSURL *configNSURL = [NSURL URLWithString:configURLNSString];
100+
[self sendHeadRequestToURL:configNSURL];
101+
}
102+
}
80103

81104
//Cancel all running background download tasks and start new one
82105
_destination = filePath;

0 commit comments

Comments
 (0)