-
Notifications
You must be signed in to change notification settings - Fork 23
Open
Description
The triangle_multiplicative_update triton kernel can not work with haiku & jax in B200/B300 NVIDIA-GPU (but H200,H100,H800,5090,4090,PRO6000 works!)
my test code:
import jax
from cuequivariance_jax import triangle_multiplicative_update
import haiku as hk
import jax.random as jrandom
class TritonNetwork(hk.Module):
def __init__(self, name=None):
super().__init__(name=name)
def triangle_multiplicative_update_block(self, pair_act):
assert len(pair_act.shape) == 4
key = jax.random.key(0)
pair_act = triangle_multiplicative_update(
x=pair_act,
direction='outgoing', # 'outgoing' or 'incoming'
key=key
)
return pair_act
def __call__(self, batch):
act = self.triangle_multiplicative_update_block(batch['feat'])
return act
def forward_triton(batch):
network = TritonNetwork()
return network(batch)
forward = hk.transform(forward_triton)
L = 256
feat_shape = (1, L, L, 128)
key = jrandom.PRNGKey(0)
feat_batch = {'feat': jrandom.normal(key, feat_shape)}
print('step1:')
params = forward.init(key, feat_batch)
print('step2:')
for i in range(10):
key, apply_key = jax.random.split(key, 2)
output = forward.apply(params, apply_key, feat_batch)
# print(output)
print(output.shape)the code will die and stuck randomly.
However using fallback mode to Jax-based triangle_multiplicative_update works.
by modifying the fallback options in cuequivariance_jax/triangle/triangle_multiplicative_update.py /triangle_multiplicative_update
# Gated dual gemm
ab = sigmoid_gated_dual_gemm(
x,
g_in_weight,
p_in_weight,
b1=g_in_bias,
b2=p_in_bias,
mask=mask,
transpose_out=True,
precision=precision,
fallback=**False**, #<- this line to False to use jax based code.
)
a, b = jnp.split(ab, 2, axis=0)so I think there are something in the sigmoid_gated_dual_gemm kernel not work with B200/B300 Nvidia-GPU
but I can't giving more details, because I haven't the source code of the kernel.
my envs:
python3.11
cuequivariance-jax 0.7.0rc2 pypi_0 pypi
cuequivariance-ops-jax-cu12 0.7.0 pypi_0 pypi
jax 0.6.0 pypi_0 pypi
jax-cuda12-pjrt 0.6.0 pypi_0 pypi
jax-cuda12-plugin 0.6.0 pypi_0 pypi
jax-triton 0.3.0 pypi_0 pypi
jaxlib 0.6.0 pypi_0 pypi
jaxtyping 0.2.34 pypi_0 pypi
nvidia-cublas-cu12 12.9.1.4 pypi_0 pypi
nvidia-cuda-cupti-cu12 12.9.79 pypi_0 pypi
nvidia-cuda-nvcc-cu12 12.9.86 pypi_0 pypi
nvidia-cuda-nvrtc-cu12 12.9.86 pypi_0 pypi
nvidia-cuda-runtime-cu12 12.9.79 pypi_0 pypi
nvidia-cudnn-cu12 9.15.0.57 pypi_0 pypi
nvidia-cufft-cu12 11.4.1.4 pypi_0 pypi
nvidia-cusolver-cu12 11.7.5.82 pypi_0 pypi
nvidia-cusparse-cu12 12.5.10.65 pypi_0 pypi
nvidia-ml-py 13.580.82 pypi_0 pypi
nvidia-nccl-cu12 2.28.7 pypi_0 pypi
nvidia-nvjitlink-cu12 12.9.86 pypi_0 pypi
nvidia-nvshmem-cu12 3.4.5 pypi_0 pypi
haiku=0.0.15
Metadata
Metadata
Assignees
Labels
No labels