@@ -216,8 +216,6 @@ function set_cuda_version() {
216216 readonly CUDA_FULL_VERSION
217217}
218218
219- set_cuda_version
220-
221219function is_cuda12() ( set +x ; [[ " ${CUDA_VERSION%% .* } " == " 12" ]] ; )
222220function le_cuda12() ( set +x ; version_le " ${CUDA_VERSION%% .* } " " 12" ; )
223221function ge_cuda12() ( set +x ; version_ge " ${CUDA_VERSION%% .* } " " 12" ; )
@@ -273,39 +271,27 @@ function set_driver_version() {
273271 fi
274272}
275273
276- set_driver_version
277-
278- readonly MIN_ROCKY8_CUDNN8_VERSION=" 8.0.5.39"
279- readonly DEFAULT_CUDNN8_VERSION=" 8.3.1.22"
280- readonly DEFAULT_CUDNN9_VERSION=" 9.1.0.70"
281-
282- # Parameters for NVIDIA-provided cuDNN library
283- readonly DEFAULT_CUDNN_VERSION=${CUDNN_FOR_CUDA["${CUDA_VERSION}"]}
284- CUDNN_VERSION=$( get_metadata_attribute ' cudnn-version' " ${DEFAULT_CUDNN_VERSION} " )
285- function is_cudnn8() ( set +x ; [[ " ${CUDNN_VERSION%% .* } " == " 8" ]] ; )
286- function is_cudnn9() ( set +x ; [[ " ${CUDNN_VERSION%% .* } " == " 9" ]] ; )
287- # The minimum cuDNN version supported by rocky is ${MIN_ROCKY8_CUDNN8_VERSION}
288- if is_rocky && (version_lt " ${CUDNN_VERSION} " " ${MIN_ROCKY8_CUDNN8_VERSION} " ) ; then
289- CUDNN_VERSION=" ${MIN_ROCKY8_CUDNN8_VERSION} "
290- elif (ge_ubuntu20 || ge_debian12) && is_cudnn8 ; then
291- # cuDNN v8 is not distribution for ubuntu20+, debian12
292- CUDNN_VERSION=" ${DEFAULT_CUDNN9_VERSION} "
293- elif (le_ubuntu18 || le_debian11) && is_cudnn9 ; then
294- # cuDNN v9 is not distributed for ubuntu18, debian10, debian11 ; fall back to 8
295- CUDNN_VERSION=" 8.8.0.121"
296- fi
297- readonly CUDNN_VERSION
298-
299- readonly DEFAULT_NCCL_VERSION=${NCCL_FOR_CUDA["${CUDA_VERSION}"]}
300- readonly NCCL_VERSION=$( get_metadata_attribute ' nccl-version' ${DEFAULT_NCCL_VERSION} )
301-
302- # Parameters for NVIDIA-provided Debian GPU driver
303- readonly DEFAULT_USERSPACE_URL=" https://us.download.nvidia.com/XFree86/Linux-x86_64/${DRIVER_VERSION} /NVIDIA-Linux-x86_64-${DRIVER_VERSION} .run"
304-
305- readonly USERSPACE_URL=$( get_metadata_attribute ' gpu-driver-url' " ${DEFAULT_USERSPACE_URL} " )
274+ function set_cudnn_version() {
275+ readonly MIN_ROCKY8_CUDNN8_VERSION=" 8.0.5.39"
276+ readonly DEFAULT_CUDNN8_VERSION=" 8.3.1.22"
277+ readonly DEFAULT_CUDNN9_VERSION=" 9.1.0.70"
278+
279+ # Parameters for NVIDIA-provided cuDNN library
280+ readonly DEFAULT_CUDNN_VERSION=${CUDNN_FOR_CUDA["${CUDA_VERSION}"]}
281+ CUDNN_VERSION=$( get_metadata_attribute ' cudnn-version' " ${DEFAULT_CUDNN_VERSION} " )
282+ # The minimum cuDNN version supported by rocky is ${MIN_ROCKY8_CUDNN8_VERSION}
283+ if ( is_rocky && version_lt " ${CUDNN_VERSION} " " ${MIN_ROCKY8_CUDNN8_VERSION} " ) ; then
284+ CUDNN_VERSION=" ${MIN_ROCKY8_CUDNN8_VERSION} "
285+ elif (ge_ubuntu20 || ge_debian12) && is_cudnn8 ; then
286+ # cuDNN v8 is not distribution for ubuntu20+, debian12
287+ CUDNN_VERSION=" ${DEFAULT_CUDNN9_VERSION} "
288+ elif (le_ubuntu18 || le_debian11) && is_cudnn9 ; then
289+ # cuDNN v9 is not distributed for ubuntu18, debian10, debian11 ; fall back to 8
290+ CUDNN_VERSION=" 8.8.0.121"
291+ fi
292+ readonly CUDNN_VERSION
293+ }
306294
307- USERSPACE_FILENAME=" $( echo ${USERSPACE_URL} | perl -pe ' s{^.+/}{}' ) "
308- readonly USERSPACE_FILENAME
309295
310296# Short name for urls
311297if is_ubuntu22 ; then
@@ -330,15 +316,14 @@ else
330316 nccl_shortname=" ${shortname} "
331317fi
332318
333- # Parameters for NVIDIA-provided package repositories
334- readonly NVIDIA_BASE_DL_URL=' https://developer.download.nvidia.com/compute'
335- readonly NVIDIA_REPO_URL=" ${NVIDIA_BASE_DL_URL} /cuda/repos/${shortname} /x86_64"
319+ function set_nv_urls() {
320+ # Parameters for NVIDIA-provided package repositories
321+ readonly NVIDIA_BASE_DL_URL=' https://developer.download.nvidia.com/compute'
322+ readonly NVIDIA_REPO_URL=" ${NVIDIA_BASE_DL_URL} /cuda/repos/${shortname} /x86_64"
336323
337- # Parameters for NVIDIA-provided NCCL library
338- readonly DEFAULT_NCCL_REPO_URL=" ${NVIDIA_BASE_DL_URL} /machine-learning/repos/${nccl_shortname} /x86_64/nvidia-machine-learning-repo-${nccl_shortname} _1.0.0-1_amd64.deb"
339- NCCL_REPO_URL=$( get_metadata_attribute ' nccl-repo-url' " ${DEFAULT_NCCL_REPO_URL} " )
340- readonly NCCL_REPO_URL
341- readonly NCCL_REPO_KEY=" ${NVIDIA_BASE_DL_URL} /machine-learning/repos/${nccl_shortname} /x86_64/7fa2af80.pub" # 3bf863cc.pub
324+ # Parameter for NVIDIA-provided Rocky Linux GPU driver
325+ readonly NVIDIA_ROCKY_REPO_URL=" ${NVIDIA_REPO_URL} /cuda-${shortname} .repo"
326+ }
342327
343328function set_cuda_runfile_url() {
344329 local MAX_DRIVER_VERSION
@@ -436,11 +421,7 @@ function set_cuda_runfile_url() {
436421 fi
437422}
438423
439- set_cuda_runfile_url
440-
441- # Parameter for NVIDIA-provided Rocky Linux GPU driver
442- readonly NVIDIA_ROCKY_REPO_URL=" ${NVIDIA_REPO_URL} /cuda-${shortname} .repo"
443-
424+ function set_cudnn_tarball_url() {
444425CUDNN_TARBALL=" cudnn-${CUDA_VERSION} -linux-x64-v${CUDNN_VERSION} .tgz"
445426CUDNN_TARBALL_URL=" ${NVIDIA_BASE_DL_URL} /redist/cudnn/v${CUDNN_VERSION% .* } /${CUDNN_TARBALL} "
446427if ( version_ge " ${CUDNN_VERSION} " " 8.3.1.22" ); then
@@ -460,6 +441,7 @@ if ( version_ge "${CUDA_VERSION}" "12.0" ); then
460441fi
461442readonly CUDNN_TARBALL
462443readonly CUDNN_TARBALL_URL
444+ }
463445
464446# Whether to install NVIDIA-provided or OS-provided GPU driver
465447GPU_DRIVER_PROVIDER=$( get_metadata_attribute ' gpu-driver-provider' ' NVIDIA' )
@@ -610,6 +592,9 @@ function uninstall_local_cudnn8_repo() {
610592}
611593
612594function install_nvidia_nccl() {
595+ readonly DEFAULT_NCCL_VERSION=${NCCL_FOR_CUDA["${CUDA_VERSION}"]}
596+ readonly NCCL_VERSION=$( get_metadata_attribute ' nccl-version' ${DEFAULT_NCCL_VERSION} )
597+
613598 is_complete nccl && return
614599
615600 if is_cuda11 && is_debian12 ; then
@@ -1044,6 +1029,13 @@ function build_driver_from_packages() {
10441029}
10451030
10461031function install_nvidia_userspace_runfile() {
1032+ # Parameters for NVIDIA-provided Debian GPU driver
1033+ readonly DEFAULT_USERSPACE_URL=" https://us.download.nvidia.com/XFree86/Linux-x86_64/${DRIVER_VERSION} /NVIDIA-Linux-x86_64-${DRIVER_VERSION} .run"
1034+
1035+ readonly USERSPACE_URL=$( get_metadata_attribute ' gpu-driver-url' " ${DEFAULT_USERSPACE_URL} " )
1036+
1037+ USERSPACE_FILENAME=" $( echo ${USERSPACE_URL} | perl -pe ' s{^.+/}{}' ) "
1038+ readonly USERSPACE_FILENAME
10471039
10481040 # This .run file contains NV's OpenGL implementation as well as
10491041 # nvidia optimized implementations of the gtk+ 2,3 stack(s) not
@@ -1565,6 +1557,10 @@ function install_dependencies() {
15651557}
15661558
15671559function prepare_gpu_env(){
1560+ # set_support_matrix
1561+
1562+ set_cuda_version
1563+ set_driver_version
15681564
15691565 set +e
15701566 gpu_count=" $( grep -i PCI_ID=10DE /sys/bus/pci/devices/* /uevent | wc -l) "
@@ -1588,6 +1584,11 @@ function prepare_gpu_env(){
15881584
15891585 # determine whether we have nvidia-smi installed and working
15901586 nvsmi
1587+
1588+ set_nv_urls
1589+ set_cuda_runfile_url
1590+ set_cudnn_version
1591+ set_cudnn_tarball_url
15911592}
15921593
15931594# Hold all NVIDIA-related packages from upgrading unintenionally or services like unattended-upgrades
0 commit comments