@@ -87,10 +87,8 @@ class Pass {
87
87
PassOptimizationType pass_optimization_type;
88
88
89
89
public:
90
- Pass (
91
- PassType pass_type,
92
- PassEfficiency pass_efficiency,
93
- PassOptimizationType pass_optimization_type);
90
+ Pass (PassType pass_type, PassEfficiency pass_efficiency,
91
+ PassOptimizationType pass_optimization_type);
94
92
virtual ~Pass ();
95
93
96
94
PassType getPassType () const {
@@ -105,34 +103,30 @@ class Pass {
105
103
virtual PassAnalysisType getPassAnalysisType () const = 0;
106
104
virtual std::string getPassName () const = 0;
107
105
108
- virtual bool initializePass (Graph&) {
106
+ virtual bool initializePass (Graph &) {
109
107
return false ;
110
108
}
111
- virtual bool finalizePass (Graph&) {
109
+ virtual bool finalizePass (Graph &) {
112
110
return false ;
113
111
}
114
- virtual std::shared_ptr<PostPassAnalysis> runPass (Graph& graph) = 0;
112
+ virtual std::shared_ptr<PostPassAnalysis> runPass (Graph & graph) = 0;
115
113
116
114
protected:
117
115
// Iterates through the elements in the graph and counts the number of times
118
116
// the transform is successfully run.
119
117
unsigned int DescendOnGraphAttributesAndCount (
120
- Node* n,
121
- std::function<unsigned int (Graph&)> fn);
118
+ Node *n, std::function<unsigned int (Graph &)> fn);
122
119
// A more general version of the function above that doesn't constrain the
123
120
// return type of fn.
124
- void DescendOnGraphAttributesUnconstrained (
125
- Node* n,
126
- std::function<void (Graph&)> fn);
121
+ void DescendOnGraphAttributesUnconstrained (Node *n,
122
+ std::function<void (Graph &)> fn);
127
123
};
128
124
129
125
class ImmutablePass : Pass {
130
126
public:
131
127
explicit ImmutablePass ()
132
- : Pass(
133
- PassType::Immutable,
134
- PassEfficiency::Complete,
135
- PassOptimizationType::None) {}
128
+ : Pass(PassType::Immutable, PassEfficiency::Complete,
129
+ PassOptimizationType::None) {}
136
130
~ImmutablePass () override ;
137
131
};
138
132
@@ -143,17 +137,16 @@ struct CountBasedPassAnalysis : PostPassAnalysis {
143
137
// but this complicates the memory model. Also since all passes come from
144
138
// GlobalPassRegistry which already utilizes smart pointers we don't have to
145
139
// worry about memory leaks from passes.
146
- Pass* pass;
140
+ Pass * pass;
147
141
unsigned int num_positive_transforms;
148
142
bool initialization_done;
149
143
bool finalization_done;
150
144
151
145
public:
152
- explicit CountBasedPassAnalysis (
153
- Pass* pass,
154
- unsigned int num_positive_transforms,
155
- bool initialization_done,
156
- bool finalization_done);
146
+ explicit CountBasedPassAnalysis (Pass *pass,
147
+ unsigned int num_positive_transforms,
148
+ bool initialization_done,
149
+ bool finalization_done);
157
150
158
151
bool graphChanged () {
159
152
return this ->num_positive_transforms > 0 ;
@@ -165,7 +158,7 @@ struct CountBasedPassAnalysis : PostPassAnalysis {
165
158
// Whether or not a repeated application of the pass might be useful.
166
159
bool fixedPointOptimizationNeeded () {
167
160
return this ->graphChanged () &&
168
- pass->getPassEfficiency () == PassEfficiency::Partial;
161
+ pass->getPassEfficiency () == PassEfficiency::Partial;
169
162
}
170
163
};
171
164
@@ -177,29 +170,28 @@ struct CountBasedPassAnalysis : PostPassAnalysis {
177
170
// patternMatchPredicate.
178
171
class PredicateBasedPass : public Pass {
179
172
public:
180
- explicit PredicateBasedPass (
181
- PassType pass_type,
182
- PassEfficiency pass_efficiency,
183
- PassOptimizationType pass_optimization_type)
173
+ explicit PredicateBasedPass (PassType pass_type,
174
+ PassEfficiency pass_efficiency,
175
+ PassOptimizationType pass_optimization_type)
184
176
: Pass(pass_type, pass_efficiency, pass_optimization_type) {}
185
177
~PredicateBasedPass () override ;
186
178
187
- virtual bool patternMatchPredicate (Node* node) = 0;
179
+ virtual bool patternMatchPredicate (Node * node) = 0;
188
180
// Run transform is given the current node in the iterator, a reference to the
189
181
// current graph as well as a reference describing how to treat the current
190
182
// node in the iterator post transform. Run transform is then responsible for
191
183
// running the actual transform as well as describing how to treat the
192
184
// iterator node. By default the current node will not call destroy. Do not
193
185
// internally delete node instead set the correct destroy_current type.
194
- virtual bool
195
- runTransform (Node* node, Graph& graph, NodeDestroyType& destroy_current) = 0 ;
186
+ virtual bool runTransform (Node *node, Graph &graph,
187
+ NodeDestroyType & destroy_current) = 0;
196
188
197
- std::shared_ptr<PostPassAnalysis> runPass (Graph& graph) override ;
189
+ std::shared_ptr<PostPassAnalysis> runPass (Graph & graph) override ;
198
190
PassAnalysisType getPassAnalysisType () const override ;
199
191
200
192
static int getOpsetVersion (const Graph &g) {
201
193
// this hack is due to `opset_versions_mutable` doesn't have a const version
202
- Graph &mut_g = const_cast <Graph&>(g);
194
+ Graph &mut_g = const_cast <Graph &>(g);
203
195
for (const OpSetID &opset : mut_g.opset_versions_mutable ()) {
204
196
if (opset.domain () == " " ) {
205
197
return opset.version ();
@@ -209,16 +201,15 @@ class PredicateBasedPass : public Pass {
209
201
}
210
202
211
203
private:
212
- unsigned int _runPassInternal (Graph& graph);
204
+ unsigned int _runPassInternal (Graph & graph);
213
205
};
214
206
215
207
// The most general pass which allows the user to run a pass given only a graph.
216
208
class FullGraphBasedPass : public Pass {
217
209
public:
218
- explicit FullGraphBasedPass (
219
- PassType pass_type,
220
- PassEfficiency pass_efficiency,
221
- PassOptimizationType pass_optimization_type)
210
+ explicit FullGraphBasedPass (PassType pass_type,
211
+ PassEfficiency pass_efficiency,
212
+ PassOptimizationType pass_optimization_type)
222
213
: Pass(pass_type, pass_efficiency, pass_optimization_type) {}
223
214
~FullGraphBasedPass () override ;
224
215
};
@@ -236,7 +227,7 @@ inline bool areTwoValuesBothInputOrOutput(const Value *value1,
236
227
const bool is_input =
237
228
value->node ()->kind () == kCaptured ||
238
229
std::find (graph->inputs ().rbegin (), graph->inputs ().rend (), value) !=
239
- graph->inputs ().rend ();
230
+ graph->inputs ().rend ();
240
231
return is_output || is_input;
241
232
};
242
233
return IsInputOrOutput (value1) && IsInputOrOutput (value2);
@@ -264,5 +255,5 @@ inline bool tryReplacingAllUsesWith(Node *oldNode, Node *newNode) {
264
255
return true ;
265
256
}
266
257
267
- } // namespace optimization
268
- } // namespace ONNX_NAMESPACE
258
+ } // namespace optimization
259
+ } // namespace ONNX_NAMESPACE
0 commit comments