Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 61d56b8

Browse files
authored
Make INTERPRETER fully support dynamic (#4944)
* revert to old interpreter * Cleanup * more cleanup * cleanup * checkpoint * checkpoint * Builds and runs * Passing tests * All tests pass * Change to kick off build * Add all ops to interpreter * Enable the option for dynamic support on interpreter * Fix clamp API * Fix compile error * Fix mlir * Update manifest * Update manifest * Add support for dynamic clamp * checkpoint * Fix merge mistakes * Fix clamp * Cleanup debug * binary elementwise working * dynamic_abc working * Fix resetting output tensor shapes * Move resetting of output tensor partial shapes until after validated * Reset output Tensor partial shapes in INTERPRETER call method * cleanup * checkpoint * Revert testing * add hyperbolic functions * Add support for v1::Reshape * Add support for ReduceMax * Add support for compute_output_shape to reduction ops * Cleanup some dynamic helper functions * Add reduction ops * Cleanup * Add a reset() method to reset a tensor originally created dynamic back to its original state * cleanup * Revert set output change * Revert change * Revert change * Update manifest * Add function to compute binary elementwise output size * More ops made dynamic * Fix migration error * cleanup * Add support for v1::Transpose * Checkpoint * Set manifest * Fix strided slice set_output * Fix some set types calls when output shape is unknown * Fix dynslice * Remove obsolete method * Remove obsolete method * re-enable op downgrading to get some unsupported op unit tests to pass
1 parent b33c316 commit 61d56b8

32 files changed

+879
-503
lines changed

src/ngraph/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,8 @@ set (SRC
486486
op/util/arithmetic_reductions_keep_dims.hpp
487487
op/util/attr_types.cpp
488488
op/util/attr_types.hpp
489+
op/util/binary_elementwise.cpp
490+
op/util/binary_elementwise.hpp
489491
op/util/binary_elementwise_arithmetic.cpp
490492
op/util/binary_elementwise_arithmetic.hpp
491493
op/util/binary_elementwise_comparison.cpp

src/ngraph/op/concat.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ void op::v0::Concat::validate_and_infer_types()
9898
}
9999
else
100100
{
101-
set_output_type(0, inputs_et, PartialShape::dynamic(concatenation_axis_output_dim));
101+
set_output_type(0, inputs_et, PartialShape::dynamic());
102102
}
103103
}
104104

src/ngraph/op/experimental/dyn_slice.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void op::v0::DynSlice::validate_and_infer_types()
107107
}
108108
else
109109
{
110-
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(arg_shape.rank()));
110+
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
111111
}
112112
}
113113

src/ngraph/op/experimental/dyn_slice.hpp

-3
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ namespace ngraph
7070
const OutputVector& deltas) override;
7171

7272
private:
73-
/// Helper method to compute output shape
74-
Shape compute_output_shape() const;
75-
7673
AxisSet m_lower_bounds_mask;
7774
AxisSet m_upper_bounds_mask;
7875
AxisSet m_new_axis;

src/ngraph/op/reshape.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ void op::v1::Reshape::validate_and_infer_types()
369369
}
370370
else
371371
{
372-
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
372+
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
373373
}
374374
}
375375

src/ngraph/op/strided_slice.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ void op::v1::StridedSlice::validate_and_infer_types()
197197
}
198198
else
199199
{
200-
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(data_rank));
200+
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
201201
}
202202
}
203203

src/ngraph/op/strided_slice.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ namespace ngraph
105105
bool evaluate(const HostTensorVector& outputs,
106106
const HostTensorVector& inputs) const override;
107107

108+
AxisSet convert_mask_to_axis_set(const std::vector<int64_t>& mask) const;
109+
108110
protected:
109111
void generate_adjoints(autodiff::Adjoints& adjoints,
110112
const OutputVector& deltas) override;
111113

112114
private:
113-
AxisSet convert_mask_to_axis_set(const std::vector<int64_t>& mask) const;
114-
115115
std::vector<int64_t> m_begin_mask;
116116
std::vector<int64_t> m_end_mask;
117117
std::vector<int64_t> m_new_axis_mask;

src/ngraph/op/util/arithmetic_reductions_keep_dims.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,23 @@ void op::util::ArithmeticReductionKeepDims::validate_and_infer_types()
9797
ArithmeticReduction::validate_and_infer_types();
9898
}
9999
}
100+
101+
Shape op::util::ArithmeticReductionKeepDims::compute_output_shape(
102+
const Shape& input_shape, const AxisSet& reduction_axes) const
103+
{
104+
Shape output_shape;
105+
size_t index = 0;
106+
for (auto dim : input_shape)
107+
{
108+
if (reduction_axes.find(index) == reduction_axes.end())
109+
{
110+
output_shape.push_back(dim);
111+
}
112+
else if (m_keep_dims)
113+
{
114+
output_shape.push_back(1);
115+
}
116+
index++;
117+
}
118+
return output_shape;
119+
}

src/ngraph/op/util/arithmetic_reductions_keep_dims.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ namespace ngraph
4747
bool get_keep_dims() const { return m_keep_dims; }
4848
void set_keep_dims(bool keep_dims) { m_keep_dims = keep_dims; }
4949

