Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions source/shape/ShapeInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ class InterpComputer : public SizeComputer {
auto& output = outputs[0]->buffer();
int w = 0;
int h = 0;
int d = 0;
const int inputSize = (int)inputs.size();
auto iw = inputs[0]->width();
auto ih = inputs[0]->height();
auto iw = inputs[0]->dimensions() > 4 ? inputs[0]->buffer().dim[4].extent : inputs[0]->width();
auto ih = inputs[0]->dimensions() > 4 ? inputs[0]->buffer().dim[3].extent : inputs[0]->height();
auto id = inputs[0]->dimensions() > 4 ? inputs[0]->buffer().dim[2].extent : 0;
// copy dims
memcpy(output.dim, input.dim, sizeof(halide_dimension_t) * input.dimensions);
outputs[0]->buffer().dimensions = inputs[0]->dimensions();
Expand Down Expand Up @@ -58,9 +60,13 @@ class InterpComputer : public SizeComputer {
// get output dims
w = interp->outputWidth();
h = interp->outputHeight();
if (w == 0 || h == 0) {
d = interp->outputDepth();
if (w == 0 || h == 0 || (inputs[0]->dimensions() == 5 && d == 0)) {
w = iw * interp->widthScale();
h = ih * interp->heightScale();
if (inputs[0]->dimensions() == 5) {
d = id * interp->depthScale();
}
}
} else {
// For mnn model from tensorflow
Expand Down Expand Up @@ -89,6 +95,11 @@ class InterpComputer : public SizeComputer {
} else {
output.dim[3].extent = w;
output.dim[2].extent = h;
if (inputs[0]->dimensions() == 5) {
output.dim[4].extent = w;
output.dim[3].extent = h;
output.dim[2].extent = d;
}
}
return true;
}
Expand Down Expand Up @@ -118,4 +129,5 @@ class InterpComputer : public SizeComputer {
};

REGISTER_SHAPE_INPUTS(InterpComputer, OpType_Interp, {1});
REGISTER_SHAPE_INPUTS(InterpComputer, OpType_Interp3D, {1});
} // namespace MNN
2 changes: 2 additions & 0 deletions source/shape/ShapeRegister.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ extern void ___ShapeRasterComputer__OpType_Raster__();
extern void ___PriorBoxComputer__OpType_PriorBox__();
extern void ___ShapeBroadcastTo__OpType_BroadcastTo__();
extern void ___InterpComputer__OpType_Interp__();
extern void ___InterpComputer__OpType_Interp3D__();
extern void ___CropSizeComputer__OpType_Crop__();
extern void ___MatMulSizeComputer__OpType_MatMul__();
extern void ___MatMulSizeComputer__OpType_BatchMatMul__();
Expand Down Expand Up @@ -129,6 +130,7 @@ ___ShapeRasterComputer__OpType_Raster__();
___PriorBoxComputer__OpType_PriorBox__();
___ShapeBroadcastTo__OpType_BroadcastTo__();
___InterpComputer__OpType_Interp__();
___InterpComputer__OpType_Interp3D__();
___CropSizeComputer__OpType_Crop__();
___MatMulSizeComputer__OpType_MatMul__();
___MatMulSizeComputer__OpType_BatchMatMul__();
Expand Down
34 changes: 34 additions & 0 deletions test/expr/ExprResizeComputeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
//

#include <MNN/expr/ExprCreator.hpp>
#include "MNN_generated.h"
#include "MNNTestSuite.h"
#include "TestUtils.h"

using namespace MNN::Express;

Expand Down Expand Up @@ -87,6 +89,38 @@ class ExprResizeComputeTest : public MNNTestCase {
}
}
}
{
auto x = _Input({1, 2, 3, 4, 5}, NCHW, halide_type_of<float>());
auto inputPtr = x->writeMap<float>();
for (int i = 0; i < x->getInfo()->size; ++i) {
inputPtr[i] = (float)i;
}
x->unMap();

std::unique_ptr<MNN::InterpT> interp(new MNN::InterpT);
interp->resizeType = 1;
interp->widthScale = 2.0f;
interp->heightScale = 2.0f;
interp->depthScale = 2.0f;

std::unique_ptr<MNN::OpT> op(new MNN::OpT);
op->type = MNN::OpType_Interp3D;
op->main.type = MNN::OpParameter_Interp;
op->main.value = interp.release();

auto y = Variable::create(Expr::create(op.get(), {x}));
auto yShape = y->getInfo();
if (yShape == nullptr) {
return false;
}
const std::vector<int> expectedDim = {1, 2, 6, 8, 10};
if (!checkVector<int>(yShape->dim.data(), expectedDim.data(), 5, 0)) {
return false;
}
if (nullptr == y->readMap<float>()) {
return false;
}
}
return true;
}
};
Expand Down
16 changes: 16 additions & 0 deletions tools/converter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ IF(MNN_BUILD_CONVERTER)
target_link_libraries(MNNDump2Json MNNConvertDeps)
add_executable(TestConvertResult ${CMAKE_CURRENT_LIST_DIR}/source/TestConvertResult.cpp)
target_link_libraries(TestConvertResult MNNConvertDeps)
add_executable(TestOnnxResize
${CMAKE_CURRENT_LIST_DIR}/source/TestOnnxResize.cpp
${COMMON_SRC}
${MNN_CONVERTER_BACKENDS_OBJECTS}
${CMAKE_CURRENT_LIST_DIR}/../../3rd_party/flatbuffers/src/util.cpp
$<TARGET_OBJECTS:MNNUtils>
)
target_link_libraries(TestOnnxResize ${MNN_DEPS} ${Protobuf_LIBRARIES})
add_executable(TestOnnxEinsum
${CMAKE_CURRENT_LIST_DIR}/source/TestOnnxEinsum.cpp
${COMMON_SRC}
${MNN_CONVERTER_BACKENDS_OBJECTS}
${CMAKE_CURRENT_LIST_DIR}/../../3rd_party/flatbuffers/src/util.cpp
$<TARGET_OBJECTS:MNNUtils>
)
target_link_libraries(TestOnnxEinsum ${MNN_DEPS} ${Protobuf_LIBRARIES})
add_executable(TestPassManager ${CMAKE_CURRENT_LIST_DIR}/source/TestPassManager.cpp)
target_link_libraries(TestPassManager MNNConvertDeps)
target_link_libraries(MNNConvert MNNConvertDeps)
Expand Down
Loading
Loading