4343 ).Hidden ().Default ("" ).String ()
4444)
4545
46- // Custom errors.
47- var (
48- errUnsupportedGPUVendor = errors .New ("unsupported gpu vendor" )
49- )
50-
5146// Regexes.
5247var (
5348 pciBusIDRegex = regexp .MustCompile (`(?P<domain>[0-9a-fA-F]+):(?P<bus>[0-9a-fA-F]+):(?P<slot>[0-9a-fA-F]+)\.(?P<function>[0-9a-fA-F]+)` )
@@ -191,6 +186,52 @@ type DeviceAttrsShared struct {
191186 DecCount uint64 `xml:"decoder_count"`
192187}
193188
189+ // UnmarshalXML implements the xml.Unmarshaler interface.
190+ func (p * DeviceAttrsShared ) UnmarshalXML (d * xml.Decoder , start xml.StartElement ) error {
191+ // Unmarshal into temporary struct.
192+ // We are just doing this in case of edge cases where
193+ // count values are not available. In that case xml.Unmarshal
194+ // might return error due to not able to parse. We are trying
195+ // to address those cases so that we get some default values
196+ // instead of error.
197+ var tmp struct {
198+ XMLName xml.Name `xml:"shared"`
199+ SMCount string `xml:"multiprocessor_count"`
200+ CECount string `xml:"copy_engine_count"`
201+ EncCount string `xml:"encoder_count"`
202+ DecCount string `xml:"decoder_count"`
203+ }
204+
205+ if err := d .DecodeElement (& tmp , & start ); err != nil {
206+ return err
207+ }
208+
209+ var err error
210+
211+ p .XMLName = tmp .XMLName
212+
213+ // In case of errors set count to 1. This is especially important for
214+ // SMCount as we compute SMFrac with it and if we set it zero, fractions
215+ // will be NaN.
216+ if p .SMCount , err = strconv .ParseUint (tmp .SMCount , 10 , 64 ); err != nil {
217+ p .SMCount = 1
218+ }
219+
220+ if p .CECount , err = strconv .ParseUint (tmp .CECount , 10 , 64 ); err != nil {
221+ p .CECount = 1
222+ }
223+
224+ if p .EncCount , err = strconv .ParseUint (tmp .EncCount , 10 , 64 ); err != nil {
225+ p .EncCount = 1
226+ }
227+
228+ if p .DecCount , err = strconv .ParseUint (tmp .DecCount , 10 , 64 ); err != nil {
229+ p .DecCount = 1
230+ }
231+
232+ return nil
233+ }
234+
194235type DeviceAttrs struct {
195236 XMLName xml.Name `xml:"device_attributes"`
196237 Shared DeviceAttrsShared `xml:"shared"`
@@ -207,22 +248,107 @@ type MIGDevice struct {
207248 UUID string
208249}
209250
251+ // UnmarshalXML implements the xml.Unmarshaler interface.
252+ func (p * MIGDevice ) UnmarshalXML (d * xml.Decoder , start xml.StartElement ) error {
253+ // Unmarshal into temporary struct
254+ // We are just doing this in case of edge cases where
255+ // count values are not available. In that case xml.Unmarshal
256+ // might return error due to not able to parse. We are trying
257+ // to address those cases so that we get some default values
258+ // instead of error.
259+ var tmp struct {
260+ XMLName xml.Name `xml:"mig_device"`
261+ Index string `xml:"index"`
262+ GPUInstID string `xml:"gpu_instance_id"`
263+ ComputeInstID string `xml:"compute_instance_id"`
264+ DeviceAttrs DeviceAttrs `xml:"device_attributes"`
265+ FBMemory Memory `xml:"fb_memory_usage"`
266+ Bar1Memory Memory `xml:"bar1_memory_usage"`
267+ }
268+
269+ if err := d .DecodeElement (& tmp , & start ); err != nil {
270+ return err
271+ }
272+
273+ var err error
274+
275+ // In case of errors return more intuitive error message.
276+ if p .Index , err = strconv .ParseUint (tmp .Index , 10 , 64 ); err != nil {
277+ return fmt .Errorf ("invalid mig index %s: %w" , tmp .Index , err )
278+ }
279+
280+ if p .GPUInstID , err = strconv .ParseUint (tmp .GPUInstID , 10 , 64 ); err != nil {
281+ return fmt .Errorf ("invalid mig gpu instance id %s: %w" , tmp .GPUInstID , err )
282+ }
283+
284+ if p .ComputeInstID , err = strconv .ParseUint (tmp .ComputeInstID , 10 , 64 ); err != nil {
285+ return fmt .Errorf ("invalid mig compute instance id %s: %w" , tmp .ComputeInstID , err )
286+ }
287+
288+ p .DeviceAttrs = tmp .DeviceAttrs
289+ p .FBMemory = tmp .FBMemory
290+ p .Bar1Memory = tmp .Bar1Memory
291+ p .XMLName = tmp .XMLName
292+
293+ return nil
294+ }
295+
210296type MIGDevices struct {
211297 XMLName xml.Name `xml:"mig_devices"`
212298 Devices []MIGDevice `xml:"mig_device"`
213299}
214300
215- type ProcessInfo struct {
216- XMLName xml.Name `xml:"process_info"`
217- GPUInstID uint64 `xml:"gpu_instance_id"`
218- ComputeInstID uint64 `xml:"compute_instance_id"`
219- PID uint64 `xml:"pid"`
220- }
221-
222- type Processes struct {
223- XMLName xml.Name `xml:"processes"`
224- ProcessInfos []ProcessInfo `xml:"process_info"`
225- }
301+ // type ProcessInfo struct {
302+ // XMLName xml.Name `xml:"process_info"`
303+ // GPUInstID uint64 `xml:"gpu_instance_id"`
304+ // ComputeInstID uint64 `xml:"compute_instance_id"`
305+ // PID uint64 `xml:"pid"`
306+ // }
307+
308+ // func (p *ProcessInfo) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
309+ // // Check if tag is valid
310+ // if start.Name.Local != "process_info" {
311+ // return fmt.Errorf("invalid start tag %s for process_info", start.Name.Local)
312+ // }
313+
314+ // // Unmarshal into temporary struct
315+ // var tmp struct {
316+ // XMLName xml.Name `xml:"process_info"`
317+ // GPUInstID string `xml:"gpu_instance_id"`
318+ // ComputeInstID string `xml:"compute_instance_id"`
319+ // PID uint64 `xml:"pid"`
320+ // }
321+
322+ // if err := d.DecodeElement(&tmp, &start); err != nil {
323+ // return err
324+ // }
325+
326+ // // Now check if GPUInstID and ComputeInstID can be converted
327+ // // to uint64
328+ // var gpuInstID, computeInstID uint64
329+
330+ // var err error
331+
332+ // if gpuInstID, err = strconv.ParseUint(tmp.GPUInstID, 10, 64); err != nil {
333+ // gpuInstID = 9999
334+ // }
335+
336+ // if computeInstID, err = strconv.ParseUint(tmp.ComputeInstID, 10, 64); err != nil {
337+ // computeInstID = 9999
338+ // }
339+
340+ // p.XMLName = tmp.XMLName
341+ // p.GPUInstID = gpuInstID
342+ // p.ComputeInstID = computeInstID
343+ // p.PID = tmp.PID
344+
345+ // return nil
346+ // }
347+
348+ // type Processes struct {
349+ // XMLName xml.Name `xml:"processes"`
350+ // ProcessInfos []ProcessInfo `xml:"process_info"`
351+ // }
226352
227353type VirtMode struct {
228354 XMLName xml.Name `xml:"gpu_virtualization_mode"`
@@ -246,7 +372,7 @@ type NvidiaGPU struct {
246372 MIGDevices MIGDevices `xml:"mig_devices"`
247373 UUID string `xml:"uuid"`
248374 MinorNumber string `xml:"minor_number"`
249- Processes Processes `xml:"processes"`
375+ // Processes Processes `xml:"processes"` // Ignore for the moment, we are not using it.
250376}
251377
252378type NVIDIASMILog struct {
@@ -432,14 +558,16 @@ func NewGPUSMI(k8sClient *ceems_k8s.Client, logger *slog.Logger) (*GPUSMI, error
432558 // Detect GPU device vendors
433559 vendors , err = detectVendors ()
434560 if err != nil {
435- logger .Warn ("Failed to detect GPU devices" )
561+ logger .Error ("Failed to detect GPU devices" , "err" , err )
436562
437563 return nil , fmt .Errorf ("failed to detect devices: %w" , err )
438564 }
439565 }
440566
441567 // If no vendors found return early
442568 if len (vendors ) == 0 {
569+ logger .Debug ("No GPU devices from supported vendors detected" )
570+
443571 return & GPUSMI {logger : logger }, nil
444572 }
445573
@@ -478,11 +606,11 @@ func NewGPUSMI(k8sClient *ceems_k8s.Client, logger *slog.Logger) (*GPUSMI, error
478606 }
479607 }
480608
481- // If k8sClient is not nil, figure out which containers are running the drivers
482- ctx , cancel := context .WithTimeout (context .Background (), 10 * time .Second )
483- defer cancel ()
484-
485609 if k8sClient != nil {
610+ // If k8sClient is not nil, figure out which containers are running the drivers
611+ ctx , cancel := context .WithTimeout (context .Background (), 30 * time .Second )
612+ defer cancel ()
613+
486614 for iv , v := range vendors {
487615 var contName , smiCmd string
488616
@@ -536,26 +664,11 @@ func (g *GPUSMI) Discover() error {
536664 for _ , vendor := range g .vendors {
537665 var devs []Device
538666
539- // Keep checking for GPU devices with a timeout of 1 minute
540- // When GPU drivers are not loaded yet, this strategy can be
541- // handy to wait for the drivers to load and for SMI commands to
542- // enumurate GPUs.
543- for start := time .Now (); time .Since (start ) < time .Minute ; {
544- // If errored out, sleep for a while and attempt to get devices again
545- devs , err = g .gpuDevices (vendor )
546- if err != nil && ! errors .Is (err , errUnsupportedGPUVendor ) {
547- time .Sleep (10 * time .Second )
548- } else {
549- g .Devices = append (g .Devices , devs ... )
550-
551- break
552- }
553- }
554-
555- // If we end up here with non nil error, we could not find GPUs for
556- // this vendor. Add to errs
667+ devs , err = g .gpuDevices (vendor )
557668 if err != nil {
558- errs = errors .Join (errs , err )
669+ errs = errors .Join (errs , fmt .Errorf ("failed to fetch GPU devices from vendor %s: %w" , vendor .name , err ))
670+ } else {
671+ g .Devices = append (g .Devices , devs ... )
559672 }
560673 }
561674
@@ -715,6 +828,12 @@ func (g *GPUSMI) ReindexGPUs(orderMap string) {
715828
716829// print emits debug logs with GPU details.
717830func (g * GPUSMI ) print () {
831+ // When GPUs are found, emit an info log
832+ if len (g .Devices ) > 0 {
833+ g .logger .Info ("GPU Devices found" , "num_devices" , len (g .Devices ))
834+ }
835+
836+ // Emit device details for debug logs
718837 for _ , gpu := range g .Devices {
719838 g .logger .Debug ("GPU device" , "vendor" , gpu .vendorID , "details" , gpu )
720839
@@ -732,7 +851,7 @@ func (g *GPUSMI) gpuDevices(vendor vendor) ([]Device, error) {
732851 case amd :
733852 return g .amdGPUDevices (vendor )
734853 default :
735- return nil , fmt . Errorf ( "only NVIDIA and AMD GPU devices are supported: %w" , errUnsupportedGPUVendor )
854+ return nil , nil
736855 }
737856}
738857
@@ -951,8 +1070,8 @@ func parseNvidiaSmiOutput(cmdOutput []byte) ([]Device, error) {
9511070
9521071 // Read XML byte array into gpu
9531072 var nvidiaSMILog NVIDIASMILog
954- if err := xml .Unmarshal (cmdOutput , & nvidiaSMILog ); err != nil { //nolint:musttag
955- return nil , err
1073+ if err := xml .Unmarshal (cmdOutput , & nvidiaSMILog ); err != nil {
1074+ return nil , fmt . Errorf ( "failed to parse nvidia-smi xml log %w" , err )
9561075 }
9571076
9581077 // NOTE: Ensure that we sort the devices using PCI address
0 commit comments