@@ -68,6 +68,11 @@ Return a tuple of supported GPU backends.
68
68
69
69
This is not the list of functional backends on the system, but rather backends which
70
70
`Lux.jl` supports.
71
+
72
+ !!! warning
73
+
74
+ `Metal.jl` support is **extremely** experimental and most things are not expected to
75
+ work.
71
76
"""
72
77
supported_gpu_backends () = map (_get_device_name, GPU_DEVICES)
73
78
@@ -87,8 +92,7 @@ Selects GPU device based on the following criteria:
87
92
"""
88
93
function gpu_device (; force_gpu_usage:: Bool = false ):: AbstractLuxDevice
89
94
if GPU_DEVICE[] != = nothing
90
- force_gpu_usage &&
91
- ! (GPU_DEVICE[] isa AbstractLuxGPUDevice) &&
95
+ force_gpu_usage && ! (GPU_DEVICE[] isa AbstractLuxGPUDevice) &&
92
96
throw (LuxDeviceSelectionException ())
93
97
return GPU_DEVICE[]
94
98
end
@@ -202,10 +206,10 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU.
202
206
"""
203
207
@inline cpu_device () = LuxCPUDevice ()
204
208
205
- (:: LuxCPUDevice )(x) = fmap (x -> adapt ( LuxCPUAdaptor (), x ), x; exclude= _isleaf)
206
- (:: LuxCUDADevice )(x) = fmap (x -> adapt ( LuxCUDAAdaptor (), x ), x; exclude= _isleaf)
207
- (:: LuxAMDGPUDevice )(x) = fmap (x -> adapt ( LuxAMDGPUAdaptor (), x ), x; exclude= _isleaf)
208
- (:: LuxMetalDevice )(x) = fmap (x -> adapt ( LuxMetalAdaptor (), x ), x; exclude= _isleaf)
209
+ (:: LuxCPUDevice )(x) = fmap (Base . Fix1 (adapt, LuxCPUAdaptor ()), x; exclude= _isleaf)
210
+ (:: LuxCUDADevice )(x) = fmap (Base . Fix1 (adapt, LuxCUDAAdaptor ()), x; exclude= _isleaf)
211
+ (:: LuxAMDGPUDevice )(x) = fmap (Base . Fix1 (adapt, LuxAMDGPUAdaptor ()), x; exclude= _isleaf)
212
+ (:: LuxMetalDevice )(x) = fmap (Base . Fix1 (adapt, LuxMetalAdaptor ()), x; exclude= _isleaf)
209
213
210
214
for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice)
211
215
@eval begin
0 commit comments