Skip to content

Commit 9ce6378

Browse files
authored
add .clang-format and apply (onnx#34)
Signed-off-by: daquexian <[email protected]>
1 parent 3d4d0d1 commit 9ce6378

31 files changed

+377
-405
lines changed

.clang-format

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
BasedOnStyle: Google
2+
AllowShortBlocksOnASingleLine: false
3+
AllowShortCaseLabelsOnASingleLine: false
4+
AllowShortFunctionsOnASingleLine: Empty
5+
AllowShortLoopsOnASingleLine: false
6+
AllowShortIfStatementsOnASingleLine: false
7+

examples/onnx_optimizer_exec.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
#include <onnxoptimizer/optimize.h>
66

7-
#include <onnx/onnx_pb.h>
87
#include <onnx/checker.h>
8+
#include <onnx/onnx_pb.h>
99

1010
#include <fstream>
1111

@@ -19,9 +19,7 @@ int main(int argc, char **argv) {
1919
}
2020
onnx::checker::check_model(model);
2121
const auto new_model = onnx::optimization::Optimize(
22-
model,
23-
onnx::optimization::GetFuseAndEliminationPass()
24-
);
22+
model, onnx::optimization::GetFuseAndEliminationPass());
2523
onnx::checker::check_model(new_model);
2624
std::ofstream ofs(argv[2]);
2725
success = new_model.SerializePartialToOstream(&ofs);

onnxoptimizer/optimize.h

+7-9
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ struct Optimizer {
2424
static GlobalPassRegistry passes;
2525

2626
public:
27-
Optimizer(const std::vector<std::string>& names, const bool fixed_point);
27+
Optimizer(const std::vector<std::string> &names, const bool fixed_point);
2828
~Optimizer();
2929

3030
ModelProto optimize(const ModelProto &mp_in) {
@@ -88,12 +88,10 @@ const std::vector<std::string> GetAvailablePasses();
8888

8989
const std::vector<std::string> GetFuseAndEliminationPass();
9090

91-
ModelProto Optimize(
92-
const ModelProto& mp_in,
93-
const std::vector<std::string>& names);
91+
ModelProto Optimize(const ModelProto &mp_in,
92+
const std::vector<std::string> &names);
9493

95-
ModelProto OptimizeFixed(
96-
const ModelProto& mp_in,
97-
const std::vector<std::string>& names);
98-
} // namespace optimization
99-
} // namespace ONNX_NAMESPACE
94+
ModelProto OptimizeFixed(const ModelProto &mp_in,
95+
const std::vector<std::string> &names);
96+
} // namespace optimization
97+
} // namespace ONNX_NAMESPACE

onnxoptimizer/pass.h

