Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit a79ca1d

Browse files
committed
Add adapt_structure for CA
1 parent db95b0a commit a79ca1d

7 files changed

+49
-8
lines changed

.JuliaFormatter.toml

-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@ always_use_return = true
44
margin = 92
55
indent = 4
66
format_docstrings = true
7-
join_lines_based_on_source = false
87
separate_kwargs_with_semicolon = true
98
always_for_in = true

Project.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxDeviceUtils"
22
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -14,13 +14,15 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1515

1616
[weakdeps]
17+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1718
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1819
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
1920
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
2021
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
2122
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2223

2324
[extensions]
25+
LuxDeviceUtilsComponentArraysExt = "ComponentArrays"
2426
LuxDeviceUtilsFillArraysExt = "FillArrays"
2527
LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU"
2628
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
@@ -30,6 +32,7 @@ LuxDeviceUtilsZygoteExt = "Zygote"
3032
[compat]
3133
Adapt = "3"
3234
ChainRulesCore = "1"
35+
ComponentArrays = "0.13, 0.14"
3336
FillArrays = "0.13, 1"
3437
Functors = "0.2, 0.3, 0.4"
3538
LuxAMDGPU = "0.1"
@@ -42,6 +45,7 @@ Zygote = "0.6"
4245
julia = "1.6"
4346

4447
[extras]
48+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
4549
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
4650
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
4751
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module LuxDeviceUtilsComponentArraysExt
2+
3+
# FIXME: Needs upstreaming
4+
using Adapt, ComponentArrays
5+
6+
function Adapt.adapt_structure(to, ca::ComponentArray)
7+
return ComponentArray(adapt(to, getdata(ca)), getaxes(ca))
8+
end
9+
10+
end

src/LuxDeviceUtils.jl

+10-6
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ Return a tuple of supported GPU backends.
6868
6969
This is not the list of functional backends on the system, but rather backends which
7070
`Lux.jl` supports.
71+
72+
!!! warning
73+
74+
`Metal.jl` support is **extremely** experimental and most things are not expected to
75+
work.
7176
"""
7277
supported_gpu_backends() = map(_get_device_name, GPU_DEVICES)
7378

@@ -87,8 +92,7 @@ Selects GPU device based on the following criteria:
8792
"""
8893
function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice
8994
if GPU_DEVICE[] !== nothing
90-
force_gpu_usage &&
91-
!(GPU_DEVICE[] isa AbstractLuxGPUDevice) &&
95+
force_gpu_usage && !(GPU_DEVICE[] isa AbstractLuxGPUDevice) &&
9296
throw(LuxDeviceSelectionException())
9397
return GPU_DEVICE[]
9498
end
@@ -202,10 +206,10 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU.
202206
"""
203207
@inline cpu_device() = LuxCPUDevice()
204208

205-
(::LuxCPUDevice)(x) = fmap(x -> adapt(LuxCPUAdaptor(), x), x; exclude=_isleaf)
206-
(::LuxCUDADevice)(x) = fmap(x -> adapt(LuxCUDAAdaptor(), x), x; exclude=_isleaf)
207-
(::LuxAMDGPUDevice)(x) = fmap(x -> adapt(LuxAMDGPUAdaptor(), x), x; exclude=_isleaf)
208-
(::LuxMetalDevice)(x) = fmap(x -> adapt(LuxMetalAdaptor(), x), x; exclude=_isleaf)
209+
(::LuxCPUDevice)(x) = fmap(Base.Fix1(adapt, LuxCPUAdaptor()), x; exclude=_isleaf)
210+
(::LuxCUDADevice)(x) = fmap(Base.Fix1(adapt, LuxCUDAAdaptor()), x; exclude=_isleaf)
211+
(::LuxAMDGPUDevice)(x) = fmap(Base.Fix1(adapt, LuxAMDGPUAdaptor()), x; exclude=_isleaf)
212+
(::LuxMetalDevice)(x) = fmap(Base.Fix1(adapt, LuxMetalAdaptor()), x; exclude=_isleaf)
209213

210214
for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice)
211215
@eval begin

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
34
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
45
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
56
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

test/component_arrays.jl

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using LuxDeviceUtils, ComponentArrays, Random
2+
3+
@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin
4+
dev = LuxCPUDevice()
5+
ps = (; weight=randn(10, 1), bias=randn(1))
6+
7+
ps_ca = ps |> ComponentArray
8+
9+
ps_ca_dev = ps_ca |> dev
10+
11+
@test ps_ca_dev isa ComponentArray
12+
13+
@test ps_ca_dev.weight == ps.weight
14+
@test ps_ca_dev.bias == ps.bias
15+
16+
@test ps_ca_dev == (ps |> dev |> ComponentArray)
17+
end

test/runtests.jl

+6
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,10 @@ end
4747
Aqua.test_all(LuxDeviceUtils; piracy=false)
4848
end
4949
end
50+
51+
@testset "Others" begin
52+
@safetestset "Component Arrays" begin
53+
include("component_arrays.jl")
54+
end
55+
end
5056
end

0 commit comments

Comments
 (0)