Skip to content

Commit 6e02df8

Browse files
committed
Add arg max map to the segmentation result
1 parent e7e726d commit 6e02df8

File tree

3 files changed

+50
-9
lines changed

3 files changed

+50
-9
lines changed

ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.h

+13
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,16 @@
66
- (NSDictionary *)runModel:(cv::Mat &)input;
77

88
@end
9+
10+
template <typename T>
11+
NSArray* matToNSArray(const cv::Mat& mat) {
12+
std::size_t numPixels = mat.rows * mat.cols;
13+
NSMutableArray *arr = [[NSMutableArray alloc] initWithCapacity:numPixels];
14+
15+
for (std::size_t x = 0; x < mat.rows; ++x) {
16+
for (std::size_t y = 0; y < mat.cols; ++y) {
17+
arr[x * mat.cols + y] = @(mat.at<T>(x, y));
18+
}
19+
}
20+
return arr;
21+
}

ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm

+14-9
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ - (NSDictionary *)postprocess:(NSArray *)output {
4646
NSAssert(outputSize == numLabels * numModelPixels,
4747
@"Model generated unexpected output size.");
4848

49-
5049
// For each label extract it's matrix and rescale it to the original size
5150
std::vector<cv::Mat> resizedLabelScores(numLabels);
5251
for (std::size_t label = 0; label < numLabels; ++label) {
@@ -61,6 +60,8 @@ - (NSDictionary *)postprocess:(NSArray *)output {
6160
cv::resize(labelMat, resizedLabelScores[label], originalSize);
6261
}
6362

63+
cv::Mat maxArg = cv::Mat(originalSize, CV_32S);
64+
6465
// For each pixel apply softmax across all the labels
6566
for (std::size_t pixel = 0; pixel < numOriginalPixels; ++pixel) {
6667
int row = pixel / originalSize.width;
@@ -73,26 +74,30 @@ - (NSDictionary *)postprocess:(NSArray *)output {
7374

7475
std::vector<double> adjustedScores = softmax(scores);
7576

77+
std::size_t maxArgIndex = 0;
78+
double maxArgVal = 0;
7679
for (std::size_t label = 0; label < numLabels; ++label) {
7780
resizedLabelScores[label].at<double>(row, col) = adjustedScores[label];
81+
if (adjustedScores[label] > maxArgVal) {
82+
maxArgIndex = label;
83+
maxArgVal = adjustedScores[label];
84+
}
7885
}
86+
87+
maxArg.at<int>(row, col) = maxArgIndex;
7988
}
8089

8190
NSMutableDictionary *result = [NSMutableDictionary dictionary];
8291

92+
// Convert to NSArray and populate the final dictionary
8393
for (std::size_t label = 0; label < numLabels; ++label) {
8494
NSString *labelString = @(deeplabv3_resnet50_labels[label].c_str());
85-
NSMutableArray *arr = [[NSMutableArray alloc] initWithCapacity:numOriginalPixels];
86-
87-
for (std::size_t x = 0; x < originalSize.height; ++x) {
88-
for (std::size_t y = 0; y < originalSize.width; ++y) {
89-
arr[x * originalSize.width + y] = @(resizedLabelScores[label].at<double>(x, y));
90-
}
91-
}
92-
95+
NSMutableArray *arr = matToNSArray<double>(resizedLabelScores[label]);
9396
result[labelString] = arr;
9497
}
9598

99+
result[@"argmax"] = matToNSArray<int>(maxArg);
100+
96101
return result;
97102
}
98103

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
export const classLabels = new Map<number, string>([
2+
[0, 'background'],
3+
[1, 'aeroplane'],
4+
[2, 'bicycle'],
5+
[3, 'bird'],
6+
[4, 'boat'],
7+
[5, 'bottle'],
8+
[6, 'bus'],
9+
[7, 'car'],
10+
[8, 'cat'],
11+
[9, 'chair'],
12+
[10, 'cow'],
13+
[11, 'diningtable'],
14+
[12, 'dog'],
15+
[13, 'horse'],
16+
[14, 'motorbike'],
17+
[15, 'person'],
18+
[16, 'pottedplant'],
19+
[17, 'sheep'],
20+
[18, 'sofa'],
21+
[19, 'train'],
22+
[20, 'tvmonitor'],
23+
]);

0 commit comments

Comments
 (0)