Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit 46fccbb

Browse files
committed
feat: handle RNGs and undef arrays gracefully
1 parent 6c3a4a7 commit 46fccbb

7 files changed

+16
-11
lines changed

.buildkite/pipeline.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
steps:
22
- label: "Triggering Pipelines (Pull Request)"
3-
if: "build.pull_request.base_branch == 'main'"
3+
if: build.branch != "main" && build.tag == null
44
agents:
55
queue: "juliagpu"
66
plugins:

ext/MLDataDevicesAMDGPUExt.jl

+2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ function Internal.get_device(x::AMDGPU.AnyROCArray)
4949
parent_x === x && return AMDGPUDevice(AMDGPU.device(x))
5050
return Internal.get_device(parent_x)
5151
end
52+
Internal.get_device(::AMDGPU.rocRAND.RNG) = AMDGPUDevice(AMDGPU.device())
5253

5354
Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice
55+
Internal.get_device_type(::AMDGPU.rocRAND.RNG) = AMDGPUDevice
5456

5557
# Set Device
5658
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)

ext/MLDataDevicesCUDAExt.jl

+4
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,12 @@ function Internal.get_device(x::CUDA.AnyCuArray)
2929
return MLDataDevices.get_device(parent_x)
3030
end
3131
Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal))
32+
Internal.get_device(::CUDA.RNG) = CUDADevice(CUDA.device())
33+
Internal.get_device(::CUDA.CURAND.RNG) = CUDADevice(CUDA.device())
3234

3335
Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice
36+
Internal.get_device_type(::CUDA.RNG) = CUDADevice
37+
Internal.get_device_type(::CUDA.CURAND.RNG) = CUDADevice
3438

3539
# Set Device
3640
MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev)

ext/MLDataDevicesGPUArraysExt.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@ module MLDataDevicesGPUArraysExt
22

33
using Adapt: Adapt
44
using GPUArrays: GPUArrays
5-
using MLDataDevices: CPUDevice
5+
using MLDataDevices: Internal, CPUDevice
66
using Random: Random
77

88
Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng()
99

10+
Internal.get_device(rng::GPUArrays.RNG) = Internal.get_device(rng.state)
11+
Internal.get_device_type(rng::GPUArrays.RNG) = Internal.get_device_type(rng.state)
12+
1013
end

src/internal.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,16 @@ end
129129

130130
for op in (:get_device, :get_device_type)
131131
cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice
132+
unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice
132133
not_assigned_msg = "AbstractArray has some undefined references. Giving up, returning \
133-
$(cpu_ret_val)..."
134+
$(unknown_ret_val)..."
134135

135136
@eval begin
136137
function $(op)(x::AbstractArray{T}) where {T}
137138
if recursive_array_eltype(T)
138139
if any(!isassigned(x, i) for i in eachindex(x))
139140
@warn $(not_assigned_msg)
140-
return $(cpu_ret_val)
141+
return $(unknown_ret_val)
141142
end
142143
return mapreduce(MLDataDevices.$(op), combine_devices, x)
143144
end

src/public.jl

-5
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,6 @@ const GET_DEVICE_ADMONITIONS = """
232232
!!! note
233233
234234
Trigger Packages must be loaded for this to return the correct device.
235-
236-
!!! warning
237-
238-
RNG types currently don't participate in device determination. We will remove this
239-
restriction in the future.
240235
"""
241236

242237
# Query Device from Array

test/misc_tests.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,6 @@ end
154154
@testset "undefined references array" begin
155155
x = Matrix{Any}(undef, 10, 10)
156156

157-
@test get_device(x) isa CPUDevice
158-
@test get_device_type(x) <: CPUDevice
157+
@test get_device(x) isa MLDataDevices.UnknownDevice
158+
@test get_device_type(x) <: MLDataDevices.UnknownDevice
159159
end

0 commit comments

Comments
 (0)