|
| 1 | +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | +#pragma once |
| 15 | +#include <fstream> |
| 16 | +#include "fastdeploy/vision.h" |
| 17 | + |
| 18 | +std::vector<std::string> stringSplit(const std::string& str, char delim) { |
| 19 | + std::stringstream ss(str); |
| 20 | + std::string item; |
| 21 | + std::vector<std::string> elems; |
| 22 | + while (std::getline(ss, item, delim)) { |
| 23 | + if (!item.empty()) { |
| 24 | + elems.push_back(item); |
| 25 | + } |
| 26 | + } |
| 27 | + return elems; |
| 28 | +} |
| 29 | + |
| 30 | + |
| 31 | +bool CompareDetResult(const fastdeploy::vision::DetectionResult& res, |
| 32 | + const std::string& det_result_file) { |
| 33 | + std::ifstream res_str(det_result_file); |
| 34 | + if (!res_str.is_open()) { |
| 35 | + std::cout<< "Could not open detect result file : " |
| 36 | + << det_result_file <<"\n"<< std::endl; |
| 37 | + return false; |
| 38 | + } |
| 39 | + int obj_num = 0; |
| 40 | + while (!res_str.eof()) { |
| 41 | + std::string line; |
| 42 | + std::getline(res_str, line); |
| 43 | + if (line.find("DetectionResult") == line.npos |
| 44 | + && line.find(",") != line.npos ) { |
| 45 | + auto strs = stringSplit(line, ','); |
| 46 | + if (strs.size() != 6) { |
| 47 | + std::cout<< "Failed to parse result file : " |
| 48 | + << det_result_file <<"\n"<< std::endl; |
| 49 | + return false; |
| 50 | + } |
| 51 | + std::vector<float> vals; |
| 52 | + for (auto str : strs) { |
| 53 | + vals.push_back(atof(str.c_str())); |
| 54 | + } |
| 55 | + if (abs(res.scores[obj_num] - vals[4]) > 0.3) { |
| 56 | + std::cout<< "Score error, the result is: " |
| 57 | + << res.scores[obj_num] << " but the expected is: " |
| 58 | + << vals[4] << std::endl; |
| 59 | + return false; |
| 60 | + } |
| 61 | + if (abs(res.label_ids[obj_num] - vals[5]) > 0) { |
| 62 | + std::cout<< "label error, the result is: " |
| 63 | + << res.label_ids[obj_num] << " but the expected is: " |
| 64 | + << vals[5] <<std::endl; |
| 65 | + return false; |
| 66 | + } |
| 67 | + std::array<float, 4> boxes = res.boxes[obj_num++]; |
| 68 | + for (auto i = 0; i < 4; i++) { |
| 69 | + if (abs(boxes[i] - vals[i]) > 5) { |
| 70 | + std::cout<< "position error, the result is: " |
| 71 | + << boxes[i] << " but the expected is: " << vals[i] <<std::endl; |
| 72 | + return false; |
| 73 | + } |
| 74 | + } |
| 75 | + } |
| 76 | + } |
| 77 | + return true; |
| 78 | +} |
| 79 | + |
| 80 | + |
| 81 | +bool CompareClsResult(const fastdeploy::vision::ClassifyResult& res, |
| 82 | + const std::string& cls_result_file) { |
| 83 | + std::ifstream res_str(cls_result_file); |
| 84 | + if (!res_str.is_open()) { |
| 85 | + std::cout<< "Could not open detect result file : " |
| 86 | + << cls_result_file << "\n" << std::endl; |
| 87 | + return false; |
| 88 | + } |
| 89 | + int obj_num = 0; |
| 90 | + while (!res_str.eof()) { |
| 91 | + std::string line; |
| 92 | + std::getline(res_str, line); |
| 93 | + if (line.find("label_ids") != line.npos |
| 94 | + && line.find(":") != line.npos) { |
| 95 | + auto strs = stringSplit(line, ':'); |
| 96 | + if (strs.size() != 2) { |
| 97 | + std::cout<< "Failed to parse result file : " |
| 98 | + << cls_result_file <<"\n"<< std::endl; |
| 99 | + return false; |
| 100 | + } |
| 101 | + int32_t label = static_cast<int32_t>(atof(strs[1].c_str())); |
| 102 | + if (res.label_ids[obj_num] != label) { |
| 103 | + std::cout<< "label error, the result is: " |
| 104 | + << res.label_ids[obj_num] << " but the expected is: " |
| 105 | + << label<< "\n" << std::endl; |
| 106 | + return false; |
| 107 | + } |
| 108 | + } else if (line.find("scores") != line.npos |
| 109 | + && line.find(":") != line.npos) { |
| 110 | + auto strs = stringSplit(line, ':'); |
| 111 | + if (strs.size() != 2) { |
| 112 | + std::cout<< "Failed to parse result file : " |
| 113 | + << cls_result_file << "\n" << std::endl; |
| 114 | + return false; |
| 115 | + } |
| 116 | + float score = atof(strs[1].c_str()); |
| 117 | + if (abs(res.scores[obj_num] - score) > 1e-1) { |
| 118 | + std::cout << "score error, the result is: " |
| 119 | + << res.scores[obj_num] << " but the expected is: " |
| 120 | + << score << "\n" << std::endl; |
| 121 | + return false; |
| 122 | + } else { |
| 123 | + obj_num++; |
| 124 | + } |
| 125 | + } else if (line.size()) { |
| 126 | + std::cout << "Unknown File. \n" << std::endl; |
| 127 | + return false; |
| 128 | + } |
| 129 | + } |
| 130 | + return true; |
| 131 | +} |
| 132 | + |
| 133 | +bool WriteSegResult(const fastdeploy::vision::SegmentationResult& res, |
| 134 | + const std::string& seg_result_file) { |
| 135 | + std::ofstream res_str(seg_result_file); |
| 136 | + if (!res_str.is_open()) { |
| 137 | + std::cerr<< "Could not open segmentation result file : " |
| 138 | + << seg_result_file <<" to write.\n"<< std::endl; |
| 139 | + return false; |
| 140 | + } |
| 141 | + std::string out; |
| 142 | + out = ""; |
| 143 | + // save shape |
| 144 | + for (auto shape : res.shape) { |
| 145 | + out += std::to_string(shape) + ","; |
| 146 | + } |
| 147 | + out += "\n"; |
| 148 | + // save label |
| 149 | + for (auto label : res.label_map) { |
| 150 | + out += std::to_string(label) + ","; |
| 151 | + } |
| 152 | + out += "\n"; |
| 153 | + // save score |
| 154 | + if (res.contain_score_map) { |
| 155 | + for (auto score : res.score_map) { |
| 156 | + out += std::to_string(score) + ","; |
| 157 | + } |
| 158 | + } |
| 159 | + res_str << out; |
| 160 | + return true; |
| 161 | +} |
| 162 | + |
| 163 | +bool CompareSegResult(const fastdeploy::vision::SegmentationResult& res, |
| 164 | + const std::string& seg_result_file) { |
| 165 | + std::ifstream res_str(seg_result_file); |
| 166 | + if (!res_str.is_open()) { |
| 167 | + std::cout<< "Could not open detect result file : " |
| 168 | + << seg_result_file <<"\n"<< std::endl; |
| 169 | + return false; |
| 170 | + } |
| 171 | + std::string line; |
| 172 | + std::getline(res_str, line); |
| 173 | + if (line.find(",") == line.npos) { |
| 174 | + std::cout << "Unexpected File." << std::endl; |
| 175 | + return false; |
| 176 | + } |
| 177 | + // check shape diff |
| 178 | + auto shape_strs = stringSplit(line, ','); |
| 179 | + std::vector<int64_t> shape; |
| 180 | + for (auto str : shape_strs) { |
| 181 | + shape.push_back(static_cast<int64_t>(atof(str.c_str()))); |
| 182 | + } |
| 183 | + if (shape.size() != res.shape.size()) { |
| 184 | + std::cout << "Output shape and expected shape size mismatch, shape size: " |
| 185 | + << res.shape.size() << " expected shape size: " |
| 186 | + << shape.size() << std::endl; |
| 187 | + return false; |
| 188 | + } |
| 189 | + for (auto i = 0; i < res.shape.size(); i++) { |
| 190 | + if (res.shape[i] != shape[i]) { |
| 191 | + std::cout << "Output Shape and expected shape mismatch, shape: " |
| 192 | + << res.shape[i] << " expected: " << shape[i] << std::endl; |
| 193 | + return false; |
| 194 | + } |
| 195 | + } |
| 196 | + std::cout << "Shape check passed!" << std::endl; |
| 197 | + |
| 198 | + std::getline(res_str, line); |
| 199 | + if (line.find(",") == line.npos) { |
| 200 | + std::cout << "Unexpected File." << std::endl; |
| 201 | + return false; |
| 202 | + } |
| 203 | + // check label |
| 204 | + auto label_strs = stringSplit(line, ','); |
| 205 | + std::vector<uint8_t> labels; |
| 206 | + for (auto str : label_strs) { |
| 207 | + labels.push_back(static_cast<uint8_t>(atof(str.c_str()))); |
| 208 | + } |
| 209 | + if (labels.size() != res.label_map.size()) { |
| 210 | + std::cout << "Output labels and expected shape size mismatch." << std::endl; |
| 211 | + return false; |
| 212 | + } |
| 213 | + for (auto i = 0; i < res.label_map.size(); i++) { |
| 214 | + if (res.label_map[i] != labels[i]) { |
| 215 | + std::cout << "Output labels and expected labels mismatch." << std::endl; |
| 216 | + return false; |
| 217 | + } |
| 218 | + } |
| 219 | + std::cout << "Label check passed!" << std::endl; |
| 220 | + |
| 221 | + // check score_map |
| 222 | + if (res.contain_score_map) { |
| 223 | + auto scores_strs = stringSplit(line, ','); |
| 224 | + std::vector<float> scores; |
| 225 | + for (auto str : scores_strs) { |
| 226 | + scores.push_back(static_cast<float>(atof(str.c_str()))); |
| 227 | + } |
| 228 | + if (scores.size() != res.score_map.size()) { |
| 229 | + std::cout << "Output scores and expected score_map size mismatch." |
| 230 | + <<std::endl; |
| 231 | + return false; |
| 232 | + } |
| 233 | + for (auto i = 0; i < res.score_map.size(); i++) { |
| 234 | + if (abs(res.score_map[i] - scores[i]) > 3e-1) { |
| 235 | + std::cout << "Output scores and expected scores mismatch." |
| 236 | + << std::endl; |
| 237 | + return false; |
| 238 | + } |
| 239 | + } |
| 240 | + } |
| 241 | + return true; |
| 242 | +} |
0 commit comments