Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pre-commit code formatting #107

Merged
merged 6 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,4 @@ body:
options:
- 'Yes'
validations:
required: true
required: true
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/feature-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ body:
description: >
Add any other relevant context, code examples, or screenshots that should be considered.
validations:
required: false
required: false
10 changes: 9 additions & 1 deletion .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,38 @@
## Description

<!-- Provide a concise and descriptive summary of the changes implemented in this PR. -->

### Type of change

- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Documentation update (improves or adds clarity to existing documentation)

### Tested on

- [ ] iOS
- [ ] Android

### Testing instructions

<!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. -->

### Screenshots

<!-- Add screenshots here, if applicable -->

### Related issues

<!-- Link related issues here using #issue-number -->

### Checklist

- [ ] I have performed a self-review of my code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have updated the documentation accordingly
- [ ] My changes generate no new warnings

### Additional notes
<!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->

<!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
5 changes: 2 additions & 3 deletions .github/workflows/build-ios-llama-example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ on:
jobs:
build:
if: github.repository == 'software-mansion/react-native-executorch'
name: "Example app iOS build check"
name: 'Example app iOS build check'
runs-on: macos-latest
steps:
- name: Check out Git repository
uses: actions/checkout@v4

- name: Install node dependencies
working-directory: examples/llama
run: yarn
Expand All @@ -42,4 +42,3 @@ jobs:
-destination 'platform=iOS Simulator,name=iPhone 16 Pro' \
build \
CODE_SIGNING_ALLOWED=NO | xcbeautify

2 changes: 1 addition & 1 deletion .github/workflows/docs-build-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ jobs:
run: mkdir -p .yarn/cache && yarn
- name: Generate docs
working-directory: ${{ env.WORKING_DIRECTORY }}
run: yarn build
run: yarn build
2 changes: 1 addition & 1 deletion .watchmanconfig
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{}
{}
4 changes: 2 additions & 2 deletions .yarnrc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ nmHoistingLimits: workspaces

plugins:
- path: .yarn/plugins/@yarnpkg/plugin-interactive-tools.cjs
spec: "@yarnpkg/plugin-interactive-tools"
spec: '@yarnpkg/plugin-interactive-tools'
- path: .yarn/plugins/@yarnpkg/plugin-workspace-tools.cjs
spec: "@yarnpkg/plugin-workspace-tools"
spec: '@yarnpkg/plugin-workspace-tools'

yarnPath: .yarn/releases/yarn-3.6.1.cjs
31 changes: 18 additions & 13 deletions android/src/main/java/com/swmansion/rnexecutorch/Classification.kt
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
package com.swmansion.rnexecutorch

import android.util.Log
import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.WritableMap
import com.swmansion.rnexecutorch.models.classification.ClassificationModel
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.android.OpenCVLoader
import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.WritableMap

