66
77# nvmath-python: NVIDIA Math Libraries for the Python Ecosystem
88
9- nvmath-python brings the power of the NVIDIA math libraries to the Python ecosystem. The
10- package aims to provide intuitive pythonic APIs that provide users full access to all the
9+ nvmath-python brings the power of the NVIDIA math libraries to the Python ecosystem.
10+ The package aims to provide intuitive pythonic APIs giving users full access to all
1111features offered by NVIDIA's libraries in a variety of execution spaces. nvmath-python works
1212seamlessly with existing Python array/tensor frameworks and focuses on providing
1313functionality that is missing from those frameworks.
1414
1515## Some Examples
1616
17- Using the nvmath-python API allows access to all parameters of the underlying NVIDIA
18- cuBLASLt library. Some of these parameters are unavailable in other wrappings of NVIDIA's
19- C-API libraries.
17+ Below are a few representative examples showcasing the three main categories of
18+ features nvmath-python offers: host, device, and distributed APIs.
19+
20+ ### Host APIs
21+
22+ Host APIs are called from host code but can execute in any supported execution
23+ space (CPU or GPU). The following example shows how to compute a matrix multiplication
24+ on CuPy matrices. Using the nvmath-python API allows access to * all* parameters
25+ of the underlying NVIDIA cuBLASLt library, a distinguishing feature of nvmath-python
26+ from other wrappings of NVIDIA's C-API libraries.
2027
2128``` python
2229import cupy as cp
@@ -61,6 +68,42 @@ print(f"Input types = {type(a), type(b)}, device = {a.device, b.device}")
6168print (f " Result type = { type (result)} , device = { result.device} " )
6269```
6370
71+ nvmath-python provides the ability to write custom prologs and epilogs for FFT functions as
72+ Python functions and compile them to LTO-IR. For example, to have unitary scaling for an
73+ FFT, we can define an epilog which rescales the output by ` 1/sqrt(N) ` .
74+
75+ ``` python
76+ import cupy as cp
77+ import nvmath
78+ import math
79+
80+ # Create the data for the batched 1-D FFT.
81+ B, N = 256 , 1024
82+ a = cp.random.rand(B, N, dtype = cp.float64) + 1j * cp.random.rand(B, N, dtype = cp.float64)
83+
84+ # Compute the normalization factor for unitary transforms
85+ norm_factor = 1.0 / math.sqrt(N)
86+
87+ # Define the epilog function for the FFT.
88+ def rescale (data_out , offset , data , user_info , unused ):
89+ data_out[offset] = data * norm_factor
90+
91+ # Compile the epilog to LTO-IR.
92+ with cp.cuda.Device():
93+ epilog = nvmath.fft.compile_epilog(rescale, " complex128" , " complex128" )
94+
95+ # Perform the forward FFT, applying the filter as a epilog...
96+ r = nvmath.fft.fft(a, axes = [- 1 ], epilog = {" ltoir" : epilog})
97+
98+ # Finally, we can test that the fused FFT run result matches the result of separate
99+ # calls
100+ s = cp.fft.fftn(a, axes = [- 1 ], norm = " ortho" )
101+
102+ assert cp.allclose(r, s)
103+ ```
104+
105+ ### Device-side APIs
106+
64107nvmath-python exposes NVIDIA's device-side (Dx) APIs. This allows developers to call NVIDIA
65108library functions inside their custom device kernels. For example, a numba jit function can
66109call cuFFT in order to implement FFT-based convolution.
@@ -91,7 +134,6 @@ def main():
91134 ffts_per_block = ffts_per_block,
92135 elements_per_thread = 2 ,
93136 execution = " Block" ,
94- compiler = " numba" ,
95137 )
96138 FFT_inv = fft(
97139 fft_type = " c2c" ,
@@ -101,41 +143,35 @@ def main():
101143 ffts_per_block = ffts_per_block,
102144 elements_per_thread = 2 ,
103145 execution = " Block" ,
104- compiler = " numba" ,
105146 )
106147
107- value_type = FFT_fwd .value_type
108- storage_size = FFT_fwd .storage_size
109- shared_memory_size = FFT_fwd .shared_memory_size
110- fft_stride = FFT_fwd .stride
111- ept = FFT_fwd .elements_per_thread
112- block_dim = FFT_fwd .block_dim
113-
114148 # Define a numba jit function targeting CUDA devices
115- @cuda.jit ( link = FFT_fwd .files + FFT_inv .files)
149+ @cuda.jit
116150 def f (signal , filter ):
117151
118- thread_data = cuda.local.array(shape = (storage_size,), dtype = value_type)
119- shared_mem = cuda.shared.array(shape = (0 ,), dtype = value_type)
152+ thread_data = cuda.local.array(
153+ shape = (FFT_fwd .storage_size,), dtype = FFT_fwd .value_type,
154+ )
155+ shared_mem = cuda.shared.array(shape = (0 ,), dtype = FFT_fwd .value_type)
120156
121157 fft_id = (cuda.blockIdx.x * ffts_per_block) + cuda.threadIdx.y
122158 if (fft_id >= batch_size):
123159 return
124160 offset = cuda.threadIdx.x
125161
126- for i in range (ept ):
127- thread_data[i] = signal[fft_id, offset + i * fft_stride ]
162+ for i in range (FFT_fwd .elements_per_thread ):
163+ thread_data[i] = signal[fft_id, offset + i * FFT_fwd .stride ]
128164
129165 # Call the cuFFTDx FFT function from *inside* your custom function
130166 FFT_fwd(thread_data, shared_mem)
131167
132- for i in range (ept ):
133- thread_data[i] = thread_data[i] * filter [fft_id, offset + i * fft_stride ]
168+ for i in range (FFT_fwd .elements_per_thread ):
169+ thread_data[i] *= filter [fft_id, offset + i * FFT_fwd .stride ]
134170
135171 FFT_inv(thread_data, shared_mem)
136172
137- for i in range (ept ):
138- signal[fft_id, offset + i * fft_stride ] = thread_data[i]
173+ for i in range (FFT_fwd .elements_per_thread ):
174+ signal[fft_id, offset + i * FFT_fwd .stride ] = thread_data[i]
139175
140176
141177 data = random_complex((ffts_per_block, size), np.float32)
@@ -144,7 +180,7 @@ def main():
144180 data_d = cuda.to_device(data)
145181 filter_d = cuda.to_device(filter )
146182
147- f[1 , block_dim, 0 , shared_memory_size](data_d, filter_d)
183+ f[1 , FFT_fwd . block_dim, 0 , FFT_fwd . shared_memory_size](data_d, filter_d)
148184 cuda.synchronize()
149185
150186 data_test = data_d.copy_to_host()
@@ -159,38 +195,79 @@ if __name__ == "__main__":
159195 main()
160196```
161197
162- nvmath-python provides the ability to write custom prologs and epilogs for FFT functions as
163- a Python functions and compiled them LTO-IR. For example, to have unitary scaling for an
164- FFT, we can define an epilog which rescales the output by 1/sqrt(N).
198+ ### Distributed APIs
199+
200+ Distributed APIs are called from host code but execute on a distributed
201+ (multi-node multi-GPU) system. The following example shows the use of the
202+ function-form distributed FFT with CuPy ndarrays:
165203
166204``` python
167205import cupy as cp
168- import nvmath
169- import math
170-
171- # Create the data for the batched 1-D FFT.
172- B, N = 256 , 1024
173- a = cp.random.rand(B, N, dtype = cp.float64) + 1j * cp.random.rand(B, N, dtype = cp.float64)
174-
175- # Compute the normalization factor for unitary transforms
176- norm_factor = 1.0 / math.sqrt(N)
177-
178- # Define the epilog function for the FFT.
179- def rescale (data_out , offset , data , user_info , unused ):
180- data_out[offset] = data * norm_factor
181-
182- # Compile the epilog to LTO-IR.
183- with cp.cuda.Device():
184- epilog = nvmath.fft.compile_epilog(rescale, " complex128" , " complex128" )
185-
186- # Perform the forward FFT, applying the filter as a epilog...
187- r = nvmath.fft.fft(a, axes = [- 1 ], epilog = {" ltoir" : epilog})
188-
189- # Finally, we can test that the fused FFT run result matches the result of separate
190- # calls
191- s = cp.fft.fftn(a, axes = [- 1 ], norm = " ortho" )
192-
193- assert cp.allclose(r, s)
206+ from mpi4py import MPI
207+
208+ import nvmath.distributed
209+ from nvmath.distributed.distribution import Slab
210+
211+ # Initialize nvmath.distributed.
212+ comm = MPI .COMM_WORLD
213+ rank = comm.Get_rank()
214+ nranks = comm.Get_size()
215+ device_id = rank % cp.cuda.runtime.getDeviceCount()
216+ nvmath.distributed.initialize(device_id, comm, backends = [" nvshmem" ])
217+
218+ # The global 3-D FFT size is (512, 256, 512).
219+ # In this example, the input data is distributed across processes according to
220+ # the cuFFTMp Slab distribution on the X axis.
221+ shape = 512 // nranks, 256 , 512
222+
223+ # cuFFTMp uses the NVSHMEM PGAS model for distributed computation, which requires GPU
224+ # operands to be on the symmetric heap.
225+ a = nvmath.distributed.allocate_symmetric_memory(shape, cp, dtype = cp.complex128)
226+ # a is a cupy ndarray and can be operated on using in-place cupy operations.
227+ with cp.cuda.Device(device_id):
228+ a[:] = cp.random.rand(* shape, dtype = cp.float64) + 1j *
229+ cp.random.rand(* shape, dtype = cp.float64)
230+
231+ # Forward FFT.
232+ # In this example, the forward FFT operand is distributed according
233+ # to Slab.X distribution. With reshape=False, the FFT result will be
234+ # distributed according to Slab.Y distribution.
235+ b = nvmath.distributed.fft.fft(a, distribution = Slab.X, options = {" reshape" : False })
236+
237+ # Distributed FFT performs computations in-place. The result is stored in the same
238+ # buffer as operand a. Note, however, that operand b has a different shape (due
239+ # to Slab.Y distribution).
240+ if rank == 0 :
241+ print (f " Shape of a on rank { rank} is { a.shape} " )
242+ print (f " Shape of b on rank { rank} is { b.shape} " )
243+
244+ # Inverse FFT.
245+ # Recall from previous transform that the inverse FFT operand is distributed according
246+ # to Slab.Y. With reshape=False, the inverse FFT result will be distributed according
247+ # to Slab.X distribution.
248+ c = nvmath.distributed.fft.ifft(b, distribution = Slab.Y, options = {" reshape" : False })
249+
250+ # The shape of c is the same as a (due to Slab.X distribution). Once again, note that
251+ # a, b and c are sharing the same symmetric memory buffer (distributed FFT operations
252+ # are in-place).
253+ if rank == 0 :
254+ print (f " Shape of c on rank { rank} is { c.shape} " )
255+
256+ # Synchronize the default stream
257+ with cp.cuda.Device(device_id):
258+ cp.cuda.get_current_stream().synchronize()
259+
260+ if rank == 0 :
261+ print (f " Input type = { type (a)} , device = { a.device} " )
262+ print (f " FFT output type = { type (b)} , device = { b.device} " )
263+ print (f " IFFT output type = { type (c)} , device = { c.device} " )
264+
265+ # GPU operands on the symmetric heap are not garbage-collected and the user is
266+ # responsible for freeing any that they own (this deallocation is a collective
267+ # operation that must be called by all processes at the same point in the execution).
268+ # All cuFFTMp operations are inplace (a, b, and c share the same memory buffer), so
269+ # we take care to only free the buffer once.
270+ nvmath.distributed.free_symmetric_memory(a)
194271```
195272
196273## License
0 commit comments