Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit db95b0a

Browse files
authored
Merge pull request #9 from LuxDL/ap/nopkgid
Use `__is_functional` & `__is_loaded` instead of PkgIDs
2 parents 10ccab1 + f2d4d62 commit db95b0a

5 files changed

+34
-36
lines changed

Project.toml

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxDeviceUtils"
22
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "0.1.4"
4+
version = "0.1.5"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -12,7 +12,6 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1212
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
15-
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
1615

1716
[weakdeps]
1817
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -39,7 +38,6 @@ LuxCore = "0.1.4"
3938
Metal = "0.4, 0.5"
4039
PackageExtensionCompat = "1"
4140
Preferences = "1"
42-
TruncatedStacktraces = "1"
4341
Zygote = "0.6"
4442
julia = "1.6"
4543

ext/LuxDeviceUtilsLuxAMDGPUExt.jl

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import ChainRulesCore as CRC
66

77
__init__() = reset_gpu_device!()
88

9+
LuxDeviceUtils.__is_loaded(::LuxAMDGPUDevice) = true
10+
LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional()
11+
912
# Device Transfer
1013
## To GPU
1114
adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x)

ext/LuxDeviceUtilsLuxCUDAExt.jl

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import ChainRulesCore as CRC
66

77
__init__() = reset_gpu_device!()
88

9+
LuxDeviceUtils.__is_loaded(::LuxCUDADevice) = true
10+
LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional()
11+
912
# Device Transfer
1013
## To GPU
1114
adapt_storage(::LuxCUDAAdaptor, x) = cu(x)

ext/LuxDeviceUtilsMetalExt.jl

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import ChainRulesCore as CRC
66

77
__init__() = reset_gpu_device!()
88

9+
LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true
10+
LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional()
11+
912
# Device Transfer
1013
## To GPU
1114
adapt_storage(::LuxMetalAdaptor, x) = mtl(x)

src/LuxDeviceUtils.jl

+24-33
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ module LuxDeviceUtils
22

33
using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays
44
import Adapt: adapt, adapt_storage
5-
import Base: PkgId, UUID
6-
import TruncatedStacktraces
75

86
using PackageExtensionCompat
97
function __init__()
@@ -17,41 +15,33 @@ export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor
1715
abstract type AbstractLuxDevice <: Function end
1816
abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end
1917

18+
__is_functional(::AbstractLuxDevice) = false
19+
__is_loaded(::AbstractLuxDevice) = false
20+
2021
struct LuxCPUDevice <: AbstractLuxDevice end
22+
struct LuxCUDADevice <: AbstractLuxGPUDevice end
23+
struct LuxAMDGPUDevice <: AbstractLuxGPUDevice end
24+
struct LuxMetalDevice <: AbstractLuxGPUDevice end
2125

22-
Base.@kwdef struct LuxCUDADevice <: AbstractLuxGPUDevice
23-
name::String = "CUDA"
24-
pkgid::PkgId = PkgId(UUID("d0bbae9a-e099-4d5b-a835-1c6931763bda"), "LuxCUDA")
25-
end
26+
__is_functional(::LuxCPUDevice) = true
27+
__is_loaded(::LuxCPUDevice) = true
2628

27-
Base.@kwdef struct LuxAMDGPUDevice <: AbstractLuxGPUDevice
28-
name::String = "AMDGPU"
29-
pkgid::PkgId = PkgId(UUID("83120cb1-ca15-4f04-bf3b-6967d2e6b60b"), "LuxAMDGPU")
30-
end
29+
_get_device_name(::LuxCPUDevice) = "CPU"
30+
_get_device_name(::LuxCUDADevice) = "CUDA"
31+
_get_device_name(::LuxAMDGPUDevice) = "AMDGPU"
32+
_get_device_name(::LuxMetalDevice) = "Metal"
3133

32-
Base.@kwdef struct LuxMetalDevice <: AbstractLuxGPUDevice
33-
name::String = "Metal"
34-
pkgid::PkgId = PkgId(UUID("dde4c033-4e86-420c-a63e-0dd931031962"), "Metal")
35-
end
34+
_get_triggerpkg_name(::LuxCPUDevice) = ""
35+
_get_triggerpkg_name(::LuxCUDADevice) = "LuxCUDA"
36+
_get_triggerpkg_name(::LuxAMDGPUDevice) = "LuxAMDGPU"
37+
_get_triggerpkg_name(::LuxMetalDevice) = "Metal"
3638

3739
Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev))
3840

3941
struct LuxDeviceSelectionException <: Exception end
4042

4143
function Base.showerror(io::IO, e::LuxDeviceSelectionException)
42-
print(io, "LuxDeviceSelectionException(No functional GPU device found!!)")
43-
if !TruncatedStacktraces.VERBOSE[]
44-
println(io, TruncatedStacktraces.VERBOSE_MSG)
45-
end
46-
end
47-
48-
@generated function _get_device_name(t::T) where {T <: AbstractLuxDevice}
49-
return hasfield(T, :name) ? :(t.name) : :("")
50-
end
51-
52-
@generated function _get_trigger_pkgid(t::T) where {T <: AbstractLuxDevice}
53-
return hasfield(T, :pkgid) ? :(t.pkgid) :
54-
:(PkgId(UUID("b2108857-7c20-44ae-9111-449ecde12c47"), "Lux"))
44+
return print(io, "LuxDeviceSelectionException(No functional GPU device found!!)")
5545
end
5646

5747
# Order is important here
@@ -125,32 +115,33 @@ function _get_gpu_device(; force_gpu_usage::Bool)
125115
else
126116
@debug "Using GPU backend set in preferences: $backend."
127117
device = GPU_DEVICES[idx]
128-
if !haskey(Base.loaded_modules, device.pkgid)
118+
if !__is_loaded(device)
129119
@warn """Trying to use backend: $(_get_device_name(device)) but the trigger package $(device.pkgid) is not loaded.
130120
Ignoring the Preferences backend!!!
131121
Please load the package and call this function again to respect the Preferences backend.""" maxlog=1
132122
else
133-
if getproperty(Base.loaded_modules[device.pkgid], :functional)()
123+
if __is_functional(device)
134124
@debug "Using GPU backend: $(_get_device_name(device))."
135125
return device
136126
else
137-
@warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional. Defaulting to automatic GPU Backend selection." maxlog=1
127+
@warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional.
128+
Defaulting to automatic GPU Backend selection." maxlog=1
138129
end
139130
end
140131
end
141132
end
142133

143134
@debug "Running automatic GPU backend selection..."
144135
for device in GPU_DEVICES
145-
if haskey(Base.loaded_modules, device.pkgid)
136+
if __is_loaded(device)
146137
@debug "Trying backend: $(_get_device_name(device))."
147-
if getproperty(Base.loaded_modules[device.pkgid], :functional)()
138+
if __is_functional(device)
148139
@debug "Using GPU backend: $(_get_device_name(device))."
149140
return device
150141
end
151142
@debug "GPU backend: $(_get_device_name(device)) is not functional."
152143
else
153-
@debug "Trigger package for backend ($(_get_device_name(device))): $((device.pkgid)) not loaded."
144+
@debug "Trigger package for backend ($(_get_device_name(device))): $(_get_trigger_pkgname(device)) not loaded."
154145
end
155146
end
156147

0 commit comments

Comments
 (0)