-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathFetcher.kt
171 lines (149 loc) · 6.01 KB
/
Fetcher.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
package com.swmansion.rnexecutorch
import android.content.Context
import okhttp3.Call
import okhttp3.Callback
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import okio.IOException
import java.io.File
import java.io.FileOutputStream
import java.net.URL
enum class ResourceType {
TOKENIZER,
MODEL
}
class Fetcher {
companion object {
private fun saveResponseToFile(
response: Response,
directory: File,
fileName: String
): File {
val file = File(directory.path, fileName)
file.outputStream().use { outputStream ->
response.body?.byteStream()?.copyTo(outputStream)
}
return file
}
private fun hasValidExtension(fileName: String, resourceType: ResourceType): Boolean {
return when (resourceType) {
ResourceType.TOKENIZER -> {
fileName.endsWith(".bin")
}
ResourceType.MODEL -> {
fileName.endsWith(".pte")
}
}
}
private fun extractFileName(url: URL): String {
if (url.path == "/assets/") {
val pathSegments = url.toString().split('/')
return pathSegments[pathSegments.size - 1].split("?")[0]
} else if (url.protocol == "file") {
val localPath = url.toString().split("://")[1]
val file = File(localPath)
if (file.exists()) {
return localPath
}
throw Exception("file_not_found")
} else {
return url.path.substringAfterLast('/')
}
}
private fun fetchModel(file: File, validFile: File, client: OkHttpClient, url: URL, onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null){
val request = Request.Builder().url(url).build()
client.newCall(request).enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
onComplete(null, e)
}
override fun onResponse(call: Call, response: Response) {
if (!response.isSuccessful) {
onComplete(null, Exception("download_error"))
}
response.body?.let { body ->
val progressBody = listener?.let { ProgressResponseBody(body, it) }
val inputStream = progressBody?.source()?.inputStream()
inputStream?.use { input ->
FileOutputStream(file).use { output ->
val buffer = ByteArray(2048)
var bytesRead: Int
while (input.read(buffer).also { bytesRead = it } != -1) {
output.write(buffer, 0, bytesRead)
}
}
}
if (file.renameTo(validFile)) {
onComplete(validFile.absolutePath, null)
} else {
onComplete(null, Exception("Failed to move the file to the valid location"))
}
}
}
})
}
fun downloadResource(
context: Context,
client: OkHttpClient,
url: URL,
resourceType: ResourceType,
onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null
) {
/*
Fetching model and tokenizer file
1. Extract file name from provided URL
2. If file name contains / it means that the file is local and we should return the path
3. Check if the file has a valid extension
a. For tokenizer, the extension should be .bin
b. For model, the extension should be .pte
4. Check if models directory exists, if not create it
5. Check if the file already exists in the models directory, if yes return the path
6. If the file does not exist, and is a tokenizer, fetch the file
7. If the file is a model, fetch the file with ProgressResponseBody
*/
val fileName: String
try {
fileName = extractFileName(url)
} catch (e: Exception) {
onComplete(null, e)
return
}
if(fileName.contains("/")){
onComplete(fileName, null)
return
}
if (!hasValidExtension(fileName, resourceType)) {
onComplete(null, Exception("invalid_resource_extension"))
return
}
var tempFile = File(context.filesDir, fileName)
if(tempFile.exists()){
tempFile.delete()
}
val modelsDirectory = File(context.filesDir, "models").apply {
if (!exists()) {
mkdirs()
}
}
var validFile = File(modelsDirectory, fileName)
if (validFile.exists()) {
onComplete(validFile.absolutePath, null)
return
}
if (resourceType == ResourceType.TOKENIZER) {
val request = Request.Builder().url(url).build()
val response = client.newCall(request).execute()
if (!response.isSuccessful) {
onComplete(null, Exception("download_error"))
return
}
validFile = saveResponseToFile(response, modelsDirectory, fileName)
onComplete(validFile.absolutePath, null)
return
}
fetchModel(tempFile, validFile, client, url, onComplete, listener)
}
}
}