diff --git a/source/shape/ShapeInterp.cpp b/source/shape/ShapeInterp.cpp index b4d7bab825..e87e2ec243 100644 --- a/source/shape/ShapeInterp.cpp +++ b/source/shape/ShapeInterp.cpp @@ -22,9 +22,11 @@ class InterpComputer : public SizeComputer { auto& output = outputs[0]->buffer(); int w = 0; int h = 0; + int d = 0; const int inputSize = (int)inputs.size(); - auto iw = inputs[0]->width(); - auto ih = inputs[0]->height(); + auto iw = inputs[0]->dimensions() > 4 ? inputs[0]->buffer().dim[4].extent : inputs[0]->width(); + auto ih = inputs[0]->dimensions() > 4 ? inputs[0]->buffer().dim[3].extent : inputs[0]->height(); + auto id = inputs[0]->dimensions() > 4 ? inputs[0]->buffer().dim[2].extent : 0; // copy dims memcpy(output.dim, input.dim, sizeof(halide_dimension_t) * input.dimensions); outputs[0]->buffer().dimensions = inputs[0]->dimensions(); @@ -58,9 +60,13 @@ class InterpComputer : public SizeComputer { // get output dims w = interp->outputWidth(); h = interp->outputHeight(); - if (w == 0 || h == 0) { + d = interp->outputDepth(); + if (w == 0 || h == 0 || (inputs[0]->dimensions() == 5 && d == 0)) { w = iw * interp->widthScale(); h = ih * interp->heightScale(); + if (inputs[0]->dimensions() == 5) { + d = id * interp->depthScale(); + } } } else { // For mnn model from tensorflow @@ -89,6 +95,11 @@ class InterpComputer : public SizeComputer { } else { output.dim[3].extent = w; output.dim[2].extent = h; + if (inputs[0]->dimensions() == 5) { + output.dim[4].extent = w; + output.dim[3].extent = h; + output.dim[2].extent = d; + } } return true; } @@ -118,4 +129,5 @@ class InterpComputer : public SizeComputer { }; REGISTER_SHAPE_INPUTS(InterpComputer, OpType_Interp, {1}); +REGISTER_SHAPE_INPUTS(InterpComputer, OpType_Interp3D, {1}); } // namespace MNN diff --git a/source/shape/ShapeRegister.cpp b/source/shape/ShapeRegister.cpp index 0c41c8ddb1..c4b30e1b9d 100644 --- a/source/shape/ShapeRegister.cpp +++ b/source/shape/ShapeRegister.cpp @@ -5,6 +5,7 @@ extern void ___ShapeRasterComputer__OpType_Raster__(); extern void ___PriorBoxComputer__OpType_PriorBox__(); extern void ___ShapeBroadcastTo__OpType_BroadcastTo__(); extern void ___InterpComputer__OpType_Interp__(); +extern void ___InterpComputer__OpType_Interp3D__(); extern void ___CropSizeComputer__OpType_Crop__(); extern void ___MatMulSizeComputer__OpType_MatMul__(); extern void ___MatMulSizeComputer__OpType_BatchMatMul__(); @@ -129,6 +130,7 @@ ___ShapeRasterComputer__OpType_Raster__(); ___PriorBoxComputer__OpType_PriorBox__(); ___ShapeBroadcastTo__OpType_BroadcastTo__(); ___InterpComputer__OpType_Interp__(); +___InterpComputer__OpType_Interp3D__(); ___CropSizeComputer__OpType_Crop__(); ___MatMulSizeComputer__OpType_MatMul__(); ___MatMulSizeComputer__OpType_BatchMatMul__(); diff --git a/test/expr/ExprResizeComputeTest.cpp b/test/expr/ExprResizeComputeTest.cpp index 23c1d7d9ed..5de1e68f2d 100644 --- a/test/expr/ExprResizeComputeTest.cpp +++ b/test/expr/ExprResizeComputeTest.cpp @@ -7,7 +7,9 @@ // #include +#include "MNN_generated.h" #include "MNNTestSuite.h" +#include "TestUtils.h" using namespace MNN::Express; @@ -87,6 +89,38 @@ class ExprResizeComputeTest : public MNNTestCase { } } } + { + auto x = _Input({1, 2, 3, 4, 5}, NCHW, halide_type_of()); + auto inputPtr = x->writeMap(); + for (int i = 0; i < x->getInfo()->size; ++i) { + inputPtr[i] = (float)i; + } + x->unMap(); + + std::unique_ptr interp(new MNN::InterpT); + interp->resizeType = 1; + interp->widthScale = 2.0f; + interp->heightScale = 2.0f; + interp->depthScale = 2.0f; + + std::unique_ptr op(new MNN::OpT); + op->type = MNN::OpType_Interp3D; + op->main.type = MNN::OpParameter_Interp; + op->main.value = interp.release(); + + auto y = Variable::create(Expr::create(op.get(), {x})); + auto yShape = y->getInfo(); + if (yShape == nullptr) { + return false; + } + const std::vector expectedDim = {1, 2, 6, 8, 10}; + if (!checkVector(yShape->dim.data(), expectedDim.data(), 5, 0)) { + return false; + } + if (nullptr == y->readMap()) { + return false; + } + } return true; } }; diff --git a/tools/converter/CMakeLists.txt b/tools/converter/CMakeLists.txt index 07a3da727a..1e757c9f82 100644 --- a/tools/converter/CMakeLists.txt +++ b/tools/converter/CMakeLists.txt @@ -76,6 +76,22 @@ IF(MNN_BUILD_CONVERTER) target_link_libraries(MNNDump2Json MNNConvertDeps) add_executable(TestConvertResult ${CMAKE_CURRENT_LIST_DIR}/source/TestConvertResult.cpp) target_link_libraries(TestConvertResult MNNConvertDeps) + add_executable(TestOnnxResize + ${CMAKE_CURRENT_LIST_DIR}/source/TestOnnxResize.cpp + ${COMMON_SRC} + ${MNN_CONVERTER_BACKENDS_OBJECTS} + ${CMAKE_CURRENT_LIST_DIR}/../../3rd_party/flatbuffers/src/util.cpp + $ + ) + target_link_libraries(TestOnnxResize ${MNN_DEPS} ${Protobuf_LIBRARIES}) + add_executable(TestOnnxEinsum + ${CMAKE_CURRENT_LIST_DIR}/source/TestOnnxEinsum.cpp + ${COMMON_SRC} + ${MNN_CONVERTER_BACKENDS_OBJECTS} + ${CMAKE_CURRENT_LIST_DIR}/../../3rd_party/flatbuffers/src/util.cpp + $ + ) + target_link_libraries(TestOnnxEinsum ${MNN_DEPS} ${Protobuf_LIBRARIES}) add_executable(TestPassManager ${CMAKE_CURRENT_LIST_DIR}/source/TestPassManager.cpp) target_link_libraries(TestPassManager MNNConvertDeps) target_link_libraries(MNNConvert MNNConvertDeps) diff --git a/tools/converter/source/TestOnnxEinsum.cpp b/tools/converter/source/TestOnnxEinsum.cpp new file mode 100644 index 0000000000..e769cb3233 --- /dev/null +++ b/tools/converter/source/TestOnnxEinsum.cpp @@ -0,0 +1,265 @@ +#include +#include +#include +#include +#include +#include +#include +#include "MNN_generated.h" +#include "cli.hpp" +#include "onnx.pb.h" + +static void addTensorShape(onnx::ValueInfoProto* valueInfo, const std::string& name, const std::vector& dims) { + valueInfo->set_name(name); + auto* tensorType = valueInfo->mutable_type()->mutable_tensor_type(); + tensorType->set_elem_type(onnx::TensorProto_DataType_FLOAT); + auto* shape = tensorType->mutable_shape(); + for (int dim : dims) { + shape->add_dim()->set_dim_value(dim); + } +} + +static void addFloatInitializer(onnx::GraphProto* graph, const std::string& name, const std::vector& dims, + const std::vector& values) { + auto* tensor = graph->add_initializer(); + tensor->set_name(name); + tensor->set_data_type(onnx::TensorProto_DataType_FLOAT); + for (auto dim : dims) { + tensor->add_dims(dim); + } + for (auto value : values) { + tensor->add_float_data(value); + } +} + +static void addInt64Initializer(onnx::GraphProto* graph, const std::string& name, const std::vector& dims, + const std::vector& values) { + auto* tensor = graph->add_initializer(); + tensor->set_name(name); + tensor->set_data_type(onnx::TensorProto_DataType_INT64); + for (auto dim : dims) { + tensor->add_dims(dim); + } + for (auto value : values) { + tensor->add_int64_data(value); + } +} + +static bool saveModel(const onnx::ModelProto& model, const std::string& fileName) { + std::ofstream output(fileName, std::ios::binary | std::ios::trunc); + return model.SerializeToOstream(&output); +} + +static bool compareVector(const float* got, const std::vector& expected, float tolerance = 1e-5f) { + for (size_t i = 0; i < expected.size(); ++i) { + if (std::fabs(got[i] - expected[i]) > tolerance) { + std::fprintf(stderr, "mismatch at %zu, expect=%f, got=%f\n", i, expected[i], got[i]); + return false; + } + } + return true; +} + +static bool runMNNModel(const std::string& modelPath, const std::vector>>& inputs, + const std::vector& expectedOutput) { + std::unique_ptr net(MNN::Interpreter::createFromFile(modelPath.c_str())); + if (!net) { + return false; + } + MNN::ScheduleConfig config; + config.type = MNN_FORWARD_CPU; + auto session = net->createSession(config); + if (!session) { + return false; + } + for (const auto& item : inputs) { + auto* inputTensor = net->getSessionInput(session, item.first.c_str()); + if (!inputTensor) { + return false; + } + MNN::Tensor hostTensor(inputTensor, inputTensor->getDimensionType()); + ::memcpy(hostTensor.host(), item.second.data(), item.second.size() * sizeof(float)); + inputTensor->copyFromHostTensor(&hostTensor); + } + if (net->runSession(session) != MNN::NO_ERROR) { + return false; + } + auto* outputTensor = net->getSessionOutput(session, nullptr); + if (!outputTensor) { + return false; + } + MNN::Tensor hostOutput(outputTensor, outputTensor->getDimensionType()); + outputTensor->copyToHostTensor(&hostOutput); + return compareVector(hostOutput.host(), expectedOutput); +} + +static bool convertOnnx(const std::string& onnxModel, const std::string& mnnModel) { + modelConfig config; + config.model = modelConfig::ONNX; + config.modelFile = onnxModel; + config.MNNModel = mnnModel; + config.keepInputFormat = true; + return MNN::Cli::convertModel(config); +} + +static onnx::ModelProto makeConcatEinsumModel() { + onnx::ModelProto model; + model.set_ir_version(8); + model.mutable_opset_import()->Add()->set_version(13); + auto* graph = model.mutable_graph(); + graph->set_name("ConcatEinsum"); + + addTensorShape(graph->add_input(), "x", {2, 3}); + addTensorShape(graph->add_input(), "y", {2, 3}); + addTensorShape(graph->add_output(), "out", {2, 3}); + + addFloatInitializer(graph, "weight", {2}, {1.5f, -0.5f}); + addInt64Initializer(graph, "axes", {1}, {0}); + + auto* unsqueezeX = graph->add_node(); + unsqueezeX->set_op_type("Unsqueeze"); + unsqueezeX->add_input("x"); + unsqueezeX->add_input("axes"); + unsqueezeX->add_output("x_unsqueezed"); + + auto* unsqueezeY = graph->add_node(); + unsqueezeY->set_op_type("Unsqueeze"); + unsqueezeY->add_input("y"); + unsqueezeY->add_input("axes"); + unsqueezeY->add_output("y_unsqueezed"); + + auto* concat = graph->add_node(); + concat->set_op_type("Concat"); + concat->add_input("x_unsqueezed"); + concat->add_input("y_unsqueezed"); + concat->add_output("stacked"); + auto* axis = concat->add_attribute(); + axis->set_name("axis"); + axis->set_i(0); + + auto* einsum = graph->add_node(); + einsum->set_op_type("Einsum"); + einsum->add_input("weight"); + einsum->add_input("stacked"); + einsum->add_output("out"); + auto* equation = einsum->add_attribute(); + equation->set_name("equation"); + equation->set_s("i,i...->..."); + return model; +} + +static onnx::ModelProto makeReduceEinsumModel() { + onnx::ModelProto model; + model.set_ir_version(8); + model.mutable_opset_import()->Add()->set_version(13); + auto* graph = model.mutable_graph(); + graph->set_name("ReduceEinsum"); + + addTensorShape(graph->add_input(), "stacked", {2, 2, 3}); + addTensorShape(graph->add_output(), "out", {2, 3}); + addFloatInitializer(graph, "weight", {2}, {1.5f, -0.5f}); + + auto* einsum = graph->add_node(); + einsum->set_op_type("Einsum"); + einsum->add_input("weight"); + einsum->add_input("stacked"); + einsum->add_output("out"); + auto* equation = einsum->add_attribute(); + equation->set_name("equation"); + equation->set_s("i,i...->..."); + return model; +} + +static onnx::ModelProto makeConcat3EinsumModel() { + onnx::ModelProto model; + model.set_ir_version(8); + model.mutable_opset_import()->Add()->set_version(13); + auto* graph = model.mutable_graph(); + graph->set_name("Concat3Einsum"); + + addTensorShape(graph->add_input(), "x", {2, 3}); + addTensorShape(graph->add_input(), "y", {2, 3}); + addTensorShape(graph->add_input(), "z", {2, 3}); + addTensorShape(graph->add_output(), "out", {2, 3}); + + addFloatInitializer(graph, "weight", {3}, {1.0f, -2.0f, 0.5f}); + addInt64Initializer(graph, "axes", {1}, {0}); + + auto* unsqueezeX = graph->add_node(); + unsqueezeX->set_op_type("Unsqueeze"); + unsqueezeX->add_input("x"); + unsqueezeX->add_input("axes"); + unsqueezeX->add_output("x_unsqueezed"); + + auto* unsqueezeY = graph->add_node(); + unsqueezeY->set_op_type("Unsqueeze"); + unsqueezeY->add_input("y"); + unsqueezeY->add_input("axes"); + unsqueezeY->add_output("y_unsqueezed"); + + auto* unsqueezeZ = graph->add_node(); + unsqueezeZ->set_op_type("Unsqueeze"); + unsqueezeZ->add_input("z"); + unsqueezeZ->add_input("axes"); + unsqueezeZ->add_output("z_unsqueezed"); + + auto* concat = graph->add_node(); + concat->set_op_type("Concat"); + concat->add_input("x_unsqueezed"); + concat->add_input("y_unsqueezed"); + concat->add_input("z_unsqueezed"); + concat->add_output("stacked"); + auto* axis = concat->add_attribute(); + axis->set_name("axis"); + axis->set_i(0); + + auto* einsum = graph->add_node(); + einsum->set_op_type("Einsum"); + einsum->add_input("weight"); + einsum->add_input("stacked"); + einsum->add_output("out"); + auto* equation = einsum->add_attribute(); + equation->set_name("equation"); + equation->set_s("i,i...->..."); + return model; +} + +int main() { + const std::string concatOnnx = "/tmp/mnn_concat_einsum.onnx"; + const std::string concatMnn = "/tmp/mnn_concat_einsum.mnn"; + const std::string reduceOnnx = "/tmp/mnn_reduce_einsum.onnx"; + const std::string reduceMnn = "/tmp/mnn_reduce_einsum.mnn"; + const std::string concat3Onnx = "/tmp/mnn_concat3_einsum.onnx"; + const std::string concat3Mnn = "/tmp/mnn_concat3_einsum.mnn"; + + bool ok = saveModel(makeConcatEinsumModel(), concatOnnx); + ok = saveModel(makeReduceEinsumModel(), reduceOnnx) && ok; + ok = saveModel(makeConcat3EinsumModel(), concat3Onnx) && ok; + + ok = convertOnnx(concatOnnx, concatMnn) && ok; + ok = convertOnnx(reduceOnnx, reduceMnn) && ok; + ok = convertOnnx(concat3Onnx, concat3Mnn) && ok; + + const std::vector x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + const std::vector y = {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; + const std::vector expected = {-1.5f, 0.5f, 2.5f, 4.5f, 6.5f, 8.5f}; + ok = runMNNModel(concatMnn, {{"x", x}, {"y", y}}, expected) && ok; + + const std::vector stacked = { + 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f + }; + ok = runMNNModel(reduceMnn, {{"stacked", stacked}}, expected) && ok; + + const std::vector z = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f}; + const std::vector expected3 = {-11.0f, -7.5f, -5.0f, -1.5f, 1.0f, 4.5f}; + ok = runMNNModel(concat3Mnn, {{"x", x}, {"y", y}, {"z", z}}, expected3) && ok; + + ::remove(concatOnnx.c_str()); + ::remove(concatMnn.c_str()); + ::remove(reduceOnnx.c_str()); + ::remove(reduceMnn.c_str()); + ::remove(concat3Onnx.c_str()); + ::remove(concat3Mnn.c_str()); + return ok ? 0 : 1; +} diff --git a/tools/converter/source/TestOnnxResize.cpp b/tools/converter/source/TestOnnxResize.cpp new file mode 100644 index 0000000000..349bea59fa --- /dev/null +++ b/tools/converter/source/TestOnnxResize.cpp @@ -0,0 +1,166 @@ +#include +#include +#include +#include +#include "MNN_generated.h" +#include "onnx.pb.h" +#include "onnxConverter.hpp" + +static void addTensorShape(onnx::ValueInfoProto* valueInfo, const std::string& name, const std::vector& dims) { + valueInfo->set_name(name); + auto* tensorType = valueInfo->mutable_type()->mutable_tensor_type(); + tensorType->set_elem_type(onnx::TensorProto_DataType_FLOAT); + auto* shape = tensorType->mutable_shape(); + for (int dim : dims) { + shape->add_dim()->set_dim_value(dim); + } +} + +static void addFloatInitializer(onnx::GraphProto* graph, const std::string& name, const std::vector& dims, + const std::vector& values) { + auto* tensor = graph->add_initializer(); + tensor->set_name(name); + tensor->set_data_type(onnx::TensorProto_DataType_FLOAT); + for (auto dim : dims) { + tensor->add_dims(dim); + } + for (auto value : values) { + tensor->add_float_data(value); + } +} + +static void addIntInitializer(onnx::GraphProto* graph, const std::string& name, const std::vector& dims, + const std::vector& values) { + auto* tensor = graph->add_initializer(); + tensor->set_name(name); + tensor->set_data_type(onnx::TensorProto_DataType_INT32); + for (auto dim : dims) { + tensor->add_dims(dim); + } + for (auto value : values) { + tensor->add_int32_data(value); + } +} + +static std::string writeModel(const onnx::ModelProto& model, const std::string& fileName) { + std::ofstream output(fileName, std::ios::binary | std::ios::trunc); + model.SerializeToOstream(&output); + return fileName; +} + +static std::unique_ptr makeMetaOp() { + std::unique_ptr meta(new MNN::OpT); + meta->type = MNN::OpType_Extra; + meta->main.type = MNN::OpParameter_Extra; + meta->main.value = new MNN::ExtraT; + meta->main.AsExtra()->type = "Meta"; + meta->main.AsExtra()->engine = "MNN"; + return meta; +} + +static MNN::OpT* findOp(MNN::NetT* net, const std::string& name) { + for (auto& op : net->oplists) { + if (op->name == name) { + return op.get(); + } + } + return nullptr; +} + +static bool runConvert(const std::string& modelPath, const std::string& opName, MNN::OpType expectedType, + int expectedInputs) { + std::unique_ptr net(new MNN::NetT); + auto meta = makeMetaOp(); + std::vector inputNames; + if (onnx2MNNNet(modelPath, "MNN", net, meta.get(), inputNames) != 0) { + return false; + } + auto* resizeOp = findOp(net.get(), opName); + if (resizeOp == nullptr) { + return false; + } + if (resizeOp->type != expectedType) { + return false; + } + if ((int)resizeOp->inputIndexes.size() != expectedInputs) { + return false; + } + return true; +} + +static onnx::ModelProto makeResizeModel(const std::vector& inputShape, bool useSizes) { + onnx::ModelProto model; + model.set_ir_version(8); + model.mutable_opset_import()->Add()->set_version(16); + auto* graph = model.mutable_graph(); + graph->set_name("ResizeTest"); + + addTensorShape(graph->add_input(), "input", inputShape); + addTensorShape(graph->add_output(), "output", inputShape); + + auto* node = graph->add_node(); + node->set_op_type("Resize"); + node->add_input("input"); + node->add_input(""); + if (useSizes) { + node->add_input(""); + node->add_input("sizes"); + } else { + node->add_input("scales"); + } + node->add_output("resize_node"); + auto* attr = node->add_attribute(); + attr->set_name("mode"); + attr->set_s("nearest"); + + if (useSizes) { + std::vector sizes(inputShape.begin(), inputShape.end()); + if (inputShape.size() == 3) { + sizes[2] *= 2; + } else if (inputShape.size() == 5) { + sizes[2] *= 2; + sizes[3] *= 2; + sizes[4] *= 2; + } + addIntInitializer(graph, "sizes", {(int64_t)sizes.size()}, sizes); + } else { + std::vector scales(inputShape.size(), 1.0f); + if (inputShape.size() == 3) { + scales[2] = 2.0f; + } else if (inputShape.size() == 5) { + scales[2] = 2.0f; + scales[3] = 2.0f; + scales[4] = 2.0f; + } + addFloatInitializer(graph, "scales", {(int64_t)scales.size()}, scales); + } + return model; +} + +int main() { + const std::string rank3Scales = "/tmp/mnn_resize_rank3_scales.onnx"; + const std::string rank3Sizes = "/tmp/mnn_resize_rank3_sizes.onnx"; + const std::string rank4Scales = "/tmp/mnn_resize_rank4_scales.onnx"; + const std::string rank4Sizes = "/tmp/mnn_resize_rank4_sizes.onnx"; + const std::string rank5Scales = "/tmp/mnn_resize_rank5_scales.onnx"; + + writeModel(makeResizeModel({2, 3, 5}, false), rank3Scales); + writeModel(makeResizeModel({2, 3, 5}, true), rank3Sizes); + writeModel(makeResizeModel({1, 2, 3, 4}, false), rank4Scales); + writeModel(makeResizeModel({1, 2, 3, 4}, true), rank4Sizes); + writeModel(makeResizeModel({1, 2, 3, 4, 5}, false), rank5Scales); + + bool ok = true; + ok = runConvert(rank3Scales, "resize_node", MNN::OpType_Interp, 2) && ok; + ok = runConvert(rank3Sizes, "resize_node", MNN::OpType_Interp, 2) && ok; + ok = runConvert(rank4Scales, "resize_node", MNN::OpType_Interp, 1) && ok; + ok = runConvert(rank4Sizes, "resize_node", MNN::OpType_Interp, 1) && ok; + ok = runConvert(rank5Scales, "resize_node", MNN::OpType_Interp3D, 1) && ok; + + ::remove(rank3Scales.c_str()); + ::remove(rank3Sizes.c_str()); + ::remove(rank4Scales.c_str()); + ::remove(rank4Sizes.c_str()); + ::remove(rank5Scales.c_str()); + return ok ? 0 : 1; +} diff --git a/tools/converter/source/onnx/ResizeOnnx.cpp b/tools/converter/source/onnx/ResizeOnnx.cpp new file mode 100644 index 0000000000..c5c91f848f --- /dev/null +++ b/tools/converter/source/onnx/ResizeOnnx.cpp @@ -0,0 +1,138 @@ +#include +#include "onnxOpConverter.hpp" +#include "logkit.h" + +DECLARE_OP_CONVERTER(ResizeOnnx); + +MNN::OpType ResizeOnnx::opType() { + return MNN::OpType_Interp; +} + +MNN::OpParameter ResizeOnnx::type() { + return MNN::OpParameter_Interp; +} + +void ResizeOnnx::run(MNN::OpT* dstOp, const onnx::NodeProto* onnxNode, + OnnxScope* scope) { + std::unique_ptr resizeParam(new MNN::InterpT); + bool bakedScales = false; + bool bakedSizes = false; + bool is3DResize = false; + std::string resizeMode = ""; + std::string coordMode = "half_pixel"; + std::string nearestMode = "round_prefer_floor"; + float cubicFactor = -0.75f; + for (int i = 0; i < onnxNode->attribute_size(); ++i) { + const auto& attr = onnxNode->attribute(i); + const auto& key = attr.name(); + if (key == "mode") { + resizeMode = attr.s(); + } else if (key == "coordinate_transformation_mode") { + coordMode = attr.s(); + } else if (key == "nearest_mode") { + nearestMode = attr.s(); + } else if (key == "cubic_coeff_a") { + cubicFactor = attr.f(); + } + } + + if (resizeMode == "nearest") { + if (nearestMode == "round_prefer_floor") { + resizeParam->resizeType = 4; + } else if (nearestMode == "floor") { + resizeParam->resizeType = 1; + } else { + LOG(ERROR) << "Don't support " << nearestMode << " nearest mode, use round_prefer_floor instead"; + resizeParam->resizeType = 4; + } + } else if (resizeMode == "bilinear" || resizeMode == "linear") { + resizeParam->resizeType = 2; + } else if (resizeMode == "cubic") { + resizeParam->resizeType = 3; + resizeParam->cubicCoeffA = cubicFactor; + } else { + LOG(ERROR) << "Unsupported Resize mode " << resizeMode << ", use bilinear instead"; + resizeParam->resizeType = 2; + } + + resizeParam->alignCorners = (coordMode == "align_corners"); + resizeParam->halfPixelCenters = (coordMode == "half_pixel"); +#define SET_MODE(str, c) \ + if (coordMode == str) \ + resizeParam->ctm = MNN::CoordinateTransformationMode_##c + SET_MODE("align_corners", AlignCorners); + SET_MODE("half_pixel", HalfPixels); + SET_MODE("pytorch_half_pixel", PytorchHalfPixels); + SET_MODE("tf_half_pixel_for_nn", TensorflowHalfPixels); + SET_MODE("tf_crop_and_resize", TensorflowCropAndResize); + SET_MODE("asymmetric", Asymmetric); +#undef SET_MODE + + // If scales / sizes are constant, bake them into Interp and drop extra shape input. + if (onnxNode->input_size() >= 3) { + const auto& scalesName = onnxNode->input(2); + auto iter = scope->mInitializers.find(scalesName); + if (iter != scope->mInitializers.end()) { + std::unique_ptr tempOp(new MNN::OpT); + std::unique_ptr blob(onnxOpConverter::convertTensorToBlob(iter->second, scope->mModelDir, tempOp.get())); + if (blob && !blob->float32s.empty()) { + auto size = (int)blob->float32s.size(); + if (size == 4 || size == 5) { + is3DResize = is3DResize || size == 5; + resizeParam->heightScale = blob->float32s[size - 2]; + resizeParam->widthScale = blob->float32s[size - 1]; + if (size == 5) { + resizeParam->depthScale = blob->float32s[2]; + } + bakedScales = true; + } + } + } + } + if (onnxNode->input_size() >= 4) { + const auto& sizesName = onnxNode->input(3); + auto iter = scope->mInitializers.find(sizesName); + if (iter != scope->mInitializers.end()) { + std::unique_ptr tempOp(new MNN::OpT); + std::unique_ptr blob(onnxOpConverter::convertTensorToBlob(iter->second, scope->mModelDir, tempOp.get())); + if (blob && !blob->int32s.empty()) { + auto size = (int)blob->int32s.size(); + if (size == 4 || size == 5) { + is3DResize = is3DResize || size == 5; + if (size == 5) { + resizeParam->outputDepth = blob->int32s[2]; + } + resizeParam->outputHeight = blob->int32s[size - 2]; + resizeParam->outputWidth = blob->int32s[size - 1]; + bakedSizes = true; + } + } + } + } + + std::vector inputIndexes; + auto dataIndex = scope->lookupTensor(onnxNode->input(0)); + if (dataIndex >= 0) { + inputIndexes.emplace_back(dataIndex); + } + if (!bakedSizes && onnxNode->input_size() >= 4 && !onnxNode->input(3).empty()) { + auto sizesIndex = scope->lookupTensor(onnxNode->input(3)); + if (sizesIndex >= 0) { + inputIndexes.emplace_back(sizesIndex); + } + } else if (!bakedScales && onnxNode->input_size() >= 3 && !onnxNode->input(2).empty()) { + auto scalesIndex = scope->lookupTensor(onnxNode->input(2)); + if (scalesIndex >= 0) { + inputIndexes.emplace_back(scalesIndex); + } + } + dstOp->inputIndexes = std::move(inputIndexes); + if (is3DResize) { + dstOp->type = MNN::OpType_Interp3D; + } + + dstOp->main.value = resizeParam.release(); + dstOp->defaultDimentionFormat = MNN::MNN_DATA_FORMAT_NCHW; +} + +REGISTER_CONVERTER(ResizeOnnx, Resize); diff --git a/tools/converter/source/onnx/onnxConverter.cpp b/tools/converter/source/onnx/onnxConverter.cpp index 77efa8f256..8290bacde4 100644 --- a/tools/converter/source/onnx/onnxConverter.cpp +++ b/tools/converter/source/onnx/onnxConverter.cpp @@ -147,10 +147,13 @@ int onnx2MNNNet(const std::string inputModel, const std::string bizCode, int inputIdx = scope->lookupTensor(onnxNode.input(k)); if (inputIdx < 0) { LOG(INFO) << "Check it out ==> " << MNNOp->name << " has empty input, the index is " << k; + if (opType == "Resize") { + continue; + } } MNNOp->inputIndexes.push_back(inputIdx); } - for (int k = onnxNode.input_size() - 1; k >= 0 && MNNOp->inputIndexes[k] < 0; --k) { + for (int k = (int)MNNOp->inputIndexes.size() - 1; k >= 0 && MNNOp->inputIndexes[k] < 0; --k) { MNNOp->inputIndexes.pop_back(); } for (int k = 0; k < onnxNode.output_size(); k++) { diff --git a/tools/converter/source/optimizer/onnxextra/OnnxEinsum.cpp b/tools/converter/source/optimizer/onnxextra/OnnxEinsum.cpp index c5f63ebd45..4d0b6a4736 100644 --- a/tools/converter/source/optimizer/onnxextra/OnnxEinsum.cpp +++ b/tools/converter/source/optimizer/onnxextra/OnnxEinsum.cpp @@ -33,6 +33,7 @@ class OnnxEinsumTransform : public OnnxExtraManager::Transform { MNN_ERROR("Can't convert Einsum for invalid Equation\n"); return nullptr; } + std::string rawEquation = equation; // Turn ... to . bool hasPrefix = false; { @@ -45,13 +46,21 @@ class OnnxEinsumTransform : public OnnxExtraManager::Transform { } // Remove space std::vector valid; + std::vector rawValid; for (int i=0; i"); if (pos == std::string::npos) { MNN_ERROR("Can't convert Einsum for no support Equation:%s\n", equation.c_str()); @@ -59,8 +68,8 @@ class OnnxEinsumTransform : public OnnxExtraManager::Transform { } auto left = equation.substr(0, pos); auto right = equation.substr(pos+2, equation.size()); - if (expr->inputs().size() == 1 ){ - auto currentVar = expr->inputs()[0]; + if (inputs.size() == 1 ){ + auto currentVar = inputs[0]; std::map outputPos; for (int i=0; iinputs()[0]; - auto var1 = expr->inputs()[1]; + auto var0 = inputs[0]; + auto var1 = inputs[1]; + if (rawEquation == "i,i...->...") { + auto concatExpr = var1->expr().first; + auto weightExpr = var0->expr().first; + if (concatExpr != nullptr && weightExpr != nullptr && weightExpr->get() != nullptr) { + auto concatOp = concatExpr->get(); + auto weightOp = weightExpr->get(); + if (concatOp != nullptr && concatOp->type() == OpType_Concat && concatOp->main_as_Axis() != nullptr && + concatOp->main_as_Axis()->axis() == 0 && weightOp->type() == OpType_Const && weightOp->main_as_Blob() != nullptr) { + auto weightBlob = weightOp->main_as_Blob(); + auto weightPtr = weightBlob->float32s() != nullptr ? weightBlob->float32s()->data() : nullptr; + auto terms = concatExpr->inputs(); + int weightSize = 1; + if (weightBlob->dims() != nullptr) { + for (int i = 0; i < weightBlob->dims()->size(); ++i) { + weightSize *= weightBlob->dims()->data()[i]; + } + } + if (weightPtr != nullptr && weightSize == terms.size()) { + VARP output; + for (int i = 0; i < terms.size(); ++i) { + auto term = terms[i]; + auto termExpr = term->expr().first; + if (termExpr != nullptr) { + auto termOp = termExpr->get(); + if (termOp != nullptr && termOp->type() == OpType_Unsqueeze && !termExpr->inputs().empty()) { + term = termExpr->inputs()[0]; + } + } + auto current = term * _Scalar(weightPtr[i]); + output = (output == nullptr) ? current : (output + current); + } + if (output.get() != nullptr) { + output->setName(expr->name()); + return output->expr().first; + } + } + } + } + VARP scale; + auto input0Info = var0->getInfo(); + auto input1Info = var1->getInfo(); + if (input0Info != nullptr && input1Info != nullptr) { + std::vector scaleShape = input0Info->dim; + scaleShape.resize(input1Info->dim.size(), 1); + scale = _Reshape(var0, scaleShape); + } else { + auto one = _Unsqueeze(_Scalar(1), {0}); + auto rank = _Rank(var1); + auto ones = _Fill(_Unsqueeze(rank - _Scalar(1), {0}), one); + auto dynamicShape = _Concat({_Shape(var0, NCHW), ones}, 0); + scale = _Reshape(var0, dynamicShape); + } + auto output = _ReduceSum(scale * var1, {0}, false); + output->setName(expr->name()); + return output->expr().first; + } // dim = 4 if (right.size() == 4 && input0.size() == 4 && input1.size() == 4) { // batch align: