Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent self mapping in the AlmostExact graph #3926

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Feb 20, 2025

Fixes #3919

The root cause of #3919 is there's a tensor that has a self mapping loop domain. Specifically:

// [1, 1]
tv0: [b1, b2]
// [3,3]
tv1 = pad(tv0, {1, 1, 1, 1});

tv0->merge(0, 1);

In this case, because the tv0 merge just uses broadcast IDs, all of the two input IDs and the output ID are mapped in the AlmostExact graph (see here). That is also the case with ComputeAtMap. That means b1 and b2 are mapped, which in turn means the two logical IDs of tv1 are also mapped since they use the same resize op.

While the mapping of b1 and b2 may not have any actual effect as they are just broadcast IDs, the mapping of the logical IDs of tv1 is problematic since they are concrete IDs. In the case of #3919, that manifests as the error in the contiguity analysis since it assumes no self mapping.

The fix in this PR is simply avoiding mapping b1 and b2. More specifically, I extended ValGraph to exclude certain Vals from mapping. In the case of the AlmostExact graph, even when two IDs are trivially mapped, we no longer map them if they can result in self mapping. It may not sound ideal as they do have the same extent, but until we address the self mapping limitation, I think this is a reasonable workaround.

Note: I am not 100% confident if this might have any negative side effect. That is, the AlmostExact graph now lacks some mappings that do have the same extent. However, self mapping would certainly break TensorIndexer, so I think this is still reasonable.

@naoyam
Copy link
Collaborator Author

naoyam commented Feb 20, 2025

!test --diff

Copy link

github-actions bot commented Feb 20, 2025

Review updated until commit 37ac5d4

Description


Changes walkthrough 📝

Relevant files
Enhancement
id_model.cpp
Add self mapping validation and unmappable marking             