50+
Shape compute_output_shape(const Shape& input_shape,
51+
const AxisSet& reduction_axes) const;
52+
5053
private:
5154
bool m_keep_dims = false;
5255
};
+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//*****************************************************************************
2+
// Copyright 2017-2020 Intel Corporation
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//*****************************************************************************
16+
17+
#include "ngraph/op/util/binary_elementwise.hpp"
18+
19+
using namespace ngraph;
20+
using namespace std;
21+
22+
op::util::BinaryElementwise::BinaryElementwise() {}
23+
24+
op::util::BinaryElementwise::BinaryElementwise(const Output<Node>& arg0,
25+
const Output<Node>& arg1,
26+
const AutoBroadcastSpec& autob)
27+
: Op({arg0, arg1})
28+
, m_autob(autob)
29+
{
30+
}
31+
32+
bool op::util::BinaryElementwise::visit_attributes(AttributeVisitor& visitor)
33+
{
34+
visitor.on_attribute("auto_broadcast", m_autob);
35+
return true;
36+
}
37+
38+
Shape op::util::BinaryElementwise::compute_output_shape(const Shape& arg0_shape,
39+
const Shape& arg1_shape) const
40+
{
41+
PartialShape arg0_partial_shape = arg0_shape;
42+
if (m_autob.m_type == op::AutoBroadcastType::NONE)
43+
{
44+
NGRAPH_CHECK(PartialShape::merge_into(arg0_partial_shape, arg1_shape),
45+
"Argument shapes are inconsistent.");
46+
}
47+
else if (m_autob.m_type == op::AutoBroadcastType::NUMPY ||
48+
m_autob.m_type == op::AutoBroadcastType::PDPD)
49+
{
50+
NGRAPH_CHECK(PartialShape::broadcast_merge_into(arg0_partial_shape, arg1_shape, m_autob),
51+
"Argument shapes are inconsistent.");
52+
}
53+
else
54+
{
55+
NGRAPH_CHECK(false, "Unsupported auto broadcast specification");
56+
}
57+
return arg0_partial_shape.get_shape();
58+
}
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//*****************************************************************************
2+
// Copyright 2017-2020 Intel Corporation
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//*****************************************************************************
16+
17+
#pragma once
18+
19+
#include "ngraph/op/op.hpp"
20+
#include "ngraph/op/util/attr_types.hpp"
21+
22+
namespace ngraph
23+
{
24+
namespace op
25+
{
26+
namespace util
27+
{
28+
class NGRAPH_API BinaryElementwise : public Op
29+
{
30+
protected:
31+
BinaryElementwise();
32+
33+
/// \brief Constructs a binary elementwise operation.
34+
///
35+
/// \param arg0 Output that produces the first input tensor.
36+
/// \param arg1 Output that produces the second input tensor.
37+
BinaryElementwise(const Output<Node>& arg0,
38+
const Output<Node>& arg1,
39+
const AutoBroadcastSpec& autob);
40+
41+
public:
42+
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
43+
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
44+
bool supports_auto_broadcast() const override { return true; }
45+
bool visit_attributes(AttributeVisitor& visitor) override;
46+
Shape compute_output_shape(const Shape& a, const Shape& b) const;
47+
48+
protected:
49+
AutoBroadcastSpec m_autob;
50+
};
51+
}
52+
}
53+
}

src/ngraph/op/util/binary_elementwise_arithmetic.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,20 @@
1616

1717
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
1818
#include "ngraph/attribute_visitor.hpp"
19+
#include "ngraph/log.hpp"
1920

2021
using namespace std;
2122
using namespace ngraph;
2223

2324
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const AutoBroadcastSpec& autob)
24-
: m_autob(autob)
25+
: BinaryElementwise()
2526
{
2627
}
2728

2829
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const Output<Node>& arg0,
2930
const Output<Node>& arg1,
3031
const AutoBroadcastSpec& autob)
31-
: Op({arg0, arg1})
32-
, m_autob(autob)
32+
: BinaryElementwise(arg0, arg1, autob)
3333
{
3434
}
3535

@@ -40,6 +40,5 @@ void op::util::BinaryElementwiseArithmetic::validate_and_infer_types()
4040

4141
bool op::util::BinaryElementwiseArithmetic::visit_attributes(AttributeVisitor& visitor)
4242
{
43-
visitor.on_attribute("auto_broadcast", m_autob);
44-
return true;
43+
return BinaryElementwise::visit_attributes(visitor);
4544
}

src/ngraph/op/util/binary_elementwise_arithmetic.hpp

+2-10
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616

1717
#pragma once
1818

19-
#include "ngraph/op/op.hpp"
20-
#include "ngraph/op/util/attr_types.hpp"
19+
#include "ngraph/op/util/binary_elementwise.hpp"
2120

