Skip to content

Bug: cuEquivariance-jax not support the B200 / B300 | cuda12 #209

@guyujun

Description

@guyujun

The triangle_multiplicative_update triton kernel can not work with haiku & jax in B200/B300 NVIDIA-GPU (but H200,H100,H800,5090,4090,PRO6000 works!)

my test code:

import jax
from cuequivariance_jax import triangle_multiplicative_update
import haiku as hk
import jax.random as jrandom

class TritonNetwork(hk.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
    
    def triangle_multiplicative_update_block(self, pair_act):
        assert len(pair_act.shape) == 4
        key = jax.random.key(0)
        
        pair_act = triangle_multiplicative_update(
            x=pair_act, 
            direction='outgoing',   # 'outgoing' or 'incoming'
            key=key
        )
        return pair_act

    def __call__(self, batch):
        act = self.triangle_multiplicative_update_block(batch['feat'])
        return act


def forward_triton(batch):
    network = TritonNetwork()
    return network(batch)

forward = hk.transform(forward_triton)
L = 256
feat_shape = (1, L, L, 128)

key = jrandom.PRNGKey(0)
feat_batch = {'feat': jrandom.normal(key, feat_shape)}
print('step1:')
params = forward.init(key, feat_batch)

print('step2:')
for i in range(10):
    key, apply_key = jax.random.split(key, 2)
    output = forward.apply(params, apply_key, feat_batch)
    # print(output)
    print(output.shape)

the code will die and stuck randomly.

However using fallback mode to Jax-based triangle_multiplicative_update works.

by modifying the fallback options in cuequivariance_jax/triangle/triangle_multiplicative_update.py /triangle_multiplicative_update

# Gated dual gemm
    ab = sigmoid_gated_dual_gemm(
        x,
        g_in_weight,
        p_in_weight,
        b1=g_in_bias,
        b2=p_in_bias,
        mask=mask,
        transpose_out=True,
        precision=precision,
        fallback=**False**,   #<-  this line to False to use jax based code.
    )
a, b = jnp.split(ab, 2, axis=0)

so I think there are something in the sigmoid_gated_dual_gemm kernel not work with B200/B300 Nvidia-GPU

but I can't giving more details, because I haven't the source code of the kernel.


my envs:

python3.11
cuequivariance-jax                      0.7.0rc2        pypi_0          pypi
cuequivariance-ops-jax-cu12     0.7.0            pypi_0           pypi
jax                                                 0.6.0            pypi_0           pypi
jax-cuda12-pjrt                            0.6.0            pypi_0           pypi
jax-cuda12-plugin                       0.6.0            pypi_0           pypi
jax-triton                                      0.3.0            pypi_0           pypi
jaxlib                                             0.6.0            pypi_0           pypi
jaxtyping                                      0.2.34           pypi_0          pypi
nvidia-cublas-cu12                    12.9.1.4         pypi_0           pypi
nvidia-cuda-cupti-cu12            12.9.79          pypi_0           pypi
nvidia-cuda-nvcc-cu12            12.9.86          pypi_0           pypi
nvidia-cuda-nvrtc-cu12           12.9.86          pypi_0           pypi
nvidia-cuda-runtime-cu12       12.9.79          pypi_0           pypi
nvidia-cudnn-cu12                   9.15.0.57       pypi_0           pypi
nvidia-cufft-cu12                     11.4.1.4         pypi_0           pypi
nvidia-cusolver-cu12               11.7.5.82        pypi_0           pypi
nvidia-cusparse-cu12             12.5.10.65      pypi_0           pypi
nvidia-ml-py                             13.580.82        pypi_0           pypi
nvidia-nccl-cu12                      2.28.7           pypi_0           pypi
nvidia-nvjitlink-cu12                12.9.86          pypi_0           pypi
nvidia-nvshmem-cu12             3.4.5            pypi_0           pypi
haiku=0.0.15

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions