Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit b3bef22

Browse files
committed
test: RNG movement
1 parent 46fccbb commit b3bef22

File tree

5 files changed

+40
-0
lines changed

5 files changed

+40
-0
lines changed

test/amdgpu_tests.jl

+8
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ using FillArrays, Zygote # Extensions
5757
@test ps_xpu.e == ps.e
5858
@test ps_xpu.d == ps.d
5959
@test ps_xpu.rng_default isa rngType
60+
@test get_device(ps_xpu.rng_default) isa AMDGPUDevice
61+
@test get_device_type(ps_xpu.rng_default) <: AMDGPUDevice
6062
@test ps_xpu.rng == ps.rng
63+
@test get_device(ps_xpu.rng) === nothing
64+
@test get_device_type(ps_xpu.rng) <: Nothing
6165

6266
if MLDataDevices.functional(AMDGPUDevice)
6367
@test ps_xpu.one_elem isa ROCArray
@@ -83,7 +87,11 @@ using FillArrays, Zygote # Extensions
8387
@test ps_cpu.e == ps.e
8488
@test ps_cpu.d == ps.d
8589
@test ps_cpu.rng_default isa Random.TaskLocalRNG
90+
@test get_device(ps_cpu.rng_default) === nothing
91+
@test get_device_type(ps_cpu.rng_default) <: Nothing
8692
@test ps_cpu.rng == ps.rng
93+
@test get_device(ps_cpu.rng) === nothing
94+
@test get_device_type(ps_cpu.rng) <: Nothing
8795

8896
if MLDataDevices.functional(AMDGPUDevice)
8997
@test ps_cpu.one_elem isa Array

test/cuda_tests.jl

+8
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ using FillArrays, Zygote # Extensions
5656
@test ps_xpu.e == ps.e
5757
@test ps_xpu.d == ps.d
5858
@test ps_xpu.rng_default isa rngType
59+
@test get_device(ps_xpu.rng_default) isa CUDADevice
60+
@test get_device_type(ps_xpu.rng_default) <: CUDADevice
5961
@test ps_xpu.rng == ps.rng
62+
@test get_device(ps_xpu.rng) === nothing
63+
@test get_device_type(ps_xpu.rng) <: Nothing
6064

6165
if MLDataDevices.functional(CUDADevice)
6266
@test ps_xpu.one_elem isa CuArray
@@ -82,7 +86,11 @@ using FillArrays, Zygote # Extensions
8286
@test ps_cpu.e == ps.e
8387
@test ps_cpu.d == ps.d
8488
@test ps_cpu.rng_default isa Random.TaskLocalRNG
89+
@test get_device(ps_cpu.rng_default) === nothing
90+
@test get_device_type(ps_cpu.rng_default) <: Nothing
8591
@test ps_cpu.rng == ps.rng
92+
@test get_device(ps_cpu.rng) === nothing
93+
@test get_device_type(ps_cpu.rng) <: Nothing
8694

8795
if MLDataDevices.functional(CUDADevice)
8896
@test ps_cpu.one_elem isa Array

test/metal_tests.jl

+8
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions
5555
@test ps_xpu.e == ps.e
5656
@test ps_xpu.d == ps.d
5757
@test ps_xpu.rng_default isa rngType
58+
@test get_device(ps_xpu.rng_default) isa MetalDevice
59+
@test get_device_type(ps_xpu.rng_default) <: MetalDevice
5860
@test ps_xpu.rng == ps.rng
61+
@test get_device(ps_xpu.rng) === nothing
62+
@test get_device_type(ps_xpu.rng) <: Nothing
5963

6064
if MLDataDevices.functional(MetalDevice)
6165
@test ps_xpu.one_elem isa MtlArray
@@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions
8185
@test ps_cpu.e == ps.e
8286
@test ps_cpu.d == ps.d
8387
@test ps_cpu.rng_default isa Random.TaskLocalRNG
88+
@test get_device(ps_cpu.rng_default) === nothing
89+
@test get_device_type(ps_cpu.rng_default) <: Nothing
8490
@test ps_cpu.rng == ps.rng
91+
@test get_device(ps_cpu.rng) === nothing
92+
@test get_device_type(ps_cpu.rng) <: Nothing
8593

8694
if MLDataDevices.functional(MetalDevice)
8795
@test ps_cpu.one_elem isa Array

test/oneapi_tests.jl

+8
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions
5555
@test ps_xpu.e == ps.e
5656
@test ps_xpu.d == ps.d
5757
@test ps_xpu.rng_default isa rngType
58+
@test get_device(ps_xpu.rng_default) isa oneAPIDevice
59+
@test get_device_type(ps_xpu.rng_default) <: oneAPIDevice
5860
@test ps_xpu.rng == ps.rng
61+
@test get_device(ps_xpu.rng) === nothing
62+
@test get_device_type(ps_xpu.rng) <: Nothing
5963

6064
if MLDataDevices.functional(oneAPIDevice)
6165
@test ps_xpu.one_elem isa oneArray
@@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions
8185
@test ps_cpu.e == ps.e
8286
@test ps_cpu.d == ps.d
8387
@test ps_cpu.rng_default isa Random.TaskLocalRNG
88+
@test get_device(ps_cpu.rng_default) === nothing
89+
@test get_device_type(ps_cpu.rng_default) <: Nothing
8490
@test ps_cpu.rng == ps.rng
91+
@test get_device(ps_cpu.rng) === nothing
92+
@test get_device_type(ps_cpu.rng) <: Nothing
8593

8694
if MLDataDevices.functional(oneAPIDevice)
8795
@test ps_cpu.one_elem isa Array

test/xla_tests.jl

+8
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@ using FillArrays, Zygote # Extensions
5454
@test ps_xpu.e == ps.e
5555
@test ps_xpu.d == ps.d
5656
@test ps_xpu.rng_default isa rngType
57+
@test get_device(ps_xpu.rng_default) === nothing
58+
@test get_device_type(ps_xpu.rng_default) <: Nothing
5759
@test ps_xpu.rng == ps.rng
60+
@test get_device(ps_xpu.rng) === nothing
61+
@test get_device_type(ps_xpu.rng) <: Nothing
5862

5963
if MLDataDevices.functional(XLADevice)
6064
@test ps_xpu.one_elem isa Reactant.RArray
@@ -80,7 +84,11 @@ using FillArrays, Zygote # Extensions
8084
@test ps_cpu.e == ps.e
8185
@test ps_cpu.d == ps.d
8286
@test ps_cpu.rng_default isa Random.TaskLocalRNG
87+
@test get_device(ps_cpu.rng_default) === nothing
88+
@test get_device_type(ps_cpu.rng_default) <: Nothing
8389
@test ps_cpu.rng == ps.rng
90+
@test get_device(ps_cpu.rng) === nothing
91+
@test get_device_type(ps_cpu.rng) <: Nothing
8492

8593
if MLDataDevices.functional(XLADevice)
8694
@test ps_cpu.one_elem isa Array

0 commit comments

Comments
 (0)