Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions .github/workflows/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,26 @@ jobs:
python -m pip install --upgrade pip
python -m pip install --upgrade uv
python -m uv pip install -U pytest "jax[cuda12]"
python -m uv pip install nvidia-cusolver-cu12==11.7.3.90
python -m uv pip install nvidia-cublas-cu12
python -m uv pip install jax-triton triton==3.3.1
# python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch
python -m uv pip install cuequivariance-ops-cu12 cuequivariance-ops-jax-cu12

# Add NVIDIA CUDA libraries to LD_LIBRARY_PATH
SITE_PACKAGES=$(python -c "import site; print(' '.join(site.getsitepackages()))")
CUDA_LIB_DIRS=$(find $SITE_PACKAGES -path "*/nvidia/*/lib" -type d 2>/dev/null | tr '\n' ':')
export LD_LIBRARY_PATH="$CUDA_LIB_DIRS$LD_LIBRARY_PATH"

python -c "import cuequivariance_ops; print('cueop', cuequivariance_ops.__version__)"
python -c "import cuequivariance_ops_jax; print('cueopx', cuequivariance_ops_jax.__version__)"

python -m uv pip install -e ./cuequivariance
python -m uv pip install -e ./cuequivariance_jax

# python -c "import cuequivariance; print('cue', cuequivariance.__version__)"
# python -c "import cuequivariance_jax; print('cuex', cuequivariance_jax.__version__)"
python -c "import cuequivariance; print('cue', cuequivariance.__version__)"
python -c "import cuequivariance_jax; print('cuex', cuequivariance_jax.__version__)"


- name: Test with pytest
run: |
# XLA_PYTHON_CLIENT_PREALLOCATE=false pytest --doctest-modules -x -m "not slow" cuequivariance_jax
echo "skipping tests"
# Set up CUDA library path for tests
SITE_PACKAGES=$(python -c "import site; print(' '.join(site.getsitepackages()))")
CUDA_LIB_DIRS=$(find $SITE_PACKAGES -path "*/nvidia/*/lib" -type d 2>/dev/null | tr '\n' ':')
export LD_LIBRARY_PATH="$CUDA_LIB_DIRS$LD_LIBRARY_PATH"
XLA_PYTHON_CLIENT_PREALLOCATE=false pytest --doctest-modules -x -m "not slow" cuequivariance_jax
Loading