Skip to content

IdModel assert during buildLoopGraph `nvfuser::LoopPromotionMapBuilder::findPromotionOfLoopGroup #5391

@jjsjann123

Description

@jjsjann123

Hit this error message

Error from segmentation group 1:  INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/id_model/loop_promotion.cpp:1170, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues.
Expected covered_it != exact_covered_ids.end() . No covered group info for idg{38}

With backtrace:

(gdb) bt
#0  nvfuser::nvfCheckFail (func=0xaaaaac1d58b0 "findPromotionOfLoopGroup",
    file=0xaaaaac1d4868 "/opt/pytorch/nvfuser/csrc/id_model/loop_promotion.cpp", line=1170,
    msg=" INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/id_model/loop_promotion.cpp:1170, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. \nExpected covered_it "...) at /opt/pytorch/nvfuser/csrc/exceptions.cpp:267
#1  0x0000aaaaab1e08b4 in nvfuser::nvfErrorFail (func=0xaaaaac1d58b0 "findPromotionOfLoopGroup",
    file=0xaaaaac1d4868 "/opt/pytorch/nvfuser/csrc/id_model/loop_promotion.cpp", line=1170,
    condMsg=0xaaaaac1d57f8 " INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/id_model/loop_promotion.cpp:1170, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. ",
    userMsg="Expected covered_it != exact_covered_ids.end() . No covered group info for idg{38}") at /opt/pytorch/nvfuser/csrc/exceptions.cpp:277
#2  0x0000aaaaab3e2b48 in nvfuser::LoopPromotionMapBuilder::findPromotionOfLoopGroup (this=0xfffcd5db97b0,
    loop_group=std::shared_ptr<nvfuser::VectorOfUniqueEntries<nvfuser::Val*, std::hash<nvfuser::Val*> >> (use count 4, weak count 0) = {...},
    iel_graph=..., iel_promotion_map=std::unordered_map with 6 elements = {...}, exact_covered_ids=std::unordered_map with 22 elements = {...},
    terminal_loop_ids=...) at /opt/pytorch/nvfuser/csrc/id_model/loop_promotion.cpp:1170
#3  0x0000aaaaab3e26e8 in nvfuser::LoopPromotionMapBuilder::projectIELPromotionToLoopGraph (this=0xfffcd5db97b0, iel_graph=...,
    iel_promotion_map=std::unordered_map with 6 elements = {...}, loop_graph=..., inlining_info=...)
    at /opt/pytorch/nvfuser/csrc/id_model/loop_promotion.cpp:1094
#4  0x0000aaaaab3dfc24 in nvfuser::LoopPromotionMapBuilder::build (this=0xfffcd5db97b0)
    at /opt/pytorch/nvfuser/csrc/id_model/loop_promotion.cpp:485
#5  0x0000aaaaab3e356c in nvfuser::LoopPromotionMapBuilder::get (id_model=..., inlining_info=..., callback=0x0,
    force_full_loop_promotion_analysis=false) at /opt/pytorch/nvfuser/csrc/id_model/loop_promotion.cpp:1274
#6  0x0000aaaaab39a6cc in nvfuser::IdModel::buildLoopGraph (this=0xfffc4c0227f0, force_full_loop_promotion_analysis=false)
    at /opt/pytorch/nvfuser/csrc/id_model/id_model.cpp:1011
#7  0x0000aaaaab39a8f0 in nvfuser::IdModel::buildAllGraphs (this=0xfffc4c0227f0) at /opt/pytorch/nvfuser/csrc/id_model/id_model.cpp:1059
#8  0x0000aaaaab395154 in nvfuser::IdModel::IdModel (this=0xfffc4c0227f0, fusion=0xfffc4c0514f0, build_graphs=true, allow_self_mapping=false,
    validate=false, loop_promotion_map_builder_callback=0x0) at /opt/pytorch/nvfuser/csrc/id_model/id_model.cpp:152
#9  0x0000aaaaaae720e8 in std::make_unique<nvfuser::IdModel, nvfuser::Fusion*&, bool, bool, bool> () at /usr/include/c++/13/bits/unique_ptr.h:1070
#10 0x0000aaaaaae67700 in nvfuser::GpuLower::analysis (this=0xfffc4c048180, fusion=0xfffc4c000e30)
    at /opt/pytorch/nvfuser/csrc/device_lower/lower2device.cpp:466
#11 0x0000aaaaaae66274 in nvfuser::GpuLower::GpuLower (this=0xfffc4c048180, fusion=0xfffc4c000e30, cparams=...)
    at /opt/pytorch/nvfuser/csrc/device_lower/lower2device.cpp:247
#12 0x0000aaaaab8dad9c in std::make_unique<nvfuser::GpuLower, nvfuser::Fusion*&, nvfuser::CompileParams&> ()
    at /usr/include/c++/13/bits/unique_ptr.h:1070
#13 0x0000aaaaab8ce1ac in nvfuser::CompiledKernel::CompiledKernel(nvfuser::Fusion*, nvfuser::CompileParams, c10::Device, nvfuser::SchedulerType, long, long, long, long, std::vector<std::function<void (nvfuser::GpuLower*)>, std::allocator<std::function<void (nvfuser::GpuLower*)> > > const&, std::vector<std::function<void (nvfuser::kir::Kernel*)>, std::allocator<std::function<void (nvfuser::kir::Kernel*)> > > const&) (this=0xfffc4c05b420,
    fusion=0xfffc4c000e30, compile_params=..., device=..., scheduler_type=nvfuser::SchedulerType::PointWise, fusion_id=0, concrete_id=1,
    runtime_id=0, group_id=1, pre_lowering_hooks=std::vector of length 0, capacity 0, post_lowering_hooks=std::vector of length 0, capacity 0)
    at /opt/pytorch/nvfuser/csrc/runtime/compiled_kernel.cpp:1248
#14 0x0000aaaaab914b0c in std::make_unique<nvfuser::CompiledKernel, nvfuser::Fusion*&, nvfuser::CompileParams&, c10::Device&, nvfuser::SchedulerType&, long&, long&, long&, long&, std::vector<std::function<void (nvfuser::GpuLower*)>, std::allocator<std::function<void (nvfuser::GpuLower*)> > >&, std::vector<std::function<void (nvfuser::kir::Kernel*)>, std::allocator<std::function<void (nvfuser::kir::Kernel*)> > >&>(nvfuser::Fusion*&, nvfuser::CompileParams&, c10::Device&, nvfuser::SchedulerType&, long&, long&, long&, long&, std::vector<std::function<void (nvfuser::GpuLower*)>, std::allocator<std::function<void (nvfuser::GpuLower*)> > >&, std::vector<std::function<void (nvfuser::kir::Kernel*)>, std::allocator<std::function<void (nvfuser::kir::Kernel*)> > >&) () at /usr/include/c++/13/bits/unique_ptr.h:1070
#15 0x0000aaaaab9036fc in nvfuser::KernelExecutor::compile (this=0xfffc4c01fc10, fusion=0xfffc4c000e30, args=..., launch_constraints=...,
    compile_params=..., scheduler_type=nvfuser::SchedulerType::PointWise) at /opt/pytorch/nvfuser/csrc/runtime/executor.cpp:241
#16 0x0000aaaaab9258e8 in nvfuser::ExecutorDispatch::compile (executor=0xfffc4c01fc10, fusion=0xfffc4c000e30, args=..., params=0xaaaab2c66e30)
    at /opt/pytorch/nvfuser/csrc/runtime/executor_dispatch.cpp:109
#17 0x0000aaaaab993eac in nvfuser::FusionKernelRuntime::compileKernel (this=0xaaaab17340f0, args=..., sg=0xaaaab2c67ef0)
    at /opt/pytorch/nvfuser/csrc/runtime/fusion_kernel_runtime.cpp:777
#18 0x0000aaaaab991b7c in operator() (__closure=0xfffc4c000b70) at /opt/pytorch/nvfuser/csrc/runtime/fusion_kernel_runtime.cpp:433
#19 0x0000aaaaab995598 in std::__invoke_impl<void, nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder)::<lambda()>&>(std::__invoke_other, struct {...} &) (__f=...) at /usr/include/c++/13/bits/invoke.h:61
#20 0x0000aaaaab9950e4 in std::__invoke_r<void, nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder)::<lambda()>&>(struct {...} &) (__fn=...) at /usr/include/c++/13/bits/invoke.h:111
#21 0x0000aaaaab994e50 in std::_Function_handler<void(), nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder)::<lambda()> >::_M_invoke(const std::_Any_data &) (__functor=...) at /usr/include/c++/13/bits/std_function.h:290

Some debug info

(gdb) f 2
#2  0x0000aaaaab3e2b48 in nvfuser::LoopPromotionMapBuilder::findPromotionOfLoopGroup (this=0xfffcd5db97b0, 
    loop_group=std::shared_ptr<nvfuser::VectorOfUniqueEntries<nvfuser::Val*, std::hash<nvfuser::Val*> >> (use count 4, weak count 0) = {...}, 
    iel_graph=..., iel_promotion_map=std::unordered_map with 6 elements = {...}, exact_covered_ids=std::unordered_map with 22 elements = {...}, 
    terminal_loop_ids=...) at /opt/pytorch/nvfuser/csrc/id_model/loop_promotion.cpp:1170
1170        NVF_ERROR(
(gdb) p exact_group->vector()
$1 = std::vector of length 1, capacity 1 = {0xfffc4c069470}
(gdb) p exact_group->vector()[0]->toString(0)
$2 = "?S38{2429}"

The fusion looks like

(gdb) f 11
#11 0x0000aaaaaae66274 in nvfuser::GpuLower::GpuLower (this=0xfffc4c048180, fusion=0xfffc4c000e30, cparams=...)
    at /opt/pytorch/nvfuser/csrc/device_lower/lower2device.cpp:247
247       analysis(fusion);
(gdb) p fusion->printMath(1)
Inputs:
  T1_g___bfloat[iS147{( ceilDiv(40, blockDim.x) )}, iS149{2048}, iS150{1}, iS146{8}, iS148{blockDim.x}]
  T0_g___bfloat[iS117{( ceilDiv(640, blockDim.x) )}, iS119{2048}, iS120{1}, iS116{8}, iS118{blockDim.x}]
  T2_g_int[iS5{3}]
  T3_g_int[iS6{3}]
Outputs:
  T14_g___e4m3[iblockIdx.x159{( ceilDiv(40, blockDim.x) )}, iblockIdx.y161{2048}, iUS162{1}, iS158{8}, ithreadIdx.x160{blockDim.x}] ca_pos( 5 )
  T15_g___e2m1[iblockIdx.x61{( ceilDiv(640, blockDim.x) )}, iblockIdx.y63{2048}, iUS64{1}, iV60{8}, ithreadIdx.x62{blockDim.x}] ca_pos( 3 ) produce_
pos( 3 )

%kernel_math {
T16_l___bfloat[iblockIdx.x141{( ceilDiv(40, blockDim.x) )}, iblockIdx.y143{2048}, iUS144{1}, iS140{8}, ithreadIdx.x142{blockDim.x}] ca_pos( 3 )
   = Set( T1_g___bfloat[iS147{( ceilDiv(40, blockDim.x) )}, iS149{2048}, iS150{1}, iS146{8}, iS148{blockDim.x}], cache_op=AllLevels )
T5_l_float[iblockIdx.x135{( ceilDiv(40, blockDim.x) )}, iblockIdx.y137{2048}, iUS138{1}, iS134{8}, ithreadIdx.x136{blockDim.x}] ca_pos( 5 ) produce_
pos( 3 )
   = __bfloat2float(T16_l___bfloat[iblockIdx.x141{( ceilDiv(40, blockDim.x) )}, iblockIdx.y143{2048}, iUS144{1}, iS140{8}, ithreadIdx.x142{blockDim.
x}] ca_pos( 3 ));
T6_l_float[iblockIdx.x129{( ceilDiv(40, blockDim.x) )}, iblockIdx.y131{2048}, iUS132{1}, iS128{8}, ithreadIdx.x130{blockDim.x}] ca_pos( 5 ) produce_
pos( 5 )
   = T5_l_float[iblockIdx.x135{( ceilDiv(40, blockDim.x) )}, iblockIdx.y137{2048}, iUS138{1}, iS134{8}, ithreadIdx.x136{blockDim.x}] ca_pos( 5 ) pro
duce_pos( 3 )
   / double(6);
T7_l_float[iblockIdx.x123{( ceilDiv(40, blockDim.x) )}, iblockIdx.y125{2048}, iUS126{1}, iS122{8}, ithreadIdx.x124{blockDim.x}] ca_pos( 5 ) produce_
pos( 5 )
   = clamp(T6_l_float[iblockIdx.x129{( ceilDiv(40, blockDim.x) )}, iblockIdx.y131{2048}, iUS132{1}, iS128{8}, ithreadIdx.x130{blockDim.x}] ca_pos( 5
 ) produce_pos( 5 )
  , double(0.015625)
  , double(448));
T13_l___e4m3[iblockIdx.x153{( ceilDiv(40, blockDim.x) )}, iblockIdx.y155{2048}, iUS156{1}, iS152{8}, ithreadIdx.x154{blockDim.x}] produce_pos( 5 )
   = __float2e4m3(T7_l_float[iblockIdx.x123{( ceilDiv(40, blockDim.x) )}, iblockIdx.y125{2048}, iUS126{1}, iS122{8}, ithreadIdx.x124{blockDim.x}] ca
_pos( 5 ) produce_pos( 5 ));
T14_g___e4m3[iblockIdx.x159{( ceilDiv(40, blockDim.x) )}, iblockIdx.y161{2048}, iUS162{1}, iS158{8}, ithreadIdx.x160{blockDim.x}] ca_pos( 5 )
   = preprocessGroupedMatmulInputSf(
    input = T13_l___e4m3[iblockIdx.x153{( ceilDiv(40, blockDim.x) )}, iblockIdx.y155{2048}, iUS156{1}, iS152{8}, ithreadIdx.x154{blockDim.x}] produc
e_pos( 5 ),
    input_offsets = T2_g_int[iS5{3}],
    output_offsets = T3_g_int[iS6{3}],
    layout = Block128x4
  )
T17_l___bfloat[iblockIdx.x110{( ceilDiv(640, blockDim.x) )}, iblockIdx.y112{2048}, iUS113{1}, iV109{8}, ithreadIdx.x111{blockDim.x}] ca_pos( 3 )
   = Set( T0_g___bfloat[iS117{( ceilDiv(640, blockDim.x) )}, iS119{2048}, iS120{1}, iS116{8}, iS118{blockDim.x}], cache_op=Streaming )
T4_l_float[iblockIdx.x103{( ceilDiv(640, blockDim.x) )}, iblockIdx.y105{2048}, iUS106{1}, iS102{8}, ithreadIdx.x104{blockDim.x}] ca_pos( 5 ) produce
_pos( 3 )
   = __bfloat2float(T17_l___bfloat[iblockIdx.x110{( ceilDiv(640, blockDim.x) )}, iblockIdx.y112{2048}, iUS113{1}, iV109{8}, ithreadIdx.x111{blockDim
.x}] ca_pos( 3 ));
T8_l_float[iblockIdx.x96{( ceilDiv(40, blockDim.x) )}, iblockIdx.y98{2048}, iUS99{1}, iS95{8}, ithreadIdx.x97{blockDim.x}] ca_pos( 5 ) produce_pos( 
5 )
   = broadcast( T7_l_float[iblockIdx.x123{( ceilDiv(40, blockDim.x) )}, iblockIdx.y125{2048}, iUS126{1}, iS122{8}, ithreadIdx.x124{blockDim.x}] ca_p
os( 5 ) produce_pos( 5 ), flags = {false, false, true} )
T9_l_float[iblockIdx.x89{( ceilDiv(640, blockDim.x) )}, iblockIdx.y91{2048}, iUS92{1}, iS88{8}, ithreadIdx.x90{blockDim.x}] ca_pos( 5 ) produce_pos(
 5 )
   = T4_l_float[iblockIdx.x103{( ceilDiv(640, blockDim.x) )}, iblockIdx.y105{2048}, iUS106{1}, iS102{8}, ithreadIdx.x104{blockDim.x}] ca_pos( 5 ) pr
oduce_pos( 3 )
   / T8_l_float[iblockIdx.x96{( ceilDiv(40, blockDim.x) )}, iblockIdx.y98{2048}, iUS99{1}, iS95{8}, ithreadIdx.x97{blockDim.x}] ca_pos( 5 ) produce_
pos( 5 );
T10_l_float[iblockIdx.x82{( ceilDiv(640, blockDim.x) )}, iblockIdx.y84{2048}, iUS85{1}, iS81{8}, ithreadIdx.x83{blockDim.x}] ca_pos( 3 ) produce_pos
( 5 )
   = clamp(T9_l_float[iblockIdx.x89{( ceilDiv(640, blockDim.x) )}, iblockIdx.y91{2048}, iUS92{1}, iS88{8}, ithreadIdx.x90{blockDim.x}] ca_pos( 5 ) p
roduce_pos( 5 )
  , double(-448)
  , double(448));
T11_l___e2m1[iblockIdx.x75{( ceilDiv(640, blockDim.x) )}, iblockIdx.y77{2048}, iUS78{1}, iV74{8}, ithreadIdx.x76{blockDim.x}] ca_pos( 3 ) produce_po
s( 3 )
   = __float2e2m1(T10_l_float[iblockIdx.x82{( ceilDiv(640, blockDim.x) )}, iblockIdx.y84{2048}, iUS85{1}, iS81{8}, ithreadIdx.x83{blockDim.x}] ca_po
s( 3 ) produce_pos( 5 ));
T18_l___e2m1[iblockIdx.x68{( ceilDiv(640, blockDim.x) )}, iblockIdx.y70{2048}, iUS71{1}, iS67{8}, ithreadIdx.x69{blockDim.x}] ca_pos( 3 ) produce_po
s( 3 )
   = SegmenterSet( T11_l___e2m1[iblockIdx.x75{( ceilDiv(640, blockDim.x) )}, iblockIdx.y77{2048}, iUS78{1}, iV74{8}, ithreadIdx.x76{blockDim.x}] ca_
pos( 3 ) produce_pos( 3 ) )
T15_g___e2m1[iblockIdx.x61{( ceilDiv(640, blockDim.x) )}, iblockIdx.y63{2048}, iUS64{1}, iV60{8}, ithreadIdx.x62{blockDim.x}] ca_pos( 3 ) produce_po
s( 3 )
   = Set( T18_l___e2m1[iblockIdx.x68{( ceilDiv(640, blockDim.x) )}, iblockIdx.y70{2048}, iUS71{1}, iS67{8}, ithreadIdx.x69{blockDim.x}] ca_pos( 3 ) 
produce_pos( 3 ), cache_op=Streaming )
} // %kernel_math

$3 = void

The ID causing the issue is coming from the layout op output

T14_g___e4m3[iblockIdx.x159{( ceilDiv(40, blockDim.x) )}, iblockIdx.y161{2048}, iUS162{1}, iS158{8}, ithreadIdx.x160{blockDim.x}] ca_pos( 5 )
 logical domain : (iS36{2048}, iS37{320})
 allocation domain : (?S38{2429}, ?S39{( ( 323 / 4 ) * 4 )})
 contiguity: t t
  Split: iS37{320} by factor 8 -> iS157{40}, iS158{8}
  Split: iS157{40} by factor blockDim.x -> iblockIdx.x159{( ceilDiv(40, blockDim.x) )}, ithreadIdx.x160{blockDim.x}
  Split: iS36{2048} by factor 1 -> iblockIdx.y161{2048}, iUS162{1}
 loop domain : (iblockIdx.x159{( ceilDiv(40, blockDim.x) )}, iblockIdx.y161{2048}, iUS162{1}, iS158{8}, ithreadIdx.x160{blockDim.x})

The repro can be found here

For the convenience, here's the c++ fusion..


  auto inp = makeContigConcreteTensor({2048, 320, 16}, DataType::BFloat16);
  auto sf = makeContigConcreteTensor({2048, 320}, DataType::BFloat16);
  // FIXME: this should be 128. i.e. number of groups, fix Masaki's scripts
  // later.
  auto in_offset = makeContigConcreteTensor({3}, DataType::Int32);
  auto out_offset = makeContigConcreteTensor({3}, DataType::Int32);

  fusion.addInput(inp);
  fusion.addInput(sf);
  fusion.addInput(in_offset);
  fusion.addInput(out_offset);

  auto max_fp4 = IrBuilder::create<Val>(6, DataType::Double);
  auto max_fp8 = IrBuilder::create<Val>(448, DataType::Double);
  auto min_fp8 = IrBuilder::create<Val>(-448, DataType::Double);
  auto eps = IrBuilder::create<Val>(0.015625, DataType::Double);

  auto T81 = castOp(DataType::Float, inp);
  auto T77 = castOp(DataType::Float, sf);
  auto T78 = div(T77, max_fp4);
  auto T79 = clamp(T78, eps, max_fp8);
  auto T80 = broadcast(T79, {false, false, true});
  auto T82 = div(T81, T80);
  auto T83 = clamp(T82, min_fp8, max_fp8);
  auto T146 = castOp(DataType::Float4_e2m1fn, T83);
  auto T155 = reshape(T146, {2048, 320, 16}, {2048, 320 * 16});

  auto T86 = castOp(DataType::Float8_e4m3fn, T79);
  auto T87 = preprocessGroupedMatmulInputSf(
      T86, in_offset, out_offset, BlockScalingFactorLayout::Block128x4);

  fusion.addOutput(T155);
  fusion.addOutput(T87);

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions