Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit cfb396c

Browse files
committed
Add fast and type stable paths for certain datastructures
1 parent 0816490 commit cfb396c

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxDeviceUtils"
22
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "0.1.7"
4+
version = "0.1.8"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/LuxDeviceUtils.jl

+18-7
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,25 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU.
209209
"""
210210
@inline cpu_device() = LuxCPUDevice()
211211

212-
(::LuxCPUDevice)(x) = fmap(Base.Fix1(adapt, LuxCPUAdaptor()), x; exclude=_isleaf)
213-
(::LuxCUDADevice)(x) = fmap(Base.Fix1(adapt, LuxCUDAAdaptor()), x; exclude=_isleaf)
214-
(::LuxAMDGPUDevice)(x) = fmap(Base.Fix1(adapt, LuxAMDGPUAdaptor()), x; exclude=_isleaf)
215-
(::LuxMetalDevice)(x) = fmap(Base.Fix1(adapt, LuxMetalAdaptor()), x; exclude=_isleaf)
216-
217-
for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice)
212+
# Dispatches for Different Data Structures
213+
# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability
214+
# For all other types we rely on fmap which means we lose type stability.
215+
# For Lux, typically models only has these 3 datastructures so we should be mostly fine.
216+
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal)
217+
ldev = Symbol("Lux$(dev)Device")
218+
ladaptor = Symbol("Lux$(dev)Adaptor")
218219
@eval begin
219-
function (::$dev)(::LuxCore.AbstractExplicitLayer)
220+
function (::$(ldev))(x::AbstractArray)
221+
fn = Base.Fix1(adapt, $(ladaptor)())
222+
return _isbitsarray(x) ? fn(x) : map(fn, x)
223+
end
224+
(::$(ldev))(x::Tuple) = map(Base.Fix1(adapt, $(ladaptor)()), x)
225+
(::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}($(ldev)(values(x)))
226+
function (::$(ldev))(x)
227+
_isleaf(x) && return adapt($(ladaptor)(), x)
228+
return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf)
229+
end
230+
function (::$(ldev))(::LuxCore.AbstractExplicitLayer)
220231
throw(ArgumentError("Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`."))
221232
end
222233
end

0 commit comments

Comments
 (0)