@@ -20,6 +20,10 @@ class SparkRapidsTestCase(DataprocTestCase):
2020 def verify_spark_instance (self , name ):
2121 self .assert_instance_command (name , "nvidia-smi" )
2222
23+ def verify_pyspark (self , name ):
24+ # Verify that pyspark works
25+ self .assert_instance_command (name , "echo 'from pyspark.sql import SparkSession ; SparkSession.builder.getOrCreate()' | pyspark -c spark.executor.resource.gpu.amount=1 -c spark.task.resource.gpu.amount=0.01" , 1 )
26+
2327 def verify_mig_instance (self , name ):
2428 self .assert_instance_command (name ,
2529 "/usr/bin/nvidia-smi --query-gpu=mig.mode.current --format=csv,noheader | uniq | xargs -I % test % = 'Enabled'" )
@@ -114,13 +118,22 @@ def test_spark_rapids_sql(self, configuration, machine_suffixes, accelerator):
114118 # Only need to do this once
115119 self .verify_spark_job_sql ()
116120
117- @parameterized .parameters (("STANDARD" , ["w-0" ], GPU_T4 , "12.4.0" , "550.54.14" ))
121+ @parameterized .parameters (
122+ ("STANDARD" , ["w-0" ], GPU_T4 , "11.8.0" , "525.147.05" ),
123+ ("STANDARD" , ["w-0" ], GPU_T4 , "12.4.0" , "550.54.14" ),
124+ ("STANDARD" , ["w-0" ], GPU_T4 , "12.6.2" , "560.35.03" )
125+ )
118126 def test_non_default_cuda_versions (self , configuration , machine_suffixes ,
119127 accelerator , cuda_version , driver_version ):
120128
121129 if self .getImageOs () == "rocky" :
122130 self .skipTest ("Not supported for Rocky OS" )
123131
132+ if pkg_resources .parse_version (cuda_version ) > pkg_resources .parse_version ("12.0" ) \
133+ and ( ( self .getImageOs () == 'ubuntu' and self .getImageVersion () <= pkg_resources .parse_version ("2.0" ) ) or \
134+ ( self .getImageOs () == 'debian' and self .getImageVersion () <= pkg_resources .parse_version ("2.1" ) ) ):
135+ self .skipTest ("CUDA > 12.0 not supported on older debian/ubuntu releases" )
136+
124137 if self .getImageVersion () <= pkg_resources .parse_version ("2.0" ):
125138 self .skipTest ("Not supported in 2.0 and earlier images" )
126139
0 commit comments