@@ -2,8 +2,6 @@ module LuxDeviceUtils
2
2
3
3
using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays
4
4
import Adapt: adapt, adapt_storage
5
- import Base: PkgId, UUID
6
- import TruncatedStacktraces
7
5
8
6
using PackageExtensionCompat
9
7
function __init__ ()
@@ -17,41 +15,33 @@ export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor
17
15
abstract type AbstractLuxDevice <: Function end
18
16
abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end
19
17
18
+ __is_functional (:: AbstractLuxDevice ) = false
19
+ __is_loaded (:: AbstractLuxDevice ) = false
20
+
20
21
struct LuxCPUDevice <: AbstractLuxDevice end
22
+ struct LuxCUDADevice <: AbstractLuxGPUDevice end
23
+ struct LuxAMDGPUDevice <: AbstractLuxGPUDevice end
24
+ struct LuxMetalDevice <: AbstractLuxGPUDevice end
21
25
22
- Base. @kwdef struct LuxCUDADevice <: AbstractLuxGPUDevice
23
- name:: String = " CUDA"
24
- pkgid:: PkgId = PkgId (UUID (" d0bbae9a-e099-4d5b-a835-1c6931763bda" ), " LuxCUDA" )
25
- end
26
+ __is_functional (:: LuxCPUDevice ) = true
27
+ __is_loaded (:: LuxCPUDevice ) = true
26
28
27
- Base . @kwdef struct LuxAMDGPUDevice <: AbstractLuxGPUDevice
28
- name :: String = " AMDGPU "
29
- pkgid :: PkgId = PkgId ( UUID ( " 83120cb1-ca15-4f04-bf3b-6967d2e6b60b " ), " LuxAMDGPU " )
30
- end
29
+ _get_device_name ( :: LuxCPUDevice ) = " CPU "
30
+ _get_device_name ( :: LuxCUDADevice ) = " CUDA "
31
+ _get_device_name ( :: LuxAMDGPUDevice ) = " AMDGPU "
32
+ _get_device_name ( :: LuxMetalDevice ) = " Metal "
31
33
32
- Base . @kwdef struct LuxMetalDevice <: AbstractLuxGPUDevice
33
- name :: String = " Metal "
34
- pkgid :: PkgId = PkgId ( UUID ( " dde4c033-4e86-420c-a63e-0dd931031962 " ), " Metal " )
35
- end
34
+ _get_triggerpkg_name ( :: LuxCPUDevice ) = " "
35
+ _get_triggerpkg_name ( :: LuxCUDADevice ) = " LuxCUDA "
36
+ _get_triggerpkg_name ( :: LuxAMDGPUDevice ) = " LuxAMDGPU "
37
+ _get_triggerpkg_name ( :: LuxMetalDevice ) = " Metal "
36
38
37
39
Base. show (io:: IO , dev:: AbstractLuxDevice ) = print (io, nameof (dev))
38
40
39
41
struct LuxDeviceSelectionException <: Exception end
40
42
41
43
function Base. showerror (io:: IO , e:: LuxDeviceSelectionException )
42
- print (io, " LuxDeviceSelectionException(No functional GPU device found!!)" )
43
- if ! TruncatedStacktraces. VERBOSE[]
44
- println (io, TruncatedStacktraces. VERBOSE_MSG)
45
- end
46
- end
47
-
48
- @generated function _get_device_name (t:: T ) where {T <: AbstractLuxDevice }
49
- return hasfield (T, :name ) ? :(t. name) : :(" " )
50
- end
51
-
52
- @generated function _get_trigger_pkgid (t:: T ) where {T <: AbstractLuxDevice }
53
- return hasfield (T, :pkgid ) ? :(t. pkgid) :
54
- :(PkgId (UUID (" b2108857-7c20-44ae-9111-449ecde12c47" ), " Lux" ))
44
+ return print (io, " LuxDeviceSelectionException(No functional GPU device found!!)" )
55
45
end
56
46
57
47
# Order is important here
@@ -125,32 +115,33 @@ function _get_gpu_device(; force_gpu_usage::Bool)
125
115
else
126
116
@debug " Using GPU backend set in preferences: $backend ."
127
117
device = GPU_DEVICES[idx]
128
- if ! haskey (Base . loaded_modules, device. pkgid )
118
+ if ! __is_loaded ( device)
129
119
@warn """ Trying to use backend: $(_get_device_name (device)) but the trigger package $(device. pkgid) is not loaded.
130
120
Ignoring the Preferences backend!!!
131
121
Please load the package and call this function again to respect the Preferences backend.""" maxlog= 1
132
122
else
133
- if getproperty (Base . loaded_modules[ device. pkgid], :functional )( )
123
+ if __is_functional ( device)
134
124
@debug " Using GPU backend: $(_get_device_name (device)) ."
135
125
return device
136
126
else
137
- @warn " GPU backend: $(_get_device_name (device)) set via Preferences.jl is not functional. Defaulting to automatic GPU Backend selection." maxlog= 1
127
+ @warn " GPU backend: $(_get_device_name (device)) set via Preferences.jl is not functional.
128
+ Defaulting to automatic GPU Backend selection." maxlog= 1
138
129
end
139
130
end
140
131
end
141
132
end
142
133
143
134
@debug " Running automatic GPU backend selection..."
144
135
for device in GPU_DEVICES
145
- if haskey (Base . loaded_modules, device. pkgid )
136
+ if __is_loaded ( device)
146
137
@debug " Trying backend: $(_get_device_name (device)) ."
147
- if getproperty (Base . loaded_modules[ device. pkgid], :functional )( )
138
+ if __is_functional ( device)
148
139
@debug " Using GPU backend: $(_get_device_name (device)) ."
149
140
return device
150
141
end
151
142
@debug " GPU backend: $(_get_device_name (device)) is not functional."
152
143
else
153
- @debug " Trigger package for backend ($(_get_device_name (device)) ): $((device. pkgid )) not loaded."
144
+ @debug " Trigger package for backend ($(_get_device_name (device)) ): $(_get_trigger_pkgname (device)) not loaded."
154
145
end
155
146
end
156
147
0 commit comments