csrc/id_model/id_model.cpp

  • Modify assertNoSelfMapping to accept a graph parameter
  • Add self mapping validation in buildExactGraph and
    buildAlmostExactGraph
  • Mark root, logical, and loop domains as unmappable in
    buildAlmostExactGraph
  • +34/-15 
    val_graph.cpp
    Add unmappable Vals handling                                                         

    csrc/val_graph.cpp

  • Add setUnmappable methods to mark Vals as unmappable
  • Add areUnmappable method to check if Vals are marked as unmappable
  • Modify mapVals to respect unmappable Vals
  • +37/-0   
    val_graph.h
    Add unmappable Vals handling declarations                               

    csrc/val_graph.h

  • Add setUnmappable and areUnmappable method declarations
  • Add unmappable_vals_ member variable
  • +12/-0   
    Miscellaneous
    exact_mapped_extent_substitution.cpp
    Update IdModel constructor call                                                   

    csrc/preseg_passes/exact_mapped_extent_substitution.cpp

    • Update IdModel constructor call to use default parameters
    +1/-1     
    id_model.h
    Update IdModel constructors and method signature                 

    csrc/id_model/id_model.h

  • Update IdModel constructors to default allow_self_mapping to true
  • Modify assertNoSelfMapping method signature
  • +3/-3     
    Tests
    test_id_model.cpp
    Update and add tests for self mapping                                       

    tests/cpp/test_id_model.cpp

  • Update DetectSelfMapping test to use default allow_self_mapping
  • Add SelfMappingInAlmostExactGraph test to reproduce issue Test failed with Error replaying transforms in contiguous ID checker, expected iS10{9} to be in the active ID set. #3919
  • +40/-1   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Self Mapping Check

    Ensure that the self-mapping check is robust and does not introduce false positives or negatives. Verify that the logic for identifying self-mapping is correct and that it covers all relevant cases.

    void IdModel::assertNoSelfMapping(const ValGraph& graph) const {
      for (TensorView* tv : tvs_) {
        std::optional<SelfMapping> self_mapping = hasSelfMapping(tv, graph);
        if (self_mapping.has_value()) {
          tv->fusion()->print();
        }
        NVF_CHECK(
            !self_mapping.has_value(),
            "Unsupported domain mapping detected in ",
            tv->toString(),
            ". ",
            self_mapping->where,
            " domains, ",
            self_mapping->id1->toString(),
            " and ",
            self_mapping->id2->toString(),
            ", are mapped with each other.");
      }
    Unmappable Logic

    Review the logic for marking and checking unmappable values. Ensure that the unmappable logic does not inadvertently prevent valid mappings and that it correctly identifies values that should not be mapped.

    void ValGraph::mapVals(Val* val0, Val* val1) {
      if (val0 == val1) {
        return;
      }
    
      if (disjointValSets().strictAreMapped(val0, val1)) {
        return;
      }
    
      if (areUnmappable(val0, val1)) {
        return;
      }
    
      // Definitions and uses are based on the groups of id0 and id1, don't merge
      // them into a single group until we grab all definitions and uses for later
      // processing.
      const ValGroup orig_val_group0 = toGroup(val0);
      const ValGroup orig_val_group1 = toGroup(val1);
    
      // Note that getDefinitions and getUses return references, which
      // will be invalidated once unique_definitions_ and unique_uses_ are
      // updated
      const ExprGroups orig_defs0 = getDefinitions(orig_val_group0);
      const ExprGroups orig_defs1 = getDefinitions(orig_val_group1);
      const ExprGroups orig_uses0 = getUses(orig_val_group0);
      const ExprGroups orig_uses1 = getUses(orig_val_group1);
    
      // Map the iter domains together before we traverse across definitions and
      // uses. Traversing definitions and uses could use the new property of id0 and
      // id1 being mapped.
      disjoint_vals_.mapEntries(val0, val1);
      auto new_val_group = toGroup(val0);
    
      unique_definitions_[new_val_group] = orig_defs0.computeUnion(orig_defs1);
      unique_uses_[new_val_group] = orig_uses0.computeUnion(orig_uses1);
    
      // Propagate on uses
      if (!orig_uses0.empty() && !orig_uses1.empty()) {
        for (const ExprGroup& use_group_1 : orig_uses1) {
          NVF_ERROR(use_group_1.get() != nullptr);
          NVF_ERROR(!use_group_1->empty());
          for (const ExprGroup& use_group_0 : orig_uses0) {
            NVF_ERROR(use_group_0.get() != nullptr);
            NVF_ERROR(!use_group_0->empty());
            if (use_group_0 == use_group_1) {
              continue;
            }
            Expr* use0 = use_group_0->front();
            Expr* use1 = use_group_1->front();
            maybeMapThroughExprs(use0, use1, true);
          }
        }
      }
    
      // Propagate on definitions
      if (!orig_defs0.empty() && !orig_defs1.empty()) {
        for (const ExprGroup& def_group_1 : orig_defs1) {
          NVF_ERROR(def_group_1.get() != nullptr);
          NVF_ERROR(!def_group_1->empty());
          for (const ExprGroup& def_group_0 : orig_defs0) {
            NVF_ERROR(def_group_0.get() != nullptr);
            NVF_ERROR(!def_group_0->empty());
            if (def_group_0 == def_group_1) {
              continue;
            }
            auto def0 = def_group_0->front();
            auto def1 = def_group_1->front();
            maybeMapThroughExprs(def0, def1, false);
          }
        }
      }
    
      unique_definitions_.erase(orig_val_group0);
      unique_definitions_.erase(orig_val_group1);
      unique_uses_.erase(orig_val_group0);
      unique_uses_.erase(orig_val_group1);
    }
    
    void ValGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) {
      // By default, expressions are mapped only when everything is
      // matched, i.e., inputs, outputs and attributes are all mapped or
      // equal. When the propagation is allowed, as long as the inputs are
      // mapped and the attributes are equal, we propagate the mappings to
      // the outputs and the expressions.
      // In either case, it should be always true that when two
      // expressions are mapped, their inputs and outputs are also mapped,
      // respectively, and vice versa.
    
      if (!exprsMap(expr0, expr1, forward)) {
        return;
      }
    
      // Expr inputs are mapped. If propagate_through_exprs_ is true, map the
      // exprs and outputs. If not, map the exprs only when both inputs
      // and outputs are mapped. Since exprsMap makes sure inputs or
      // outputs are mapped, only outputs or inputs need to be checked
      if (propagate_through_exprs_) {
        mapExprs(expr0, expr1);
        mapThroughExpr(expr0, expr1, forward);
      } else if (
          (forward &&
           outputGroups(toGroup(expr0)) == outputGroups(toGroup(expr1))) ||
          (!forward &&
           inputGroups(toGroup(expr0)) == inputGroups(toGroup(expr1)))) {
        mapExprs(expr0, expr1);
      }
    }
    
    void ValGraph::mapExprs(Expr* expr0, Expr* expr1) {
      if (expr0 == expr1) {
        return;
      }
    
      if (disjointExprSets().strictAreMapped(expr0, expr1)) {
        return;
      }
    
      // Note that non-reference copies are required here as they may be
      // removed by mapEntries
      const ExprGroup expr0_orig_group = toGroup(expr0);
      const ExprGroup expr1_orig_group = toGroup(expr1);
    
      disjoint_exprs_.mapEntries(expr0, expr1);
    
      const ExprGroup& expr_new_group = toGroup(expr0);
    
      // Update unique uses
      for (auto& [producer_group, use_groups] : unique_uses_) {
        if (use_groups.has(expr0_orig_group) || use_groups.has(expr1_orig_group)) {
          use_groups.erase(expr0_orig_group);
          use_groups.erase(expr1_orig_group);
          use_groups.pushBack(expr_new_group);
        }
      }
    
      // Update unique definitions
      for (auto& [consumer_group, def_groups] : unique_definitions_) {
        if (def_groups.has(expr0_orig_group) || def_groups.has(expr1_orig_group)) {
          def_groups.erase(expr0_orig_group);
          def_groups.erase(expr1_orig_group);
          def_groups.pushBack(expr_new_group);
        }
      }
    }
    
    bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) {
      if (first == nullptr || second == nullptr) {
        return false;
      }
    
      if (!exprsMap(first, second, forward)) {
        return false;
      }
    
      NVF_ERROR(
          propagate_through_exprs_,
          "Asked to propagate expression mappings on a graph that has propagate_exprs_ disabled.");
    
      const auto& first_ids = forward ? first->outputs() : first->inputs();
      const auto& second_ids = forward ? second->outputs() : second->inputs();
    
      NVF_ERROR(
          first_ids.size() == second_ids.size(),
          "This should be unreachable, if transformation expressions match, their number of inputs and outputs should as well.\n However found:\n",
          first->toString(),
          "\nand\n",
          second->toString());
      for (auto out_i : c10::irange(first_ids.size())) {
        mapVals(first_ids[out_i], second_ids[out_i]);
      }
    
      return true;
    }
    
    void ValGraph::setUnmappable(Val* val0, Val* val1) {
      unmappable_vals_[val0].insert(val1);
      unmappable_vals_[val1].insert(val0);
    }
    
    void ValGraph::setUnmappable(const std::vector<Val*>& vals) {
      for (const auto i : c10::irange(vals.size() - 1)) {
        for (const auto j : c10::irange(i + 1, vals.size())) {
          setUnmappable(vals.at(i), vals.at(j));
        }
      }
    }
    
    bool ValGraph::areUnmappable(Val* val0, Val* val1) const {
      const ValGroup& val_group0 = toGroup(val0);
      const ValGroup& val_group1 = toGroup(val1);
    
      for (const auto v0 : *val_group0) {
        auto it = unmappable_vals_.find(v0);
        if (it == unmappable_vals_.end()) {
          continue;
        }
        const auto& unmappable_val_set = it->second;
        if (std::any_of(val_group1->begin(), val_group1->end(), [&](Val* v1) {
              return unmappable_val_set.count(v1);
            })) {
          return true;
        }
      }
    
      return false;
    }
    
    void ValGraph::validateConsistency() const {
    Test Coverage

    Ensure that the added tests cover all scenarios where self-mapping could occur and that they effectively validate the fix. Consider adding more edge cases to the tests.

    TEST_F(IdModelTest, DetectSelfMapping) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeConcreteTensor({2, 2});
      fusion.addInput(tv0);
      auto tv1 = transpose(tv0, 0, 1);
      auto tv2 = add(tv0, tv1);
      fusion.addOutput(tv2);
    
      EXPECT_THAT(
          [&]() {
            IdModel id_model(
                &fusion, /*build_graphs=*/true, /*allow_self_mapping=*/false);
          },
          ::testing::ThrowsMessage<nvfuser::nvfError>(
              ::testing::HasSubstr("are mapped with each other")));
    }
    
    TEST_F(IdModelTest, PerTensorSelfMapping) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      TensorView* x0 = makeConcreteTensor({2, 2});
      fusion.addInput(x0);
      TensorView* x1 = makeConcreteTensor({2, 2});
      fusion.addInput(x1);
    
      TensorView* y0 = transpose(x0, 0, 1);
      y0 = add(x0, y0);
      fusion.addOutput(y0);
    
      TensorView* y1 = transpose(x1, 0, 1);
      fusion.addOutput(y1);
    
      IdModel id_model(&fusion, /*build_graphs=*/true, /*allow_self_mapping=*/true);
      const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT);
      EXPECT_TRUE(hasSelfMapping(y0, exact_graph).has_value());
      EXPECT_FALSE(hasSelfMapping(y1, exact_graph).has_value());
    }
    
    namespace {
    
    // Get n-th parent expr traversing through the first input of each
    // parent
    Expr* getParentExpr(Val* val, int n) {
      for (int i = 0; i < n - 1; ++i) {
        NVF_ERROR(val->definition() != nullptr);
        val = val->definition()->input(0);
      }
      NVF_ERROR(val->definition() != nullptr);
      return val->definition();
    };
    
    IterDomain* getParentId(IterDomain* id, int n) {
      for (int i = 0; i < n; ++i) {
        NVF_ERROR(id->definition() != nullptr);
        NVF_ERROR(id->definition()->input(0)->isA<IterDomain>());
        id = id->definition()->input(0)->as<IterDomain>();
      }
      NVF_ERROR(id != nullptr);
      return id;
    };
    
    // Get the n-th descendant by traversing a sibling
    IterDomain* getChildId(IterDomain* id, int n, int sibling_idx = 0) {
      for (int i = 0; i < n; ++i) {
        NVF_ERROR(!id->uses().empty());
        NVF_ERROR(id->uses().front()->output(sibling_idx)->isA<IterDomain>());
        id = id->uses().front()->output(sibling_idx)->as<IterDomain>();
      }
      NVF_ERROR(id != nullptr);
      return id;
    };
    
    template <typename ValType>
    ValType* getValByName(const std::vector<ValType*>& vals, StmtNameType name) {
      if (auto it = std::find_if(
              vals.begin(),
              vals.end(),
              [&](auto val) { return val->name() == name; });
          it != vals.end()) {
        return *it;
      } else {
        return nullptr;
      }
    }
    
    IterDomain* getChildIdByName(IterDomain* id, StmtNameType name) {
      auto named_val = getValByName(ir_utils::consumerValsOf(id), name);
      NVF_ERROR(named_val != nullptr, "Cannot find a child ID named ", name);
      NVF_ERROR(named_val->isA<IterDomain>());
      return named_val->as<IterDomain>();
    };
    
    // Helper class to test IdModel
    class IdModelTester : public LoopPromotionMapBuilderCallback {
     public:
      // Do not automatically build the graphs
      IdModelTester(Fusion* fusion) {
        id_model = std::make_unique<IdModel>(
            fusion,
            /*build_graphs=*/false,
            /*allow_self_mapping=*/false,
            /*validate=*/true,
            /*loop_promotion_map_builder_callback=*/this);
    
        // Only build the loop graph
        id_model->buildLoopGraph(/*force_full_loop_promotion_analysis=*/true);
      }
    
      void postStep1(
          const std::unordered_map<ValGroup, IterDomain*>&
              iel_logical_resolution_map,
          const ValGraph& iel_graph) override {
        this->iel_graph = iel_graph;
        // this->iel_graph is a copy of the original IEL graph. The given
        // map is for the original graph and needs to be updated.
        s1_logical_resolution_map =
            updateValGroupIdMap(iel_logical_resolution_map, this->iel_graph);
      }
    
      void postStep2(
          const std::unordered_map<ValGroup, IterDomain*>& iel_promotion_map,
          const ValGraph& iel_graph) override {
        s2_iel_promotion_map =
            updateValGroupIdMap(iel_promotion_map, this->iel_graph);
      }
    
      void postStep3(const std::unordered_map<ValGroup, IterDomain*>&
                         loop_promotion_map) override {
        s3_loop_graph = id_model->idGraph(IdMappingMode::LOOP);
        s3_loop_promotion_map =
            updateValGroupIdMap(loop_promotion_map, s3_loop_graph);
      }
    
      void postStep4(
          const std::unordered_map<ValGroup, IterDomain*>& iel_promotion_map,
          const ValGraph& iel_graph) override {
        s4_iel_promotion_map =
            updateValGroupIdMap(iel_promotion_map, this->iel_graph);
      }
    
      void postStep5(const std::unordered_map<ValGroup, IterDomain*>&
                         loop_promotion_map) override {
        s5_loop_graph = id_model->idGraph(IdMappingMode::LOOP);
        s5_loop_promotion_map =
            updateValGroupIdMap(loop_promotion_map, s5_loop_graph);
      }
    
      void print(std::ostream& os) const {
        os << "Step 1 results:\n";
        for (const auto& [g, id] : s1_logical_resolution_map) {
          os << nvfuser::toString(g) << " -> " << id->toString() << std::endl;
        }
        os << "Step 2 results:\n";
        for (const auto& [g, id] : s2_iel_promotion_map) {
          os << nvfuser::toString(g) << " -> " << id->toString() << std::endl;
        }
        os << "Step 3 results:\n";
        for (const auto& [g, id] : s3_loop_promotion_map) {
          os << nvfuser::toString(g) << " -> " << id->toString() << std::endl;
        }
        os << "Step 4 results:\n";
        for (const auto& [g, id] : s4_iel_promotion_map) {
          os << nvfuser::toString(g) << " -> " << id->toString() << std::endl;
        }
        os << "Step 5 results:\n";
        for (const auto& [g, id] : s5_loop_promotion_map) {
          os << nvfuser::toString(g) << " -> " << id->toString() << std::endl;
        }
      }
    
      std::unique_ptr<IdModel> id_model;
      ValGraph iel_graph;
      std::unordered_map<ValGroup, IterDomain*> s1_logical_resolution_map;
      std::unordered_map<ValGroup, IterDomain*> s2_iel_promotion_map;
      ValGraph s3_loop_graph;
      std::unordered_map<ValGroup, IterDomain*> s3_loop_promotion_map;
      std::unordered_map<ValGroup, IterDomain*> s4_iel_promotion_map;
      ValGraph s5_loop_graph;
      std::unordered_map<ValGroup, IterDomain*> s5_loop_promotion_map;
    };
    
    // Test if id is resolved to an ID that is exact mapped with
    // ref_id. If ref_id  is nullptr, test if root_broadcast_id has no
    // resolution.
    void validateIELResolution(
        IterDomain* id,
        IterDomain* ref_id,
        const IdModelTester& tester,
        const std::unordered_map<ValGroup, IterDomain*>& iel_promotion_map) {
      const auto& iel_graph = tester.iel_graph;
      const auto& exact_graph = tester.id_model->idGraph(IdMappingMode::EXACT);
      const auto& loop_graph = tester.id_model->idGraph(IdMappingMode::LOOP);
    
      const auto& iel_group = iel_graph.toGroup(id);
      auto iel_promotion_map_it = iel_promotion_map.find(iel_group);
      if (ref_id != nullptr) {
        ASSERT_TRUE(iel_promotion_map_it != iel_promotion_map.end())
            << "IEL promotion not found for: " << nvfuser::toString(iel_group);
        ASSERT_FALSE(ref_id->isBroadcast());
        auto promotion_id = iel_promotion_map_it->second;
        ASSERT_TRUE(
            exact_graph.disjointValSets().strictAreMapped(promotion_id, ref_id))
            << "Unexpected promotion. "
            << "Expected: " << ref_id->toString()
            << ". Actual: " << promotion_id->toString();
        ASSERT_TRUE(loop_graph.disjointValSets().strictAreMapped(id, promotion_id))
            << "Promotion of " << id->toString()
            << " not mapped in the loop graph: " << promotion_id->toString();
      } else {
        ASSERT_TRUE(iel_promotion_map_it == iel_promotion_map.end())
            << "Promotion should not exist for: " << nvfuser::toString(iel_group)
            << ", but found: " << iel_promotion_map_it->second->toString();
      }
    }
    
    // Check if each domain gets promoted to a proper domain after the
    // Step 2 IEL propagation. It is assumed that the proper promotion is
    // the corresponding domain in the unique consumer tensor, which is
    // the case with most of the test fusions.
    void checkStep2Results(Fusion* fusion, const IdModelTester& tester) {
      const auto& iel_graph = tester.iel_graph;
      const auto& iel_promotion_map = tester.s2_iel_promotion_map;
    
      auto getPromotedDomain = [&](IterDomain* id) -> IterDomain* {
        if (auto it = iel_promotion_map.find(iel_graph.toGroup(id));
            it != iel_promotion_map.end()) {
          return it->second;
        } else {
          return nullptr;
        }
      };
    
      for (auto tv : fusion->allTvs()) {
        // If there's no broadcast or it isn't inlined, there's no
        // promotion
        if (std::none_of(
                tv->getLogicalDomain().begin(),
                tv->getLogicalDomain().end(),
                [](auto id) { return id->isBroadcast(); }) ||
            (tv->getComputeAtPosition() == 0 &&
             tv->getMaxProducerPosition() == 0)) {
          // Make sure there's no promotion of any of the IDs of this tensor
          for (auto id : tv->domain()->allIDs()) {
            auto promoted_id = getPromotedDomain(id);
            ASSERT_EQ(promoted_id, nullptr)
                << "Expected no mapping for " << id->toString()
                << " but found to be mapped to: " << promoted_id->toString();
          }
          continue;
        }
    
        auto consumers = ir_utils::consumerTvsOf(tv);
        ASSERT_EQ(consumers.size(), 1) << "Assumed to have one consumer";
        TensorView* c_tv = consumers.at(0);
        const auto p2c = BestEffortReplay::replayCasP(
                             c_tv, tv, -1, PairwiseLogicalDomainMap(tv, c_tv))
                             .getReplay();
    
        for (auto p_id : tv->domain()->allIDs()) {
          // Root domains are already done at Step 1
          if (std::find(
                  tv->getLogicalDomain().begin(),
                  tv->getLogicalDomain().end(),
                  p_id) != tv->getLogicalDomain().end()) {
            continue;
          }
    
          // If no broadcast is involved, nothing should be promoted
          auto p_id_dep_vals = DependencyCheck::getAllValsBetween(
              {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()},
              {p_id});
          if (std::find_if(
                  p_id_dep_vals.begin(), p_id_dep_vals.end(), [](Val* dep_id) {
                    return dep_id->as<IterDomain>()->isBroadcast();
                  }) == p_id_dep_vals.end()) {
            auto promoted_id = getPromotedDomain(p_id);
            ASSERT_EQ(promoted_id, nullptr)
                << "Expected no mapping for " << p_id->toString()
                << " but found to be mapped to: " << promoted_id->toString();
            continue;
          }
    
          // p_id should be promoted to c_id
          auto c_id = p2c.at(p_id);
          validateIELResolution(p_id, c_id, tester, iel_promotion_map);
        }
      }
    }
    
    // Validate the loop promotion map at Step 3. This validation ensures
    // the promotion map is exactly the same as a given reference
    // map. Since the valid promotion map may not be unique, the exact
    // equality is not required, however, as long as everything is done
    // deterministically, the resulting map should always be the
    // same. The exact equality helps ensure the determinism as well.
    void checkStep3Results(
        const IdModelTester& tester,
        const std::vector<std::pair<std::unordered_set<Val*>, IterDomain*>>&
            ref_promotion_map) {
      const auto& loop_graph = tester.s3_loop_graph;
      const auto& loop_promotion_map = tester.s3_loop_promotion_map;
    
      for (const auto& loop_group : loop_graph.disjointValSets().disjointSets()) {
        auto promotion_it = loop_promotion_map.find(loop_group);
        ASSERT_NE(promotion_it, loop_promotion_map.end())
            << "No promotion found for: " << nvfuser::toString(loop_group);
        IterDomain* promotion_id = promotion_it->second;
    
        auto ref_promotion_it = std::find_if(
            ref_promotion_map.begin(),
            ref_promotion_map.end(),
            [&](const auto& ref_promotion) {
              return ref_promotion.first == loop_group->set();
            });
    
        // Self promotion omitted in the reference
        if (ref_promotion_it == ref_promotion_map.end()) {
          ASSERT_EQ(loop_group->size(), 1);
          ASSERT_EQ(loop_group->front(), promotion_id)
              << "Expected promotion: " << loop_group->front()->toString()
              << ". Actual: " << promotion_id->toString();
          continue;
        }
    
        auto ref_promotion_id = ref_promotion_it->second;
        ASSERT_EQ(promotion_id, ref_promotion_id)
            << "Expected promotion: " << ref_promotion_id->toString()
            << ". Actual: " << promotion_id->toString();
    
        ASSERT_EQ(loop_graph.toGroup(promotion_id), loop_group)
            << "Loop group promoted to a non-mapped domain. Loop group: "
            << nvfuser::toString(loop_group)
            << ". Promotion: " << promotion_id->toString();
      }
    }
    
    void checkStep4Results(
        const IdModelTester& tester,
        const std::vector<std::pair<std::unordered_set<Val*>, IterDomain*>>&
            ref_promotion_map) {
      const auto& iel_promotion_map = tester.s4_iel_promotion_map;
    
      EXPECT_EQ(iel_promotion_map.size(), ref_promotion_map.size())
          << "Mismatched Step-4 result map. "
          << "Expected to have " << ref_promotion_map.size()
          << " mappings but found " << iel_promotion_map.size();
    
      for (const auto& ref_promotion_pair : ref_promotion_map) {
        const auto& ref_promotion_group = ref_promotion_pair.first;
        const auto& ref_promotion_id = ref_promotion_pair.second;
    
        auto iel_promotion_it = std::find_if(
            iel_promotion_map.begin(),
            iel_promotion_map.end(),
            [&](const auto& iel_promotion) {
              return iel_promotion.first->set() == ref_promotion_group;
            });
    
        auto iel_promotion_id = iel_promotion_it->second;
        EXPECT_EQ(ref_promotion_id, iel_promotion_id)
            << "Expected promotion: " << ref_promotion_id->toString()
            << ". Actual: " << iel_promotion_id->toString();
      }
    }
    
    void checkStep5Results(
        const IdModelTester& tester,
        const std::unordered_map<TensorView*, std::vector<IterDomain*>>&
            ref_promotion_map) {
      const auto& loop_graph = tester.s5_loop_graph;
      const auto& loop_promotion_map = tester.s5_loop_promotion_map;
    
      // Record if each entry of ref_promotion_map is found
      std::vector<bool> ref_promotion_map_found(ref_promotion_map.size(), false);
    
      for (const auto& [tv, ref_promotion_domains] : ref_promotion_map) {
        ASSERT_EQ(ref_promotion_domains.size(), tv->nDims())
            << "Invalid number of domains: "
            << toDelimitedString(ref_promotion_domains);
        for (const auto i : c10::irange(tv->nDims())) {
          IterDomain* loop_id = tv->axis(i);
          const ValGroup& loop_group = loop_graph.toGroup(loop_id);
    
          auto promotion_it = loop_promotion_map.find(loop_group);
          ASSERT_NE(promotion_it, loop_promotion_map.end())
              << "No promotion found for: " << nvfuser::toString(loop_group);
    
          IterDomain* promotion_id = promotion_it->second;
    
          ASSERT_EQ(promotion_id, ref_promotion_domains.at(i))
              << "Expected promotion: " << ref_promotion_domains.at(i)->toString()
              << ". Actual: " << promotion_id->toString();
    
          ASSERT_EQ(loop_graph.toGroup(promotion_id), loop_group)
              << "Loop group promoted to a non-mapped domain. Loop group: "
              << nvfuser::toString(loop_group)
              << ". Promotion: " << promotion_id->toString();
        }
      }
    }
    
    // Create a fusion where we're missing a valid concrete id so the compute at map
    // processing will fail. We need to be able to create the concrete ID not just
    // look for one. It is not yet possible to lower this fusion as the
    // current indexing cannot generate correct indices. Also used in
    // FusionIndeixing19 as well as Example 2 in the design doc about Loop
    // Promotion Analysis.
    std::unique_ptr<Fusion> createFusionWithMultipleResolutionPaths() {
      std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
      Fusion& fusion = *fusion_ptr.get();
      FusionGuard fg(&fusion);
    
      auto tv0 = makeConcreteTensor({7});
      fusion.addInput(tv0);
    
      auto tv1 = set(tv0);
    
      auto tv2 = broadcast(tv1, {false, true});
    
      auto tv3 = makeConcreteTensor({7, 11});
      fusion.addInput(tv3);
    
      auto tv4 = add(tv3, tv2);
      auto tv5 = broadcast(tv4, {false, false, true});
      // tv4[7, 11, 1]
    
      auto tv6 = broadcast(tv1, {false, true});
    
      auto tv7 = makeConcreteTensor({7, 13});
      fusion.addInput(tv7);
      auto tv8 = add(tv7, tv6);
      auto tv9 = broadcast(tv8, {false, true, false});
      // tv9[7, 1, 13]
    
      auto tv10 = add(tv5, tv9);
      fusion.addOutput(tv10);
    
      // tv10[7, 11, 13]
      tv10->merge(0)->merge(0);
      // tv10[7*11*13]
      tv10->split(0, 5)->split(0, 3);
      // tv10[7*11*13//5//3, 3, 5]
    
      TransformPropagatorWithCheck propagator(tv10);
      MaxLogicalDomainInfoSpanningTree(tv10).traverse(&propagator);
    
      std::vector<TensorView*> tensors_to_inline{tv1, tv2, tv4, tv6, tv8};
      for (auto tensor : tensors_to_inline) {
        tensor->inlineAt(1);
      }
    
      return fusion_ptr;
    }
    
    // Check the results of ValGraphStmtSort. Only the ordering of
    // ExprGroups is checked for now as it's likely sufficient.
    //
    // ref_order: The order must be exactly the
    // same as indicated by this list. While there can be different
    // order that still satisfy the topologial ordering, we also need
    // deterministic ordering, so the results should be always the same.
    void checkSortingResults(
        const ValGraph& graph,
        const ExprGroups& sorted_expr_groups,
        const ValGroups& sorted_val_groups,
        const std::vector<Expr*>& ref_order) {
      // Make sure sorted_val_groups cover all Expr groups
      const std::unordered_set<ExprGroup>& ref_expr_group_set{
          graph.disjointExprSets().disjointSets().begin(),
          graph.disjointExprSets().disjointSets().end()};
      std::unordered_set<ExprGroup> sorted_expr_group_set{
          sorted_expr_groups.begin(), sorted_expr_groups.end()};
      ASSERT_EQ(sorted_expr_group_set, ref_expr_group_set)
          << "Mismatched ExprGroups.";
    
      // Make sure sorted_val_groups covers all Val groups
      const std::unordered_set<ValGroup>& ref_val_group_set{
          graph.disjointValSets().disjointSets().begin(),
          graph.disjointValSets().disjointSets().end()};
      std::unordered_set<ValGroup> sorted_val_group_set{
          sorted_val_groups.begin(), sorted_val_groups.end()};
      ASSERT_EQ(sorted_val_group_set, ref_val_group_set) << "Mismatched ValGroups.";
    
      // Check the ordering
      ASSERT_EQ(sorted_expr_groups.size(), ref_order.size());
      for (const auto i : c10::irange(ref_order.size())) {
        Expr* ref_expr = ref_order.at(i);
        const ExprGroup& eg = sorted_expr_groups.at(i);
        ASSERT_TRUE(eg->has(ref_expr))
            << "Mismatch detected at " << i << "-th expr group. "
            << "Expected: " << nvfuser::toString(graph.toGroup(ref_expr)) << ", "
            << ref_expr->toString() << ". Actual: " << nvfuser::toString(eg) << ", "
            << eg->front()->toString();
      }
    }
    
    } // namespace
    
    // Sorting test with a trivial fusion
    TEST_F(IdModelTest, ValGraphStmtSort1) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeSymbolicTensor(2);
      fusion.addInput(tv0);
      auto tv1 = makeSymbolicTensor(2);
      fusion.addInput(tv1);
      auto tv2 = add(tv0, tv1);
      fusion.addOutput(tv2);
    
      // No ID expr yet. checkSortingResults validates the exprssion
      // order, but since there's no expr, it just makes sure exprs() and
      // vals() return all the val and expr groups.
      {
        IdModel id_model(&fusion, /*build_graphs=*/false);
        const ValGraph& vg = id_model.buildExactGraph();
        ValGraphStmtSort vg_stmt_sort(vg);
        checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), {});
      }
    
      // Add ID exprs. Just apply a merge-and-split pattern to all
      // tensors.
      tv2->merge(0)->split(0, 4);
      TransformPropagator propagator(tv2);
      MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator);
    
      // The exact graph should just map all IDs of the tensors. Ther
      // ordering of the exprs should be the merge and then the split.
      {
        IdModel id_model(&fusion, /*build_graphs=*/false);
    
        const ValGraph& vg = id_model.buildExactGraph();
        ValGraphStmtSort vg_stmt_sort(vg);
    
        // Reference expr order: merge, split
        std::vector<Expr*> ref_order;
        ref_order.push_back(getParentExpr(tv2->axis(0), 2));
        ref_order.push_back(getParentExpr(tv2->axis(0), 1));
    
        checkSortingResults(
            vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order);
      }
    }
    
    // Sorting test wth a disconnected graph
    TEST_F(IdModelTest, ValGraphStmtSort2) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeSymbolicTensor(2);
      fusion.addInput(tv0);
      auto tv1 = set(tv0);
      fusion.addOutput(tv1);
    
      auto tv2 = makeSymbolicTensor(2);
      fusion.addInput(tv2);
      auto tv3 = set(tv2);
      fusion.addOutput(tv3);
    
      // Note that the two groups of tensors, {tv0, tv1} and {tv2, tv3},
      // are not connected
    
      for (auto tv : fusion.allTvs()) {
        tv->merge(0)->split(0, 4);
      }
    
      // Since the two tensors are disconnected, there's no ordering
      // between the ID exprs of the two tensor groups. So, the correct
      // ordering should have the merge exprs before the split exprs, but
      // there's no order between the tv1 and tv3 exprs. For example,
      // these are all valid:
      //
      // tv1 merge -> tv3 merge -> tv1 split -> tv3 split
      // tv1 merge -> tv1 split -> tv3 merge -> tv3 split
      // tv3 merge -> tv3 split -> tv1 merge -> tv1 split
      // tv3 merge -> tv1 merge -> tv3 split -> tv1 split
      //
      // Here, the actual order returned by ValGraphStmtSort is the first
      // one. Since it should be deterministic, we check if the returned
      // expr vector is indeed ordered that way.
    
      IdModel id_model(&fusion, /*build_graphs=*/false);
    
      const ValGraph& vg = id_model.buildExactGraph();
      ValGraphStmtSort vg_stmt_sort(vg);
    
      std::vector<Expr*> ref_order;
      ref_order.push_back(getParentExpr(tv1->axis(0), 2));
      ref_order.push_back(getParentExpr(tv3->axis(0), 2));
      ref_order.push_back(getParentExpr(tv1->axis(0), 1));
      ref_order.push_back(getParentExpr(tv3->axis(0), 1));
    
      checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order);
    }
    
    // Sorting with trivial ExprGroup, i.e., ExprGroup whose input and
    // output are mapped as the same ValGroup. It's effectively a cyclic
    // dependency and the graph is no longer a DAG.
    TEST_F(IdModelTest, ValGraphStmtSort3) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeSymbolicTensor(2);
      fusion.addInput(tv0);
      auto tv1 = makeSymbolicTensor(2);
      fusion.addInput(tv1);
      auto tv2 = add(tv0, tv1);
      fusion.addOutput(tv2);
    
      auto tv3 = makeSymbolicTensor(2);
      fusion.addInput(tv3);
      auto tv4 = set(tv3);
      fusion.addOutput(tv4);
    
      // Merge and split by one. The split input and output will be mapped.
      for (auto tv : {tv0, tv1, tv2}) {
        tv->merge(0)->split(0, 1);
      }
    
      // Also test an isolated trivial expr. Note that tv3 and tv4 are not
      // connected with tv0, tv1 and tv2.
      tv4->merge(0)->split(0, 1);
    
      IdModel id_model(&fusion, /*build_graphs=*/false);
      ValGraph vg = id_model.buildExactGraph();
    
      // Map the split-by-1 input and output
      vg.mapVals(tv2->axis(0), tv2->axis(0)->definition()->input(0));
      vg.mapVals(tv4->axis(0), tv4->axis(0)->definition()->input(0));
    
      ValGraphStmtSort vg_stmt_sort(vg);
    
      std::vector<Expr*> ref_order;
      ref_order.push_back(getParentExpr(tv2->axis(0), 2));
      ref_order.push_back(getParentExpr(tv4->axis(0), 2));
      ref_order.push_back(getParentExpr(tv2->axis(0), 1));
      ref_order.push_back(getParentExpr(tv4->axis(0), 1));
    
      checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order);
    }
    
    // Sorting test with the same fusion as Indexing19
    TEST_F(IdModelTest, ValGraphStmtSort4) {
      auto fusion = createFusionWithMultipleResolutionPaths();
      FusionGuard fg(fusion.get());
      auto all_tvs = fusion->allTvs();
    
      // Since this fusion is not supported by ComputeAtMap, the
      // validation flag must be false
      IdModel id_model(fusion.get(), false, false, false);
      id_model.buildExactGraph();
      const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT);
    
      ValGraphStmtSort vg_stmt_sort(vg);
    
      auto tv1 = getValByName(all_tvs, 1);
      auto tv2 = getValByName(all_tvs, 2);
      auto tv4 = getValByName(all_tvs, 4);
      auto tv5 = getValByName(all_tvs, 5);
      auto tv6 = getValByName(all_tvs, 6);
      auto tv8 = getValByName(all_tvs, 8);
      auto tv9 = getValByName(all_tvs, 9);
      auto tv10 = getValByName(all_tvs, 10);
    
      // Expected reference order:
      //
      // exprg{39}: Merge iS2 bS3
      // exprg{57}: Merge iS11 bS12
      // exprg{17}: Merge iS17 bS18
      // exprg{51 63}: Merge iS15 iS16
      // exprg{69 73}: Split iS1
      // exprg{9 25 33 45}: Merge iS20 iS21
      // exprg{41}: Split iS46
      // exprg{59}: Split iS61
      // exprg{19}: Merge iS29 iS19
      // exprg{53 65}: Split iS56
      // exprg{71 75}: Split iS71
      // exprg{11}: Merge iS23 iS22
      // exprg{27}: Merge iS35 bS10
      // exprg{35 47}: Split iS41
      // exprg{43}: Split iS47
      // exprg{61}: Split iS62
      // exprg{21}: Split iS30
      // exprg{55 67}: Split iS57
      // exprg{13}: Split iS24
      // exprg{29}: Split iS36
      // exprg{37 49}: Split iS42
      // exprg{23}: Split iS31
      // exprg{15}: Split iS25
      // exprg{31}: Split iS37
    
      std::vector<Expr*> ref_order;
      ref_order.push_back(getParentExpr(tv2->axis(0), 3));
      ref_order.push_back(getParentExpr(tv6->axis(0), 3));
      ref_order.push_back(getParentExpr(tv9->axis(0), 4));
      ref_order.push_back(getParentExpr(tv8->axis(0), 3));
      ref_order.push_back(getParentExpr(tv1->axis(0), 2));
      ref_order.push_back(getParentExpr(tv10->axis(0), 4));
      ref_order.push_back(getParentExpr(tv2->axis(0), 2));
      ref_order.push_back(getParentExpr(tv6->axis(0), 2));
      ref_order.push_back(getParentExpr(tv9->axis(0), 3));
      ref_order.push_back(getParentExpr(tv8->axis(0), 2));
      ref_order.push_back(getParentExpr(tv1->axis(0), 1));
      ref_order.push_back(getParentExpr(tv10->axis(0), 3));
      ref_order.push_back(getParentExpr(tv5->axis(0), 3));
      ref_order.push_back(getParentExpr(tv4->axis(0), 2));
      ref_order.push_back(getParentExpr(tv2->axis(0), 1));
      ref_order.push_back(getParentExpr(tv6->axis(0), 1));
      ref_order.push_back(getParentExpr(tv9->axis(0), 2));
      ref_order.push_back(getParentExpr(tv8->axis(0), 1));
      ref_order.push_back(getParentExpr(tv10->axis(0), 2));
      ref_order.push_back(getParentExpr(tv5->axis(0), 2));
      ref_order.push_back(getParentExpr(tv4->axis(0), 1));
      ref_order.push_back(getParentExpr(tv9->axis(0), 1));
      ref_order.push_back(getParentExpr(tv10->axis(0), 1));
      ref_order.push_back(getParentExpr(tv5->axis(0), 1));
    
      checkSortingResults(vg, vg_stmt_sort.exprs(), vg_stmt_sort.vals(), ref_order);
    }
    
    // Testing loop promotion with a simple broadcast pattern
    TEST_F(IdModelTest, LoopPromotion1) {
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      auto t0 = makeSymbolicTensor(1);
      fusion->addInput(t0);
      auto t1 = makeSymbolicTensor(2);
      fusion->addInput(t1);
      auto t2 = broadcast(t0, {true, false});
      auto t3 = add(t2, t1);
      fusion->addOutput(t3);
    
      {
        IdModelTester tester(fusion.get());
    
        // Nothing inlined. Should be no resolution
        ASSERT_TRUE(tester.s1_logical_resolution_map.empty());
      }
    
      t2->inlineAt(2);
      ASSERT_EQ(t2->getComputeAtPosition(), 2);
    
      {
        IdModelTester tester(fusion.get());
    
        // Check Step 1 results
        // t2 is now fully inlined. Its logical broadcast domain should be
        // resolved with the corresponding domain of t3
        validateIELResolution(
            t2->getLogicalDomain().at(0),
            t3->getLogicalDomain().at(0),
            tester,
            tester.s1_logical_resolution_map);
    
        // Check Step 2 results
        // Nothing to propagate in this fusion, so iel_promotion_map
        // should be equivalent to root_resolution_map
        ASSERT_EQ(tester.s1_logical_resolution_map, tester.s2_iel_promotion_map)
            << "Unexpected IEL promotion map";
    
        // Check Step 3 results. See the design doc for the expected results
        std::vector<std::pair<std::unordered_set<Val*>, IterDomain*>>
            s3_reference_map = {
                {std::unordered_set<Val*>{t2->axis(0), t3->axis(0)}, t3->axis(0)},
                {std::unordered_set<Val*>{t2->axis(1), t3->axis(1)}, t3->axis(1)}};
    
        checkStep3Results(tester, s3_reference_map);
    
        ASSERT_TRUE(tester.s4_iel_promotion_map.empty())
            << "No step-4 IEL promotion expected";
      }
    }
    
    // Test with a fusion with progressive broadcasting
    TEST_F(IdModelTest, LoopPromotion2) {
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      auto t0 = makeSymbolicTensor(1);
      fusion->addInput(t0);
      auto t1 = makeSymbolicTensor(3);
      fusion->addInput(t1);
    
      auto t2 = broadcast(t0, {true, false});
      auto t3 = broadcast(t2, {true, false, false});
      auto t4 = add(t3, t1);
      fusion->addOutput(t4);
    
      inlineMost();
    
      IdModelTester tester(fusion.get());
    
      // Check Step 1 results
      // Validate t2 and t3 as they have logical broadcast domains
      validateIELResolution(
          t2->getLogicalDomain().at(0),
          t4->getLogicalDomain().at(1),
          tester,
          tester.s1_logical_resolution_map);
    
      validateIELResolution(
          t3->getLogicalDomain().at(0),
          t4->getLogicalDomain().at(0),
          tester,
          tester.s1_logical_resolution_map);
    
      // Check Step 2 results
      // Nothing to propagate in this fusion, so iel_promotion_map
      // should be equivalent to root_resolution_map
      ASSERT_EQ(tester.s1_logical_resolution_map, tester.s2_iel_promotion_map)
          << "Unexpected IEL promotion map";
    
      // Check Step 3 results. See the design doc for the expected results
      std::vector<std::pair<std::unordered_set<Val*>, IterDomain*>>
          s3_reference_map = {
              {std::unordered_set<Val*>{t2->axis(0), t3->axis(1), t4->axis(1)},
               t4->axis(1)},
              {std::unordered_set<Val*>{t2->axis(1), t3->axis(2), t4->axis(2)},
               t4->axis(2)},
              {std::unordered_set<Val*>{t3->axis(0), t4->axis(0)}, t4->axis(0)}};
    
      checkStep3Results(tester, s3_reference_map);
    
      ASSERT_TRUE(tester.s4_iel_promotion_map.empty())
          << "No step-4 IEL promotion expected";
    }
    
    // Multiple inlined and non-inlined broadcast domains
    TEST_F(IdModelTest, LoopPromotion3) {
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      auto tv0 = makeSymbolicTensor(2);
      fusion->addInput(tv0);
      auto tv1 = makeSymbolicTensor(4);
      fusion->addInput(tv1);
    
      auto tv2 = broadcast(tv0, {false, true, false, true});
      auto tv3 = add(tv2, tv1);
      fusion->addOutput(tv3);
    
      // tv3: [i0, i1, i2, i3] -> [i0*i1, i2*i3]
      tv3->merge(0);
      tv3->merge(1);
    
      TransformPropagatorWithCheck propagator(tv3);
      MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator);
    
      tv2->inlineAt(1);
    
      // tv2: [i0*b1, i2*b3] ca(1)
      // tv3: [i0*i1, i2*i3]
    
      IdModelTester tester(fusion.get());
    
      // Check Step 1 results
      // The b1 broadcast domain tv2 should be resolved as it's inlined,
      // but b3 should not.
      validateIELResolution(
          tv2->getLogicalDomain().at(1),
          tv3->getLogicalDomain().at(1),
          tester,
          tester.s1_logical_resolution_map);
    
      validateIELResolution(
          tv2->getLogicalDomain().at(3),
          nullptr,
          tester,
          tester.s1_logical_resolution_map);
    
      // Check Step 2 results
      validateIELResolution(
          tv2->axis(0), tv3->axis(0), tester, tester.s2_iel_promotion_map);
    
      validateIELResolution(
          tv2->axis(1), nullptr, tester, tester.s2_iel_promotion_map);
    
      // Check Step 3 results. See the design doc for the expected results
      std::vector<std::pair<std::unordered_set<Val*>, IterDomain*>>
          s3_reference_map = {
              {std::unordered_set<Val*>{
                   tv2->axis(0),
                   tv2->getLogicalDomain().at(0),
                   tv2->getLogicalDomain().at(1),
                   tv3->axis(0),
                   tv3->getLogicalDomain().at(0),
                   tv3->getLogicalDomain().at(1)},
               tv3->axis(0)}};
    
      checkStep3Results(tester, s3_reference_map);
    
      ASSERT_TRUE(tester.s4_iel_promotion_map.empty())
          << "No step-4 IEL promotion expected";
    }
    
    // Test root resolution with a fusion with outer split.
    // Currently invalid code will be generated.
    //
    // Used as Example 1 in the design doc about Loop
    // Promotion Analysis.
    TEST_F(IdModelTest, LoopPromotion4) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeContigConcreteTensor({1, 4});
      fusion.addInput(tv0);
      auto tv1 = makeContigConcreteTensor({3, 4});
      fusion.addInput(tv1);
    
      auto tv2 = set(tv0);
      auto tv3 = set(tv1);
      auto tv4 = add(tv2, tv3);
      fusion.addOutput(tv4);
    
      // [i0, i1]
      tv4->merge(0);
      // [i0*i1]
      tv4->split(0, 4, false); // outer split
      // [4, i0*i1/4]
    
      TransformPropagator propagator(tv4);
      MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator);
    
      for (auto tv : fusion.allTvs()) {
        tv->inlineAt(-2);
      }
    
      IdModelTester tester(&fusion);
    
      // Verify all tensors with root broadcast have correct resolutions
      for (auto tv : fusion.allTvs()) {
        // Skip tensors with no broadcast or non-inlined
        if (std::none_of(
                tv->getLogicalDomain().begin(),
                tv->getLogicalDomain().end(),
                [](auto id) { return id->isBroadcast(); }) ||
            tv->getComputeAtPosition() == 0) {
          continue;
        }
    
        switch (tv->name()) {
          case 2:
            // T2_l[ iS20{4}, iS21{( ceilDiv(( 1 * 4 ), 4) )} ] ca_pos( 1 )
            //  root domain : (bS4{1}, iS5{4})
            validateIELResolution(
                tv->getLogicalDomain().at(0),
                tv4->getLogicalDomain().at(0),
                tester,
                tester.s1_logical_resolution_map);
            break;
          default:
            FAIL() << "Unexpected tensor: " << tv->toString();
        }
      }
    
      checkStep2Results(&fusion, tester);
    
      auto id10 = getChildIdByName(tv4->getLogicalDomain()[0], 10);
      auto id11 = getChildIdByName(id10, 11);
      auto id12 = getChildIdByName(id10, 12);
      auto id13 = getChildIdByName(tv3->getLogicalDomain()[0], 13);
      auto id15 = getChildIdByName(id13, 15);
      auto id19 = getChildIdByName(tv2->getLogicalDomain()[0], 19);
      auto id25 = getChildIdByName(id10, 25);
      auto id26 = getChildIdByName(id10, 26);
    
      // Check Step 3 results. See the design doc for the expected results
      std::vector<std::pair<std::unordered_set<Val*>, IterDomain*>>
          s3_reference_map = {// 4, 6, 8 -> 8
                              {std::unordered_set<Val*>{
                                   tv2->getLogicalDomain().at(0),
                                   tv3->getLogicalDomain().at(0),
                                   tv4->getLogicalDomain().at(0)},
                               tv4->getLogicalDomain().at(0)},
                              // 5, 7, 9 -> 9
                              {std::unordered_set<Val*>{
                                   tv2->getLogicalDomain().at(1),
                                   tv3->getLogicalDomain().at(1),
                                   tv4->getLogicalDomain().at(1)},
                               tv4->getLogicalDomain().at(1)},
                              // 10, 13, 19 -> 10
                              {std::unordered_set<Val*>{id10, id13, id19}, id10},
                              // 11, 14, 20, 25 -> 11
                              {std::unordered_set<Val*>{
                                   tv2->axis(0), tv3->axis(0), tv4->axis(0), id25},
                               id11},
                              // 21, 26 -> 26
                              {std::unordered_set<Val*>{tv2->axis(1), id26}, id26}};
    
      checkStep3Results(tester, s3_reference_map);
    
      ASSERT_EQ(id10->name(), 10);
      auto id27 = getChildIdByName(id10, 27);
      auto id28 = getChildIdByName(id10, 28);
    
      std::vector<std::pair<std::unordered_set<Val*>, IterDomain*>>
          s4_reference_map = {// 20 -> 27
                              {std::unordered_set<Val*>{tv2->axis(0)}, id27},
                              // 21 -> 28
                              {std::unordered_set<Val*>{tv2->axis(1)}, id28}};
    
      checkStep4Results(tester, s4_reference_map);
    
      // Check Step 5 results. See the design doc for the expected results
      std::unordered_map<TensorView*, std::vector<IterDomain*>> s5_reference_map = {
          {tv2, {id11, id28}},
          {tv3, {id11, id15}},
          {tv4, {id11, id12}},
      };
    
      checkStep5Results(tester, s5_reference_map);
    }
    
    // Test root resolution with the same fusion as Indexing1
    TEST_F(IdModelTest, LoopPromotion5) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeSymbolicTensor(3);
      auto tv1 = makeSymbolicTensor(4);
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv2 = add(tv0, IrBuilder::create<Val>(1.0));
      auto tv3 = broadcast(tv2, {true, false, false, false});
      auto tv4 = add(tv3, tv1);
    
      fusion.addOutput(tv4);
    
      tv4->merge(0);
      tv4->merge(0);
      tv4->merge(0);
    
      tv4->split(0, 128);
      tv4->split(0, 4);
    
      tv2->computeAt(tv4, 1);
    
      tv4->axis(0)->parallelize(ParallelType::BIDx);
      tv4->axis(1)->parallelize(ParallelType::Unroll);
      tv4->axis(2)->parallelize(ParallelType::TIDx);
    
      tv3->axis(1)->parallelize(ParallelType::Unroll);
      tv3->axis(2)->parallelize(ParallelType::TIDx);
    
      tv2->axis(1)->parallelize(ParallelType::Unroll);
      tv2->axis(2)->parallelize(ParallelType::TIDx);
    
      auto all_tvs = fusion.allTvs();
    
      IdModelTester tester(&fusion);
    
      // Check Step 1 results
      for (auto tv : all_tvs) {
        // Skip tensors with no broadcast or non-inlined
        if (std::none_of(
                tv->getLogicalDomain().begin(),
                tv->getLogicalDomain().end(),
                [](auto id) { return id->isBroadcast(); }) ||
            tv->getComputeAtPosition() == 0) {
          continue;
        }
    
        switch (tv->name()) {
          case 3:
            // T3_l[ iS30{( ceilDiv(( ceilDiv(( ( ( 1 * i0 ) * i2 ) * i3 ), 128) ),
            // 4) )}, iUR31{4}, ithreadIdx.x29{128} ] ca_pos( 1 ) produce_pos( 1 )
            //  root domain : (bS10{1}, iS11{i0}, iS12{i2}, iS13{i3})
            validateIELResolution(
                tv->getLogicalDomain().at(0),
                tv4->getLogicalDomain().at(0),
                tester,
                tester.s1_logical_resolution_map);
            break;
          default:
            FAIL() << "Unexpected tensor: " << tv->toString();
        }
      }
    
      // Check Step 2 results
      checkStep2Results(&fusion, tester);
    
      auto id19 = getParentId(tv4->axis(0), 3);
      ASSERT_EQ(id19->name(), 19);
      auto id20 = getParentId(tv4->axis(0), 2);
      ASSERT_EQ(id20->name(), 20);
      auto id21 = getChildIdByName(id20, 21);
      auto id22 = getChildIdByName(id20, 22);
      auto id23 = getChildIdByName(id21, 23);
      auto id24 = getChildIdByName(id21, 24);
      auto id38 = getChildIdByName(id20, 38);
      auto id39 = getChildIdByName(id20, 39);
      auto id40 = getChildIdByName(id38, 40);
      auto id41 = getChildIdByName(id38, 41);
    
      // Check Step 3 results. See the design doc for the expected results
      std::vector<std::pair<std::unordered_set<Val*>, IterDomain*>>
          s3_reference_map = {
              // 7, 10, 11, 25, 14, 15, 18 -> 18
              {std::unordered_set<Val*>{
                   tv2->getLogicalDomain().at(0),
                   tv3->getLogicalDomain().at(0),
                   tv3->getLogicalDomain().at(1),
                   getParentId(tv3->axis(0), 4),
                   tv4->getLogicalDomain().at(0),
                   tv4->getLogicalDomain().at(1),
                   getParentId(tv4->axis(0), 4)},
               getParentId(tv4->axis(0), 4)},
              // 8, 12, 16 -> 16
              {std::unordered_set<Val*>{
                   tv2->getLogicalDomain().at(1),
                   tv3->getLogicalDomain().at(2),
                   tv4->getLogicalDomain().at(2)},
               tv4->getLogicalDomain().at(2)},
              // 9, 13, 17 -> 17
              {std::unordered_set<Val*>{
                   tv2->getLogicalDomain().at(2),
                   tv3->getLogicalDomain().at(3),
                   tv4->getLogicalDomain().at(3)},
               tv4->getLogicalDomain().at(3)},
              // 32, 26, 19 -> 19
              {std::unordered_set<Val*>{
                   getParentId(tv2->axis(0), 3),
                   getParentId(tv3->axis(0), 3),
                   getParentId(tv4->axis(0), 3)},
               getParentId(tv4->axis(0), 3)},
              // 33, 27, 20 -> 20
              {std::unordered_set<Val*>{
                   getParentId(tv2->axis(0), 2),
                   getParentId(tv3->axis(0), 2),
                   getParentId(tv4->axis(0), 2)},
               getParentId(tv4->axis(0), 2)},
              // 21, 28, 34, 38 -> 21
              {std::unordered_set<Val*>{
                   getParentId(tv2->axis(0), 1),
                   getParentId(tv3->axis(0), 1),
                   id21,
                   id38},
               getParentId(tv4->axis(0), 1)},
              // 29, 39 -> 29
              {std::unordered_set<Val*>{tv3->axis(2), id39}, id39},
              // 31, 41 -> 41
              {std::unordered_set<Val*>{tv3->axis(1), id41}, id41},
              // 23, 30, 36, 40 -> 23
              {std::unordered_set<Val*>{tv2->axis(0), tv3->axis(0), id23, id40},
               id23},
          };
    
      checkStep3Results(tester, s3_reference_map);
    
      auto id42 = getChildIdByName(id20, 42);
      auto id43 = getChildIdByName(id20, 43);
      auto id48 = getChildIdByName(id42, 48);
      auto id49 = getChildIdByName(id42, 49);
    
      auto id44 = getChildIdByName(id20, 44);
      auto id45 = getChildIdByName(id20, 45);
      auto id50 = getChildIdByName(id44, 50);
      auto id51 = getChildIdByName(id44, 51);
    
      std::vector<std::pair<std::unordered_set<Val*>, IterDomain*>>
          s4_reference_map = {
              // 34 -> 42
              {std::unordered_set<Val*>{getParentId(tv2->axis(0), 1)}, id42},
              // 35 -> 43
              {std::unordered_set<Val*>{tv2->axis(2)}, id43},
              // 36 -> 48
              {std::unordered_set<Val*>{tv2->axis(0)}, id48},
              // 37 -> 49
              {std::unordered_set<Val*>{tv2->axis(1)}, id49},
              // 28 -> 44
              {std::unordered_set<Val*>{getParentId(tv3->axis(0), 1)}, id44},
              // 29 -> 45
              {std::unordered_set<Val*>{tv3->axis(2)}, id45},
              // 30 -> 50
              {std::unordered_set<Val*>{tv3->axis(0)}, id50},
              // 31 -> 51
              {std::unordered_set<Val*>{tv3->axis(1)}, id51}};
    
      checkStep4Results(tester, s4_reference_map);
    
      // Check Step 5 results. See the design doc for the expected results
      std::unordered_map<TensorView*, std::vector<IterDomain*>> s5_reference_map = {
          {tv2, {id23, id49, id43}},
          {tv3, {id23, id51, id45}},
          {tv4, {id23, id24, id22}},
      };
    
      checkStep5Results(tester, s5_reference_map);
    }
    
    // Test root resolution with the same fusion as Indexing19
    TEST_F(IdModelTest, LoopPromotion6) {
      auto fusion = createFusionWithMultipleResolutionPaths();
      FusionGuard fg(fusion.get());
      auto all_tvs = fusion->allTvs();
    
      IdModelTester tester(fusion.get());
    
      auto tv1 = getValByName(all_tvs, 1);
      auto tv2 = getValByName(all_tvs, 2);
      auto tv4 = getValByName(all_tvs, 4);
      auto tv5 = getValByName(all_tvs, 5);
      auto tv6 = getValByName(all_tvs, 6);
      auto tv8 = getValByName(all_tvs, 8);
      auto tv9 = getValByName(all_tvs, 9);
    
      // Check Step 1 results
      for (auto tv : all_tvs) {
        // Skip tensors with no broadcast or non-inlined
        if (std::none_of(
                tv->getLogicalDomain().begin(),
                tv->getLogicalDomain().end(),
                [](auto id) { return id->isBroadcast(); }) ||
            tv->getComputeAtPosition() == 0) {
          continue;
        }
    
        switch (tv->name()) {
          case 2:
            // T2_l[ iS49{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS50{3},
            // iS48{5} ] ca_pos( 1 ) produce_pos( 1 )
            //  root domain : (iS2{7}, bS3{1})
            // Resolution: Resolved by the immediate consumer (T4)
            validateIELResolution(
                tv->getLogicalDomain().at(1),
                tv4->getLogicalDomain().at(1),
                tester,
                tester.s1_logical_resolution_map);
            break;
          case 5:
            // T5_l[ iS39{( ceilDiv(( ceilDiv(( ( 7 * 11 ) * 1 ), 5) ), 3) )},
            // iS40{3}, iS38{5} ] produce_pos( 1 )
            //  root domain : (iS8{7}, iS9{11}, bS10{1})
            // Resolution: T5 is not inlined to the immediate consumer,
            // T10. Resolution is done with the other path from T1, such
            // as T8 or T9.
            validateIELResolution(
                tv->getLogicalDomain().at(2),
                tv9->getLogicalDomain().at(2),
                tester,
                tester.s1_logical_resolution_map);
            break;
          case 6:
            // T6_l[ iS64{( ceilDiv(( ceilDiv(( 7 * 1 ), 5) ), 3) )}, iS65{3},
            // iS63{5} ] ca_pos( 1 ) produce_pos( 1 )
            //  root domain : (iS11{7}, bS12{1})
            // Resolution: Resolved by the immediate consumer (T8)
            validateIELResolution(
                tv->getLogicalDomain().at(1),
                tv8->getLogicalDomain().at(1),
                tester,
                tester.s1_logical_resolution_map);
            break;
          case 9:
            // T9_l[ iS33{( ceilDiv(( ceilDiv(( ( 7 * 1 ) * 13 ), 5) ), 3) )},
            // iS34{3}, iS32{5} ] produce_pos( 1 )
            //  root domain : (iS17{7}, bS18{1}, iS19{13})
            // Resolution: T9 is not inlined to the immediate consumer,
            // T10. Resolution is done with the other path from T1, such
            // as T4 or T5
            validateIELResolution(
                tv->getLogicalDomain().at(1),
                tv5->getLogicalDomain().at(1),
                tester,
                tester.s1_logical_resolution_map);
            break;
          default:
            FAIL() << "Unexpected tensor: " << tv->toString();
        }
      }
    
      checkStep2Results(fusion.get(), tester);
    
      // 83 -> 89, 90
      // 89 -> 93, 94
      auto id83 = getChildIdByName(tv9->getLogicalDomain().at(2), 83);
      auto id89 = getChildIdByName(id83, 89);
      auto id90 = getChildIdByName(id83, 90);
      auto id93 = getChildIdByName(id89, 93);
      auto id94 = getChildIdByName(id89, 94);
    
      // 84 -> 91, 92
      // 91 -> 95, 96
      auto id84 = getChildIdByName(tv9->getLogicalDomain().at(2), 84);
      auto id91 = getChildIdByName(id84, 91);
      auto id92 = getChildIdByName(id84, 92);
      auto id95 = getChildIdByName(id91, 95);
      auto id96 = getChildIdByName(id91, 96);
    
      // 35 -> 79, 80
      // 79 -> 85, 86
      auto id35 = getChildIdByName(tv5->getLogicalDomain().at(0), 35);
      auto id79 = getChildIdByName(id35, 79);
      auto id80 = getChildIdByName(id35, 80);
      auto id85 = getChildIdByName(id79, 85);
      auto id86 = getChildIdByName(id79, 86);
    
      // 56 -> 81, 82
      // 81 -> 87, 88
      auto id56 = getChildIdByName(tv8->getLogicalDomain().at(0), 56);
      auto id81 = getChildIdByName(id56, 81);
      auto id82 = getChildIdByName(id56, 82);
      auto id87 = getChildIdByName(id81, 87);
      auto id88 = getChildIdByName(id81, 88);
    
      // Check Step 3 results. See the design doc for the expected results
      std::vector<std::pair<std::unordered_set<Val*>, IterDomain*>>
          s3_reference_map = {
              // 1 2 3 6 7 8 9 10 11 12 15 16 17 18 19 29 30 35 36 41 46 56 61
              // 83 84 -> 84
              {std::unordered_set<Val*>{
                   tv1->getLogicalDomain().at(0),
                   tv2->getLogicalDomain().at(0),
                   tv2->getLogicalDomain().at(1),
                   getChildId(tv2->getLogicalDomain().at(0), 1),
                   tv4->getLogicalDomain().at(0),
                   tv4->getLogicalDomain().at(1),
                   getChildId(tv4->getLogicalDomain().at(0), 1),
                   tv5->getLogicalDomain().at(0),
                   tv5->getLogicalDomain().at(1),
                   tv5->getLogicalDomain().at(2),
                   getChildId(tv5->getLogicalDomain().at(0), 1),
                   getChildId(tv5->getLogicalDomain().at(2), 1),
                   tv6->getLogicalDomain().at(0),
                   tv6->getLogicalDomain().at(1),
                   getChildId(tv6->getLogicalDomain().at(0), 1),
                   tv8->getLogicalDomain().at(0),
                   tv8->getLogicalDomain().at(1),
                   getChildId(tv8->getLogicalDomain().at(0), 1),
                   tv9->getLogicalDomain().at(0),
                   tv9->getLogicalDomain().at(1),
                   tv9->getLogicalDomain().at(2),
                   getChildId(tv9->getLogicalDomain().at(0), 1),
                   getChildId(tv9->getLogicalDomain().at(0), 2),
                   id83,
                   id84},
               id84},
              // 31 37 42 47 57 62 71 79 81 89 91 -> 91
              {std::unordered_set<Val*>{
                   getChildId(tv1->getLogicalDomain().at(0), 1),
                   getChildId(tv2->getLogicalDomain().at(0), 2),
                   getChildId(tv4->getLogicalDomain().at(0), 2),
                   getChildId(tv5->getLogicalDomain().at(0), 3),
                   getChildId(tv6->getLogicalDomain().at(0), 2),
                   getChildId(tv8->getLogicalDomain().at(0), 2),
                   getChildId(tv9->getLogicalDomain().at(0), 3),
                   id79,
                   id81,
                   id89,
                   id91},
               id91},
              // 33 39 44 49 59 64 73 85 87 93 95 -> 95
              {std::unordered_set<Val*>{
                   tv1->axis(0),
                   tv2->axis(0),
                   tv4->axis(0),
                   tv5->axis(0),
                   tv6->axis(0),
                   tv8->axis(0),
                   tv9->axis(0),
                   id85,
                   id87,
                   id93,
                   id95},
               id95},
              // 48 80 -> 80
              {std::unordered_set<Val*>{tv2->axis(2), id80}, id80},
              // 50 86 -> 86
              {std::unordered_set<Val*>{tv2->axis(1), id86}, id86},
              // 40 96 -> 96
              {std::unordered_set<Val*>{tv5->axis(1), id96}, id96},
              // 63 82 -> 82
              {std::unordered_set<Val*>{tv6->axis(2), id82}, id82},
              // 65 88 -> 88
              {std::unordered_set<Val*>{tv6->axis(1), id88}, id88},
              // 34 94 -> 94
              {std::unordered_set<Val*>{tv9->axis(1), id94}, id94},
              // 38 92 -> 92
              {std::unordered_set<Val*>{tv5->axis(2), id92}, id92},
              // 32 90 -> 90
              {std::unordered_set<Val*>{tv9->axis(2), id90}, id90},
          };
    
      checkStep3Results(tester, s3_reference_map);
    
      // For tv1
      auto id97 = getChildIdByName(id84, 97);
      auto id98 = getChildIdByName(id84, 98);
      auto id105 = getChildIdByName(id97, 105);
      auto id106 = getChildIdByName(id97, 106);
    
      // For tv2
      auto id99 = getChildIdByName(id84, 99);
      auto id100 = getChildIdByName(id84, 100);
      auto id109 = getChildIdByName(id99, 109);
      auto id110 = getChildIdByName(id99, 110);
    
      // For tv6
      auto id101 = getChildIdByName(id84, 101);
      auto id102 = getChildIdByName(id84, 102);
      auto id111 = getChildIdByName(id101, 111);
      auto id112 = getChildIdByName(id101, 112);
    
      // For tv4
      auto id107 = getChildIdByName(id84, 107);
      auto id108 = getChildIdByName(id84, 108);
      auto id119 = getChildIdByName(id107, 119);
      auto id120 = getChildIdByName(id107, 120);
    
      // For tv5
      auto id117 = getChildIdByName(id84, 117);
      auto id118 = getChildIdByName(id84, 118);
      auto id123 = getChildIdByName(id117, 123);
      auto id124 = getChildIdByName(id117, 124);
    
      // For tv8
      auto id103 = getChildIdByName(id84, 103);
      auto id104 = getChildIdByName(id84, 104);
      auto id115 = getChildIdByName(id103, 115);
      auto id116 = getChildIdByName(id103, 116);
    
      // For tv9
      auto id113 = getChildIdByName(id84, 113);
      auto id114 = getChildIdByName(id84, 114);
      auto id121 = getChildIdByName(id113, 121);
      auto id122 = getChildIdByName(id113, 122);
    
      std::vector<std::pair<std::unordered_set<Val*>, IterDomain*>>
          s4_reference_map = {
              // tv1: 71 -> 97
              {std::unordered_set<Val*>{getParentId(tv1->axis(0), 1)}, id97},
              // tv1: 72 -> 98
              {std::unordered_set<Val*>{tv1->axis(2)}, id98},
              // tv1: 73 -> 105
              {std::unordered_set<Val*>{tv1->axis(0)}, id105},
              // tv1: 74 -> 106
              {std::unordered_set<Val*>{tv1->axis(1)}, id106},
              // tv2: 47 -> 99
              {std::unordered_set<Val*>{getParentId(tv2->axis(0), 1)}, id99},
              // tv2: 48 -> 100
              {std::uno...

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 20, 2025

    !test --diff

    1 similar comment
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 20, 2025

    !test --diff

    @naoyam naoyam force-pushed the restrict_almost_exact_mapping branch from fc99e68 to e0e55e9 Compare February 20, 2025 23:24
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 20, 2025

    !test --diff

    @naoyam naoyam changed the title Do not include merging of two broadcast IDs in trivially mapped IDs Prevent self mapping in the AlmostExact graph Feb 21, 2025
    // Make sure there's no self mapping in the Exact graph as that
    // would invalidate lowering assumptions.
    if (!allow_self_mapping_) {
    assertNoSelfMapping(graph);
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Moved this check to buildExactGraph

    // had self mapping (see issue #3919). ALMOSTEXACT is used in
    // indexing, which assumes the graph has no self mapping.
    if (!allow_self_mapping_) {
    assertNoSelfMapping(almost_exact_graph);
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Added this check for AlmostExact

    @@ -116,7 +116,7 @@ class IdModel : public PolymorphicBase {
    const std::vector<Expr*>& exprs,
    const std::vector<TensorView*>& additional_tvs = {},
    bool build_graphs = false,
    bool allow_self_mapping = false,
    bool allow_self_mapping = true,
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Changed to not check self mapping by default since that restriction only matters for indexing.

    @naoyam naoyam marked this pull request as ready for review February 21, 2025 03:08
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 21, 2025

    !test --diff

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 21, 2025

    !test --diff

    @xwang233
    Copy link
    Collaborator

    !test --dev

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    2 participants