class Classification(reactContext: ReactApplicationContext) :
NativeClassificationSpec(reactContext) {

class Classification(
reactContext: ReactApplicationContext,
) : NativeClassificationSpec(reactContext) {
private lateinit var classificationModel: ClassificationModel

companion object {
const val NAME = "Classification"

init {
if(!OpenCVLoader.initLocal()){
if (!OpenCVLoader.initLocal()) {
Log.d("rn_executorch", "OpenCV not loaded")
} else {
Log.d("rn_executorch", "OpenCV loaded")
}
}
}

override fun loadModule(modelSource: String, promise: Promise) {
override fun loadModule(
modelSource: String,
promise: Promise,
) {
try {
classificationModel = ClassificationModel(reactApplicationContext)
classificationModel.loadModel(modelSource)
Expand All @@ -36,24 +40,25 @@ class Classification(reactContext: ReactApplicationContext) :
}
}

override fun forward(input: String, promise: Promise) {
override fun forward(
input: String,
promise: Promise,
) {
try {
val image = ImageProcessor.readImage(input)
val output = classificationModel.runModel(image)

val writableMap: WritableMap = Arguments.createMap()

for ((key, value) in output) {
writableMap.putDouble(key, value.toDouble())
}

promise.resolve(writableMap)
}catch(e: Exception){
} catch (e: Exception) {
promise.reject(e.message!!, e.message)
}
}

override fun getName(): String {
return NAME
}
override fun getName(): String = NAME
}
47 changes: 28 additions & 19 deletions android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,26 @@ import org.pytorch.executorch.EValue
import org.pytorch.executorch.Module
import java.net.URL

class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(reactContext) {
class ETModule(
reactContext: ReactApplicationContext,
) : NativeETModuleSpec(reactContext) {
private lateinit var module: Module
private var reactApplicationContext = reactContext;
override fun getName(): String {
return NAME
}
private var reactApplicationContext = reactContext

override fun getName(): String = NAME

override fun loadModule(modelSource: String, promise: Promise) {
override fun loadModule(
modelSource: String,
promise: Promise,
) {
module = Module.load(URL(modelSource).path)
promise.resolve(0)
}

override fun loadMethod(methodName: String, promise: Promise) {
override fun loadMethod(
methodName: String,
promise: Promise,
) {
val result = module.loadMethod(methodName)
if (result != 0) {
promise.reject("Method loading failed", result.toString())
Expand All @@ -37,35 +44,37 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react
inputs: ReadableArray,
shapes: ReadableArray,
inputTypes: ReadableArray,
promise: Promise
promise: Promise,
) {
val inputEValues = ArrayList<EValue>()
try {
for (i in 0 until inputs.size()) {
val currentInput = inputs.getArray(i)
?: throw Exception(ETError.InvalidArgument.code.toString())
val currentShape = shapes.getArray(i)
?: throw Exception(ETError.InvalidArgument.code.toString())
val currentInput =
inputs.getArray(i)
?: throw Exception(ETError.InvalidArgument.code.toString())
val currentShape =
shapes.getArray(i)
?: throw Exception(ETError.InvalidArgument.code.toString())
val currentInputType = inputTypes.getInt(i)

val currentEValue = TensorUtils.getExecutorchInput(
currentInput,
ArrayUtils.createLongArray(currentShape),
currentInputType
)
val currentEValue =
TensorUtils.getExecutorchInput(
currentInput,
ArrayUtils.createLongArray(currentShape),
currentInputType,
)

inputEValues.add(currentEValue)
}

val forwardOutputs = module.forward(*inputEValues.toTypedArray());
val forwardOutputs = module.forward(*inputEValues.toTypedArray())
val outputArray = Arguments.createArray()

for (output in forwardOutputs) {
val arr = ArrayUtils.createReadableArrayFromTensor(output.toTensor())
outputArray.pushArray(arr)
}
promise.resolve(outputArray)

} catch (e: IllegalArgumentException) {
// The error is thrown when transformation to Tensor fails
promise.reject("Forward Failed Execution", ETError.InvalidArgument.code.toString())
Expand Down
30 changes: 18 additions & 12 deletions android/src/main/java/com/swmansion/rnexecutorch/LLM.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,23 @@ import android.util.Log
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.utils.ArrayUtils
import com.swmansion.rnexecutorch.utils.llms.ChatRole
import com.swmansion.rnexecutorch.utils.llms.ConversationManager
import com.swmansion.rnexecutorch.utils.llms.END_OF_TEXT_TOKEN
import org.pytorch.executorch.LlamaCallback
import org.pytorch.executorch.LlamaModule
import com.swmansion.rnexecutorch.utils.ArrayUtils
import java.net.URL

class LLM(reactContext: ReactApplicationContext) : NativeLLMSpec(reactContext), LlamaCallback {

class LLM(
reactContext: ReactApplicationContext,
) : NativeLLMSpec(reactContext),
LlamaCallback {
private var llamaModule: LlamaModule? = null
private var tempLlamaResponse = StringBuilder()
private lateinit var conversationManager: ConversationManager

override fun getName(): String {
return NAME
}
override fun getName(): String = NAME

override fun initialize() {
super.initialize()
Expand All @@ -41,12 +41,15 @@ class LLM(reactContext: ReactApplicationContext) : NativeLLMSpec(reactContext),
systemPrompt: String,
messageHistory: ReadableArray,
contextWindowLength: Double,
promise: Promise
promise: Promise,
) {
try {
this.conversationManager = ConversationManager(
contextWindowLength.toInt(), systemPrompt, ArrayUtils.createMapArray<String>(messageHistory)
)
this.conversationManager =
ConversationManager(
contextWindowLength.toInt(),
systemPrompt,
ArrayUtils.createMapArray<String>(messageHistory),
)
llamaModule = LlamaModule(1, URL(modelSource).path, URL(tokenizerSource).path, 0.7f)
this.tempLlamaResponse.clear()
promise.resolve("Model loaded successfully")
Expand All @@ -55,7 +58,10 @@ class LLM(reactContext: ReactApplicationContext) : NativeLLMSpec(reactContext),
}
}

override fun runInference(input: String, promise: Promise) {
override fun runInference(
input: String,
promise: Promise,
) {
this.conversationManager.addResponse(input, ChatRole.USER)
val conversation = this.conversationManager.getConversation()

Expand All @@ -66,7 +72,7 @@ class LLM(reactContext: ReactApplicationContext) : NativeLLMSpec(reactContext),
// generated sequence length is larger than specified in the JNI callback, hence we check if EOT
// is there and if not, we append it to the output and emit the EOT token to the JS side.
if (!this.tempLlamaResponse.endsWith(END_OF_TEXT_TOKEN)) {
this.onResult(END_OF_TEXT_TOKEN);
this.onResult(END_OF_TEXT_TOKEN)
}

// We want to add the LLM response to the conversation once all the tokens are generated.
Expand Down
Loading
Loading