diff --git a/onnxoptimizer/pass_registry.h b/onnxoptimizer/pass_registry.h index 7fa910cdf..08b0ea0ab 100644 --- a/onnxoptimizer/pass_registry.h +++ b/onnxoptimizer/pass_registry.h @@ -34,6 +34,7 @@ #include "onnxoptimizer/passes/fuse_matmul_add_bias_into_gemm.h" #include "onnxoptimizer/passes/fuse_pad_into_conv.h" #include "onnxoptimizer/passes/fuse_transpose_into_gemm.h" +#include "onnxoptimizer/passes/fuse_constant_reshape.h" #include "onnxoptimizer/passes/lift_lexical_references.h" #include "onnxoptimizer/passes/nop.h" #include "onnxoptimizer/passes/split.h" @@ -62,6 +63,7 @@ struct GlobalPassRegistry { registerPass(); registerPass(); registerPass(); + registerPass(); registerPass(); registerPass(); registerPass(); diff --git a/onnxoptimizer/passes/fuse_constant_reshape.h b/onnxoptimizer/passes/fuse_constant_reshape.h new file mode 100644 index 000000000..fec887783 --- /dev/null +++ b/onnxoptimizer/passes/fuse_constant_reshape.h @@ -0,0 +1,122 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + + +#pragma once + +// Before: +// B = Reshape(Constant) +// After: +// B = Constant (Constant with new shape) + +#include + +#include "onnx/defs/tensor_util.h" +#include "onnxoptimizer/pass.h" + +namespace ONNX_NAMESPACE { +namespace optimization { + +struct FuseConstantReshape final : public PredicateBasedPass { + explicit FuseConstantReshape() + : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, + PassOptimizationType::Compute) {} + std::string getPassName() const override { + return "fuse_constant_reshape"; + } + + bool patternMatchPredicate(Node* node) override { + return node->kind() == kReshape && node->inputs()[0]->node()->kind() == kConstant; + } + bool runTransform(Node* n, Graph& graph, + NodeDestroyType& destroy_current) override { + destroy_current = NodeDestroyType::DestroyZero; + + // check if Constant is only used by Reshape + if (n->inputs()[0]->uses().size() > 1) { + return false; + } + + Node* reshape = n; + Node* constant = n->inputs()[0]->node(); + + // Process 'reshape' data + std::vector shape; + if (reshape->hasAttribute(kshape)) { + // opset 5 and below + shape = reshape->is(kshape); + } else { + // opset 6 and above - first check if 'reshape' has 'shape' input + // constant + if (reshape->inputs()[1]->node()->kind() != kConstant) { + return false; + } + if (reshape->inputs()[1]->uses().size() > 1) { + return false; + } + Node* shape_const = reshape->inputs()[1]->node(); + Tensor t = shape_const->t(kvalue); + shape = ParseData(&t); + } + + int allow_zero = 0; + Symbol sym = Symbol("allowzero"); + if (reshape->hasAttribute(sym)) { + allow_zero = reshape->i(sym); + } + + Tensor t = constant->t(kvalue); + const auto& ori_size = t.sizes(); + + // process 0 in shape + if (allow_zero != 0) { + for (size_t i = 0; i < shape.size(); ++i) { + if(shape[i] == 0) { + // illegal situation + if (ori_size.size() <= i) { + return false; + } + shape[i] = ori_size[i]; + } + } + } + + // process -1 in shape + int count_of_unkown = 0; + int index_of_unkown = -1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == -1) { + count_of_unkown += 1; + index_of_unkown = i; + } + } + // illegal situtaion + if (count_of_unkown > 1) { + return false; + } + int64_t numel = std::accumulate(ori_size.begin(), ori_size.end(), 1, std::multiplies()); + if (index_of_unkown >= 0) { + int64_t value_of_unkown = -1 * numel / std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + shape[index_of_unkown] = value_of_unkown; + } + + t.sizes().clear(); + t.sizes().insert(t.sizes().begin(), shape.begin(), shape.begin() + shape.size()); + constant->t_(kvalue, std::move(t)); + + // update constant node + constant->output()->setSizes(reshape->output()->sizes()); + constant->output()->setElemType(reshape->output()->elemType()); + const bool replacing_success = tryReplacingAllUsesWith(reshape->output(), reshape->inputs()[0]); + if (!replacing_success) { + return false; + } + destroy_current = NodeDestroyType::DestroyOne; + return true; + + } +}; + +} // namespace optimization +} // namespace ONNX_NAMESPACE