Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -680,12 +680,20 @@ class PodSpecBuilder {
@PackageScope
String getAcceleratorType(AcceleratorResource accelerator) {

def type = accelerator.type ?: 'nvidia.com'
// Default to standard NVIDIA GPU if left entirely blank.
def type = accelerator.type?.toLowerCase() ?: 'nvidia.com'

if ( type.contains('/') )
// Assume the user has fully specified the resource type.
return type

// Map common vendor shorthands to their standard K8s Extended Resource strings.
if (type =~ /\b(nvidia|tesla|ampere|h100|a100)\b/) return 'nvidia.com/gpu'
if (type =~ /\b(amd|radeon|instinct)\b/) return 'amd.com/gpu'
if (type =~ /\b(tpu|google)\b/) return 'google.com/tpu'
if (type =~ /\b(neuron|inferentia|trainium|aws)\b/) return 'aws.amazon.com/neuron'
if (type =~ /\b(intel|gaudi)\b/) return 'gpu.intel.com/i915'

// Assume we're using GPU and update as necessary.
if( !type.contains('.') ) type += '.com'
type += '/gpu'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,47 +613,24 @@ class PodSpecBuilderTest extends Specification {
given:
def builder = new PodSpecBuilder()

when:
def res = builder.addAcceleratorResources(new AcceleratorResource(request:2, limit: 5), null)
then:
res.requests == ['nvidia.com/gpu': 2]
res.limits == ['nvidia.com/gpu': 5]

when:
res = builder.addAcceleratorResources(new AcceleratorResource(limit: 5, type:'foo'), null)
then:
res.requests == ['foo.com/gpu': 5]
res.limits == ['foo.com/gpu': 5]

when:
res = builder.addAcceleratorResources(new AcceleratorResource(request: 5, type:'foo.org'), null)
then:
res.requests == ['foo.org/gpu': 5]
res.limits == null

when:
res = builder.addAcceleratorResources(new AcceleratorResource(request: 5, type: 'foo.org'), [requests: [cpu: 2]])
then:
res.requests == [cpu: 2, 'foo.org/gpu': 5]
res.limits == null

when:
res = builder.addAcceleratorResources(new AcceleratorResource(request: 5, limit: 10, type: 'foo.org'), [requests: [cpu: 2]])
then:
res.requests == [cpu: 2, 'foo.org/gpu': 5]
res.limits == ['foo.org/gpu': 10]

when:
res = builder.addAcceleratorResources(new AcceleratorResource(request: 5, type:'example.com/fpga'), null)
then:
res.requests == ['example.com/fpga': 5]
res.limits == null
expect:
def res = builder.addAcceleratorResources(new AcceleratorResource(request: req, limit: lim, type: inputType), existing)
res.requests == expectedReq
res.limits == expectedLim

when:
res = builder.addAcceleratorResources(new AcceleratorResource(request: 5, limit: 10, type: 'example.com/fpga'), [requests: [cpu: 2]])
then:
res.requests == [cpu: 2, 'example.com/fpga': 5]
res.limits == ['example.com/fpga': 10]
where:
inputType | req | lim | existing | expectedReq | expectedLim
null | 2 | 5 | null | ['nvidia.com/gpu': 2] | ['nvidia.com/gpu': 5]
'foo' | null| 5 | null | ['foo.com/gpu': 5] | ['foo.com/gpu': 5]
'foo.org' | 5 | null| null | ['foo.org/gpu': 5] | null
'foo.org' | 5 | null| [requests: [cpu: 2]] | [cpu: 2, 'foo.org/gpu': 5] | null
'foo.org' | 5 | 10 | [requests: [cpu: 2]] | [cpu: 2, 'foo.org/gpu': 5] | ['foo.org/gpu': 10]
'example.com/fpga' | 5 | null| null | ['example.com/fpga': 5] | null
'nvidia-tesla-k80' | 4 | 4 | null | ['nvidia.com/gpu': 4] | ['nvidia.com/gpu': 4]
'tpu-v5-lite-podslice' | 8 | 8 | null | ['google.com/tpu': 8] | ['google.com/tpu': 8]
'amd-instinct-mi300' | 2 | 2 | null | ['amd.com/gpu': 2] | ['amd.com/gpu': 2]
'aws-trn1-32xlarge' | 8 | 8 | null | ['aws.amazon.com/neuron': 8] | ['aws.amazon.com/neuron': 8]
'intel-gaudi-3' | 1 | 1 | null | ['gpu.intel.com/i915': 1] | ['gpu.intel.com/i915': 1]
}

def 'should add resources limits' () {
Expand Down
Loading