Skip to content

Commit e96d46f

Browse files
NorbertKlockiewiczjakmrochmjkbmkopcins
authored
feat: BaseModel abstract class, primitive Style Transfer, functions for retrieving model metadata (#44)
## Description This pull request introduces Model base class on top of which every other model should be built, there's also a primitive implementation of StyleTransferModel and functions for retrieving model metadata. Fix for bindings allowing to return multiple outputs. **__During code review don't focus on StyleTransfer Module as it will change and it's for test purposes right now__** Don't review: - StyleTransfer.kt - StyleTransferModel.kt - StyleTransfer.h - StyleTransfer.mm - StyleTransferModel.mm - StyleTransferModel.h - BitmapUtils.kt - Imageprocessor.h/Imageprocessor.m - TensorUtils.kt ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. --> --------- Co-authored-by: jakmro <[email protected]> Co-authored-by: Jakub Chmura <[email protected]> Co-authored-by: mkopcins <[email protected]>
1 parent 0be2992 commit e96d46f

File tree

41 files changed

+1180
-103
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1180
-103
lines changed

android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt

+21-41
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,35 @@
11
package com.swmansion.rnexecutorch
22

3+
import com.facebook.react.bridge.Arguments
34
import com.facebook.react.bridge.Promise
45
import com.facebook.react.bridge.ReactApplicationContext
56
import com.facebook.react.bridge.ReadableArray
67
import com.swmansion.rnexecutorch.utils.ArrayUtils
8+
import com.swmansion.rnexecutorch.utils.ETError
79
import com.swmansion.rnexecutorch.utils.Fetcher
8-
import com.swmansion.rnexecutorch.utils.ProgressResponseBody
9-
import com.swmansion.rnexecutorch.utils.ResourceType
1010
import com.swmansion.rnexecutorch.utils.TensorUtils
11-
import okhttp3.OkHttpClient
1211
import org.pytorch.executorch.Module
13-
import org.pytorch.executorch.Tensor
14-
import java.net.URL
1512

1613
class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(reactContext) {
1714
private lateinit var module: Module
18-
private val client = OkHttpClient()
1915

2016
override fun getName(): String {
2117
return NAME
2218
}
2319

24-
private fun downloadModel(
25-
url: String, resourceType: ResourceType, callback: (path: String?, error: Exception?) -> Unit
26-
) {
27-
Fetcher.downloadResource(reactApplicationContext,
28-
client,
29-
url,
30-
resourceType,
31-
false,
32-
{ path, error -> callback(path, error) },
33-
object : ProgressResponseBody.ProgressListener {
34-
override fun onProgress(bytesRead: Long, contentLength: Long, done: Boolean) {
35-
}
36-
})
37-
}
38-
3920
override fun loadModule(modelPath: String, promise: Promise) {
40-
try {
41-
downloadModel(
42-
modelPath, ResourceType.MODEL
43-
) { path, error ->
44-
if (error != null) {
45-
promise.reject(error.message!!, "-1")
46-
return@downloadModel
47-
}
48-
49-
module = Module.load(path)
50-
promise.resolve(0)
21+
Fetcher.downloadModel(
22+
reactApplicationContext,
23+
modelPath,
24+
) { path, error ->
25+
if (error != null) {
26+
promise.reject(error.message!!, ETError.InvalidModelPath.toString())
5127
return@downloadModel
5228
}
53-
} catch (e: Exception) {
54-
promise.reject(e.message!!, "-1")
29+
30+
module = Module.load(path)
31+
promise.resolve(0)
32+
return@downloadModel
5533
}
5634
}
5735

@@ -75,19 +53,21 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react
7553
val executorchInput =
7654
TensorUtils.getExecutorchInput(input, ArrayUtils.createLongArray(shape), inputType.toInt())
7755

78-
lateinit var result: Tensor
79-
module.forward(executorchInput)[0].toTensor().also { result = it }
56+
val result = module.forward(executorchInput)
57+
val resultArray = Arguments.createArray()
58+
59+
for (evalue in result) {
60+
resultArray.pushArray(ArrayUtils.createReadableArray(evalue.toTensor()))
61+
}
8062

81-
promise.resolve(ArrayUtils.createReadableArray(result))
63+
promise.resolve(resultArray)
8264
return
8365
} catch (e: IllegalArgumentException) {
8466
//The error is thrown when transformation to Tensor fails
85-
promise.reject("Forward Failed Execution", "18")
67+
promise.reject("Forward Failed Execution", ETError.InvalidArgument.code.toString())
8668
return
8769
} catch (e: Exception) {
88-
//Executorch forward method throws an exception with a message: "Method forward failed with code XX"
89-
val exceptionCode = e.message!!.substring(e.message!!.length - 2)
90-
promise.reject("Forward Failed Execution", exceptionCode)
70+
promise.reject("Forward Failed Execution", e.message!!)
9171
return
9272
}
9373
}

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchModule.kt

+1-4
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,15 @@ import com.swmansion.rnexecutorch.utils.ResourceType
99
import com.swmansion.rnexecutorch.utils.llms.ChatRole
1010
import com.swmansion.rnexecutorch.utils.llms.ConversationManager
1111
import com.swmansion.rnexecutorch.utils.llms.END_OF_TEXT_TOKEN
12-
import okhttp3.OkHttpClient
1312
import org.pytorch.executorch.LlamaCallback
1413
import org.pytorch.executorch.LlamaModule
15-
import java.net.URL
1614

1715
class RnExecutorchModule(reactContext: ReactApplicationContext) :
1816
NativeRnExecutorchSpec(reactContext), LlamaCallback {
1917

2018
private var llamaModule: LlamaModule? = null
2119
private var tempLlamaResponse = StringBuilder()
2220
private lateinit var conversationManager: ConversationManager
23-
private val client = OkHttpClient()
2421
private var isFetching = false
2522

2623
override fun getName(): String {
@@ -51,7 +48,7 @@ class RnExecutorchModule(reactContext: ReactApplicationContext) :
5148
callback: (path: String?, error: Exception?) -> Unit,
5249
) {
5350
Fetcher.downloadResource(
54-
reactApplicationContext, client, url, resourceType, isLargeFile,
51+
reactApplicationContext, url, resourceType, isLargeFile,
5552
{ path, error -> callback(path, error) },
5653
object : ProgressResponseBody.ProgressListener {
5754
override fun onProgress(bytesRead: Long, contentLength: Long, done: Boolean) {

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

+11
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class RnExecutorchPackage : TurboReactPackage() {
1717
RnExecutorchModule(reactContext)
1818
} else if (name == ETModule.NAME) {
1919
ETModule(reactContext)
20+
} else if(name == StyleTransfer.NAME){
21+
StyleTransfer(reactContext)
2022
} else {
2123
null
2224
}
@@ -40,6 +42,15 @@ class RnExecutorchPackage : TurboReactPackage() {
4042
false, // isCxxModule
4143
true
4244
)
45+
46+
moduleInfos[StyleTransfer.NAME] = ReactModuleInfo(
47+
StyleTransfer.NAME,
48+
StyleTransfer.NAME,
49+
false, // canOverrideExistingModule
50+
false, // needsEagerInit
51+
false, // isCxxModule
52+
true
53+
)
4354
moduleInfos
4455
}
4556
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package com.swmansion.rnexecutorch
2+
3+
import android.graphics.BitmapFactory
4+
import android.net.Uri
5+
import com.facebook.react.bridge.Promise
6+
import com.facebook.react.bridge.ReactApplicationContext
7+
import com.swmansion.rnexecutorch.models.StyleTransferModel
8+
import com.swmansion.rnexecutorch.utils.BitmapUtils
9+
import com.swmansion.rnexecutorch.utils.ETError
10+
11+
class StyleTransfer(reactContext: ReactApplicationContext) :
12+
NativeStyleTransferSpec(reactContext) {
13+
14+
private lateinit var styleTransferModel: StyleTransferModel
15+
16+
companion object {
17+
const val NAME = "StyleTransfer"
18+
}
19+
20+
override fun loadModule(modelSource: String, promise: Promise) {
21+
try {
22+
styleTransferModel = StyleTransferModel(reactApplicationContext)
23+
styleTransferModel.loadModel(modelSource)
24+
promise.resolve(0)
25+
} catch (e: Exception) {
26+
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
27+
}
28+
}
29+
30+
override fun forward(input: String, promise: Promise) {
31+
try {
32+
val uri = Uri.parse(input)
33+
val bitmapInputStream = reactApplicationContext.contentResolver.openInputStream(uri)
34+
val rawBitmap = BitmapFactory.decodeStream(bitmapInputStream)
35+
bitmapInputStream!!.close()
36+
37+
val output = styleTransferModel.runModel(rawBitmap)
38+
val outputUri = BitmapUtils.saveToTempFile(output, "test")
39+
40+
promise.resolve(outputUri.toString())
41+
}catch(e: Exception){
42+
promise.reject(e.message!!, e.message)
43+
}
44+
}
45+
46+
override fun getName(): String {
47+
return NAME
48+
}
49+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package com.swmansion.rnexecutorch.models
2+
3+
import android.content.Context
4+
import com.swmansion.rnexecutorch.utils.ETError
5+
import com.swmansion.rnexecutorch.utils.Fetcher
6+
import org.pytorch.executorch.EValue
7+
import org.pytorch.executorch.Module
8+
9+
10+
abstract class BaseModel<Input, Output>(val context: Context) {
11+
protected lateinit var module: Module
12+
13+
fun loadModel(modelSource: String) {
14+
Fetcher.downloadModel(
15+
context,
16+
modelSource
17+
) { path, error ->
18+
if (error != null) {
19+
throw Error(error.message!!)
20+
}
21+
22+
module = Module.load(path)
23+
}
24+
}
25+
26+
protected fun forward(input: EValue): Array<EValue> {
27+
try {
28+
val result = module.forward(input)
29+
return result
30+
} catch (e: IllegalArgumentException) {
31+
//The error is thrown when transformation to Tensor fails
32+
throw Error(ETError.InvalidArgument.code.toString())
33+
} catch (e: Exception) {
34+
throw Error(e.message!!)
35+
}
36+
}
37+
38+
abstract fun runModel(input: Input): Output
39+
40+
protected abstract fun preprocess(input: Input): Input
41+
42+
protected abstract fun postprocess(input: Output): Output
43+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package com.swmansion.rnexecutorch.models
2+
3+
import android.graphics.Bitmap
4+
import android.util.Log
5+
import com.facebook.react.bridge.ReactApplicationContext
6+
import com.swmansion.rnexecutorch.utils.TensorUtils
7+
import org.pytorch.executorch.EValue
8+
9+
class StyleTransferModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Bitmap, Bitmap>(reactApplicationContext) {
10+
override fun runModel(input: Bitmap): Bitmap {
11+
val processedData = preprocess(input)
12+
val inputTensor = TensorUtils.bitmapToFloat32Tensor(processedData)
13+
14+
Log.d("RnExecutorch", module.numberOfInputs.toString())
15+
for (i in 0 until module.numberOfInputs) {
16+
Log.d("RnExecutorch", module.getInputType(i).toString())
17+
for(shape in module.getInputShape(i)){
18+
Log.d("RnExecutorch", shape.toString())
19+
}
20+
}
21+
22+
Log.d("RnExecutorch", module.numberOfOutputs.toString())
23+
for(i in 0 until module.numberOfOutputs){
24+
Log.d("RnExecutorch", module.getOutputType(i).toString())
25+
for(shape in module.getOutputShape(i)){
26+
Log.d("RnExecutorch", shape.toString())
27+
}
28+
}
29+
30+
val outputTensor = forward(EValue.from(inputTensor))
31+
val outputData = postprocess(TensorUtils.float32TensorToBitmap(outputTensor[0].toTensor()))
32+
33+
return outputData
34+
}
35+
36+
override fun preprocess(input: Bitmap): Bitmap {
37+
val inputBitmap = Bitmap.createScaledBitmap(
38+
input,
39+
640, 640, true
40+
)
41+
return inputBitmap
42+
}
43+
44+
override fun postprocess(input: Bitmap): Bitmap {
45+
val scaledUpBitmap = Bitmap.createScaledBitmap(
46+
input,
47+
1280, 1280, true
48+
)
49+
return scaledUpBitmap
50+
}
51+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package com.swmansion.rnexecutorch.utils
2+
3+
import android.graphics.Bitmap
4+
import android.graphics.Matrix
5+
import android.net.Uri
6+
import androidx.core.net.toUri
7+
import java.io.File
8+
import java.io.FileOutputStream
9+
import java.io.IOException
10+
11+
class BitmapUtils {
12+
companion object {
13+
fun saveToTempFile(bitmap: Bitmap, fileName: String): Uri {
14+
val tempFile = File.createTempFile(fileName, ".png")
15+
var outputStream : FileOutputStream? = null
16+
try {
17+
outputStream = FileOutputStream(tempFile)
18+
bitmap.compress(Bitmap.CompressFormat.PNG, 100, outputStream)
19+
} catch (e: IOException) {
20+
e.printStackTrace()
21+
}
22+
finally {
23+
outputStream?.close()
24+
}
25+
return tempFile.toUri()
26+
}
27+
28+
private fun rotateBitmap(bitmap: Bitmap, angle: Float): Bitmap {
29+
val matrix = Matrix()
30+
matrix.postRotate(angle)
31+
return Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
32+
}
33+
34+
private fun flipBitmap(bitmap: Bitmap, horizontal: Boolean, vertical: Boolean): Bitmap {
35+
val matrix = Matrix()
36+
matrix.preScale(
37+
if (horizontal) -1f else 1f,
38+
if (vertical) -1f else 1f
39+
)
40+
return Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
41+
}
42+
}
43+
44+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package com.swmansion.rnexecutorch.utils
2+
3+
enum class ETError(val code: Int) {
4+
InvalidModelPath(0xff),
5+
6+
// System errors
7+
Ok(0x00),
8+
Internal(0x01),
9+
InvalidState(0x02),
10+
EndOfMethod(0x03),
11+
12+
// Logical errors
13+
NotSupported(0x10),
14+
NotImplemented(0x11),
15+
InvalidArgument(0x12),
16+
InvalidType(0x13),
17+
OperatorMissing(0x14),
18+
19+
// Resource errors
20+
NotFound(0x20),
21+
MemoryAllocationFailed(0x21),
22+
AccessFailed(0x22),
23+
InvalidProgram(0x23),
24+
25+
// Delegate errors
26+
DelegateInvalidCompatibility(0x30),
27+
DelegateMemoryAllocationFailed(0x31),
28+
DelegateInvalidHandle(0x32);
29+
}

0 commit comments

Comments
 (0)