Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduction scheduler fails to recognize iter domains not captured by reference #3811

Open
naoyam opened this issue Feb 3, 2025 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@naoyam
Copy link
Collaborator

naoyam commented Feb 3, 2025

This fusion is scheduled by the reduction scheduler without segmentation, which should not be.

TEST_F(NVFuserTest, ReductionSchedulerWithAdditionalID) {
  auto fusion_ptr = std::make_unique<Fusion>();
  auto& fusion = *fusion_ptr;
  FusionGuard fg(fusion_ptr.get());

  auto tv0 = makeContigConcreteTensor({1, -1});
  fusion.addInput(tv0);
  auto tv1 = makeContigTensor(2);
  fusion.addInput(tv1);

  auto tv2 = sum(tv0, {0, 1});
  fusion.addOutput(tv2);
  auto tv3 = add(tv0, tv1);
  fusion.addOutput(tv3);

  fusion.printMath();

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  auto t0 = at::randn({1, 100}, options);
  auto t1 = at::randn({5, 100}, options);
  std::vector<c10::IValue> inputs({t0, t1});

  FusionExecutorCache executor_cache(std::move(fusion_ptr));
  auto outputs = executor_cache.runFusionWithInputs(inputs);
  testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}
Inputs:
  T0_g_float[bS0{1}, iS1{i1}]
  T1_g_float[iS2{i4}, iS3{i5}]
Outputs:
  T3_g_float[rS5{i1}]
  T4_g_float[iS6{i4}, iS7{i1}]

%kernel_math {
T2_l_float[iS4{i1}]
   = squeeze( T0_g_float[bS0{1}, iS1{i1}] )
T3_g_float[rS5{i1}]
   = reduction( T2_l_float[iS4{i1}], op = add, initial value = float(0), allreduce = false )
T4_g_float[iS6{i4}, iS7{i1}]
   = T0_g_float[bS0{1}, iS1{i1}]
   + T1_g_float[iS2{i4}, iS3{i5}];
} // %kernel_math

The reduction tensor, T3, is chosen as the reference for the reduction part, but it doesn't have any ID that is connected with iS6 or iS2, so it cannot be used as the reference for the whole fusion, yet the scheduler accepted it without segmentation, yielding:

Inputs:
  T0_g_float[iS49{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, iS48{blockDim.x}, iS50{1}, iS46{2}, bS0{1}]
  T1_g_float[iS67{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, iS66{blockDim.x}, iS68{1}, iS64{2}, iS2{i4}]
Outputs:
  T3_g_float[]
  T4_g_float[iS73{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, ithreadIdx.x72{blockDim.x}_p, iUS74{1}, iV70{2}, iS17{i4}] ca_pos( 3 ) produce_pos( 3 )

%kernel_math {
T5_l_float[iS43{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, ithreadIdx.x42{blockDim.x}_p, iUS44{1}, iV40{2}, bS13{1}] ca_pos( 3 )
   = Set( T0_g_float[iS49{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, iS48{blockDim.x}, iS50{1}, iS46{2}, bS0{1}], cache_op=Streaming )
T2_l_float[iS37{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, ithreadIdx.x36{blockDim.x}_p, iUS38{1}, iS34{2}] ca_pos( 4 ) produce_pos( 3 )
   = squeeze( T5_l_float[iS43{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, ithreadIdx.x42{blockDim.x}_p, iUS44{1}, iV40{2}, bS13{1}] ca_pos( 3 ) )
T9_l_float[rS30{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}rf, ithreadIdx.x29{blockDim.x}rf_p, rUS31{1}rf, rS27{2}rf] produce_pos( 4 )
   = reduction( T2_l_float[iS37{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, ithreadIdx.x36{blockDim.x}_p, iUS38{1}, iS34{2}] ca_pos( 4 ) produce_pos( 3 ), op = add, initial value = float(0), allreduce = false )
T7_l_float[rthreadIdx.x32{blockDim.x}_p]
   = reduction( T9_l_float[rS30{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}rf, ithreadIdx.x29{blockDim.x}rf_p, rUS31{1}rf, rS27{2}rf] produce_pos( 4 ), op = add, initial value = float(0), allreduce = false )
T3_g_float[]
   = Set( T7_l_float[rthreadIdx.x32{blockDim.x}_p], cache_op=Streaming )
T6_l_float[iS61{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, ithreadIdx.x60{blockDim.x}_p, iUS62{1}, iV58{2}, iS15{i4}] ca_pos( 3 )
   = Set( T1_g_float[iS67{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, iS66{blockDim.x}, iS68{1}, iS64{2}, iS2{i4}], cache_op=Streaming )
T8_l_float[iS55{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, ithreadIdx.x54{blockDim.x}_p, iUS56{1}, iS52{2}, iS6{i4}] ca_pos( 3 ) produce_pos( 3 )
   = T5_l_float[iS43{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, ithreadIdx.x42{blockDim.x}_p, iUS44{1}, iV40{2}, bS13{1}] ca_pos( 3 )
   + T6_l_float[iS61{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, ithreadIdx.x60{blockDim.x}_p, iUS62{1}, iV58{2}, iS15{i4}] ca_pos( 3 );
T4_g_float[iS73{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, ithreadIdx.x72{blockDim.x}_p, iUS74{1}, iV70{2}, iS17{i4}] ca_pos( 3 ) produce_pos( 3 )
   = Set( T8_l_float[iS55{( ceilDiv(( ceilDiv(i1, 2) ), blockDim.x) )}, ithreadIdx.x54{blockDim.x}_p, iUS56{1}, iS52{2}, iS6{i4}] ca_pos( 3 ) produce_pos( 3 ), cache_op=Streaming )
} // %kernel_math

As expected, because of the dangling IDs, T6 has an unscheduled ID, iS15, and because of that, its execution resutls in:

C++ exception with description " INTERNAL ASSERT FAILED at "/home/nmaruyama/nvfuser/debug2/csrc/runtime/compiled_kernel.cpp":1320, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Allocations must be based on constant integers for local memory. However, found: T6_l_float[iS61{( ceilDiv(( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 2) ), blockDim.x) )}, ithreadIdx.x60{blockDim.x}_p, iUS62{1}, iV58{2}, iS81{( (( (( getMetaData(T1) )).logical_size ))[0] )}] ca_pos( 3 ), T8_l_float[iS55{( ceilDiv(( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 2) ), blockDim.x) )}, ithreadIdx.x54{blockDim.x}_p, iUS56{1}, iS52{2}, iS83{( (( (( getMetaData(T1) )).logical_size ))[0] )}] ca_pos( 3 ) produce_pos( 3 ), T6_l_float[iS61{( ceilDiv(( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 2) ), blockDim.x) )}, ithreadIdx.x60{blockDim.x}_p, iUS62{1}, iV58{2}, iS81{( (( (( getMetaData(T1) )).logical_size ))[0] )}] ca_pos( 3 ), T8_l_float[iS55{( ceilDiv(( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 2) ), blockDim.x) )}, ithreadIdx.x54{blockDim.x}_p, iUS56{1}, iS52{2}, iS83{( (( (( getMetaData(T1) )).logical_size ))[0] )}] ca_pos( 3 ) produce_pos( 3 ),  have dynamic allocations but are placed in local memory.

The problem seems to be the canScheduleCompileTime of the reduction scheduler. It should detect such dangling IDs and reject fusions if any. It does have some related checks like hasPostReductionBCast, but that's not sufficient.

@jjsjann123 jjsjann123 self-assigned this Feb 3, 2025
@naoyam naoyam added the bug Something isn't working label Feb 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants