diff --git a/support/defaults.go b/support/defaults.go index c165b22..137556d 100644 --- a/support/defaults.go +++ b/support/defaults.go @@ -5,7 +5,9 @@ package support // *********************** const ( - RayVersion = "2.35.0" - RayImage = "quay.io/modh/ray:2.35.0-py39-cu121" - RayROCmImage = "quay.io/modh/ray:2.35.0-py39-rocm61" + RayVersion = "2.35.0" + RayImage = "quay.io/modh/ray:2.35.0-py39-cu121" + RayROCmImage = "quay.io/modh/ray:2.35.0-py39-rocm61" + RayTorchCudaImage = "quay.io/rhoai/ray:2.35.0-py39-cu121-torch24-fa26" + RayTorchROCmImage = "quay.io/rhoai/ray:2.35.0-py39-rocm61-torch24-fa26" ) diff --git a/support/environment.go b/support/environment.go index 3de25e6..98e8886 100644 --- a/support/environment.go +++ b/support/environment.go @@ -25,10 +25,12 @@ const ( // The environment variables hereafter can be used to change the components // used for testing. - CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION" - CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE" - CodeFlareTestRayROCmImage = "CODEFLARE_TEST_RAY_ROCM_IMAGE" - CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE" + CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION" + CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE" + CodeFlareTestRayROCmImage = "CODEFLARE_TEST_RAY_ROCM_IMAGE" + CodeFlareTestRayTorchCudaImage = "CODEFLARE_TEST_RAY_TORCH_CUDA_IMAGE" + CodeFlareTestRayTorchROCmImage = "CODEFLARE_TEST_RAY_TORCH_ROCM_IMAGE" + CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE" // The testing output directory, to write output files into. CodeFlareTestOutputDir = "CODEFLARE_TEST_OUTPUT_DIR" @@ -83,6 +85,14 @@ func GetRayROCmImage() string { return lookupEnvOrDefault(CodeFlareTestRayROCmImage, RayROCmImage) } +func GetRayTorchCudaImage() string { + return lookupEnvOrDefault(CodeFlareTestRayTorchCudaImage, RayTorchCudaImage) +} + +func GetRayTorchROCmImage() string { + return lookupEnvOrDefault(CodeFlareTestRayTorchROCmImage, RayTorchROCmImage) +} + func GetPyTorchImage() string { return lookupEnvOrDefault(CodeFlareTestPyTorchImage, "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime") }