2221
namespace ngraph
2322
{
@@ -51,7 +50,7 @@ namespace ngraph
5150
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
5251
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape and element type as the input tensors (after auto broadcasting). |
5352
// clang-format on
54-
class NGRAPH_API BinaryElementwiseArithmetic : public Op
53+
class NGRAPH_API BinaryElementwiseArithmetic : public BinaryElementwise
5554
{
5655
protected:
5756
BinaryElementwiseArithmetic(const AutoBroadcastSpec& autob);
@@ -66,15 +65,8 @@ namespace ngraph
6665

6766
public:
6867
void validate_and_infer_types() override;
69-
70-
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
71-
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
7268
bool is_binary_elementwise_arithmetic() const override { return true; }
73-
bool supports_auto_broadcast() const override { return true; }
7469
bool visit_attributes(AttributeVisitor& visitor) override;
75-
76-
private:
77-
AutoBroadcastSpec m_autob;
7870
};
7971
}
8072
}

src/ngraph/op/util/binary_elementwise_comparison.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ op::util::BinaryElementwiseComparison::BinaryElementwiseComparison() {}
2525
op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const Output<Node>& arg0,
2626
const Output<Node>& arg1,
2727
const AutoBroadcastSpec& autob)
28-
: Op({arg0, arg1})
29-
, m_autob(autob)
28+
: BinaryElementwise(arg0, arg1, autob)
3029
{
3130
}
3231

@@ -40,6 +39,5 @@ void op::util::BinaryElementwiseComparison::validate_and_infer_types()
4039

4140
bool op::util::BinaryElementwiseComparison::visit_attributes(AttributeVisitor& visitor)
4241
{
43-
visitor.on_attribute("auto_broadcast", m_autob);
44-
return true;
42+
return BinaryElementwise::visit_attributes(visitor);
4543
}

src/ngraph/op/util/binary_elementwise_comparison.hpp

+2-10
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616

1717
#pragma once
1818

19-
#include "ngraph/op/op.hpp"
20-
#include "ngraph/op/util/attr_types.hpp"
19+
#include "ngraph/op/util/binary_elementwise.hpp"
2120

2221
namespace ngraph
2322
{
@@ -51,7 +50,7 @@ namespace ngraph
5150
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
5251
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. |
5352
// clang-format on
54-
class NGRAPH_API BinaryElementwiseComparison : public Op
53+
class NGRAPH_API BinaryElementwiseComparison : public BinaryElementwise
5554
{
5655
protected:
5756
/// \brief Constructs a binary elementwise comparison operation.
@@ -68,15 +67,8 @@ namespace ngraph
6867

6968
public:
7069
void validate_and_infer_types() override;
71-
72-
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
73-
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
74-
bool supports_auto_broadcast() const override { return true; }
7570
bool is_binary_elementwise_comparison() const override { return true; }
7671
bool visit_attributes(AttributeVisitor& visitor) override;
77-
78-
private:
79-
AutoBroadcastSpec m_autob;
8072
};
8173
}
8274
}

src/ngraph/op/util/binary_elementwise_logical.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ op::util::BinaryElementwiseLogical::BinaryElementwiseLogical() {}
2525
op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const Output<Node>& arg0,
2626
const Output<Node>& arg1,
2727
const AutoBroadcastSpec& autob)
28-
: Op({arg0, arg1})
29-
, m_autob(autob)
28+
: BinaryElementwise(arg0, arg1, autob)
3029
{
3130
}
3231

@@ -37,6 +36,5 @@ void op::util::BinaryElementwiseLogical::validate_and_infer_types()
3736

3837
bool op::util::BinaryElementwiseLogical::visit_attributes(AttributeVisitor& visitor)
3938
{
40-
visitor.on_attribute("auto_broadcast", m_autob);
41-
return true;
39+
return BinaryElementwise::visit_attributes(visitor);
4240
}

src/ngraph/op/util/binary_elementwise_logical.hpp

+2-9
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#pragma once
1818

19-
#include "ngraph/op/op.hpp"
19+
#include "ngraph/op/util/binary_elementwise.hpp"
2020

2121
namespace ngraph
2222
{
@@ -50,7 +50,7 @@ namespace ngraph
5050
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
5151
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. |
5252
// clang-format on
53-
class NGRAPH_API BinaryElementwiseLogical : public Op
53+
class NGRAPH_API BinaryElementwiseLogical : public BinaryElementwise
5454
{
5555
protected:
5656
BinaryElementwiseLogical();
@@ -65,15 +65,8 @@ namespace ngraph
6565

6666
public:
6767
void validate_and_infer_types() override;
68-
69-
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
70-
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
71-
bool supports_auto_broadcast() const override { return true; }
7268
bool is_binary_elementwise_logical() const override { return true; }
7369
bool visit_attributes(AttributeVisitor& visitor) override;
74-
75-
private:
76-
AutoBroadcastSpec m_autob;
7770
};
7871
}
7972
}

src/ngraph/runtime/cpu/unit_test.manifest

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ CPU.onnx_dyn_shapes_flatten_neg_axis
3131

3232
CPU.onnx_dyn_shapes_flatten_axis
3333
CPU.onnx_dyn_shapes_flatten_neg_axis
34+
CPU.onnx_dyn_model_hardmax
3435

3536
# Need use evaluate, only applicable to INTERPRETER
3637
CPU.non_zero

0 commit comments

Comments
 (0)