Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit 2866e27

Browse files
authored
Merge pull request #5 from LuxDL/ap/simplify
Simplify Accelerator Loading + Aqua
2 parents 983e735 + 5f1e2f2 commit 2866e27

12 files changed

+40
-32
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxDeviceUtils"
22
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg)](https://buildkite.com/julialang/luxdeviceutils-dot-jl)
99
[![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl)
1010
[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxDeviceUtils)](https://pkgs.genieframework.com?packages=LuxDeviceUtils)
11+
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
1112

1213
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
1314
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)

docs/src/index.md

+6
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,11 @@ gpu_backend!
3737
```@docs
3838
cpu_device
3939
gpu_device
40+
```
41+
42+
### Miscellaneous
43+
44+
```@docs
45+
reset_gpu_device!
4046
supported_gpu_backends
4147
```

ext/LuxDeviceUtilsLuxAMDGPUExt.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@ using ChainRulesCore, LuxAMDGPU, LuxDeviceUtils, Random
44
import Adapt: adapt_storage, adapt
55
import ChainRulesCore as CRC
66

7-
function __init__()
8-
LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] = true
9-
return
10-
end
7+
__init__() = reset_gpu_device!()
118

129
# Device Transfer
1310
## To GPU

ext/LuxDeviceUtilsLuxCUDAExt.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@ using ChainRulesCore, LuxCUDA, LuxDeviceUtils, Random
44
import Adapt: adapt_storage, adapt
55
import ChainRulesCore as CRC
66

7-
function __init__()
8-
LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] = true
9-
return
10-
end
7+
__init__() = reset_gpu_device!()
118

129
# Device Transfer
1310
## To GPU

ext/LuxDeviceUtilsMetalExt.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@ using ChainRulesCore, LuxDeviceUtils, Metal, Random
44
import Adapt: adapt_storage, adapt
55
import ChainRulesCore as CRC
66

7-
function __init__()
8-
LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[] = true
9-
return
10-
end
7+
__init__() = reset_gpu_device!()
118

129
# Device Transfer
1310
## To GPU

src/LuxDeviceUtils.jl

+17-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module LuxDeviceUtils
22

3-
using Functors, LuxCore, Preferences, Random, SparseArrays
3+
using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays
44
import Adapt: adapt, adapt_storage
55
import Base: PkgId, UUID
66

@@ -9,12 +9,10 @@ function __init__()
99
@require_extensions
1010
end
1111

12-
export gpu_backend!, supported_gpu_backends
12+
export gpu_backend!, supported_gpu_backends, reset_gpu_device!
1313
export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice
1414
export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor
1515

16-
const ACCELERATOR_STATE_CHANGED = Ref{Bool}(false)
17-
1816
abstract type AbstractLuxDevice <: Function end
1917
abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end
2018

@@ -58,6 +56,16 @@ const GPU_DEVICES = (LuxCUDADevice(), LuxAMDGPUDevice(), LuxMetalDevice())
5856

5957
const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing)
6058

59+
"""
60+
reset_gpu_device!()
61+
62+
Resets the selected GPU device. This is useful when automatic GPU selection needs to be
63+
run again.
64+
"""
65+
function reset_gpu_device!()
66+
return GPU_DEVICE[] = nothing
67+
end
68+
6169
"""
6270
supported_gpu_backends() -> Tuple{String, ...}
6371
@@ -85,17 +93,14 @@ Selects GPU device based on the following criteria:
8593
4. If nothing works, an error is thrown.
8694
"""
8795
function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice
88-
if !ACCELERATOR_STATE_CHANGED[]
89-
if GPU_DEVICE[] !== nothing
90-
force_gpu_usage &&
91-
!(GPU_DEVICE[] isa AbstractLuxGPUDevice) &&
92-
throw(LuxDeviceSelectionException())
93-
return GPU_DEVICE[]
94-
end
96+
if GPU_DEVICE[] !== nothing
97+
force_gpu_usage &&
98+
!(GPU_DEVICE[] isa AbstractLuxGPUDevice) &&
99+
throw(LuxDeviceSelectionException())
100+
return GPU_DEVICE[]
95101
end
96102

97103
device = _get_gpu_device(; force_gpu_usage)
98-
ACCELERATOR_STATE_CHANGED[] = false
99104
GPU_DEVICE[] = device
100105

101106
return device

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
23
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
34
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
45
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

test/amdgpu.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ end
1010
using LuxAMDGPU
1111

1212
@testset "Loaded Trigger Package" begin
13-
@test LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[]
13+
@test LuxDeviceUtils.GPU_DEVICE[] === nothing
1414

1515
if LuxAMDGPU.functional()
1616
@info "LuxAMDGPU is functional"
@@ -22,7 +22,7 @@ using LuxAMDGPU
2222
@test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(;
2323
force_gpu_usage=true)
2424
end
25-
@test !LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[]
25+
@test LuxDeviceUtils.GPU_DEVICE[] !== nothing
2626
end
2727

2828
using FillArrays, Zygote # Extensions

test/cuda.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ end
1010
using LuxCUDA
1111

1212
@testset "Loaded Trigger Package" begin
13-
@test LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[]
13+
@test LuxDeviceUtils.GPU_DEVICE[] === nothing
1414

1515
if LuxCUDA.functional()
1616
@info "LuxCUDA is functional"
@@ -22,7 +22,7 @@ using LuxCUDA
2222
@test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(;
2323
force_gpu_usage=true)
2424
end
25-
@test !LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[]
25+
@test LuxDeviceUtils.GPU_DEVICE[] !== nothing
2626
end
2727

2828
using FillArrays, Zygote # Extensions

test/metal.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ end
1010
using Metal
1111

1212
@testset "Loaded Trigger Package" begin
13-
@test LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[]
13+
@test LuxDeviceUtils.GPU_DEVICE[] === nothing
1414

1515
if Metal.functional()
1616
@info "Metal is functional"
@@ -22,7 +22,7 @@ using Metal
2222
@test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(;
2323
force_gpu_usage=true)
2424
end
25-
@test !LuxDeviceUtils.ACCELERATOR_STATE_CHANGED[]
25+
@test LuxDeviceUtils.GPU_DEVICE[] !== nothing
2626
end
2727

2828
using FillArrays, Zygote # Extensions

test/runtests.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using SafeTestsets, Test, Pkg
1+
using Aqua, SafeTestsets, Test, Pkg
22
using LuxCore, LuxDeviceUtils
33

44
const GROUP = get(ENV, "GROUP", "CUDA")
@@ -41,4 +41,8 @@ end
4141
end
4242
end
4343
end
44+
45+
@testset "Aqua Tests" begin
46+
Aqua.test_all(LuxDeviceUtils; piracy=false)
47+
end
4448
end

0 commit comments

Comments
 (0)