-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathbasic_trmm.cu
550 lines (434 loc) · 15.5 KB
/
basic_trmm.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
This example demonstrates how to call a CUTLASS TRMM kernel and provides a naive reference
matrix multiply kernel to verify its correctness.
The CUTLASS Trmm template is instantiated in the function CutlassStrmmNN. This is kernel computes
the triangular matrix product (TRMM) using double-precision floating-point arithmetic and assumes
all matrices have column-major layout.
The threadblock tile size is chosen as 64x64x16 which offers good performance for large matrices.
See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available
in CUTLASS.
https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/
Aside from defining and launching the STRMM kernel, this example does not use any other components
or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are
prevalent in the CUTLASS unit tests.
*/
// Standard Library includes
#include <iostream>
#include <sstream>
#include <vector>
// Helper methods to check for errors
#include "helper.h"
//
// CUTLASS includes needed for double-precision TRMM kernel
//
// Defines cutlass::gemm::device::Trmm, the generic Trmm computation template class.
#include "cutlass/gemm/device/trmm.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// This function defines a CUTLASS TRMM kernel instantiation, constructs its parameters object,
// and launches it on the CUDA device.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Define a CUTLASS TRMM template and launch a TRMM kernel.
cudaError_t CutlassStrmmNN(
int M,
int N,
double alpha,
double const *A,
int lda,
double const *B,
int ldb,
double *C,
int ldc) {
// Define type definition for double-precision CUTLASS TRMM with column-major
// input matrices and 64x64x16 threadblock tile size (chosen by default).
//
// To keep the interface manageable, several helpers are defined for plausible compositions
// including the following example for double-precision TRMM. Typical values are used as
// default template arguments.
//
// To view the full trmm device API interface, see `cutlass/gemm/device/trmm.h`
using ColumnMajor = cutlass::layout::ColumnMajor;
using CutlassTrmm = cutlass::gemm::device::Trmm<
double,
ColumnMajor,
cutlass::SideMode::kLeft,
cutlass::FillMode::kLower,
cutlass::DiagType::kNonUnit,
double,
ColumnMajor,
double,
ColumnMajor,
double,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<
double,
1,
double,
double,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
5,
1,
1,
false,
cutlass::arch::OpMultiplyAdd
>;
// Define a CUTLASS TRMM type
CutlassTrmm trmm_operator;
// Construct the CUTLASS TRMM arguments object.
//
// One of CUTLASS's design patterns is to define trmm argument objects that are constructible
// in host code and passed to kernels by value. These may include pointers, strides, scalars,
// and other arguments needed by Trmm and its components.
//
// The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible
// arguments to kernels and (2.) minimized initialization overhead on kernel entry.
//
CutlassTrmm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, M}, // Trmm Problem dimensions in Left-Side Mode
1, // batch_count,
{alpha}, // Scalars used in the Epilogue
reinterpret_cast<void const *>(A),
reinterpret_cast<void const *>(B),
reinterpret_cast<void *>(C), // destination matrix D (may be different memory than source C matrix)
(int64_t)M*M, // Batch strides
(int64_t)M*N,
(int64_t)M*N,
lda,
ldb,
ldc);
//
// Launch the CUTLASS TRMM kernel.
//
cutlass::Status status = trmm_operator(args);
//
// Return a cudaError_t if the CUTLASS TRMM operator returned an error code.
//
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
// Return success, if no errors were encountered.
return cudaSuccess;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// The source code after this point in the file is generic CUDA using the CUDA Runtime API
// and simple CUDA kernels to initialize matrices and compute the general matrix product.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Kernel to initialize a matrix with small integers.
__global__ void InitializeMatrix_kernel(
double *matrix,
int ldm,
int rows,
int columns,
int seed = 0,
cutlass::FillMode fill_mode = cutlass::FillMode::kInvalid) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
int j = threadIdx.y + blockIdx.y * blockDim.y;
if (i < rows && j < columns) {
if (fill_mode == cutlass::FillMode::kLower && i < j) return;
else if (fill_mode == cutlass::FillMode::kUpper && i > j) return;
int offset = i + j * ldm;
// Generate arbitrary elements.
int const k = 16807;
int const m = 16;
double value = double(((offset + seed) * k % m) - m / 2);
matrix[offset] = value;
}
}
/// Simple function to initialize a matrix to arbitrary small integers.
cudaError_t InitializeMatrix(double *matrix, int ldm, int rows, int columns, int seed = 0,
cutlass::FillMode fill_mode = cutlass::FillMode::kInvalid) {
dim3 block(16, 16);
dim3 grid(
(rows + block.x - 1) / block.x,
(columns + block.y - 1) / block.y
);
InitializeMatrix_kernel<<< grid, block >>>(matrix, ldm, rows, columns, seed, fill_mode);
return cudaGetLastError();
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Allocates device memory for a matrix then fills with arbitrary small integers.
cudaError_t AllocateMatrix(double **matrix, int ldm, int rows, int columns, int seed = 0,
cutlass::FillMode fill_mode = cutlass::FillMode::kInvalid) {
cudaError_t result;
size_t sizeof_matrix = sizeof(double) * ldm * columns;
// Allocate device memory.
result = cudaMalloc(reinterpret_cast<void **>(matrix), sizeof_matrix);
if (result != cudaSuccess) {
std::cerr << "Failed to allocate matrix: "
<< cudaGetErrorString(result) << std::endl;
return result;
}
// Clear the allocation.
result = cudaMemset(*matrix, 0, sizeof_matrix);
if (result != cudaSuccess) {
std::cerr << "Failed to clear matrix device memory: "
<< cudaGetErrorString(result) << std::endl;
return result;
}
// Initialize matrix elements to arbitrary small integers.
result = InitializeMatrix(*matrix, ldm, rows, columns, seed, fill_mode);
if (result != cudaSuccess) {
std::cerr << "Failed to initialize matrix: "
<< cudaGetErrorString(result) << std::endl;
return result;
}
return result;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Naive reference TRMM computation.
__global__ void ReferenceTrmm_kernel(
int M,
int N,
double alpha,
double const *A,
int lda,
double const *B,
int ldb,
double *C,
int ldc) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
int j = threadIdx.y + blockIdx.y * blockDim.y;
if (i < M && j < N) {
double accumulator = 0;
for (int k = 0; k < M; ++k) {
accumulator += A[i + k * lda] * B[k + j * ldb]; // Since A is in Left-Side Mode
}
C[i + j * ldc] = alpha * accumulator;
}
}
/// Reference TRMM computation.
cudaError_t ReferenceTrmm(
int M,
int N,
double alpha,
double const *A,
int lda,
double const *B,
int ldb,
double *C,
int ldc) {
dim3 block(16, 16);
dim3 grid(
(M + block.x - 1) / block.x,
(N + block.y - 1) / block.y
);
ReferenceTrmm_kernel<<< grid, block >>>(M, N, alpha, A, lda, B, ldb, C, ldc);
return cudaGetLastError();
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Allocate several matrices in GPU device memory and call a double-precision
/// CUTLASS TRMM kernel.
cudaError_t TestCutlassTrmm(int M, int N, double alpha) {
cudaError_t result;
//
// Define several matrices to be used as operands to TRMM kernels.
//
// Compute leading dimensions for each matrix.
int lda = M;
int ldb = M;
int ldc = M;
// Compute size in bytes of the C matrix.
size_t sizeof_C = sizeof(double) * ldc * N;
// Define pointers to matrices in GPU device memory.
double *A;
double *B;
double *C_cutlass;
double *C_reference;
//
// Allocate matrices in GPU device memory with arbitrary seeds.
//
result = AllocateMatrix(&A, lda, M, M, 0, cutlass::FillMode::kLower);
if (result != cudaSuccess) {
return result;
}
result = AllocateMatrix(&B, ldb, M, N, 17);
if (result != cudaSuccess) {
cudaFree(A);
return result;
}
result = AllocateMatrix(&C_cutlass, ldc, M, N, 101);
if (result != cudaSuccess) {
cudaFree(A);
cudaFree(B);
return result;
}
result = AllocateMatrix(&C_reference, ldc, M, N, 101);
if (result != cudaSuccess) {
cudaFree(A);
cudaFree(B);
cudaFree(C_cutlass);
return result;
}
result = cudaMemcpy(C_reference, C_cutlass, sizeof_C, cudaMemcpyDeviceToDevice);
if (result != cudaSuccess) {
std::cerr << "Failed to copy C_cutlass matrix to C_reference: "
<< cudaGetErrorString(result) << std::endl;
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
return result;
}
//
// Launch CUTLASS TRMM.
//
result = CutlassStrmmNN(M, N, alpha, A, lda, B, ldb, C_cutlass, ldc);
if (result != cudaSuccess) {
std::cerr << "CUTLASS TRMM kernel failed: "
<< cudaGetErrorString(result) << std::endl;
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
return result;
}
//
// Verify.
//
// Launch reference TRMM
result = ReferenceTrmm(M, N, alpha, A, lda, B, ldb, C_reference, ldc);
if (result != cudaSuccess) {
std::cerr << "Reference TRMM kernel failed: "
<< cudaGetErrorString(result) << std::endl;
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
return result;
}
// Copy to host and verify equivalence.
std::vector<double> host_cutlass(ldc * N, 0);
std::vector<double> host_reference(ldc * N, 0);
result = cudaMemcpy(host_cutlass.data(), C_cutlass, sizeof_C, cudaMemcpyDeviceToHost);
if (result != cudaSuccess) {
std::cerr << "Failed to copy CUTLASS TRMM results: "
<< cudaGetErrorString(result) << std::endl;
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
return result;
}
result = cudaMemcpy(host_reference.data(), C_reference, sizeof_C, cudaMemcpyDeviceToHost);
if (result != cudaSuccess) {
std::cerr << "Failed to copy Reference TRMM results: "
<< cudaGetErrorString(result) << std::endl;
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
return result;
}
//
// Free device memory allocations.
//
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
//
// Test for bit equivalence of results.
//
if (host_cutlass != host_reference) {
std::cerr << "CUTLASS results incorrect." << std::endl;
return cudaErrorUnknown;
}
return cudaSuccess;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Entry point to basic_trmm example.
//
// usage:
//
// 00_basic_trmm <M> <N> <alpha>
//
int main(int argc, const char *arg[]) {
bool notSupported = false;
// CUTLASS must be compiled with CUDA 11 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
std::cerr << "NVIDIA Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
notSupported = true;
}
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (!((props.major * 10 + props.minor) >= 80)) {
std::cerr << "This example requires compute capability at least 80."
<< std::endl;
notSupported = true;
}
if (notSupported) {
return 0;
}
//
// Parse the command line to obtain TRMM dimensions and scalar values.
//
// TRMM problem dimensions.
int problem[2] = { 128, 128 };
for (int i = 1; i < argc && i < 3; ++i) {
std::stringstream ss(arg[i]);
ss >> problem[i - 1];
}
// Scalars used for linear scaling the result of the matrix product.
double scalars[1] = { 1 };
for (int i = 3; i < argc && i < 4; ++i) {
std::stringstream ss(arg[i]);
ss >> scalars[i - 3];
}
//
// Run the CUTLASS TRMM test.
//
cudaError_t result = TestCutlassTrmm(
problem[0], // TRMM M dimension
problem[1], // TRMM N dimension
scalars[0] // alpha
);
if (result == cudaSuccess) {
std::cout << "Passed." << std::endl;
}
// Exit.
return result == cudaSuccess ? 0 : -1;
}
///////////////////////////////////////////////////////////////////////////////////////////////////