Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent self mapping in the AlmostExact graph #3926

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
55 changes: 40 additions & 15 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,22 @@ void mapThroughLoopSwizzles(ValGraph& graph) {

} // namespace

void IdModel::assertNoSelfMapping() {
const ValGraph& exact_graph = idGraph(IdMappingMode::EXACT);
void IdModel::assertNoSelfMapping(const ValGraph& graph) const {
for (TensorView* tv : tvs_) {
std::optional<SelfMapping> self_mapping = hasSelfMapping(tv, exact_graph);
std::optional<SelfMapping> self_mapping = hasSelfMapping(tv, graph);
if (self_mapping.has_value()) {
tv->fusion()->print();
}
NVF_CHECK(
!self_mapping.has_value(),
"Unsupported domain mapping detected in ",
tv,
tv->toString(),
". ",
self_mapping->where,
" domains, ",
self_mapping->id1,
self_mapping->id1->toString(),
" and ",
self_mapping->id2,
self_mapping->id2->toString(),
", are mapped with each other.");
}
}
Expand Down Expand Up @@ -413,6 +415,12 @@ ValGraph& IdModel::buildExactGraph() {

graph.validateConsistency();

// Make sure there's no self mapping in the Exact graph as that
// would invalidate lowering assumptions.
if (!allow_self_mapping_) {
assertNoSelfMapping(graph);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this check to buildExactGraph

}

if (isOptionEnabled(EnableOption::IdModelExtraValidation)) {
checkStaticExtentGroups(graph);
}
Expand Down Expand Up @@ -495,6 +503,25 @@ ValGraph& IdModel::buildAlmostExactGraph() {

auto& almost_exact_graph = idGraph(IdMappingMode::ALMOSTEXACT);

for (TensorView* tv : tvs_) {
if (tv->hasRoot()) {
for (auto id : tv->getRootDomain()) {
almost_exact_graph.setUnmappable(
{tv->getRootDomain().begin(), tv->getRootDomain().end()});
}
}
for (auto id : tv->getLogicalDomain()) {
forbidden_pairs[id].insert(
tv->getLogicalDomain().begin(), tv->getLogicalDomain().end());
}
for (auto id : tv->getLoopDomain()) {
forbidden_pairs[id].insert(
tv->getLoopDomain().begin(), tv->getLoopDomain().end());
}
}

almost_exact_graph.do_not_map_vals_ = forbidden_pairs;

// Maps iter domain pairs returned by calling that return mappings from
// isTrivialExpr on every expression in the graph.

Expand All @@ -514,7 +541,6 @@ ValGraph& IdModel::buildAlmostExactGraph() {
// Map through trivial expressions
for (auto mapped_id_group : mapped_ids) {
for (auto id : mapped_id_group) {
// almost_exact_graph.mapVals(mapped_id_group.front(), id);
ids_to_map.emplace_back(mapped_id_group.front(), id);
}
}
Expand All @@ -527,6 +553,13 @@ ValGraph& IdModel::buildAlmostExactGraph() {

almost_exact_graph.validateConsistency();

// Even when EXACT has no self mapping, there was a case ALMOSTEXACT
// had self mapping (see issue #3919). ALMOSTEXACT is used in
// indexing, which assumes the graph has no self mapping.
if (!allow_self_mapping_) {
assertNoSelfMapping(almost_exact_graph);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this check for AlmostExact

}

if (isOptionEnabled(EnableOption::IdModelExtraValidation)) {
checkStaticExtentGroups(almost_exact_graph);
}
Expand Down Expand Up @@ -851,15 +884,7 @@ void IdModel::buildAllGraphs() {
validator->checkExactGraphEquivalence(idGraph(IdMappingMode::EXACT));
}

// Make sure there's no self mapping in the Exact graph as that
// would invalidate lowering assumptions.
if (!allow_self_mapping_) {
assertNoSelfMapping();
}

buildAlmostExactGraph();
// Skip validating the almost exact graph as the IdModel graph also
// maps non-size-one broadcast domains

buildPermissiveGraph();
// Validation is not implemented when compliment mapping is enabled
Expand Down
6 changes: 3 additions & 3 deletions csrc/id_model/id_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class IdModel : public PolymorphicBase {
const std::vector<Expr*>& exprs,
const std::vector<TensorView*>& additional_tvs = {},
bool build_graphs = false,
bool allow_self_mapping = false,
bool allow_self_mapping = true,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to not check self mapping by default since that restriction only matters for indexing.

LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback =
nullptr);

Expand All @@ -129,7 +129,7 @@ class IdModel : public PolymorphicBase {
IdModel(
Fusion* fusion,
bool build_graphs = false,
bool allow_self_mapping = false,
bool allow_self_mapping = true,
bool validate = false,
LoopPromotionMapBuilderCallback* loop_promotion_map_builder_callback =
nullptr);
Expand Down Expand Up @@ -286,7 +286,7 @@ class IdModel : public PolymorphicBase {
const StatefulInliningInfo& info);

// Errors if self mapping occurs
void assertNoSelfMapping();
void assertNoSelfMapping(const ValGraph& graph) const;

// Loop graph represents the loop structure of the given fusion, so
// there must not be any mapping between the loop domains of each
Expand Down
2 changes: 1 addition & 1 deletion csrc/preseg_passes/exact_mapped_extent_substitution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ void exactMappedExtentSubstitution(Fusion* fusion) {
std::unordered_map<Val*, Val*> replacement_map;

// Build the exact graph
IdModel id_model(fusion, false, false, false);
IdModel id_model(fusion);
id_model.buildExactGraph();
const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT);
const DisjointSets<Val*>& id_sets = exact_graph.disjointValSets();
Expand Down
37 changes: 37 additions & 0 deletions csrc/val_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,10 @@ void ValGraph::mapVals(Val* val0, Val* val1) {
return;
}

if (areUnmappable(val0, val1)) {
return;
}

// Definitions and uses are based on the groups of id0 and id1, don't merge
// them into a single group until we grab all definitions and uses for later
// processing.
Expand Down Expand Up @@ -565,6 +569,39 @@ bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) {
return true;
}

void ValGraph::setUnmappable(Val* val0, Val* val1) {
unmappable_vals_[val0].insert(val1);
unmappable_vals_[val1].insert(val0);
}

void ValGraph::setUnmappable(const std::vector<Val*>& vals) {
for (const auto i : c10::irange(vals.size() - 1)) {
for (const auto j : c10::irange(i + 1, vals.size())) {
setUnmappable(vals.at(i), vals.at(j));
}
}
}

bool ValGraph::areUnmappable(Val* val0, Val* val1) const {
const ValGroup& val_group0 = toGroup(val0);
const ValGroup& val_group1 = toGroup(val1);

for (const auto v0 : *val_group0) {
auto it = unmappable_vals_.find(v0);
if (it == unmappable_vals_.end()) {
continue;
}
const auto& unmappable_val_set = it->second;
if (std::any_of(val_group1->begin(), val_group1->end(), [&](Val* v1) {
return unmappable_val_set.count(v1);
})) {
return true;
}
}

return false;
}

void ValGraph::validateConsistency() const {
// Check the consistency of the mapping information. Specifically:
// 1. All ValGroup and ExprGroup sets are not empty. This may not be
Expand Down
12 changes: 12 additions & 0 deletions csrc/val_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,12 @@ class ValGraph {
return false;
}

// Mark val0 and val1 should not be mapped
void setUnmappable(Val* val0, Val* val1);

// Mark any of Vals of a given list of Vals should not be mapped
void setUnmappable(const std::vector<Val*>& vals);

private:
// Map expr0 and expr1 with each other, update unique_definitions_
// unique_uses_
Expand All @@ -340,6 +346,9 @@ class ValGraph {
// Returns true if expressions were mapped through.
bool mapThroughExpr(Expr* first, Expr* second, bool forward);

// Check if val0 and val1 are marked as unmappable
bool areUnmappable(Val* val0, Val* val1) const;

private:
// If propagate_through_exprs_ = false, then mapThroughExpr will not be called
// as a consequence of calling mapVals. As well as mapThroughExpr will not be
Expand All @@ -361,6 +370,9 @@ class ValGraph {
std::unordered_map<ValGroup, ExprGroups> unique_definitions_;

std::unordered_map<ValGroup, ExprGroups> unique_uses_;

// Mapping of a Val to a set of Vals that should be mapped
std::unordered_map<Val*, std::unordered_set<Val*>> unmappable_vals_;
};

struct ValGroupAndItsGraph {
Expand Down
41 changes: 40 additions & 1 deletion tests/cpp/test_id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ TEST_F(IdModelTest, DetectSelfMapping) {
fusion.addOutput(tv2);

EXPECT_THAT(
[&]() { IdModel id_model(&fusion, /*build_graphs=*/true); },
[&]() {
IdModel id_model(
&fusion, /*build_graphs=*/true, /*allow_self_mapping=*/false);
},
::testing::ThrowsMessage<nvfuser::nvfError>(
::testing::HasSubstr("are mapped with each other")));
}
Expand Down Expand Up @@ -2102,6 +2105,42 @@ TEST_F(IdModelTest, ComplimentMappingCausingLoopSelfMapping) {
// loop_graph.toGroup(tv10->axis(1)), loop_graph.toGroup(tv10->axis(2)));
}

// When two broadcast IDs are merged, all of the two input IDs and the
// output ID can be considered trivially mapped. However, doing so could
// cause self mappings in a loop domain, which violates the assumption
// of TensorIndexer. (For example, in this test case, tv1's loop
// domain has two padded IDs of extent 3. If the merge of tv0 triggers
// mappings of the two broadcast IDs of tv0, the two root IDs of tv1
// would be mapped too in the AlmostExact graph, which then means the
// two logical IDs of tv1 would also be mapped. This should be fixed
// by avoiding mapping that could result in self mapping.
//
// This is also a repro of issue #3919.
TEST_F(IdModelTest, SelfMappingInAlmostExactGraph) {
Fusion fusion;
FusionGuard fg(&fusion);

// [1, 1]
auto tv0 = makeConcreteTensor({1, 1});
fusion.addInput(tv0);

// [3, 3]
auto tv1 =
pad(tv0,
{fusion.oneVal(), fusion.oneVal(), fusion.oneVal(), fusion.oneVal()});

fusion.addOutput(tv1);

tv0->merge(0);

IdModel id_model(&fusion);
const auto& almost_exact = id_model.buildAlmostExactGraph();
EXPECT_FALSE(almost_exact.disjointValSets().strictAreMapped(
tv1->getLogicalDomain().at(0), tv1->getLogicalDomain().at(1)))
<< "Should not be mapped: " << tv1->getLogicalDomain().at(0)->toString()
<< ", " << tv1->getLogicalDomain().at(1)->toString();
}

namespace {
bool iterDomainsAreMapped(
const IdModel& id_model,
Expand Down
Loading