+31-40
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,8 @@ class Pass {
8787
PassOptimizationType pass_optimization_type;
8888

8989
public:
90-
Pass(
91-
PassType pass_type,
92-
PassEfficiency pass_efficiency,
93-
PassOptimizationType pass_optimization_type);
90+
Pass(PassType pass_type, PassEfficiency pass_efficiency,
91+
PassOptimizationType pass_optimization_type);
9492
virtual ~Pass();
9593

9694
PassType getPassType() const {
@@ -105,34 +103,30 @@ class Pass {
105103
virtual PassAnalysisType getPassAnalysisType() const = 0;
106104
virtual std::string getPassName() const = 0;
107105

108-
virtual bool initializePass(Graph&) {
106+
virtual bool initializePass(Graph &) {
109107
return false;
110108
}
111-
virtual bool finalizePass(Graph&) {
109+
virtual bool finalizePass(Graph &) {
112110
return false;
113111
}
114-
virtual std::shared_ptr<PostPassAnalysis> runPass(Graph& graph) = 0;
112+
virtual std::shared_ptr<PostPassAnalysis> runPass(Graph &graph) = 0;
115113

116114
protected:
117115
// Iterates through the elements in the graph and counts the number of times
118116
// the transform is successfully run.
119117
unsigned int DescendOnGraphAttributesAndCount(
120-
Node* n,
121-
std::function<unsigned int(Graph&)> fn);
118+
Node *n, std::function<unsigned int(Graph &)> fn);
122119
// A more general version of the function above that doesn't constrain the
123120
// return type of fn.
124-
void DescendOnGraphAttributesUnconstrained(
125-
Node* n,
126-
std::function<void(Graph&)> fn);
121+
void DescendOnGraphAttributesUnconstrained(Node *n,
122+
std::function<void(Graph &)> fn);
127123
};
128124

129125
class ImmutablePass : Pass {
130126
public:
131127
explicit ImmutablePass()
132-
: Pass(
133-
PassType::Immutable,
134-
PassEfficiency::Complete,
135-
PassOptimizationType::None) {}
128+
: Pass(PassType::Immutable, PassEfficiency::Complete,
129+
PassOptimizationType::None) {}
136130
~ImmutablePass() override;
137131
};
138132

@@ -143,17 +137,16 @@ struct CountBasedPassAnalysis : PostPassAnalysis {
143137
// but this complicates the memory model. Also since all passes come from
144138
// GlobalPassRegistry which already utilizes smart pointers we don't have to
145139
// worry about memory leaks from passes.
146-
Pass* pass;
140+
Pass *pass;
147141
unsigned int num_positive_transforms;
148142
bool initialization_done;
149143
bool finalization_done;
150144

151145
public:
152-
explicit CountBasedPassAnalysis(
153-
Pass* pass,
154-
unsigned int num_positive_transforms,
155-
bool initialization_done,
156-
bool finalization_done);
146+
explicit CountBasedPassAnalysis(Pass *pass,
147+
unsigned int num_positive_transforms,
148+
bool initialization_done,
149+
bool finalization_done);
157150

158151
bool graphChanged() {
159152
return this->num_positive_transforms > 0;
@@ -165,7 +158,7 @@ struct CountBasedPassAnalysis : PostPassAnalysis {
165158
// Whether or not a repeated application of the pass might be useful.
166159
bool fixedPointOptimizationNeeded() {
167160
return this->graphChanged() &&
168-
pass->getPassEfficiency() == PassEfficiency::Partial;
161+
pass->getPassEfficiency() == PassEfficiency::Partial;
169162
}
170163
};
171164

@@ -177,29 +170,28 @@ struct CountBasedPassAnalysis : PostPassAnalysis {
177170
// patternMatchPredicate.
178171
class PredicateBasedPass : public Pass {
179172
public:
180-
explicit PredicateBasedPass(
181-
PassType pass_type,
182-
PassEfficiency pass_efficiency,
183-
PassOptimizationType pass_optimization_type)
173+
explicit PredicateBasedPass(PassType pass_type,
174+
PassEfficiency pass_efficiency,
175+
PassOptimizationType pass_optimization_type)
184176
: Pass(pass_type, pass_efficiency, pass_optimization_type) {}
185177
~PredicateBasedPass() override;
186178

187-
virtual bool patternMatchPredicate(Node* node) = 0;
179+
virtual bool patternMatchPredicate(Node *node) = 0;
188180
// Run transform is given the current node in the iterator, a reference to the
189181
// current graph as well as a reference describing how to treat the current
190182
// node in the iterator post transform. Run transform is then responsible for
191183
// running the actual transform as well as describing how to treat the
192184
// iterator node. By default the current node will not call destroy. Do not
193185
// internally delete node instead set the correct destroy_current type.
194-
virtual bool
195-
runTransform(Node* node, Graph& graph, NodeDestroyType& destroy_current) = 0;
186+
virtual bool runTransform(Node *node, Graph &graph,
187+
NodeDestroyType &destroy_current) = 0;
196188

197-
std::shared_ptr<PostPassAnalysis> runPass(Graph& graph) override;
189+
std::shared_ptr<PostPassAnalysis> runPass(Graph &graph) override;
198190
PassAnalysisType getPassAnalysisType() const override;
199191

200192
static int getOpsetVersion(const Graph &g) {
201193
// this hack is due to `opset_versions_mutable` doesn't have a const version
202-
Graph &mut_g = const_cast<Graph&>(g);
194+
Graph &mut_g = const_cast<Graph &>(g);
203195
for (const OpSetID &opset : mut_g.opset_versions_mutable()) {
204196
if (opset.domain() == "") {
205197
return opset.version();
@@ -209,16 +201,15 @@ class PredicateBasedPass : public Pass {
209201
}
210202

211203
private:
212-
unsigned int _runPassInternal(Graph& graph);
204+
unsigned int _runPassInternal(Graph &graph);
213205
};
214206

215207
// The most general pass which allows the user to run a pass given only a graph.
216208
class FullGraphBasedPass : public Pass {
217209
public:
218-
explicit FullGraphBasedPass(
219-
PassType pass_type,
220-
PassEfficiency pass_efficiency,
221-
PassOptimizationType pass_optimization_type)
210+
explicit FullGraphBasedPass(PassType pass_type,
211+
PassEfficiency pass_efficiency,
212+
PassOptimizationType pass_optimization_type)
222213
: Pass(pass_type, pass_efficiency, pass_optimization_type) {}
223214
~FullGraphBasedPass() override;
224215
};
@@ -236,7 +227,7 @@ inline bool areTwoValuesBothInputOrOutput(const Value *value1,
236227
const bool is_input =
237228
value->node()->kind() == kCaptured ||
238229
std::find(graph->inputs().rbegin(), graph->inputs().rend(), value) !=
239-
graph->inputs().rend();
230+
graph->inputs().rend();
240231
return is_output || is_input;
241232
};
242233
return IsInputOrOutput(value1) && IsInputOrOutput(value2);
@@ -264,5 +255,5 @@ inline bool tryReplacingAllUsesWith(Node *oldNode, Node *newNode) {
264255
return true;
265256
}
266257

267-
} // namespace optimization
268-
} // namespace ONNX_NAMESPACE
258+
} // namespace optimization
259+
} // namespace ONNX_NAMESPACE

onnxoptimizer/pass_manager.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,5 @@ class FixedPointPassManager : public GeneralPassManager {
5050
std::shared_ptr<PassManagerAnalysis> run(Graph& graph) override;
5151
};
5252

53-
} // namespace optimization
54-
} // namespace ONNX_NAMESPACE
53+
} // namespace optimization
54+
} // namespace ONNX_NAMESPACE

onnxoptimizer/pass_registry.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ struct GlobalPassRegistry {
8484

8585
std::shared_ptr<Pass> find(std::string pass_name) {
8686
auto it = this->passes.find(pass_name);
87-
ONNX_ASSERTM(
88-
it != this->passes.end(), "pass %s is unknown.", pass_name.c_str());
87+
ONNX_ASSERTM(it != this->passes.end(), "pass %s is unknown.",
88+
pass_name.c_str());
8989
return it->second;
9090
}
9191
const std::vector<std::string> GetAvailablePasses();
@@ -99,5 +99,5 @@ struct GlobalPassRegistry {
9999
passes[pass->getPassName()] = pass;
100100
}
101101
};
102-
} // namespace optimization
103-
} // namespace ONNX_NAMESPACE
102+
} // namespace optimization
103+
} // namespace ONNX_NAMESPACE

