-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathObjectDetectionUtils.mm
84 lines (74 loc) · 2.54 KB
/
ObjectDetectionUtils.mm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include "ObjectDetectionUtils.hpp"
#include "Constants.h"
#include <map>
#include <vector>
NSString *floatLabelToNSString(float label) {
int intLabel = static_cast<int>(label);
auto it = cocoLabelsMap.find(intLabel);
if (it != cocoLabelsMap.end()) {
return [NSString stringWithUTF8String:it->second.c_str()];
} else {
return [NSString stringWithUTF8String:"unknown"];
}
}
NSDictionary *detectionToNSDictionary(const Detection &detection) {
return @{
@"bbox" : @{
@"x1" : @(detection.x1),
@"y1" : @(detection.y1),
@"x2" : @(detection.x2),
@"y2" : @(detection.y2),
},
@"label" : floatLabelToNSString(detection.label),
@"score" : @(detection.score)
};
}
float iou(const Detection &a, const Detection &b) {
float x1 = std::max(a.x1, b.x1);
float y1 = std::max(a.y1, b.y1);
float x2 = std::min(a.x2, b.x2);
float y2 = std::min(a.y2, b.y2);
float intersectionArea = std::max(0.0f, x2 - x1) * std::max(0.0f, y2 - y1);
float areaA = (a.x2 - a.x1) * (a.y2 - a.y1);
float areaB = (b.x2 - b.x1) * (b.y2 - b.y1);
float unionArea = areaA + areaB - intersectionArea;
return intersectionArea / unionArea;
};
std::vector<Detection> nms(std::vector<Detection> detections,
float iouThreshold) {
if (detections.empty()) {
return {};
}
// Sort by label, then by score
std::sort(detections.begin(), detections.end(),
[](const Detection &a, const Detection &b) {
if (a.label == b.label) {
return a.score > b.score;
}
return a.label < b.label;
});
std::vector<Detection> result;
// Apply NMS for each label
for (size_t i = 0; i < detections.size();) {
float currentLabel = detections[i].label;
std::vector<Detection> labelDetections;
while (i < detections.size() && detections[i].label == currentLabel) {
labelDetections.push_back(detections[i]);
++i;
}
std::vector<Detection> filteredLabelDetections;
while (!labelDetections.empty()) {
Detection current = labelDetections.front();
filteredLabelDetections.push_back(current);
labelDetections.erase(
std::remove_if(labelDetections.begin(), labelDetections.end(),
[&](const Detection &other) {
return iou(current, other) > iouThreshold;
}),
labelDetections.end());
}
result.insert(result.end(), filteredLabelDetections.begin(),
filteredLabelDetections.end());
}
return result;
}