Skip to content

Commit 50142f6

Browse files
committed
slightly better variable declaration ordering ; it is better still in the templates/ directory from GoogleCloudDataproc#1282
1 parent 1d2166c commit 50142f6

File tree

1 file changed

+48
-47
lines changed

1 file changed

+48
-47
lines changed

gpu/install_gpu_driver.sh

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,6 @@ function set_cuda_version() {
216216
readonly CUDA_FULL_VERSION
217217
}
218218

219-
set_cuda_version
220-
221219
function is_cuda12() ( set +x ; [[ "${CUDA_VERSION%%.*}" == "12" ]] ; )
222220
function le_cuda12() ( set +x ; version_le "${CUDA_VERSION%%.*}" "12" ; )
223221
function ge_cuda12() ( set +x ; version_ge "${CUDA_VERSION%%.*}" "12" ; )
@@ -273,39 +271,27 @@ function set_driver_version() {
273271
fi
274272
}
275273

276-
set_driver_version
277-
278-
readonly MIN_ROCKY8_CUDNN8_VERSION="8.0.5.39"
279-
readonly DEFAULT_CUDNN8_VERSION="8.3.1.22"
280-
readonly DEFAULT_CUDNN9_VERSION="9.1.0.70"
281-
282-
# Parameters for NVIDIA-provided cuDNN library
283-
readonly DEFAULT_CUDNN_VERSION=${CUDNN_FOR_CUDA["${CUDA_VERSION}"]}
284-
CUDNN_VERSION=$(get_metadata_attribute 'cudnn-version' "${DEFAULT_CUDNN_VERSION}")
285-
function is_cudnn8() ( set +x ; [[ "${CUDNN_VERSION%%.*}" == "8" ]] ; )
286-
function is_cudnn9() ( set +x ; [[ "${CUDNN_VERSION%%.*}" == "9" ]] ; )
287-
# The minimum cuDNN version supported by rocky is ${MIN_ROCKY8_CUDNN8_VERSION}
288-
if is_rocky && (version_lt "${CUDNN_VERSION}" "${MIN_ROCKY8_CUDNN8_VERSION}") ; then
289-
CUDNN_VERSION="${MIN_ROCKY8_CUDNN8_VERSION}"
290-
elif (ge_ubuntu20 || ge_debian12) && is_cudnn8 ; then
291-
# cuDNN v8 is not distribution for ubuntu20+, debian12
292-
CUDNN_VERSION="${DEFAULT_CUDNN9_VERSION}"
293-
elif (le_ubuntu18 || le_debian11) && is_cudnn9 ; then
294-
# cuDNN v9 is not distributed for ubuntu18, debian10, debian11 ; fall back to 8
295-
CUDNN_VERSION="8.8.0.121"
296-
fi
297-
readonly CUDNN_VERSION
298-
299-
readonly DEFAULT_NCCL_VERSION=${NCCL_FOR_CUDA["${CUDA_VERSION}"]}
300-
readonly NCCL_VERSION=$(get_metadata_attribute 'nccl-version' ${DEFAULT_NCCL_VERSION})
301-
302-
# Parameters for NVIDIA-provided Debian GPU driver
303-
readonly DEFAULT_USERSPACE_URL="https://us.download.nvidia.com/XFree86/Linux-x86_64/${DRIVER_VERSION}/NVIDIA-Linux-x86_64-${DRIVER_VERSION}.run"
304-
305-
readonly USERSPACE_URL=$(get_metadata_attribute 'gpu-driver-url' "${DEFAULT_USERSPACE_URL}")
274+
function set_cudnn_version() {
275+
readonly MIN_ROCKY8_CUDNN8_VERSION="8.0.5.39"
276+
readonly DEFAULT_CUDNN8_VERSION="8.3.1.22"
277+
readonly DEFAULT_CUDNN9_VERSION="9.1.0.70"
278+
279+
# Parameters for NVIDIA-provided cuDNN library
280+
readonly DEFAULT_CUDNN_VERSION=${CUDNN_FOR_CUDA["${CUDA_VERSION}"]}
281+
CUDNN_VERSION=$(get_metadata_attribute 'cudnn-version' "${DEFAULT_CUDNN_VERSION}")
282+
# The minimum cuDNN version supported by rocky is ${MIN_ROCKY8_CUDNN8_VERSION}
283+
if ( is_rocky && version_lt "${CUDNN_VERSION}" "${MIN_ROCKY8_CUDNN8_VERSION}" ) ; then
284+
CUDNN_VERSION="${MIN_ROCKY8_CUDNN8_VERSION}"
285+
elif (ge_ubuntu20 || ge_debian12) && is_cudnn8 ; then
286+
# cuDNN v8 is not distribution for ubuntu20+, debian12
287+
CUDNN_VERSION="${DEFAULT_CUDNN9_VERSION}"
288+
elif (le_ubuntu18 || le_debian11) && is_cudnn9 ; then
289+
# cuDNN v9 is not distributed for ubuntu18, debian10, debian11 ; fall back to 8
290+
CUDNN_VERSION="8.8.0.121"
291+
fi
292+
readonly CUDNN_VERSION
293+
}
306294

307-
USERSPACE_FILENAME="$(echo ${USERSPACE_URL} | perl -pe 's{^.+/}{}')"
308-
readonly USERSPACE_FILENAME
309295

310296
# Short name for urls
311297
if is_ubuntu22 ; then
@@ -330,15 +316,14 @@ else
330316
nccl_shortname="${shortname}"
331317
fi
332318

