Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit c5586c2

Browse files
authored
Merge pull request #15 from LuxDL/ap/fastpaths
Add fast and type stable paths for certain datastructures
2 parents 0816490 + 8b741dc commit c5586c2

File tree

4 files changed

+20
-22
lines changed

4 files changed

+20
-22
lines changed

Project.toml

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxDeviceUtils"
22
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "0.1.7"
4+
version = "0.1.8"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -14,15 +14,13 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1515

1616
[weakdeps]
17-
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1817
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1918
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
2019
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
2120
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
2221
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2322

2423
[extensions]
25-
LuxDeviceUtilsComponentArraysExt = "ComponentArrays"
2624
LuxDeviceUtilsFillArraysExt = "FillArrays"
2725
LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU"
2826
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
@@ -32,7 +30,6 @@ LuxDeviceUtilsZygoteExt = "Zygote"
3230
[compat]
3331
Adapt = "3"
3432
ChainRulesCore = "1"
35-
ComponentArrays = "0.13, 0.14"
3633
FillArrays = "0.13, 1"
3734
Functors = "0.2, 0.3, 0.4"
3835
LuxAMDGPU = "0.1"
@@ -45,7 +42,6 @@ Zygote = "0.6"
4542
julia = "1.6"
4643

4744
[extras]
48-
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
4945
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
5046
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
5147
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"

ext/LuxDeviceUtilsComponentArraysExt.jl

-10
This file was deleted.

src/LuxDeviceUtils.jl

+18-7
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,25 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU.
209209
"""
210210
@inline cpu_device() = LuxCPUDevice()
211211

212-
(::LuxCPUDevice)(x) = fmap(Base.Fix1(adapt, LuxCPUAdaptor()), x; exclude=_isleaf)
213-
(::LuxCUDADevice)(x) = fmap(Base.Fix1(adapt, LuxCUDAAdaptor()), x; exclude=_isleaf)
214-
(::LuxAMDGPUDevice)(x) = fmap(Base.Fix1(adapt, LuxAMDGPUAdaptor()), x; exclude=_isleaf)
215-
(::LuxMetalDevice)(x) = fmap(Base.Fix1(adapt, LuxMetalAdaptor()), x; exclude=_isleaf)
216-
217-
for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice)
212+
# Dispatches for Different Data Structures
213+
# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability
214+
# For all other types we rely on fmap which means we lose type stability.
215+
# For Lux, typically models only has these 3 datastructures so we should be mostly fine.
216+
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal)
217+
ldev = Symbol("Lux$(dev)Device")
218+
ladaptor = Symbol("Lux$(dev)Adaptor")
218219
@eval begin
219-
function (::$dev)(::LuxCore.AbstractExplicitLayer)
220+
function (::$(ldev))(x::AbstractArray)
221+
fn = Base.Fix1(adapt, $(ladaptor)())
222+
return _isbitsarray(x) ? fn(x) : map(fn, x)
223+
end
224+
(::$(ldev))(x::Tuple) = map(Base.Fix1(adapt, $(ladaptor)()), x)
225+
(dev::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(dev(values(x)))
226+
function (::$(ldev))(x)
227+
_isleaf(x) && return adapt($(ladaptor)(), x)
228+
return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf)
229+
end
230+
function (::$(ldev))(::LuxCore.AbstractExplicitLayer)
220231
throw(ArgumentError("Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`."))
221232
end
222233
end

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1010
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1111

1212
[compat]
13+
ComponentArrays = "0.14.1"
1314
julia = "1.6"

0 commit comments

Comments
 (0)