|
1 | 1 | #import "ImageSegmentationModel.h"
|
| 2 | +#import <unordered_set> |
2 | 3 | #import "../../utils/ImageProcessor.h"
|
3 | 4 | #import "../../utils/Numerical.h"
|
4 | 5 | #import "opencv2/opencv.hpp"
|
5 | 6 | #import "Constants.h"
|
6 | 7 |
|
7 | 8 | @interface ImageSegmentationModel ()
|
8 | 9 | - (NSArray *)preprocess:(cv::Mat &)input;
|
9 |
| - - (NSDictionary *)postprocess:(NSArray *)output; |
| 10 | + - (NSDictionary *)postprocess:(NSArray *)output |
| 11 | + returnClasses:(NSArray *)classesOfInterest; |
10 | 12 | @end
|
11 | 13 |
|
12 | 14 | @implementation ImageSegmentationModel {
|
@@ -35,7 +37,8 @@ - (NSArray *)preprocess:(cv::Mat &)input {
|
35 | 37 | return modelInput;
|
36 | 38 | }
|
37 | 39 |
|
38 |
| -- (NSDictionary *)postprocess:(NSArray *)output { |
| 40 | +- (NSDictionary *)postprocess:(NSArray *)output |
| 41 | + returnClasses:(NSArray *)classesOfInterest{ |
39 | 42 | cv::Size modelImageSize = [self getModelImageSize];
|
40 | 43 |
|
41 | 44 | std::size_t numLabels = deeplabv3_resnet50_labels.size();
|
@@ -87,25 +90,34 @@ - (NSDictionary *)postprocess:(NSArray *)output {
|
87 | 90 | maxArg.at<int>(row, col) = maxArgIndex;
|
88 | 91 | }
|
89 | 92 |
|
| 93 | + std::unordered_set<std::string> labelSet; |
| 94 | + |
| 95 | + for (id label in classesOfInterest) { |
| 96 | + labelSet.insert(std::string([label UTF8String])); |
| 97 | + } |
| 98 | + |
90 | 99 | NSMutableDictionary *result = [NSMutableDictionary dictionary];
|
91 |
| - |
| 100 | + |
92 | 101 | // Convert to NSArray and populate the final dictionary
|
93 | 102 | for (std::size_t label = 0; label < numLabels; ++label) {
|
94 |
| - NSString *labelString = @(deeplabv3_resnet50_labels[label].c_str()); |
95 |
| - NSMutableArray *arr = matToNSArray<double>(resizedLabelScores[label]); |
96 |
| - result[labelString] = arr; |
| 103 | + if (labelSet.contains(deeplabv3_resnet50_labels[label])){ |
| 104 | + NSString *labelString = @(deeplabv3_resnet50_labels[label].c_str()); |
| 105 | + NSArray *arr = matToNSArray<double>(resizedLabelScores[label]); |
| 106 | + result[labelString] = arr; |
| 107 | + } |
97 | 108 | }
|
98 | 109 |
|
99 | 110 | result[@"argmax"] = matToNSArray<int>(maxArg);
|
100 | 111 |
|
101 | 112 | return result;
|
102 | 113 | }
|
103 | 114 |
|
104 |
| -- (NSDictionary *)runModel:(cv::Mat &)input { |
| 115 | +- (NSDictionary *)runModel:(cv::Mat &)input |
| 116 | + returnClasses:(NSArray *)classesOfInterest { |
105 | 117 | NSArray *modelInput = [self preprocess:input];
|
106 | 118 | NSArray *result = [self forward:modelInput];
|
107 | 119 |
|
108 |
| - NSDictionary *output = [self postprocess:result[0]]; |
| 120 | + NSDictionary *output = [self postprocess:result[0] returnClasses:classesOfInterest]; |
109 | 121 |
|
110 | 122 | return output;
|
111 | 123 | }
|
|
0 commit comments