Skip to content

Commit aa0f9c7

Browse files
feat: style transfer(android) (#51)
## Description This pr introduces implementation of style transfer with openCV for android ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [ ] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent f556b8a commit aa0f9c7

File tree

7 files changed

+177
-105
lines changed

7 files changed

+177
-105
lines changed

android/build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,6 @@ dependencies {
9898
implementation "com.facebook.react:react-native:+"
9999
implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
100100
implementation 'com.github.software-mansion:react-native-executorch:main-SNAPSHOT'
101+
implementation 'org.opencv:opencv:4.10.0'
101102
implementation("com.squareup.okhttp3:okhttp:4.9.2")
102103
}

android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt

+13-12
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package com.swmansion.rnexecutorch
22

3-
import android.graphics.BitmapFactory
4-
import android.net.Uri
3+
import android.util.Log
54
import com.facebook.react.bridge.Promise
65
import com.facebook.react.bridge.ReactApplicationContext
76
import com.swmansion.rnexecutorch.models.StyleTransferModel
8-
import com.swmansion.rnexecutorch.utils.BitmapUtils
97
import com.swmansion.rnexecutorch.utils.ETError
8+
import com.swmansion.rnexecutorch.utils.ImageProcessor
9+
import org.opencv.android.OpenCVLoader
1010

1111
class StyleTransfer(reactContext: ReactApplicationContext) :
1212
NativeStyleTransferSpec(reactContext) {
@@ -15,6 +15,14 @@ class StyleTransfer(reactContext: ReactApplicationContext) :
1515

1616
companion object {
1717
const val NAME = "StyleTransfer"
18+
19+
init {
20+
if(!OpenCVLoader.initLocal()){
21+
Log.d("rn_executorch", "OpenCV not loaded")
22+
} else {
23+
Log.d("rn_executorch", "OpenCV loaded")
24+
}
25+
}
1826
}
1927

2028
override fun loadModule(modelSource: String, promise: Promise) {
@@ -29,15 +37,8 @@ class StyleTransfer(reactContext: ReactApplicationContext) :
2937

3038
override fun forward(input: String, promise: Promise) {
3139
try {
32-
val uri = Uri.parse(input)
33-
val bitmapInputStream = reactApplicationContext.contentResolver.openInputStream(uri)
34-
val rawBitmap = BitmapFactory.decodeStream(bitmapInputStream)
35-
bitmapInputStream!!.close()
36-
37-
val output = styleTransferModel.runModel(rawBitmap)
38-
val outputUri = BitmapUtils.saveToTempFile(output, "test")
39-
40-
promise.resolve(outputUri.toString())
40+
val output = styleTransferModel.runModel(ImageProcessor.readImage(input))
41+
promise.resolve(ImageProcessor.saveToTempFile(reactApplicationContext, output))
4142
}catch(e: Exception){
4243
promise.reject(e.message!!, e.message)
4344
}

android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import com.swmansion.rnexecutorch.utils.ETError
55
import com.swmansion.rnexecutorch.utils.Fetcher
66
import org.pytorch.executorch.EValue
77
import org.pytorch.executorch.Module
8+
import org.pytorch.executorch.Tensor
89

910

1011
abstract class BaseModel<Input, Output>(val context: Context) {
@@ -39,5 +40,5 @@ abstract class BaseModel<Input, Output>(val context: Context) {
3940

4041
protected abstract fun preprocess(input: Input): Input
4142

42-
protected abstract fun postprocess(input: Output): Output
43+
protected abstract fun postprocess(input: Tensor): Output
4344
}
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,40 @@
11
package com.swmansion.rnexecutorch.models
22

3-
import android.graphics.Bitmap
4-
import android.util.Log
53
import com.facebook.react.bridge.ReactApplicationContext
6-
import com.swmansion.rnexecutorch.utils.TensorUtils
7-
import org.pytorch.executorch.EValue
8-
9-
class StyleTransferModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Bitmap, Bitmap>(reactApplicationContext) {
10-
override fun runModel(input: Bitmap): Bitmap {
11-
val processedData = preprocess(input)
12-
val inputTensor = TensorUtils.bitmapToFloat32Tensor(processedData)
13-
14-
Log.d("RnExecutorch", module.numberOfInputs.toString())
15-
for (i in 0 until module.numberOfInputs) {
16-
Log.d("RnExecutorch", module.getInputType(i).toString())
17-
for(shape in module.getInputShape(i)){
18-
Log.d("RnExecutorch", shape.toString())
19-
}
20-
}
21-
22-
Log.d("RnExecutorch", module.numberOfOutputs.toString())
23-
for(i in 0 until module.numberOfOutputs){
24-
Log.d("RnExecutorch", module.getOutputType(i).toString())
25-
for(shape in module.getOutputShape(i)){
26-
Log.d("RnExecutorch", shape.toString())
27-
}
28-
}
29-
30-
val outputTensor = forward(EValue.from(inputTensor))
31-
val outputData = postprocess(TensorUtils.float32TensorToBitmap(outputTensor[0].toTensor()))
32-
33-
return outputData
4+
import com.swmansion.rnexecutorch.utils.ImageProcessor
5+
import org.opencv.core.Mat
6+
import org.opencv.core.Size
7+
import org.opencv.imgproc.Imgproc
8+
import org.pytorch.executorch.Tensor
9+
10+
11+
class StyleTransferModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Mat>(reactApplicationContext) {
12+
private lateinit var originalSize: Size
13+
14+
private fun getModelImageSize(): Size {
15+
val inputShape = module.getInputShape(0)
16+
val width = inputShape[inputShape.lastIndex]
17+
val height = inputShape[inputShape.lastIndex - 1]
18+
19+
return Size(height.toDouble(), width.toDouble())
20+
}
21+
22+
override fun preprocess(input: Mat): Mat {
23+
originalSize = input.size()
24+
Imgproc.resize(input, input, getModelImageSize())
25+
return input
3426
}
3527

36-
override fun preprocess(input: Bitmap): Bitmap {
37-
val inputBitmap = Bitmap.createScaledBitmap(
38-
input,
39-
640, 640, true
40-
)
41-
return inputBitmap
28+
override fun postprocess(input: Tensor): Mat {
29+
val modelShape = getModelImageSize()
30+
val result = ImageProcessor.EValueToMat(input.dataAsFloatArray, modelShape.width.toInt(), modelShape.height.toInt())
31+
Imgproc.resize(result, result, originalSize)
32+
return result
4233
}
4334

44-
override fun postprocess(input: Bitmap): Bitmap {
45-
val scaledUpBitmap = Bitmap.createScaledBitmap(
46-
input,
47-
1280, 1280, true
48-
)
49-
return scaledUpBitmap
35+
override fun runModel(input: Mat): Mat {
36+
val inputTensor = ImageProcessor.matToEValue(preprocess(input), module.getInputShape(0))
37+
val outputTensor = forward(inputTensor)
38+
return postprocess(outputTensor[0].toTensor())
5039
}
5140
}

android/src/main/java/com/swmansion/rnexecutorch/utils/BitmapUtils.kt

-44
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,32 @@
11
package com.swmansion.rnexecutorch.utils
22

33
enum class ETError(val code: Int) {
4+
UndefinedError(0x65),
5+
ModuleNotLoaded(0x66),
6+
FileWriteFailed(0x67),
47
InvalidModelPath(0xff),
5-
6-
// System errors
8+
9+
// System errors
710
Ok(0x00),
811
Internal(0x01),
912
InvalidState(0x02),
1013
EndOfMethod(0x03),
11-
14+
1215
// Logical errors
1316
NotSupported(0x10),
1417
NotImplemented(0x11),
1518
InvalidArgument(0x12),
1619
InvalidType(0x13),
1720
OperatorMissing(0x14),
18-
21+
1922
// Resource errors
2023
NotFound(0x20),
2124
MemoryAllocationFailed(0x21),
2225
AccessFailed(0x22),
2326
InvalidProgram(0x23),
24-
27+
2528
// Delegate errors
2629
DelegateInvalidCompatibility(0x30),
2730
DelegateMemoryAllocationFailed(0x31),
2831
DelegateInvalidHandle(0x32);
29-
}
32+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package com.swmansion.rnexecutorch.utils
2+
3+
import android.content.Context
4+
import android.net.Uri
5+
import android.util.Base64
6+
import org.opencv.core.CvType
7+
import org.opencv.core.Mat
8+
import org.opencv.imgcodecs.Imgcodecs
9+
import org.pytorch.executorch.EValue
10+
import org.pytorch.executorch.Tensor
11+
import java.io.File
12+
import java.io.InputStream
13+
import java.net.URL
14+
import java.util.UUID
15+
16+
17+
class ImageProcessor {
18+
companion object {
19+
fun matToEValue(mat: Mat, shape: LongArray): EValue {
20+
val pixelCount = mat.cols() * mat.rows()
21+
val floatArray = FloatArray(pixelCount * 3)
22+
23+
for (i in 0 until pixelCount) {
24+
val row = i / mat.cols()
25+
val col = i % mat.cols()
26+
val pixel = mat.get(row, col)
27+
28+
if (mat.type() == CvType.CV_8UC3 || mat.type() == CvType.CV_8UC4) {
29+
val b = pixel[0] / 255.0f
30+
val g = pixel[1] / 255.0f
31+
val r = pixel[2] / 255.0f
32+
33+
floatArray[i] = r.toFloat()
34+
floatArray[pixelCount + i] = g.toFloat()
35+
floatArray[2 * pixelCount + i] = b.toFloat()
36+
}
37+
}
38+
39+
return EValue.from(Tensor.fromBlob(floatArray, shape))
40+
}
41+
42+
fun EValueToMat(array: FloatArray, width: Int, height: Int): Mat {
43+
val mat = Mat(height, width, CvType.CV_8UC3)
44+
45+
val pixelCount = width * height
46+
for (i in 0 until pixelCount) {
47+
val row = i / width
48+
val col = i % width
49+
50+
val r = (array[i] * 255).toInt().toByte()
51+
val g = (array[pixelCount + i] * 255).toInt().toByte()
52+
val b = (array[2 * pixelCount + i] * 255).toInt().toByte()
53+
54+
val color = byteArrayOf(b, g, r)
55+
mat.put(row, col, color)
56+
}
57+
return mat
58+
}
59+
60+
fun saveToTempFile(context: Context, mat: Mat): String {
61+
try {
62+
val uniqueID = UUID.randomUUID().toString()
63+
val tempFile = File(context.cacheDir, "rn_executorch_$uniqueID.png")
64+
Imgcodecs.imwrite(tempFile.absolutePath, mat)
65+
66+
return "file://${tempFile.absolutePath}"
67+
}catch (e: Exception) {
68+
throw Exception(ETError.FileWriteFailed.toString())
69+
}
70+
}
71+
72+
fun readImage(source: String): Mat {
73+
val inputImage: Mat
74+
75+
val uri = Uri.parse(source)
76+
val scheme = uri.scheme ?: ""
77+
78+
when {
79+
scheme.equals("data", ignoreCase = true) -> {
80+
//base64
81+
val parts = source.split(",", limit = 2)
82+
if (parts.size < 2) throw IllegalArgumentException(ETError.InvalidArgument.toString())
83+
84+
val encodedString = parts[1]
85+
val data = Base64.decode(encodedString, Base64.DEFAULT)
86+
87+
val encodedData = Mat(1, data.size, CvType.CV_8UC1).apply {
88+
put(0, 0, data)
89+
}
90+
inputImage = Imgcodecs.imdecode(encodedData, Imgcodecs.IMREAD_COLOR)
91+
}
92+
scheme.equals("file", ignoreCase = true) -> {
93+
//device storage
94+
val path = uri.path
95+
inputImage = Imgcodecs.imread(path, Imgcodecs.IMREAD_COLOR)
96+
}
97+
else -> {
98+
//external source
99+
val url = URL(source)
100+
val connection = url.openConnection()
101+
connection.connect()
102+
103+
val inputStream: InputStream = connection.getInputStream()
104+
val data = inputStream.readBytes()
105+
inputStream.close()
106+
107+
val encodedData = Mat(1, data.size, CvType.CV_8UC1).apply {
108+
put(0, 0, data)
109+
}
110+
inputImage = Imgcodecs.imdecode(encodedData, Imgcodecs.IMREAD_COLOR)
111+
}
112+
}
113+
114+
if (inputImage.empty()) {
115+
throw IllegalArgumentException(ETError.InvalidArgument.toString())
116+
}
117+
118+
return inputImage
119+
}
120+
}
121+
}

0 commit comments

Comments
 (0)