333-
# Parameters for NVIDIA-provided package repositories
334-
readonly NVIDIA_BASE_DL_URL='https://developer.download.nvidia.com/compute'
335-
readonly NVIDIA_REPO_URL="${NVIDIA_BASE_DL_URL}/cuda/repos/${shortname}/x86_64"
319+
function set_nv_urls() {
320+
# Parameters for NVIDIA-provided package repositories
321+
readonly NVIDIA_BASE_DL_URL='https://developer.download.nvidia.com/compute'
322+
readonly NVIDIA_REPO_URL="${NVIDIA_BASE_DL_URL}/cuda/repos/${shortname}/x86_64"
336323

337-
# Parameters for NVIDIA-provided NCCL library
338-
readonly DEFAULT_NCCL_REPO_URL="${NVIDIA_BASE_DL_URL}/machine-learning/repos/${nccl_shortname}/x86_64/nvidia-machine-learning-repo-${nccl_shortname}_1.0.0-1_amd64.deb"
339-
NCCL_REPO_URL=$(get_metadata_attribute 'nccl-repo-url' "${DEFAULT_NCCL_REPO_URL}")
340-
readonly NCCL_REPO_URL
341-
readonly NCCL_REPO_KEY="${NVIDIA_BASE_DL_URL}/machine-learning/repos/${nccl_shortname}/x86_64/7fa2af80.pub" # 3bf863cc.pub
324+
# Parameter for NVIDIA-provided Rocky Linux GPU driver
325+
readonly NVIDIA_ROCKY_REPO_URL="${NVIDIA_REPO_URL}/cuda-${shortname}.repo"
326+
}
342327

343328
function set_cuda_runfile_url() {
344329
local MAX_DRIVER_VERSION
@@ -436,11 +421,7 @@ function set_cuda_runfile_url() {
436421
fi
437422
}
438423

439-
set_cuda_runfile_url
440-
441-
# Parameter for NVIDIA-provided Rocky Linux GPU driver
442-
readonly NVIDIA_ROCKY_REPO_URL="${NVIDIA_REPO_URL}/cuda-${shortname}.repo"
443-
424+
function set_cudnn_tarball_url() {
444425
CUDNN_TARBALL="cudnn-${CUDA_VERSION}-linux-x64-v${CUDNN_VERSION}.tgz"
445426
CUDNN_TARBALL_URL="${NVIDIA_BASE_DL_URL}/redist/cudnn/v${CUDNN_VERSION%.*}/${CUDNN_TARBALL}"
446427
if ( version_ge "${CUDNN_VERSION}" "8.3.1.22" ); then
@@ -460,6 +441,7 @@ if ( version_ge "${CUDA_VERSION}" "12.0" ); then
460441
fi
461442
readonly CUDNN_TARBALL
462443
readonly CUDNN_TARBALL_URL
444+
}
463445

464446
# Whether to install NVIDIA-provided or OS-provided GPU driver
465447
GPU_DRIVER_PROVIDER=$(get_metadata_attribute 'gpu-driver-provider' 'NVIDIA')
@@ -610,6 +592,9 @@ function uninstall_local_cudnn8_repo() {
610592
}
611593

612594
function install_nvidia_nccl() {
595+
readonly DEFAULT_NCCL_VERSION=${NCCL_FOR_CUDA["${CUDA_VERSION}"]}
596+
readonly NCCL_VERSION=$(get_metadata_attribute 'nccl-version' ${DEFAULT_NCCL_VERSION})
597+
613598
is_complete nccl && return
614599

615600
if is_cuda11 && is_debian12 ; then
@@ -1044,6 +1029,13 @@ function build_driver_from_packages() {
10441029
}
10451030

10461031
function install_nvidia_userspace_runfile() {
1032+
# Parameters for NVIDIA-provided Debian GPU driver
1033+
readonly DEFAULT_USERSPACE_URL="https://us.download.nvidia.com/XFree86/Linux-x86_64/${DRIVER_VERSION}/NVIDIA-Linux-x86_64-${DRIVER_VERSION}.run"
1034+
1035+
readonly USERSPACE_URL=$(get_metadata_attribute 'gpu-driver-url' "${DEFAULT_USERSPACE_URL}")
1036+
1037+
USERSPACE_FILENAME="$(echo ${USERSPACE_URL} | perl -pe 's{^.+/}{}')"
1038+
readonly USERSPACE_FILENAME
10471039

10481040
# This .run file contains NV's OpenGL implementation as well as
10491041
# nvidia optimized implementations of the gtk+ 2,3 stack(s) not
@@ -1565,6 +1557,10 @@ function install_dependencies() {
15651557
}
15661558

15671559
function prepare_gpu_env(){
1560+
#set_support_matrix
1561+
1562+
set_cuda_version
1563+
set_driver_version
15681564

15691565
set +e
15701566
gpu_count="$(grep -i PCI_ID=10DE /sys/bus/pci/devices/*/uevent | wc -l)"
@@ -1588,6 +1584,11 @@ function prepare_gpu_env(){
15881584

15891585
# determine whether we have nvidia-smi installed and working
15901586
nvsmi
1587+
1588+
set_nv_urls
1589+
set_cuda_runfile_url
1590+
set_cudnn_version
1591+
set_cudnn_tarball_url
15911592
}
15921593

15931594
# Hold all NVIDIA-related packages from upgrading unintenionally or services like unattended-upgrades

0 commit comments

Comments
 (0)