Skip to content

Commit aeb4df6

Browse files
authored
[CINN] Clear OpInferSymbolicShapeCache when lowering subgraph (#72248)
1 parent eaa90a9 commit aeb4df6

File tree

3 files changed

+11
-0
lines changed

3 files changed

+11
-0
lines changed

paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.cc

+1
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ CreateGroupShapeOrDataExprs(
234234

235235
local_shape_analysis.RegisterSymbolConstraintFromShapeAnalysis(
236236
global_shape_analysis);
237+
local_shape_analysis.ClearOpInferSymbolicShapeCache();
237238
for (const auto& item : dim_expr_map) {
238239
local_shape_analysis.AddEqualCstr(item.first, item.second);
239240
}

paddle/pir/include/dialect/shape/utils/shape_analysis.h

+6
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ class IR_API InferSymbolicShapeContext {
136136
std::optional<InferSymbolicShapeCacheValue> GetOpInferSymbolicShapeCache(
137137
const InferSymbolicShapeCacheKey& op_infer_cache_key) const;
138138

139+
void ClearOpInferSymbolicShapeCache();
140+
139141
const symbol::ConstraintsManager& constraints_manager() const {
140142
return constraints_manager_;
141143
}
@@ -250,6 +252,10 @@ class IR_API ShapeConstraintIRAnalysis final
250252
return context_.constraints_manager();
251253
}
252254

255+
void ClearOpInferSymbolicShapeCache() {
256+
context_.ClearOpInferSymbolicShapeCache();
257+
}
258+
253259
void SetInputDynamicDimSpec(
254260
const std::vector<InputDynamicDimSpec>& input_dynamic_dim_spec);
255261

paddle/pir/src/dialect/shape/utils/shape_analysis.cc

+4
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,10 @@ InferSymbolicShapeContext::GetOpInferSymbolicShapeCache(
480480
return std::nullopt;
481481
}
482482

483+
void InferSymbolicShapeContext::ClearOpInferSymbolicShapeCache() {
484+
infer_symbolic_shape_cache_.clear();
485+
}
486+
483487
bool InferSymbolicShapeContext::HasPredefinedDimExprForInputName(
484488
const std::string& input_name) const {
485489
return predefined_dimexpr_map_for_inputs_.count(input_name) != 0;

0 commit comments

Comments
 (0)