@@ -5,104 +5,135 @@ import okhttp3.Call
55import okhttp3.Callback
66import okhttp3.OkHttpClient
77import okhttp3.Request
8+ import okhttp3.RequestBody
89import okhttp3.Response
910import okio.IOException
1011import java.io.File
1112import java.io.FileOutputStream
1213import java.net.URL
1314
1415enum class ResourceType {
15- TOKENIZER ,
16- MODEL
16+ TOKENIZER ,
17+ MODEL ,
1718}
1819
1920class Fetcher {
20- companion object {
21- private fun saveResponseToFile (
22- response : Response ,
23- directory : File ,
24- fileName : String
25- ): File {
26- val file = File (directory.path, fileName)
27- file.outputStream().use { outputStream ->
28- response.body?.byteStream()?.copyTo(outputStream)
29- }
30- return file
21+ companion object {
22+ private fun saveResponseToFile (
23+ response : Response ,
24+ directory : File ,
25+ fileName : String ,
26+ ): File {
27+ val file = File (directory.path, fileName)
28+ file.outputStream().use { outputStream ->
29+ response.body?.byteStream()?.copyTo(outputStream)
30+ }
31+ return file
32+ }
33+
34+ private fun hasValidExtension (fileName : String , resourceType : ResourceType ): Boolean {
35+ return when (resourceType) {
36+ ResourceType .TOKENIZER -> {
37+ fileName.endsWith(" .bin" )
3138 }
3239
33- private fun hasValidExtension ( fileName : String , resourceType : ResourceType ): Boolean {
34- return when (resourceType) {
35- ResourceType . TOKENIZER -> {
36- fileName.endsWith( " .bin " )
37- }
40+ ResourceType . MODEL -> {
41+ fileName.endsWith( " .pte " )
42+ }
43+ }
44+ }
3845
39- ResourceType .MODEL -> {
40- fileName.endsWith(" .pte" )
41- }
42- }
46+ private fun extractFileName (url : URL ): String {
47+ if (url.path == " /assets/" ) {
48+ val pathSegments = url.toString().split(' /' )
49+ return pathSegments[pathSegments.size - 1 ].split(" ?" )[0 ]
50+ } else if (url.protocol == " file" ) {
51+ val localPath = url.toString().split(" ://" )[1 ]
52+ val file = File (localPath)
53+ if (file.exists()) {
54+ return localPath
4355 }
4456
45- private fun extractFileName (url : URL ): String {
46- if (url.path == " /assets/" ) {
47- val pathSegments = url.toString().split(' /' )
48- return pathSegments[pathSegments.size - 1 ].split(" ?" )[0 ]
49- } else if (url.protocol == " file" ) {
50- val localPath = url.toString().split(" ://" )[1 ]
51- val file = File (localPath)
52- if (file.exists()) {
53- return localPath
57+ throw Exception (" file_not_found" )
58+ } else {
59+ return url.path.substringAfterLast(' /' )
60+ }
61+ }
62+
63+ private fun fetchModel (
64+ file : File ,
65+ validFile : File ,
66+ client : OkHttpClient ,
67+ url : URL ,
68+ onComplete : (String? , Exception ? ) -> Unit ,
69+ listener : ProgressResponseBody .ProgressListener ? = null,
70+ ) {
71+ val request = Request .Builder ().url(url).build()
72+ client.newCall(request).enqueue(object : Callback {
73+ override fun onFailure (call : Call , e : IOException ) {
74+ onComplete(null , e)
75+ }
76+
77+ override fun onResponse (call : Call , response : Response ) {
78+ if (! response.isSuccessful) {
79+ onComplete(null , Exception (" download_error" ))
80+ }
81+
82+ response.body?.let { body ->
83+ val progressBody = listener?.let { ProgressResponseBody (body, it) }
84+ val inputStream = progressBody?.source()?.inputStream()
85+ inputStream?.use { input ->
86+ FileOutputStream (file).use { output ->
87+ val buffer = ByteArray (2048 )
88+ var bytesRead: Int
89+ while (input.read(buffer).also { bytesRead = it } != - 1 ) {
90+ output.write(buffer, 0 , bytesRead)
5491 }
92+ }
93+ }
5594
56- throw Exception (" file_not_found" )
95+ if (file.renameTo(validFile)) {
96+ onComplete(validFile.absolutePath, null )
5797 } else {
58- return url.path.substringAfterLast( ' / ' )
98+ onComplete( null , Exception ( " Failed to move the file to the valid location " ) )
5999 }
100+ }
60101 }
102+ })
103+ }
61104
62- private fun fetchModel (file : File , validFile : File , client : OkHttpClient , url : URL , onComplete : (String? , Exception ? ) -> Unit ,
63- listener : ProgressResponseBody .ProgressListener ? = null){
64- val request = Request .Builder ().url(url).build()
65- client.newCall(request).enqueue(object : Callback {
66- override fun onFailure (call : Call , e : IOException ) {
67- onComplete(null , e)
68- }
105+ private fun isUrlPointingToHfRepo (url : URL ): Boolean {
106+ val expectedHost = " huggingface.co"
107+ val expectedPathPrefix = " /software-mansion/"
108+ if (url.host != expectedHost) {
109+ return false
110+ }
111+ return url.path.startsWith(expectedPathPrefix)
112+ }
69113
70- override fun onResponse (call : Call , response : Response ) {
71- if (! response.isSuccessful) {
72- onComplete(null , Exception (" download_error" ))
73- }
74-
75- response.body?.let { body ->
76- val progressBody = listener?.let { ProgressResponseBody (body, it) }
77- val inputStream = progressBody?.source()?.inputStream()
78- inputStream?.use { input ->
79- FileOutputStream (file).use { output ->
80- val buffer = ByteArray (2048 )
81- var bytesRead: Int
82- while (input.read(buffer).also { bytesRead = it } != - 1 ) {
83- output.write(buffer, 0 , bytesRead)
84- }
85- }
86- }
87-
88- if (file.renameTo(validFile)) {
89- onComplete(validFile.absolutePath, null )
90- } else {
91- onComplete(null , Exception (" Failed to move the file to the valid location" ))
92- }
93- }
94- }
95- })
96- }
114+ private fun resolveConfigUrlFromModelUrl (modelUrl : URL ): URL {
115+ // Create a new URL using the base URL and append the desired path
116+ val baseUrl = modelUrl.protocol + " ://" + modelUrl.host + modelUrl.path.substringBefore(" resolve/" )
117+ return URL (baseUrl + " resolve/main/config.json" )
118+ }
97119
98- fun downloadResource (
99- context : Context ,
100- client : OkHttpClient ,
101- url : URL ,
102- resourceType : ResourceType ,
103- onComplete : (String? , Exception ? ) -> Unit ,
104- listener : ProgressResponseBody .ProgressListener ? = null
105- ) {
120+ private fun sendRequestToUrl (url : URL , method : String , body : RequestBody ? , client : OkHttpClient ): Response {
121+ val request = Request .Builder ()
122+ .url(url)
123+ .method(method, body)
124+ .build()
125+ val response = client.newCall(request).execute()
126+ return response
127+ }
128+
129+ fun downloadResource (
130+ context : Context ,
131+ client : OkHttpClient ,
132+ url : URL ,
133+ resourceType : ResourceType ,
134+ onComplete : (String? , Exception ? ) -> Unit ,
135+ listener : ProgressResponseBody .ProgressListener ? = null,
136+ ) {
106137 /*
107138 Fetching model and tokenizer file
108139 1. Extract file name from provided URL
@@ -115,57 +146,65 @@ class Fetcher {
115146 6. If the file does not exist, and is a tokenizer, fetch the file
116147 7. If the file is a model, fetch the file with ProgressResponseBody
117148 */
118- val fileName: String
119-
120- try {
121- fileName = extractFileName(url)
122- } catch (e: Exception ) {
123- onComplete(null , e)
124- return
125- }
126-
127- if (fileName.contains(" /" )){
128- onComplete(fileName, null )
129- return
130- }
131-
132- if (! hasValidExtension(fileName, resourceType)) {
133- onComplete(null , Exception (" invalid_resource_extension" ))
134- return
135- }
136-
137- var tempFile = File (context.filesDir, fileName)
138- if (tempFile.exists()){
139- tempFile.delete()
140- }
149+ val fileName: String
150+
151+ try {
152+ fileName = extractFileName(url)
153+ } catch (e: Exception ) {
154+ onComplete(null , e)
155+ return
156+ }
157+
158+ if (fileName.contains(" /" )) {
159+ onComplete(fileName, null )
160+ return
161+ }
162+
163+ if (! hasValidExtension(fileName, resourceType)) {
164+ onComplete(null , Exception (" invalid_resource_extension" ))
165+ return
166+ }
167+
168+ var tempFile = File (context.filesDir, fileName)
169+ if (tempFile.exists()) {
170+ tempFile.delete()
171+ }
172+
173+ val modelsDirectory = File (context.filesDir, " models" ).apply {
174+ if (! exists()) {
175+ mkdirs()
176+ }
177+ }
141178
142- val modelsDirectory = File (context.filesDir, " models " ). apply {
143- if (! exists()) {
144- mkdirs( )
145- }
146- }
179+ var validFile = File (modelsDirectory, fileName)
180+ if (validFile. exists()) {
181+ onComplete(validFile.absolutePath, null )
182+ return
183+ }
147184
148- var validFile = File (modelsDirectory, fileName)
149- if (validFile.exists()) {
150- onComplete(validFile.absolutePath, null )
151- return
152- }
185+ if (resourceType == ResourceType .TOKENIZER ) {
186+ val request = Request .Builder ().url(url).build()
187+ val response = client.newCall(request).execute()
153188
154- if (resourceType == ResourceType .TOKENIZER ) {
155- val request = Request .Builder ().url(url).build()
156- val response = client.newCall(request).execute()
189+ if (! response.isSuccessful) {
190+ onComplete(null , Exception (" download_error" ))
191+ return
192+ }
157193
158- if ( ! response.isSuccessful) {
159- onComplete(null , Exception ( " download_error " ) )
160- return
161- }
194+ validFile = saveResponseToFile( response, modelsDirectory, fileName)
195+ onComplete(validFile.absolutePath, null )
196+ return
197+ }
162198
163- validFile = saveResponseToFile(response, modelsDirectory, fileName)
164- onComplete(validFile.absolutePath, null )
165- return
166- }
199+ // If the url is a Software Mansion HuggingFace repo, we want to send a HEAD
200+ // request to the config.json file, this increments HF download counter
201+ // https://huggingface.co/docs/hub/models-download-stats
202+ if (isUrlPointingToHfRepo(url)) {
203+ val configUrl = resolveConfigUrlFromModelUrl(url)
204+ sendRequestToUrl(configUrl, " HEAD" , null , client)
205+ }
167206
168- fetchModel(tempFile, validFile, client, url, onComplete, listener)
169- }
207+ fetchModel(tempFile, validFile, client, url, onComplete, listener)
170208 }
171- }
209+ }
210+ }
0 commit comments