diff --git a/plugins/nf-k8s/src/main/nextflow/k8s/model/PodSpecBuilder.groovy b/plugins/nf-k8s/src/main/nextflow/k8s/model/PodSpecBuilder.groovy index 2c70832039..41b57545e5 100644 --- a/plugins/nf-k8s/src/main/nextflow/k8s/model/PodSpecBuilder.groovy +++ b/plugins/nf-k8s/src/main/nextflow/k8s/model/PodSpecBuilder.groovy @@ -680,12 +680,20 @@ class PodSpecBuilder { @PackageScope String getAcceleratorType(AcceleratorResource accelerator) { - def type = accelerator.type ?: 'nvidia.com' + // Default to standard NVIDIA GPU if left entirely blank. + def type = accelerator.type?.toLowerCase() ?: 'nvidia.com' if ( type.contains('/') ) // Assume the user has fully specified the resource type. return type + // Map common vendor shorthands to their standard K8s Extended Resource strings. + if (type =~ /\b(nvidia|tesla|ampere|h100|a100)\b/) return 'nvidia.com/gpu' + if (type =~ /\b(amd|radeon|instinct)\b/) return 'amd.com/gpu' + if (type =~ /\b(tpu|google)\b/) return 'google.com/tpu' + if (type =~ /\b(neuron|inferentia|trainium|aws)\b/) return 'aws.amazon.com/neuron' + if (type =~ /\b(intel|gaudi)\b/) return 'gpu.intel.com/i915' + // Assume we're using GPU and update as necessary. if( !type.contains('.') ) type += '.com' type += '/gpu' diff --git a/plugins/nf-k8s/src/test/nextflow/k8s/model/PodSpecBuilderTest.groovy b/plugins/nf-k8s/src/test/nextflow/k8s/model/PodSpecBuilderTest.groovy index d3ea3f9d57..3069fd9760 100644 --- a/plugins/nf-k8s/src/test/nextflow/k8s/model/PodSpecBuilderTest.groovy +++ b/plugins/nf-k8s/src/test/nextflow/k8s/model/PodSpecBuilderTest.groovy @@ -613,47 +613,24 @@ class PodSpecBuilderTest extends Specification { given: def builder = new PodSpecBuilder() - when: - def res = builder.addAcceleratorResources(new AcceleratorResource(request:2, limit: 5), null) - then: - res.requests == ['nvidia.com/gpu': 2] - res.limits == ['nvidia.com/gpu': 5] - - when: - res = builder.addAcceleratorResources(new AcceleratorResource(limit: 5, type:'foo'), null) - then: - res.requests == ['foo.com/gpu': 5] - res.limits == ['foo.com/gpu': 5] - - when: - res = builder.addAcceleratorResources(new AcceleratorResource(request: 5, type:'foo.org'), null) - then: - res.requests == ['foo.org/gpu': 5] - res.limits == null - - when: - res = builder.addAcceleratorResources(new AcceleratorResource(request: 5, type: 'foo.org'), [requests: [cpu: 2]]) - then: - res.requests == [cpu: 2, 'foo.org/gpu': 5] - res.limits == null - - when: - res = builder.addAcceleratorResources(new AcceleratorResource(request: 5, limit: 10, type: 'foo.org'), [requests: [cpu: 2]]) - then: - res.requests == [cpu: 2, 'foo.org/gpu': 5] - res.limits == ['foo.org/gpu': 10] - - when: - res = builder.addAcceleratorResources(new AcceleratorResource(request: 5, type:'example.com/fpga'), null) - then: - res.requests == ['example.com/fpga': 5] - res.limits == null + expect: + def res = builder.addAcceleratorResources(new AcceleratorResource(request: req, limit: lim, type: inputType), existing) + res.requests == expectedReq + res.limits == expectedLim - when: - res = builder.addAcceleratorResources(new AcceleratorResource(request: 5, limit: 10, type: 'example.com/fpga'), [requests: [cpu: 2]]) - then: - res.requests == [cpu: 2, 'example.com/fpga': 5] - res.limits == ['example.com/fpga': 10] + where: + inputType | req | lim | existing | expectedReq | expectedLim + null | 2 | 5 | null | ['nvidia.com/gpu': 2] | ['nvidia.com/gpu': 5] + 'foo' | null| 5 | null | ['foo.com/gpu': 5] | ['foo.com/gpu': 5] + 'foo.org' | 5 | null| null | ['foo.org/gpu': 5] | null + 'foo.org' | 5 | null| [requests: [cpu: 2]] | [cpu: 2, 'foo.org/gpu': 5] | null + 'foo.org' | 5 | 10 | [requests: [cpu: 2]] | [cpu: 2, 'foo.org/gpu': 5] | ['foo.org/gpu': 10] + 'example.com/fpga' | 5 | null| null | ['example.com/fpga': 5] | null + 'nvidia-tesla-k80' | 4 | 4 | null | ['nvidia.com/gpu': 4] | ['nvidia.com/gpu': 4] + 'tpu-v5-lite-podslice' | 8 | 8 | null | ['google.com/tpu': 8] | ['google.com/tpu': 8] + 'amd-instinct-mi300' | 2 | 2 | null | ['amd.com/gpu': 2] | ['amd.com/gpu': 2] + 'aws-trn1-32xlarge' | 8 | 8 | null | ['aws.amazon.com/neuron': 8] | ['aws.amazon.com/neuron': 8] + 'intel-gaudi-3' | 1 | 1 | null | ['gpu.intel.com/i915': 1] | ['gpu.intel.com/i915': 1] } def 'should add resources limits' () {