@@ -465,6 +465,20 @@ void ONNXImporter::populateNet(Net dstNet)
465
465
}
466
466
layerParams.set (" begin" , DictValue::arrayInt (&begin[0 ], begin.size ()));
467
467
layerParams.set (" end" , DictValue::arrayInt (&end[0 ], end.size ()));
468
+ }
469
+ else if (layer_type == " Split" )
470
+ {
471
+ DictValue splits = layerParams.get (" split" );
472
+ const int numSplits = splits.size ();
473
+ CV_Assert (numSplits > 1 );
474
+
475
+ std::vector<int > slicePoints (numSplits - 1 , splits.get <int >(0 ));
476
+ for (int i = 1 ; i < splits.size () - 1 ; ++i)
477
+ {
478
+ slicePoints[i] = slicePoints[i - 1 ] + splits.get <int >(i - 1 );
479
+ }
480
+ layerParams.set (" slice_point" , DictValue::arrayInt (&slicePoints[0 ], slicePoints.size ()));
481
+ layerParams.type = " Slice" ;
468
482
}
469
483
else if (layer_type == " Add" || layer_type == " Sum" )
470
484
{
@@ -486,6 +500,11 @@ void ONNXImporter::populateNet(Net dstNet)
486
500
layerParams.type = " Eltwise" ;
487
501
}
488
502
}
503
+ else if (layer_type == " Max" )
504
+ {
505
+ layerParams.type = " Eltwise" ;
506
+ layerParams.set (" operation" , " max" );
507
+ }
489
508
else if (layer_type == " Sub" )
490
509
{
491
510
Mat blob = getBlob (node_proto, constBlobs, 1 );
@@ -741,6 +760,16 @@ void ONNXImporter::populateNet(Net dstNet)
741
760
{
742
761
layerParams.type = " Permute" ;
743
762
replaceLayerParam (layerParams, " perm" , " order" );
763
+
764
+ CV_Assert (node_proto.input_size () == 1 );
765
+ if (constBlobs.find (node_proto.input (0 )) != constBlobs.end ())
766
+ {
767
+ std::vector<Mat> inputs (1 , getBlob (node_proto, constBlobs, 0 )), transposed;
768
+ runLayer (layerParams, inputs, transposed);
769
+ CV_Assert (transposed.size () == 1 );
770
+ constBlobs.insert (std::make_pair (layerParams.name , transposed[0 ]));
771
+ continue ;
772
+ }
744
773
}
745
774
else if (layer_type == " Unsqueeze" )
746
775
{
@@ -906,8 +935,10 @@ void ONNXImporter::populateNet(Net dstNet)
906
935
}
907
936
908
937
int id = dstNet.addLayer (layerParams.name , layerParams.type , layerParams);
909
- layer_id.insert (std::make_pair (layerParams.name , LayerInfo (id, 0 )));
910
-
938
+ for (int i = 0 ; i < node_proto.output_size (); ++i)
939
+ {
940
+ layer_id.insert (std::make_pair (node_proto.output (i), LayerInfo (id, i)));
941
+ }
911
942
912
943
std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
913
944
for (int j = 0 ; j < node_proto.input_size (); j++) {
@@ -924,8 +955,10 @@ void ONNXImporter::populateNet(Net dstNet)
924
955
// Compute shape of output blob for this layer.
925
956
Ptr <Layer> layer = dstNet.getLayer (id);
926
957
layer->getMemoryShapes (layerInpShapes, 0 , layerOutShapes, layerInternalShapes);
927
- CV_Assert (!layerOutShapes.empty ());
928
- outShapes[layerParams.name ] = layerOutShapes[0 ];
958
+ for (int i = 0 ; i < node_proto.output_size () && i < (int )layerOutShapes.size (); ++i)
959
+ {
960
+ outShapes[node_proto.output (i)] = layerOutShapes[i];
961
+ }
929
962
}
930
963
}
931
964
0 commit comments