Skip to content

Commit 4dadf17

Browse files
committed
Merge pull request opencv#15168 from dkurt:dnn_onnx_15120
2 parents d3cf0d2 + f9f1604 commit 4dadf17

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

modules/dnn/src/onnx/onnx_importer.cpp

+37-4
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,20 @@ void ONNXImporter::populateNet(Net dstNet)
465465
}
466466
layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
467467
layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
468+
}
469+
else if (layer_type == "Split")
470+
{
471+
DictValue splits = layerParams.get("split");
472+
const int numSplits = splits.size();
473+
CV_Assert(numSplits > 1);
474+
475+
std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
476+
for (int i = 1; i < splits.size() - 1; ++i)
477+
{
478+
slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i - 1);
479+
}
480+
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
481+
layerParams.type = "Slice";
468482
}
469483
else if (layer_type == "Add" || layer_type == "Sum")
470484
{
@@ -486,6 +500,11 @@ void ONNXImporter::populateNet(Net dstNet)
486500
layerParams.type = "Eltwise";
487501
}
488502
}
503+
else if (layer_type == "Max")
504+
{
505+
layerParams.type = "Eltwise";
506+
layerParams.set("operation", "max");
507+
}
489508
else if (layer_type == "Sub")
490509
{
491510
Mat blob = getBlob(node_proto, constBlobs, 1);
@@ -741,6 +760,16 @@ void ONNXImporter::populateNet(Net dstNet)
741760
{
742761
layerParams.type = "Permute";
743762
replaceLayerParam(layerParams, "perm", "order");
763+
764+
CV_Assert(node_proto.input_size() == 1);
765+
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
766+
{
767+
std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), transposed;
768+
runLayer(layerParams, inputs, transposed);
769+
CV_Assert(transposed.size() == 1);
770+
constBlobs.insert(std::make_pair(layerParams.name, transposed[0]));
771+
continue;
772+
}
744773
}
745774
else if (layer_type == "Unsqueeze")
746775
{
@@ -906,8 +935,10 @@ void ONNXImporter::populateNet(Net dstNet)
906935
}
907936

908937
int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
909-
layer_id.insert(std::make_pair(layerParams.name, LayerInfo(id, 0)));
910-
938+
for (int i = 0; i < node_proto.output_size(); ++i)
939+
{
940+
layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
941+
}
911942

912943
std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
913944
for (int j = 0; j < node_proto.input_size(); j++) {
@@ -924,8 +955,10 @@ void ONNXImporter::populateNet(Net dstNet)
924955
// Compute shape of output blob for this layer.
925956
Ptr<Layer> layer = dstNet.getLayer(id);
926957
layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
927-
CV_Assert(!layerOutShapes.empty());
928-
outShapes[layerParams.name] = layerOutShapes[0];
958+
for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
959+
{
960+
outShapes[node_proto.output(i)] = layerOutShapes[i];
961+
}
929962
}
930963
}
931964

modules/dnn/test/test_onnx_importer.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ TEST_P(Test_ONNX_layers, Softmax)
348348
testONNXModels("log_softmax", npy, 0, 0, false, false);
349349
}
350350

351+
TEST_P(Test_ONNX_layers, Split_EltwiseMax)
352+
{
353+
if (backend == DNN_BACKEND_INFERENCE_ENGINE)
354+
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE);
355+
testONNXModels("split_max");
356+
}
357+
351358
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
352359

353360
class Test_ONNX_nets : public Test_ONNX_layers

0 commit comments

Comments
 (0)