Skip to content

Commit e3e3ed7

Browse files
Support arm64 for CUDA feature (#1155)
* Support arm64 for CUDA image * Update install.sh * Update install.sh * Update note --------- Co-authored-by: Álvaro Rausell Guiard <[email protected]>
1 parent 805ce77 commit e3e3ed7

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

src/nvidia-cuda/install.sh

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,24 @@ export DEBIAN_FRONTEND=noninteractive
4444

4545
check_packages wget ca-certificates
4646

47+
# Determine system architecture and set NVIDIA repository URL accordingly
48+
ARCH=$(uname -m)
49+
case $ARCH in
50+
x86_64)
51+
NVIDIA_ARCH="x86_64"
52+
;;
53+
aarch64 | arm64)
54+
NVIDIA_ARCH="arm64"
55+
;;
56+
*)
57+
echo "Unsupported architecture: $ARCH"
58+
exit 1
59+
;;
60+
esac
61+
4762
# Add NVIDIA's package repository to apt so that we can download packages
48-
# Always use the ubuntu2004 repo because the other repos (e.g., debian11) are missing packages
4963
# Updating the repo to ubuntu2204 as ubuntu 20.04 is going out of support.
50-
NVIDIA_REPO_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64"
64+
NVIDIA_REPO_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$NVIDIA_ARCH"
5165
KEYRING_PACKAGE="cuda-keyring_1.0-1_all.deb"
5266
KEYRING_PACKAGE_URL="$NVIDIA_REPO_URL/$KEYRING_PACKAGE"
5367
KEYRING_PACKAGE_PATH="$(mktemp -d)"
@@ -62,6 +76,10 @@ nvtx_pkg="cuda-nvtx-${CUDA_VERSION/./-}"
6276
toolkit_pkg="cuda-toolkit-${CUDA_VERSION/./-}"
6377
if ! apt-cache show "$cuda_pkg"; then
6478
echo "The requested version of CUDA is not available: CUDA $CUDA_VERSION"
79+
if [ "$NVIDIA_ARCH" = "arm64" ]; then
80+
echo "Note: arm64 supports limited CUDA versions. Please check available versions:"
81+
echo "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/arm64"
82+
fi
6583
exit 1
6684
fi
6785

@@ -93,6 +111,9 @@ if [ "$INSTALL_CUDNN" = "true" ]; then
93111

94112
if ! apt-cache show "$cudnn_pkg_version"; then
95113
echo "The requested version of cuDNN is not available: cuDNN $CUDNN_VERSION for CUDA $CUDA_VERSION"
114+
if [ "$NVIDIA_ARCH" = "arm64" ]; then
115+
echo "Note: arm64 has limited cuDNN package availability"
116+
fi
96117
exit 1
97118
fi
98119

@@ -112,6 +133,9 @@ if [ "$INSTALL_CUDNNDEV" = "true" ]; then
112133
fi
113134
if ! apt-cache show "$cudnn_dev_pkg_version"; then
114135
echo "The requested version of cuDNN development package is not available: cuDNN $CUDNN_VERSION for CUDA $CUDA_VERSION"
136+
if [ "$NVIDIA_ARCH" = "arm64" ]; then
137+
echo "Note: arm64 has limited cuDNN development package availability"
138+
fi
115139
exit 1
116140
fi
117141

0 commit comments

Comments
 (0)