onnxoptimizer/passes/eliminate_deadend.h

+4-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
* SPDX-License-Identifier: Apache-2.0
33
*/
44

5-
65
// ATTENTION: The code in this file is highly EXPERIMENTAL.
76
// Adventurous users should note that the APIs will probably change.
87
#pragma once
@@ -11,10 +10,8 @@ namespace ONNX_NAMESPACE {
1110
namespace optimization {
1211
struct EliminateDeadEnd final : public FullGraphBasedPass {
1312
explicit EliminateDeadEnd()
14-
: FullGraphBasedPass(
15-
PassType::Nop,
16-
PassEfficiency::Complete,
17-
PassOptimizationType::Compute) {}
13+
: FullGraphBasedPass(PassType::Nop, PassEfficiency::Complete,
14+
PassOptimizationType::Compute) {}
1815
std::string getPassName() const override {
1916
return "eliminate_deadend";
2017
}
@@ -39,5 +36,5 @@ struct EliminateDeadEnd final : public FullGraphBasedPass {
3936
new CountBasedPassAnalysis(this, nodes_removed, false, false));
4037
}
4138
};
42-
} // namespace optimization
43-
} // namespace ONNX_NAMESPACE
39+
} // namespace optimization
40+
} // namespace ONNX_NAMESPACE

onnxoptimizer/passes/eliminate_duplicate_initializer.h

+30-30
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
// and removes the added initializers from inputs after
2424
// optimization. That makes us cannot distinguish the
2525
// initializers really in the inputs and the initializers
26-
// not in the inputs. While only the latter can be eliminated,
26+
// not in the inputs. While only the latter can be eliminated,
2727
// we eliminate all duplicated initializers instead. That
2828
// may cause unexpected behavior in some rare cases.
2929

