Skip to content

Commit da8cc8d

Browse files
committed
Change segmentation labels to uppercase on Android
1 parent e99dc84 commit da8cc8d

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,26 @@
11
package com.swmansion.rnexecutorch.models.imagesegmentation
22

3-
val deeplabv3_resnet50_labels: Array<String> = arrayOf(
4-
"background", "aeroplane", "bicycle", "bird", "boat",
5-
"bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
6-
"dog", "horse", "motorbike", "person", "pottedplant", "sheep",
7-
"sofa", "train", "tvmonitor"
8-
)
3+
val deeplabv3_resnet50_labels: Array<String> =
4+
arrayOf(
5+
"BACKGROUND",
6+
"AEROPLANE",
7+
"BICYCLE",
8+
"BIRD",
9+
"BOAT",
10+
"BOTTLE",
11+
"BUS",
12+
"CAR",
13+
"CAT",
14+
"CHAIR",
15+
"COW",
16+
"DININGTABLE",
17+
"DOG",
18+
"HORSE",
19+
"MOTORBIKE",
20+
"PERSON",
21+
"POTTEDPLANT",
22+
"SHEEP",
23+
"SOFA",
24+
"TRAIN",
25+
"TVMONITOR",
26+
)

android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt

+4-4
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,15 @@ class ImageSegmentationModel(
9494
classesOfInterest: ReadableArray,
9595
resize: Boolean,
9696
): WritableMap {
97-
val output = output[0].toTensor().dataAsFloatArray
97+
val outputData = output[0].toTensor().dataAsFloatArray
9898
val modelSize = getModelImageSize()
9999
val numLabels = deeplabv3_resnet50_labels.size
100100

101-
require(output.count() == (numLabels * modelSize.height * modelSize.width).toInt()) { "Model generated unexpected output size." }
101+
require(outputData.count() == (numLabels * modelSize.height * modelSize.width).toInt()) { "Model generated unexpected output size." }
102102

103103
val outputSize = if (resize) originalSize else modelSize
104104

105-
val extractedResults = extractResults(output, numLabels, resize)
105+
val extractedResults = extractResults(outputData, numLabels, resize)
106106

107107
val argMax = adjustScoresPerPixel(extractedResults, numLabels, outputSize)
108108

@@ -124,7 +124,7 @@ class ImageSegmentationModel(
124124
}
125125

126126
res.putArray(
127-
"argmax",
127+
"ARGMAX",
128128
ArrayUtils.createReadableArrayFromIntArray(argMax),
129129
)
130130

0 commit comments

Comments
 (0)