Skip to content

Commit 44c7eb8

Browse files
authored
Use expo-file-system and expo-asset to fetch files (#97)
## Description 1. Use expo-file-system and expo-asset to fetch files. 2. Enable background model downloads. 3. Add downloadProgress functionality to useClassification, useObjectDetection, useStyleTransfer, useExecutorchModule and its corresponding hookless implementations ClassificationModule, ObjectDetectionModule, StyleTransferModule, ExecutorchModule. 4. Add functionality allowing for listing downloaded files and models. ### Type of change - [x] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Related issues - #77 ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings
1 parent e319e0b commit 44c7eb8

33 files changed

+926
-975
lines changed

android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt

+3-14
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ import com.facebook.react.bridge.ReactApplicationContext
66
import com.facebook.react.bridge.ReadableArray
77
import com.swmansion.rnexecutorch.utils.ArrayUtils
88
import com.swmansion.rnexecutorch.utils.ETError
9-
import com.swmansion.rnexecutorch.utils.Fetcher
109
import com.swmansion.rnexecutorch.utils.TensorUtils
1110
import org.pytorch.executorch.Module
11+
import java.net.URL
1212

1313
class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(reactContext) {
1414
private lateinit var module: Module
@@ -18,19 +18,8 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react
1818
}
1919

2020
override fun loadModule(modelSource: String, promise: Promise) {
21-
Fetcher.downloadModel(
22-
reactApplicationContext,
23-
modelSource,
24-
) { path, error ->
25-
if (error != null) {
26-
promise.reject(error.message!!, ETError.InvalidModelSource.toString())
27-
return@downloadModel
28-
}
29-
30-
module = Module.load(path)
31-
promise.resolve(0)
32-
return@downloadModel
33-
}
21+
module = Module.load(URL(modelSource).path)
22+
promise.resolve(0)
3423
}
3524

3625
override fun loadMethod(methodName: String, promise: Promise) {

android/src/main/java/com/swmansion/rnexecutorch/LLM.kt

+5-72
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,19 @@ package com.swmansion.rnexecutorch
33
import android.util.Log
44
import com.facebook.react.bridge.Promise
55
import com.facebook.react.bridge.ReactApplicationContext
6-
import com.swmansion.rnexecutorch.utils.Fetcher
7-
import com.swmansion.rnexecutorch.utils.ProgressResponseBody
8-
import com.swmansion.rnexecutorch.utils.ResourceType
96
import com.swmansion.rnexecutorch.utils.llms.ChatRole
107
import com.swmansion.rnexecutorch.utils.llms.ConversationManager
118
import com.swmansion.rnexecutorch.utils.llms.END_OF_TEXT_TOKEN
129
import org.pytorch.executorch.LlamaCallback
1310
import org.pytorch.executorch.LlamaModule
11+
import java.net.URL
1412

1513
class LLM(reactContext: ReactApplicationContext) :
1614
NativeLLMSpec(reactContext), LlamaCallback {
1715

1816
private var llamaModule: LlamaModule? = null
1917
private var tempLlamaResponse = StringBuilder()
2018
private lateinit var conversationManager: ConversationManager
21-
private var isFetching = false
2219

2320
override fun getName(): String {
2421
return NAME
@@ -37,84 +34,20 @@ class LLM(reactContext: ReactApplicationContext) :
3734
Log.d("rn_executorch", "TPS: $tps")
3835
}
3936

40-
private fun updateDownloadProgress(progress: Float) {
41-
emitOnDownloadProgress((progress / 100).toDouble())
42-
}
43-
44-
private fun downloadResource(
45-
url: String,
46-
resourceType: ResourceType,
47-
isLargeFile: Boolean = false,
48-
callback: (path: String?, error: Exception?) -> Unit,
49-
) {
50-
Fetcher.downloadResource(
51-
reactApplicationContext, url, resourceType, isLargeFile,
52-
{ path, error -> callback(path, error) },
53-
object : ProgressResponseBody.ProgressListener {
54-
override fun onProgress(bytesRead: Long, contentLength: Long, done: Boolean) {
55-
val progress = (bytesRead * 100 / contentLength).toFloat()
56-
updateDownloadProgress(progress)
57-
if (done) {
58-
isFetching = false
59-
}
60-
}
61-
})
62-
}
63-
64-
private fun initializeLlamaModule(modelPath: String, tokenizerPath: String, promise: Promise) {
65-
llamaModule = LlamaModule(1, modelPath, tokenizerPath, 0.7f)
66-
isFetching = false
67-
this.tempLlamaResponse.clear()
68-
promise.resolve("Model loaded successfully")
69-
}
70-
7137
override fun loadLLM(
7238
modelSource: String,
7339
tokenizerSource: String,
7440
systemPrompt: String,
7541
contextWindowLength: Double,
7642
promise: Promise
7743
) {
78-
if (isFetching) {
79-
promise.reject("Model is fetching", "Model is fetching")
80-
return
81-
}
82-
8344
try {
8445
this.conversationManager = ConversationManager(contextWindowLength.toInt(), systemPrompt)
85-
86-
isFetching = true
87-
88-
downloadResource(
89-
tokenizerSource,
90-
ResourceType.TOKENIZER
91-
) tokenizerDownload@{ tokenizerPath, error ->
92-
if (error != null) {
93-
promise.reject("Download Error", "Tokenizer download failed: ${error.message}")
94-
isFetching = false
95-
return@tokenizerDownload
96-
}
97-
98-
downloadResource(
99-
modelSource,
100-
ResourceType.MODEL,
101-
isLargeFile = true
102-
) modelDownload@{ modelPath, modelError ->
103-
if (modelError != null) {
104-
promise.reject(
105-
"Download Error",
106-
"Model download failed: ${modelError.message}"
107-
)
108-
isFetching = false
109-
return@modelDownload
110-
}
111-
112-
initializeLlamaModule(modelPath!!, tokenizerPath!!, promise)
113-
}
114-
}
46+
llamaModule = LlamaModule(1, URL(modelSource).path, URL(tokenizerSource).path, 0.7f)
47+
this.tempLlamaResponse.clear()
48+
promise.resolve("Model loaded successfully")
11549
} catch (e: Exception) {
116-
promise.reject("Download Error", e.message)
117-
isFetching = false
50+
promise.reject("Model loading failed", e.message)
11851
}
11952
}
12053

android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt

+2-12
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,16 @@ package com.swmansion.rnexecutorch.models
22

33
import android.content.Context
44
import com.swmansion.rnexecutorch.utils.ETError
5-
import com.swmansion.rnexecutorch.utils.Fetcher
65
import org.pytorch.executorch.EValue
76
import org.pytorch.executorch.Module
8-
import org.pytorch.executorch.Tensor
7+
import java.net.URL
98

109

1110
abstract class BaseModel<Input, Output>(val context: Context) {
1211
protected lateinit var module: Module
1312

1413
fun loadModel(modelSource: String) {
15-
Fetcher.downloadModel(
16-
context,
17-
modelSource
18-
) { path, error ->
19-
if (error != null) {
20-
throw Error(error.message!!)
21-
}
22-
23-
module = Module.load(path)
24-
}
14+
module = Module.load(URL(modelSource).path)
2515
}
2616

2717
protected fun forward(input: EValue): Array<EValue> {

0 commit comments

Comments
 (0)