Skip to content

Commit e91ce11

Browse files
feat: ExecuTorch bindings for android (#38)
## Description Added TurboModule(ETModule) which let's user use ExecuTorch Module methods such as: - loadMethod - loadModule - loadForward - forward ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [ ] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 8ff6076 commit e91ce11

17 files changed

+480
-59
lines changed

android/libs/executorch-llama.aar

-241 KB
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package com.swmansion.rnexecutorch
2+
3+
import com.facebook.react.bridge.Promise
4+
import com.facebook.react.bridge.ReactApplicationContext
5+
import com.facebook.react.bridge.ReadableArray
6+
import com.swmansion.rnexecutorch.utils.ArrayUtils
7+
import com.swmansion.rnexecutorch.utils.Fetcher
8+
import com.swmansion.rnexecutorch.utils.ProgressResponseBody
9+
import com.swmansion.rnexecutorch.utils.ResourceType
10+
import com.swmansion.rnexecutorch.utils.TensorUtils
11+
import okhttp3.OkHttpClient
12+
import org.pytorch.executorch.Module
13+
import org.pytorch.executorch.Tensor
14+
import java.net.URL
15+
16+
class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(reactContext) {
17+
private lateinit var module: Module
18+
private val client = OkHttpClient()
19+
20+
override fun getName(): String {
21+
return NAME
22+
}
23+
24+
private fun downloadModel(
25+
url: URL, resourceType: ResourceType, callback: (path: String?, error: Exception?) -> Unit
26+
) {
27+
Fetcher.downloadResource(reactApplicationContext,
28+
client,
29+
url,
30+
resourceType,
31+
{ path, error -> callback(path, error) },
32+
object : ProgressResponseBody.ProgressListener {
33+
override fun onProgress(bytesRead: Long, contentLength: Long, done: Boolean) {
34+
}
35+
})
36+
}
37+
38+
override fun loadModule(modelPath: String, promise: Promise) {
39+
try {
40+
downloadModel(
41+
URL(modelPath), ResourceType.MODEL
42+
) { path, error ->
43+
if (error != null) {
44+
promise.reject(error.message!!, "-1")
45+
return@downloadModel
46+
}
47+
48+
module = Module.load(path)
49+
promise.resolve(0)
50+
return@downloadModel
51+
}
52+
} catch (e: Exception) {
53+
promise.reject(e.message!!, "-1")
54+
}
55+
}
56+
57+
override fun loadMethod(methodName: String, promise: Promise) {
58+
val result = module.loadMethod(methodName)
59+
if (result != 0) {
60+
promise.reject("Method loading failed", result.toString())
61+
return
62+
}
63+
64+
promise.resolve(result)
65+
}
66+
67+
override fun forward(
68+
input: ReadableArray,
69+
shape: ReadableArray,
70+
inputType: Double,
71+
promise: Promise
72+
) {
73+
try {
74+
val executorchInput =
75+
TensorUtils.getExecutorchInput(input, ArrayUtils.createLongArray(shape), inputType.toInt())
76+
77+
lateinit var result: Tensor
78+
module.forward(executorchInput)[0].toTensor().also { result = it }
79+
80+
promise.resolve(ArrayUtils.createReadableArray(result))
81+
return
82+
} catch (e: IllegalArgumentException) {
83+
//The error is thrown when transformation to Tensor fails
84+
promise.reject("Forward Failed Execution", "18")
85+
return
86+
} catch (e: Exception) {
87+
//Executorch forward method throws an exception with a message: "Method forward failed with code XX"
88+
val exceptionCode = e.message!!.substring(e.message!!.length - 2)
89+
promise.reject("Forward Failed Execution", exceptionCode)
90+
return
91+
}
92+
}
93+
94+
companion object {
95+
const val NAME = "ETModule"
96+
}
97+
}

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchModule.kt

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
package com.swmansion.rnexecutorch
22

3-
import com.facebook.react.bridge.ReactApplicationContext
43
import android.os.Build
54
import android.util.Log
65
import androidx.annotation.RequiresApi
76
import com.facebook.react.bridge.Promise
8-
import com.facebook.react.bridge.ReactContextBaseJavaModule
9-
import com.facebook.react.bridge.ReactMethod
7+
import com.facebook.react.bridge.ReactApplicationContext
8+
import com.swmansion.rnexecutorch.utils.Fetcher
9+
import com.swmansion.rnexecutorch.utils.ProgressResponseBody
10+
import com.swmansion.rnexecutorch.utils.ResourceType
11+
import com.swmansion.rnexecutorch.utils.llms.ChatRole
12+
import com.swmansion.rnexecutorch.utils.llms.ConversationManager
13+
import com.swmansion.rnexecutorch.utils.llms.END_OF_TEXT_TOKEN
1014
import okhttp3.OkHttpClient
11-
import okhttp3.Request
12-
import org.pytorch.executorch.LlamaModule
1315
import org.pytorch.executorch.LlamaCallback
14-
import java.io.File
16+
import org.pytorch.executorch.LlamaModule
1517
import java.net.URL
1618

1719
class RnExecutorchModule(reactContext: ReactApplicationContext) :

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

+21-13
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,40 @@ import com.facebook.react.module.model.ReactModuleInfo
77
import com.facebook.react.module.model.ReactModuleInfoProvider
88
import com.facebook.react.uimanager.ViewManager
99

10-
1110
class RnExecutorchPackage : TurboReactPackage() {
1211
override fun createViewManagers(reactContext: ReactApplicationContext): List<ViewManager<*, *>> {
1312
return listOf()
1413
}
1514

16-
override fun getModule(name: String, reactContext: ReactApplicationContext): NativeModule? =
17-
if (name == RnExecutorchModule.NAME) {
18-
RnExecutorchModule(reactContext)
19-
} else {
20-
null
21-
}
15+
override fun getModule(name: String, reactContext: ReactApplicationContext): NativeModule? =
16+
if (name == RnExecutorchModule.NAME) {
17+
RnExecutorchModule(reactContext)
18+
} else if (name == ETModule.NAME) {
19+
ETModule(reactContext)
20+
} else {
21+
null
22+
}
2223

23-
override fun getReactModuleInfoProvider(): ReactModuleInfoProvider {
24-
return ReactModuleInfoProvider {
24+
override fun getReactModuleInfoProvider(): ReactModuleInfoProvider {
25+
return ReactModuleInfoProvider {
2526
val moduleInfos: MutableMap<String, ReactModuleInfo> = HashMap()
2627
moduleInfos[RnExecutorchModule.NAME] = ReactModuleInfo(
2728
RnExecutorchModule.NAME,
2829
RnExecutorchModule.NAME,
2930
false, // canOverrideExistingModule
3031
false, // needsEagerInit
31-
true, // hasConstants
3232
false, // isCxxModule
33-
true // isTurboModule
33+
true,
34+
)
35+
moduleInfos[ETModule.NAME] = ReactModuleInfo(
36+
ETModule.NAME,
37+
ETModule.NAME,
38+
false, // canOverrideExistingModule
39+
false, // needsEagerInit
40+
false, // isCxxModule
41+
true
3442
)
3543
moduleInfos
3644
}
37-
}
38-
}
45+
}
46+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package com.swmansion.rnexecutorch.utils
2+
3+
import com.facebook.react.bridge.Arguments
4+
import com.facebook.react.bridge.ReadableArray
5+
import org.pytorch.executorch.DType
6+
import org.pytorch.executorch.Tensor
7+
8+
class ArrayUtils {
9+
companion object {
10+
fun createByteArray(input: ReadableArray): ByteArray {
11+
val byteArray = ByteArray(input.size())
12+
for (i in 0 until input.size()) {
13+
byteArray[i] = input.getInt(i).toByte()
14+
}
15+
return byteArray
16+
}
17+
18+
fun createIntArray(input: ReadableArray): IntArray {
19+
val intArray = IntArray(input.size())
20+
for (i in 0 until input.size()) {
21+
intArray[i] = input.getInt(i)
22+
}
23+
return intArray
24+
}
25+
26+
fun createFloatArray(input: ReadableArray): FloatArray {
27+
val floatArray = FloatArray(input.size())
28+
for (i in 0 until input.size()) {
29+
floatArray[i] = input.getDouble(i).toFloat()
30+
}
31+
return floatArray
32+
}
33+
34+
fun createLongArray(input: ReadableArray): LongArray {
35+
val longArray = LongArray(input.size())
36+
for (i in 0 until input.size()) {
37+
longArray[i] = input.getInt(i).toLong()
38+
}
39+
return longArray
40+
}
41+
42+
fun createDoubleArray(input: ReadableArray): DoubleArray {
43+
val doubleArray = DoubleArray(input.size())
44+
for (i in 0 until input.size()) {
45+
doubleArray[i] = input.getDouble(i)
46+
}
47+
return doubleArray
48+
}
49+
50+
fun createReadableArray(result: Tensor): ReadableArray {
51+
val resultArray = Arguments.createArray()
52+
when (result.dtype()) {
53+
DType.UINT8 -> {
54+
val byteArray = result.dataAsByteArray
55+
for (i in byteArray) {
56+
resultArray.pushInt(i.toInt())
57+
}
58+
}
59+
60+
DType.INT32 -> {
61+
val intArray = result.dataAsIntArray
62+
for (i in intArray) {
63+
resultArray.pushInt(i)
64+
}
65+
}
66+
67+
DType.FLOAT -> {
68+
val longArray = result.dataAsFloatArray
69+
for (i in longArray) {
70+
resultArray.pushDouble(i.toDouble())
71+
}
72+
}
73+
74+
DType.DOUBLE -> {
75+
val floatArray = result.dataAsDoubleArray
76+
for (i in floatArray) {
77+
resultArray.pushDouble(i)
78+
}
79+
}
80+
81+
DType.INT64 -> {
82+
val doubleArray = result.dataAsLongArray
83+
for (i in doubleArray) {
84+
resultArray.pushLong(i)
85+
}
86+
}
87+
88+
else -> {
89+
throw IllegalArgumentException("Invalid dtype: ${result.dtype()}")
90+
}
91+
}
92+
93+
return resultArray
94+
}
95+
}
96+
}

android/src/main/java/com/swmansion/rnexecutorch/Fetcher.kt android/src/main/java/com/swmansion/rnexecutorch/utils/Fetcher.kt

+22-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.swmansion.rnexecutorch
1+
package com.swmansion.rnexecutorch.utils
22

33
import android.content.Context
44
import okhttp3.Call
@@ -113,11 +113,17 @@ class Fetcher {
113113

114114
private fun resolveConfigUrlFromModelUrl(modelUrl: URL): URL {
115115
// Create a new URL using the base URL and append the desired path
116-
val baseUrl = modelUrl.protocol + "://" + modelUrl.host + modelUrl.path.substringBefore("resolve/")
116+
val baseUrl =
117+
modelUrl.protocol + "://" + modelUrl.host + modelUrl.path.substringBefore("resolve/")
117118
return URL(baseUrl + "resolve/main/config.json")
118119
}
119120

120-
private fun sendRequestToUrl(url: URL, method: String, body: RequestBody?, client: OkHttpClient): Response {
121+
private fun sendRequestToUrl(
122+
url: URL,
123+
method: String,
124+
body: RequestBody?,
125+
client: OkHttpClient
126+
): Response {
121127
val request = Request.Builder()
122128
.url(url)
123129
.method(method, body)
@@ -134,18 +140,18 @@ class Fetcher {
134140
onComplete: (String?, Exception?) -> Unit,
135141
listener: ProgressResponseBody.ProgressListener? = null,
136142
) {
137-
/*
138-
Fetching model and tokenizer file
139-
1. Extract file name from provided URL
140-
2. If file name contains / it means that the file is local and we should return the path
141-
3. Check if the file has a valid extension
142-
a. For tokenizer, the extension should be .bin
143-
b. For model, the extension should be .pte
144-
4. Check if models directory exists, if not create it
145-
5. Check if the file already exists in the models directory, if yes return the path
146-
6. If the file does not exist, and is a tokenizer, fetch the file
147-
7. If the file is a model, fetch the file with ProgressResponseBody
148-
*/
143+
/*
144+
Fetching model and tokenizer file
145+
1. Extract file name from provided URL
146+
2. If file name contains / it means that the file is local and we should return the path
147+
3. Check if the file has a valid extension
148+
a. For tokenizer, the extension should be .bin
149+
b. For model, the extension should be .pte
150+
4. Check if models directory exists, if not create it
151+
5. Check if the file already exists in the models directory, if yes return the path
152+
6. If the file does not exist, and is a tokenizer, fetch the file
153+
7. If the file is a model, fetch the file with ProgressResponseBody
154+
*/
149155
val fileName: String
150156

151157
try {
@@ -165,7 +171,7 @@ class Fetcher {
165171
return
166172
}
167173

168-
var tempFile = File(context.filesDir, fileName)
174+
val tempFile = File(context.filesDir, fileName)
169175
if (tempFile.exists()) {
170176
tempFile.delete()
171177
}

android/src/main/java/com/swmansion/rnexecutorch/ProgressResponseBody.kt android/src/main/java/com/swmansion/rnexecutorch/utils/ProgressResponseBody.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.swmansion.rnexecutorch
1+
package com.swmansion.rnexecutorch.utils
22

33
import okhttp3.MediaType
44
import okhttp3.ResponseBody
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package com.swmansion.rnexecutorch.utils
2+
3+
import com.facebook.react.bridge.ReadableArray
4+
import org.pytorch.executorch.EValue
5+
import org.pytorch.executorch.Tensor
6+
7+
class TensorUtils {
8+
companion object {
9+
fun getExecutorchInput(input: ReadableArray, shape: LongArray, type: Int): EValue {
10+
try {
11+
when (type) {
12+
0 -> {
13+
val inputTensor = Tensor.fromBlob(ArrayUtils.createByteArray(input), shape)
14+
return EValue.from(inputTensor)
15+
}
16+
17+
1 -> {
18+
val inputTensor = Tensor.fromBlob(ArrayUtils.createIntArray(input), shape)
19+
return EValue.from(inputTensor)
20+
}
21+
22+
2 -> {
23+
val inputTensor = Tensor.fromBlob(ArrayUtils.createLongArray(input), shape)
24+
return EValue.from(inputTensor)
25+
}
26+
27+
3 -> {
28+
val inputTensor = Tensor.fromBlob(ArrayUtils.createFloatArray(input), shape)
29+
return EValue.from(inputTensor)
30+
}
31+
32+
4 -> {
33+
val inputTensor = Tensor.fromBlob(ArrayUtils.createDoubleArray(input), shape)
34+
return EValue.from(inputTensor)
35+
}
36+
37+
else -> {
38+
throw IllegalArgumentException("Invalid input type: $type")
39+
}
40+
}
41+
} catch (e: IllegalArgumentException) {
42+
throw e
43+
}
44+
}
45+
}
46+
}

0 commit comments

Comments
 (0)