@@ -5,104 +5,135 @@ import okhttp3.Call
5
5
import okhttp3.Callback
6
6
import okhttp3.OkHttpClient
7
7
import okhttp3.Request
8
+ import okhttp3.RequestBody
8
9
import okhttp3.Response
9
10
import okio.IOException
10
11
import java.io.File
11
12
import java.io.FileOutputStream
12
13
import java.net.URL
13
14
14
15
enum class ResourceType {
15
- TOKENIZER ,
16
- MODEL
16
+ TOKENIZER ,
17
+ MODEL ,
17
18
}
18
19
19
20
class Fetcher {
20
- companion object {
21
- private fun saveResponseToFile (
22
- response : Response ,
23
- directory : File ,
24
- fileName : String
25
- ): File {
26
- val file = File (directory.path, fileName)
27
- file.outputStream().use { outputStream ->
28
- response.body?.byteStream()?.copyTo(outputStream)
29
- }
30
- return file
21
+ companion object {
22
+ private fun saveResponseToFile (
23
+ response : Response ,
24
+ directory : File ,
25
+ fileName : String ,
26
+ ): File {
27
+ val file = File (directory.path, fileName)
28
+ file.outputStream().use { outputStream ->
29
+ response.body?.byteStream()?.copyTo(outputStream)
30
+ }
31
+ return file
32
+ }
33
+
34
+ private fun hasValidExtension (fileName : String , resourceType : ResourceType ): Boolean {
35
+ return when (resourceType) {
36
+ ResourceType .TOKENIZER -> {
37
+ fileName.endsWith(" .bin" )
31
38
}
32
39
33
- private fun hasValidExtension ( fileName : String , resourceType : ResourceType ): Boolean {
34
- return when (resourceType) {
35
- ResourceType . TOKENIZER -> {
36
- fileName.endsWith( " .bin " )
37
- }
40
+ ResourceType . MODEL -> {
41
+ fileName.endsWith( " .pte " )
42
+ }
43
+ }
44
+ }
38
45
39
- ResourceType .MODEL -> {
40
- fileName.endsWith(" .pte" )
41
- }
42
- }
46
+ private fun extractFileName (url : URL ): String {
47
+ if (url.path == " /assets/" ) {
48
+ val pathSegments = url.toString().split(' /' )
49
+ return pathSegments[pathSegments.size - 1 ].split(" ?" )[0 ]
50
+ } else if (url.protocol == " file" ) {
51
+ val localPath = url.toString().split(" ://" )[1 ]
52
+ val file = File (localPath)
53
+ if (file.exists()) {
54
+ return localPath
43
55
}
44
56
45
- private fun extractFileName (url : URL ): String {
46
- if (url.path == " /assets/" ) {
47
- val pathSegments = url.toString().split(' /' )
48
- return pathSegments[pathSegments.size - 1 ].split(" ?" )[0 ]
49
- } else if (url.protocol == " file" ) {
50
- val localPath = url.toString().split(" ://" )[1 ]
51
- val file = File (localPath)
52
- if (file.exists()) {
53
- return localPath
57
+ throw Exception (" file_not_found" )
58
+ } else {
59
+ return url.path.substringAfterLast(' /' )
60
+ }
61
+ }
62
+
63
+ private fun fetchModel (
64
+ file : File ,
65
+ validFile : File ,
66
+ client : OkHttpClient ,
67
+ url : URL ,
68
+ onComplete : (String? , Exception ? ) -> Unit ,
69
+ listener : ProgressResponseBody .ProgressListener ? = null,
70
+ ) {
71
+ val request = Request .Builder ().url(url).build()
72
+ client.newCall(request).enqueue(object : Callback {
73
+ override fun onFailure (call : Call , e : IOException ) {
74
+ onComplete(null , e)
75
+ }
76
+
77
+ override fun onResponse (call : Call , response : Response ) {
78
+ if (! response.isSuccessful) {
79
+ onComplete(null , Exception (" download_error" ))
80
+ }
81
+
82
+ response.body?.let { body ->
83
+ val progressBody = listener?.let { ProgressResponseBody (body, it) }
84
+ val inputStream = progressBody?.source()?.inputStream()
85
+ inputStream?.use { input ->
86
+ FileOutputStream (file).use { output ->
87
+ val buffer = ByteArray (2048 )
88
+ var bytesRead: Int
89
+ while (input.read(buffer).also { bytesRead = it } != - 1 ) {
90
+ output.write(buffer, 0 , bytesRead)
54
91
}
92
+ }
93
+ }
55
94
56
- throw Exception (" file_not_found" )
95
+ if (file.renameTo(validFile)) {
96
+ onComplete(validFile.absolutePath, null )
57
97
} else {
58
- return url.path.substringAfterLast( ' / ' )
98
+ onComplete( null , Exception ( " Failed to move the file to the valid location " ) )
59
99
}
100
+ }
60
101
}
102
+ })
103
+ }
61
104
62
- private fun fetchModel (file : File , validFile : File , client : OkHttpClient , url : URL , onComplete : (String? , Exception ? ) -> Unit ,
63
- listener : ProgressResponseBody .ProgressListener ? = null){
64
- val request = Request .Builder ().url(url).build()
65
- client.newCall(request).enqueue(object : Callback {
66
- override fun onFailure (call : Call , e : IOException ) {
67
- onComplete(null , e)
68
- }
105
+ private fun isUrlPointingToHfRepo (url : URL ): Boolean {
106
+ val expectedHost = " huggingface.co"
107
+ val expectedPathPrefix = " /software-mansion/"
108
+ if (url.host != expectedHost) {
109
+ return false
110
+ }
111
+ return url.path.startsWith(expectedPathPrefix)
112
+ }
69
113
70
- override fun onResponse (call : Call , response : Response ) {
71
- if (! response.isSuccessful) {
72
- onComplete(null , Exception (" download_error" ))
73
- }
74
-
75
- response.body?.let { body ->
76
- val progressBody = listener?.let { ProgressResponseBody (body, it) }
77
- val inputStream = progressBody?.source()?.inputStream()
78
- inputStream?.use { input ->
79
- FileOutputStream (file).use { output ->
80
- val buffer = ByteArray (2048 )
81
- var bytesRead: Int
82
- while (input.read(buffer).also { bytesRead = it } != - 1 ) {
83
- output.write(buffer, 0 , bytesRead)
84
- }
85
- }
86
- }
87
-
88
- if (file.renameTo(validFile)) {
89
- onComplete(validFile.absolutePath, null )
90
- } else {
91
- onComplete(null , Exception (" Failed to move the file to the valid location" ))
92
- }
93
- }
94
- }
95
- })
96
- }
114
+ private fun resolveConfigUrlFromModelUrl (modelUrl : URL ): URL {
115
+ // Create a new URL using the base URL and append the desired path
116
+ val baseUrl = modelUrl.protocol + " ://" + modelUrl.host + modelUrl.path.substringBefore(" resolve/" )
117
+ return URL (baseUrl + " resolve/main/config.json" )
118
+ }
97
119
98
- fun downloadResource (
99
- context : Context ,
100
- client : OkHttpClient ,
101
- url : URL ,
102
- resourceType : ResourceType ,
103
- onComplete : (String? , Exception ? ) -> Unit ,
104
- listener : ProgressResponseBody .ProgressListener ? = null
105
- ) {
120
+ private fun sendRequestToUrl (url : URL , method : String , body : RequestBody ? , client : OkHttpClient ): Response {
121
+ val request = Request .Builder ()
122
+ .url(url)
123
+ .method(method, body)
124
+ .build()
125
+ val response = client.newCall(request).execute()
126
+ return response
127
+ }
128
+
129
+ fun downloadResource (
130
+ context : Context ,
131
+ client : OkHttpClient ,
132
+ url : URL ,
133
+ resourceType : ResourceType ,
134
+ onComplete : (String? , Exception ? ) -> Unit ,
135
+ listener : ProgressResponseBody .ProgressListener ? = null,
136
+ ) {
106
137
/*
107
138
Fetching model and tokenizer file
108
139
1. Extract file name from provided URL
@@ -115,57 +146,65 @@ class Fetcher {
115
146
6. If the file does not exist, and is a tokenizer, fetch the file
116
147
7. If the file is a model, fetch the file with ProgressResponseBody
117
148
*/
118
- val fileName: String
119
-
120
- try {
121
- fileName = extractFileName(url)
122
- } catch (e: Exception ) {
123
- onComplete(null , e)
124
- return
125
- }
126
-
127
- if (fileName.contains(" /" )){
128
- onComplete(fileName, null )
129
- return
130
- }
131
-
132
- if (! hasValidExtension(fileName, resourceType)) {
133
- onComplete(null , Exception (" invalid_resource_extension" ))
134
- return
135
- }
136
-
137
- var tempFile = File (context.filesDir, fileName)
138
- if (tempFile.exists()){
139
- tempFile.delete()
140
- }
149
+ val fileName: String
150
+
151
+ try {
152
+ fileName = extractFileName(url)
153
+ } catch (e: Exception ) {
154
+ onComplete(null , e)
155
+ return
156
+ }
157
+
158
+ if (fileName.contains(" /" )) {
159
+ onComplete(fileName, null )
160
+ return
161
+ }
162
+
163
+ if (! hasValidExtension(fileName, resourceType)) {
164
+ onComplete(null , Exception (" invalid_resource_extension" ))
165
+ return
166
+ }
167
+
168
+ var tempFile = File (context.filesDir, fileName)
169
+ if (tempFile.exists()) {
170
+ tempFile.delete()
171
+ }
172
+
173
+ val modelsDirectory = File (context.filesDir, " models" ).apply {
174
+ if (! exists()) {
175
+ mkdirs()
176
+ }
177
+ }
141
178
142
- val modelsDirectory = File (context.filesDir, " models " ). apply {
143
- if (! exists()) {
144
- mkdirs( )
145
- }
146
- }
179
+ var validFile = File (modelsDirectory, fileName)
180
+ if (validFile. exists()) {
181
+ onComplete(validFile.absolutePath, null )
182
+ return
183
+ }
147
184
148
- var validFile = File (modelsDirectory, fileName)
149
- if (validFile.exists()) {
150
- onComplete(validFile.absolutePath, null )
151
- return
152
- }
185
+ if (resourceType == ResourceType .TOKENIZER ) {
186
+ val request = Request .Builder ().url(url).build()
187
+ val response = client.newCall(request).execute()
153
188
154
- if (resourceType == ResourceType .TOKENIZER ) {
155
- val request = Request .Builder ().url(url).build()
156
- val response = client.newCall(request).execute()
189
+ if (! response.isSuccessful) {
190
+ onComplete(null , Exception (" download_error" ))
191
+ return
192
+ }
157
193
158
- if ( ! response.isSuccessful) {
159
- onComplete(null , Exception ( " download_error " ) )
160
- return
161
- }
194
+ validFile = saveResponseToFile( response, modelsDirectory, fileName)
195
+ onComplete(validFile.absolutePath, null )
196
+ return
197
+ }
162
198
163
- validFile = saveResponseToFile(response, modelsDirectory, fileName)
164
- onComplete(validFile.absolutePath, null )
165
- return
166
- }
199
+ // If the url is a Software Mansion HuggingFace repo, we want to send a HEAD
200
+ // request to the config.json file, this increments HF download counter
201
+ // https://huggingface.co/docs/hub/models-download-stats
202
+ if (isUrlPointingToHfRepo(url)) {
203
+ val configUrl = resolveConfigUrlFromModelUrl(url)
204
+ sendRequestToUrl(configUrl, " HEAD" , null , client)
205
+ }
167
206
168
- fetchModel(tempFile, validFile, client, url, onComplete, listener)
169
- }
207
+ fetchModel(tempFile, validFile, client, url, onComplete, listener)
170
208
}
171
- }
209
+ }
210
+ }
0 commit comments