@@ -103,40 +103,40 @@ struct EliminateDuplicateInitializer final : public FullGraphBasedPass {
103103
Tensor i_tensor = *iter_i_initializer;
104104
Value *i_value = input_map.find(i_tensor.name())->second;
105105

106-
#define DO_COMPARISON(data_type) \
107-
const std::vector<data_type> i_data = ParseData<data_type>(&i_tensor); \
108-
for (auto iter_j = iter_i + 1; iter_j != pair.second.end(); ++iter_j) { \
109-
const auto iter_j_initializer = graph.getInitializer(*iter_j); \
110-
if (iter_j_initializer == graph.initializers().end()) { \
111-
visited.insert(*iter_j); \
112-
continue; \
113-
} \
114-
Tensor j_tensor = *iter_j_initializer; \
115-
if (i_tensor.elem_type() != j_tensor.elem_type()) { \
116-
continue; \
117-
} else { \
118-
const std::vector<data_type> j_data = ParseData<data_type>(&j_tensor); \
119-
if (std::equal(i_data.begin(), i_data.end(), j_data.begin())) { \
120-
visited.insert(*iter_j); \
121-
Value *j_value = input_map.find(j_tensor.name())->second; \
122-
j_value->replaceAllUsesWith(i_value); \
123-
graph.eraseInitializerAndInput(j_value); \
124-
initializers_removed++; \
125-
} \
126-
} \
106+
#define DO_COMPARISON(data_type) \
107+
const std::vector<data_type> i_data = ParseData<data_type>(&i_tensor); \
108+
for (auto iter_j = iter_i + 1; iter_j != pair.second.end(); ++iter_j) { \
109+
const auto iter_j_initializer = graph.getInitializer(*iter_j); \
110+
if (iter_j_initializer == graph.initializers().end()) { \
111+
visited.insert(*iter_j); \
112+
continue; \
113+
} \
114+
Tensor j_tensor = *iter_j_initializer; \
115+
if (i_tensor.elem_type() != j_tensor.elem_type()) { \
116+
continue; \
117+
} else { \
118+
const std::vector<data_type> j_data = ParseData<data_type>(&j_tensor); \
119+
if (std::equal(i_data.begin(), i_data.end(), j_data.begin())) { \
120+
visited.insert(*iter_j); \
121+
Value *j_value = input_map.find(j_tensor.name())->second; \
122+
j_value->replaceAllUsesWith(i_value); \
123+
graph.eraseInitializerAndInput(j_value); \
124+
initializers_removed++; \
125+
} \
126+
} \
127127
}
128-
#define CASE_DO_COMPARISON(ONNX_DTYPE_SUFFIX, CPP_DTYPE) \
129-
case ONNX_NAMESPACE::TensorProto_DataType_##ONNX_DTYPE_SUFFIX: { \
130-
DO_COMPARISON(CPP_DTYPE) \
131-
break; \
128+
#define CASE_DO_COMPARISON(ONNX_DTYPE_SUFFIX, CPP_DTYPE) \
129+
case ONNX_NAMESPACE::TensorProto_DataType_##ONNX_DTYPE_SUFFIX: { \
130+
DO_COMPARISON(CPP_DTYPE) \
131+
break; \
132132
}
133133
switch (i_tensor.elem_type()) {
134134
CASE_DO_COMPARISON(FLOAT, float)
135135
CASE_DO_COMPARISON(DOUBLE, double)
136136
CASE_DO_COMPARISON(INT32, int32_t)
137137
CASE_DO_COMPARISON(INT64, int64_t)
138-
default:
139-
break;
138+
default:
139+
break;
140140
}
141141
#undef CASE_DO_COMPARISON
142142
#undef DO_COMPARISON
@@ -150,5 +150,5 @@ struct EliminateDuplicateInitializer final : public FullGraphBasedPass {
150150
new CountBasedPassAnalysis(this, initializers_removed, false, false));
151151
}
152152
};
153-
} // namespace optimization
154-
} // namespace ONNX_NAMESPACE
153+
} // namespace optimization
154+
} // namespace ONNX_NAMESPACE

onnxoptimizer/passes/eliminate_identity.h

+11-11
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@ namespace optimization {
1414

1515
struct EliminateIdentity final : public PredicateBasedPass {
1616
explicit EliminateIdentity()
17-
: PredicateBasedPass(
18-
PassType::Nop,
19-
PassEfficiency::Complete,
20-
PassOptimizationType::Compute) {}
17+
: PredicateBasedPass(PassType::Nop, PassEfficiency::Complete,
18+
PassOptimizationType::Compute) {}
2119

2220
std::string getPassName() const override {
2321
return "eliminate_identity";
@@ -26,15 +24,17 @@ struct EliminateIdentity final : public PredicateBasedPass {
2624
bool patternMatchPredicate(Node* node) override {
2725
return node->kind() == kIdentity;
2826
}
29-
bool runTransform(Node* node, Graph& graph, NodeDestroyType& destroy_current)
30-
override {
31-
32-
const bool replacing_success = tryReplacingAllUsesWith(node->output(), node->input());
33-
if (!replacing_success) { return false; }
27+
bool runTransform(Node* node, Graph& graph,
28+
NodeDestroyType& destroy_current) override {
29+
const bool replacing_success =
30+
tryReplacingAllUsesWith(node->output(), node->input());
31+
if (!replacing_success) {
32+
return false;
33+
}
3434
destroy_current = NodeDestroyType::DestroyOne;
3535
return true;
3636
}
3737
};
3838

39-
} // namespace optimization
40-
} // namespace ONNX_NAMESPACE
39+
} // namespace optimization
40+
} // namespace ONNX_NAMESPACE

0 commit comments

Comments
 (0)