Skip to content

TorchInductor stride mismatch when compiling cuequivariance custom op #223

@danielkovtun

Description

@danielkovtun

Describe the bug
When enabling torch.compile on a model that uses cuequivariance triangle attention ops, TorchInductor generates code that asserts an incorrect stride layout for a tensor at runtime, causing a hard crash during the first compiled forward pass.

The failure occurs after Dynamo tracing succeeds and Inductor code is generated, during execution of the compiled model. The error indicates that Inductor expects a non-contiguous stride layout for the output of a custom CUDA op (cuequivariance.triangle_attention_mask), but the runtime tensor is actually contiguous.

This suggests an incorrect fake / meta kernel definition for the custom op, leading Inductor to specialize on invalid stride metadata.

To Reproduce
Steps to reproduce the behavior:

  1. Remove @torch.compiler.disable from e.g. triangle attention and triangle multiplication
  2. Run the training script, with compilation of e.g. PairformerModule enabled (model.compile_pairformer=true)
  3. Dynamo config used:
compiled_autograd=True
optimize_ddp=True
capture_scalar_outputs=True
  1. Observe crash during first compiled execution

Observed error:

AssertionError: expected size 4==4, stride 12288==32 at dim=2;
expected size 384==384, stride 32==128 at dim=3

This error most often comes from a incorrect fake (aka meta) kernel for a custom op.
Use torch.library.opcheck to test your custom op.

The crash occurs in TorchInductor-generated code:

assert_size_stride(buf63, (1, 384, 4, 384, 32), (18874368, 49152, 32, 128, 1))

Expected behavior
torch.compile should either:

  • Correctly infer and specialize on the actual runtime strides of the tensor returned by cuequivariance.triangle_attention_mask, or
  • Reject the custom op during tracing with a clear error if its meta kernel does not faithfully represent runtime behavior.

In particular, Inductor should not hard-assert on incorrect stride metadata that contradicts the real tensor layout.

GPU HW/SW:

  • GPU: NVIDIA H100 80GB HBM3
  • Driver version: 580.95.05
  • nvidia-smi CUDA version: 13.0 (driver capability)
  • CUDA toolkit version: 12.6, from pytorch/pytorch:2.7.0-cuda12.6-cudnn9-runtime
  • torch or ngc docker version, 2.7.0+cu126 (pytorch/pytorch:2.7.0-cuda12.6-cudnn9-runtime)
  • triton version: 3.3.0
  • cuequivariance: 0.8.0
    (Driver supports CUDA ≥13; runtime/toolkit inside container is CUDA 12.6. note, this was still observed prior to 13.0)

Additional context
For a tensor of shape:

(1, 384, 4, 384, 32)

The actual contiguous strides at runtime are:

(18874368, 49152, 12288, 32, 1)

This matches standard contiguous layout.

Inductor expectation (incorrect) instead assumes:

(18874368, 49152, 32, 128, 1)

This implies:

  • dim=2 stride = 32 (should be 12288)
  • dim=3 stride = 128 (should be 32)

This stride pattern is neither contiguous nor a valid permutation of the tensor’s dimensions.

From Inductor output:

buf62 = torch.ops.cuequivariance.triangle_attention_mask.default(
    reinterpret_tensor(buf50, (1, 384, 4, 384, 32), (18874368, 49152, 32, 128, 1), 0),
    ...
)
buf63 = buf62
assert_size_stride(buf63, (1, 384, 4, 384, 32), (18874368, 49152, 32, 128, 1))

From the post-grad graph:

triangle_attention_mask:
"bf16[1, 384, 4, 384, 32][18874368, 49152, 32, 128, 1]cuda:0"

This strongly suggests that the fake / meta kernel for triangle_attention_mask advertises incorrect stride metadata, which Inductor then trusts. note: this occurs under both bf16-mixed precision as well as fp32.

  • Dynamo tracing and Inductor code generation both succeed; failure occurs only at runtime.
  • Error message explicitly points to an incorrect fake/meta kernel, consistent with observed behavior.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions