Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit 37429fb

Browse files
committed
test: functions and closures
1 parent b3bef22 commit 37429fb

File tree

5 files changed

+105
-0
lines changed

5 files changed

+105
-0
lines changed

test/amdgpu_tests.jl

+21
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,27 @@ using FillArrays, Zygote # Extensions
126126
end
127127
end
128128

129+
@testset "Functions" begin
130+
if MLDataDevices.functional(AMDGPUDevice)
131+
@test get_device(tanh) isa MLDataDevices.UnknownDevice
132+
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice
133+
134+
f(x, y) = () -> (x, x .^ 2, y)
135+
136+
ff = f([1, 2, 3], 1)
137+
@test get_device(ff) isa CPUDevice
138+
@test get_device_type(ff) <: CPUDevice
139+
140+
ff_xpu = ff |> AMDGPUDevice()
141+
@test get_device(ff_xpu) isa AMDGPUDevice
142+
@test get_device_type(ff_xpu) <: AMDGPUDevice
143+
144+
ff_cpu = ff_xpu |> cpu_device()
145+
@test get_device(ff_cpu) isa CPUDevice
146+
@test get_device_type(ff_cpu) <: CPUDevice
147+
end
148+
end
149+
129150
@testset "Wrapped Arrays" begin
130151
if MLDataDevices.functional(AMDGPUDevice)
131152
x = rand(10, 10) |> AMDGPUDevice()

test/cuda_tests.jl

+21
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,27 @@ using FillArrays, Zygote # Extensions
151151
end
152152
end
153153

154+
@testset "Functions" begin
155+
if MLDataDevices.functional(CUDADevice)
156+
@test get_device(tanh) isa MLDataDevices.UnknownDevice
157+
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice
158+
159+
f(x, y) = () -> (x, x .^ 2, y)
160+
161+
ff = f([1, 2, 3], 1)
162+
@test get_device(ff) isa CPUDevice
163+
@test get_device_type(ff) <: CPUDevice
164+
165+
ff_xpu = ff |> CUDADevice()
166+
@test get_device(ff_xpu) isa CUDADevice
167+
@test get_device_type(ff_xpu) <: CUDADevice
168+
169+
ff_cpu = ff_xpu |> cpu_device()
170+
@test get_device(ff_cpu) isa CPUDevice
171+
@test get_device_type(ff_cpu) <: CPUDevice
172+
end
173+
end
174+
154175
@testset "Wrapped Arrays" begin
155176
if MLDataDevices.functional(CUDADevice)
156177
x = rand(10, 10) |> CUDADevice()

test/metal_tests.jl

+21
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,27 @@ using FillArrays, Zygote # Extensions
115115
end
116116
end
117117

118+
@testset "Functions" begin
119+
if MLDataDevices.functional(MetalDevice)
120+
@test get_device(tanh) isa MLDataDevices.UnknownDevice
121+
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice
122+
123+
f(x, y) = () -> (x, x .^ 2, y)
124+
125+
ff = f([1, 2, 3], 1)
126+
@test get_device(ff) isa CPUDevice
127+
@test get_device_type(ff) <: CPUDevice
128+
129+
ff_xpu = ff |> MetalDevice()
130+
@test get_device(ff_xpu) isa MetalDevice
131+
@test get_device_type(ff_xpu) <: MetalDevice
132+
133+
ff_cpu = ff_xpu |> cpu_device()
134+
@test get_device(ff_cpu) isa CPUDevice
135+
@test get_device_type(ff_cpu) <: CPUDevice
136+
end
137+
end
138+
118139
@testset "Wrapper Arrays" begin
119140
if MLDataDevices.functional(MetalDevice)
120141
x = rand(Float32, 10, 10) |> MetalDevice()

test/oneapi_tests.jl

+21
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,27 @@ using FillArrays, Zygote # Extensions
115115
end
116116
end
117117

118+
@testset "Functions" begin
119+
if MLDataDevices.functional(oneAPIDevice)
120+
@test get_device(tanh) isa MLDataDevices.UnknownDevice
121+
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice
122+
123+
f(x, y) = () -> (x, x .^ 2, y)
124+
125+
ff = f([1, 2, 3], 1)
126+
@test get_device(ff) isa CPUDevice
127+
@test get_device_type(ff) <: CPUDevice
128+
129+
ff_xpu = ff |> oneAPIDevice()
130+
@test get_device(ff_xpu) isa oneAPIDevice
131+
@test get_device_type(ff_xpu) <: oneAPIDevice
132+
133+
ff_cpu = ff_xpu |> cpu_device()
134+
@test get_device(ff_cpu) isa CPUDevice
135+
@test get_device_type(ff_cpu) <: CPUDevice
136+
end
137+
end
138+
118139
@testset "Wrapper Arrays" begin
119140
if MLDataDevices.functional(oneAPIDevice)
120141
x = rand(10, 10) |> oneAPIDevice()

test/xla_tests.jl

+21
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,27 @@ using FillArrays, Zygote # Extensions
114114
end
115115
end
116116

117+
@testset "Functions" begin
118+
if MLDataDevices.functional(XLADevice)
119+
@test get_device(tanh) isa MLDataDevices.UnknownDevice
120+
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice
121+
122+
f(x, y) = () -> (x, x .^ 2, y)
123+
124+
ff = f([1, 2, 3], 1)
125+
@test get_device(ff) isa CPUDevice
126+
@test get_device_type(ff) <: CPUDevice
127+
128+
ff_xpu = ff |> XLADevice()
129+
@test get_device(ff_xpu) isa XLADevice
130+
@test get_device_type(ff_xpu) <: XLADevice
131+
132+
ff_cpu = ff_xpu |> cpu_device()
133+
@test get_device(ff_cpu) isa CPUDevice
134+
@test get_device_type(ff_cpu) <: CPUDevice
135+
end
136+
end
137+
117138
@testset "Wrapped Arrays" begin
118139
if MLDataDevices.functional(XLADevice)
119140
x = rand(10, 10) |> XLADevice()

0 commit comments

Comments
 (0)