Skip to content

Commit da6a3cf

Browse files
committed
expanding non-default version tests ; adding utility function to verify pyspark ; disk size correction
1 parent 8f2bb2c commit da6a3cf

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

spark-rapids/test_spark_rapids.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ class SparkRapidsTestCase(DataprocTestCase):
2020
def verify_spark_instance(self, name):
2121
self.assert_instance_command(name, "nvidia-smi")
2222

23+
def verify_pyspark(self, name):
24+
# Verify that pyspark works
25+
self.assert_instance_command(name, "echo 'from pyspark.sql import SparkSession ; SparkSession.builder.getOrCreate()' | pyspark -c spark.executor.resource.gpu.amount=1 -c spark.task.resource.gpu.amount=0.01", 1)
26+
2327
def verify_mig_instance(self, name):
2428
self.assert_instance_command(name,
2529
"/usr/bin/nvidia-smi --query-gpu=mig.mode.current --format=csv,noheader | uniq | xargs -I % test % = 'Enabled'")
@@ -114,13 +118,22 @@ def test_spark_rapids_sql(self, configuration, machine_suffixes, accelerator):
114118
# Only need to do this once
115119
self.verify_spark_job_sql()
116120

117-
@parameterized.parameters(("STANDARD", ["w-0"], GPU_T4, "12.4.0", "550.54.14"))
121+
@parameterized.parameters(
122+
("STANDARD", ["w-0"], GPU_T4, "11.8.0", "525.147.05"),
123+
("STANDARD", ["w-0"], GPU_T4, "12.4.0", "550.54.14"),
124+
("STANDARD", ["w-0"], GPU_T4, "12.6.2", "560.35.03")
125+
)
118126
def test_non_default_cuda_versions(self, configuration, machine_suffixes,
119127
accelerator, cuda_version, driver_version):
120128

121129
if self.getImageOs() == "rocky":
122130
self.skipTest("Not supported for Rocky OS")
123131

132+
if pkg_resources.parse_version(cuda_version) > pkg_resources.parse_version("12.0") \
133+
and ( ( self.getImageOs() == 'ubuntu' and self.getImageVersion() <= pkg_resources.parse_version("2.0") ) or \
134+
( self.getImageOs() == 'debian' and self.getImageVersion() <= pkg_resources.parse_version("2.1") ) ):
135+
self.skipTest("CUDA > 12.0 not supported on older debian/ubuntu releases")
136+
124137
if self.getImageVersion() <= pkg_resources.parse_version("2.0"):
125138
self.skipTest("Not supported in 2.0 and earlier images")
126139

0 commit comments

Comments
 (0)