@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- // clang-format off
16
15
#include " paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"
17
16
18
17
#include < algorithm>
@@ -31,7 +30,6 @@ limitations under the License. */
31
30
#include " paddle/fluid/framework/convert_utils.h"
32
31
#include " paddle/fluid/platform/enforce.h"
33
32
#include " paddle/fluid/platform/errors.h"
34
- // clang-format on
35
33
36
34
namespace paddle {
37
35
namespace framework {
@@ -79,7 +77,8 @@ FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const {
79
77
for (auto & feed_pair : input_tensors_) {
80
78
const auto & feed_name = feed_pair.first ;
81
79
const auto * tensor = feed_pair.second ;
82
- PADDLE_ENFORCE_NE (tensor, nullptr ,
80
+ PADDLE_ENFORCE_NE (tensor,
81
+ nullptr ,
83
82
platform::errors::PreconditionNotMet (
84
83
" The input variable %s's tensor cannot be NULL,"
85
84
" we need the variable's dtype and shape from tensor." ,
@@ -96,7 +95,8 @@ FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const {
96
95
}
97
96
98
97
PADDLE_ENFORCE_NE (
99
- feed_map[feed_name].shape .size (), 0UL ,
98
+ feed_map[feed_name].shape .size (),
99
+ 0UL ,
100
100
platform::errors::PreconditionNotMet (
101
101
" The input variable %s's tensor shape cannot be empty,"
102
102
" we need the variable's dtype and shape from tensor." ,
@@ -136,7 +136,8 @@ CinnGraphSymbolization::CreateCinnScope(const FeedInfoMap& feed_map) {
136
136
137
137
for (const auto & param_name : parameter_names) {
138
138
PADDLE_ENFORCE_GT (
139
- feed_map.count (param_name), 0UL ,
139
+ feed_map.count (param_name),
140
+ 0UL ,
140
141
platform::errors::NotFound (" Cannot find parameter %s from input list,"
141
142
" please add the tensor into input." ,
142
143
param_name.c_str ()));
@@ -162,12 +163,12 @@ CinnGraphSymbolization::CreateCinnScope(const FeedInfoMap& feed_map) {
162
163
163
164
std::vector<Node*> CinnGraphSymbolization::TopologicalSort () const {
164
165
std::unordered_set<Node*> op_nodes;
165
- std::for_each (graph_. Nodes (). begin (), graph_. Nodes (). end (),
166
- [&op_nodes](Node* n) {
167
- if (n->IsOp ()) {
168
- op_nodes.emplace (n);
169
- }
170
- });
166
+ std::for_each (
167
+ graph_. Nodes (). begin (), graph_. Nodes (). end (), [&op_nodes](Node* n) {
168
+ if (n->IsOp ()) {
169
+ op_nodes.emplace (n);
170
+ }
171
+ });
171
172
172
173
std::unordered_map<Node*, std::unordered_map<Node*, size_t >> adj_list;
173
174
std::unordered_map<Node*, size_t > in_degrees;
@@ -210,7 +211,8 @@ std::vector<Node*> CinnGraphSymbolization::TopologicalSort() const {
210
211
}
211
212
}
212
213
213
- PADDLE_ENFORCE_EQ (sorted_ops.size (), op_nodes.size (),
214
+ PADDLE_ENFORCE_EQ (sorted_ops.size (),
215
+ op_nodes.size (),
214
216
platform::errors::PreconditionNotMet (
215
217
" The sorting graph contains cycles." ));
216
218
return sorted_ops;
@@ -234,7 +236,8 @@ void CinnGraphSymbolization::RunOp(const CinnOpDesc& op_desc,
234
236
const OpMapperContext& ctx) const {
235
237
const auto & op_type = op_desc.Type ();
236
238
auto * kernel = ::cinn::frontend::OpMapperRegistry::Global ()->Find (op_type);
237
- PADDLE_ENFORCE_NE (kernel, nullptr ,
239
+ PADDLE_ENFORCE_NE (kernel,
240
+ nullptr ,
238
241
platform::errors::NotFound (
239
242
" Op %s is Not Supported by CINN, please register"
240
243
" this op in the CINN repo." ,
@@ -256,10 +259,12 @@ std::unordered_set<std::string> CinnGraphSymbolization::GetFetchIds() const {
256
259
std::unordered_set<std::string> fetch_names;
257
260
fetch_names.reserve (fetch_var_names_.size ());
258
261
std::for_each (
259
- fetch_var_names_.begin (), fetch_var_names_.end (),
262
+ fetch_var_names_.begin (),
263
+ fetch_var_names_.end (),
260
264
[this , &fetch_names](const std::string& name) {
261
265
PADDLE_ENFORCE_EQ (
262
- var_model_to_program_map_.count (name), 1 ,
266
+ var_model_to_program_map_.count (name),
267
+ 1 ,
263
268
platform::errors::PreconditionNotMet (
264
269
" Cannot find %s in var_model_to_program_map_" , name.c_str ()));
265
270
fetch_names.insert (var_model_to_program_map_.at (name));
@@ -276,8 +281,12 @@ ::cinn::frontend::Program CinnGraphSymbolization::operator()() {
276
281
auto feed_map = GetFeedInfoMapFromInput ();
277
282
auto cinn_scope = CreateCinnScope (feed_map);
278
283
279
- OpMapperContext ctx (*cinn_scope, target_, &builder, &var_map_,
280
- &var_model_to_program_map_, &fetch_var_names_);
284
+ OpMapperContext ctx (*cinn_scope,
285
+ target_,
286
+ &builder,
287
+ &var_map_,
288
+ &var_model_to_program_map_,
289
+ &fetch_var_names_);
281
290
// add all tensor's feed info into context
282
291
for (auto & feed_pair : feed_map) {
283
292
ctx.AddFeedInfo (feed_pair.first , feed_pair.second );
0 commit comments