Skip to content

Commit 43e9a62

Browse files
committed
Initial commit
0 parents  commit 43e9a62

11 files changed

+910
-0
lines changed

README.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Usage
2+
3+
```bash
4+
python setup.py develop
5+
python test.py
6+
python benchmark.py
7+
```
8+
9+
Optionally, do `denoise-gpu.sh python test.py` (or `benchmark.py`) for less
10+
noisy (but slower) results.

TARGETS

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library")
2+
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
3+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
4+
load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args")
5+
6+
cpp_library(
7+
name = "cutlass_kernel",
8+
srcs = [
9+
"cutlass_kernel.cu",
10+
],
11+
headers = [
12+
"cutlass_kernel.h",
13+
],
14+
nvcc_flags = get_nvcc_arch_args(),
15+
deps = [
16+
"fbsource//third-party/cutlass-3:cutlass-3",
17+
],
18+
)
19+
20+
cpp_library(
21+
name = "cutlass",
22+
srcs = [
23+
"cutlass.cpp",
24+
],
25+
supports_python_dlopen = True,
26+
deps = [
27+
":cutlass_kernel",
28+
"//caffe2:torch-cpp", # @manual
29+
"//caffe2:torch_extension", # @manual
30+
],
31+
)
32+
33+
python_library(
34+
name = "triton_kernel",
35+
srcs = [
36+
"triton_kernel.py",
37+
],
38+
deps = [
39+
"//caffe2:torch",
40+
],
41+
)
42+
43+
python_binary(
44+
name = "test",
45+
srcs = [
46+
"test.py",
47+
],
48+
cpp_deps = [
49+
":cutlass",
50+
],
51+
main_function = "scripts.bertrand.tf32_gemm.test.main",
52+
par_style = "xar",
53+
deps = [
54+
":triton_kernel",
55+
"//caffe2:torch",
56+
],
57+
)
58+
59+
python_binary(
60+
name = "benchmark",
61+
srcs = [
62+
"benchmark.py",
63+
],
64+
cpp_deps = [
65+
":cutlass",
66+
],
67+
main_function = "scripts.bertrand.tf32_gemm.benchmark.main",
68+
par_style = "xar",
69+
deps = [
70+
"fbsource//third-party/pypi/matplotlib:matplotlib", # @manual
71+
"fbsource//third-party/pypi/pandas:pandas", # @manual
72+
":triton_kernel",
73+
"//caffe2:torch",
74+
],
75+
)

benchmark.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
import torch
4+
import triton # @manual
5+
6+
from .triton_kernel import matmul as triton_matmul
7+
8+
try:
9+
torch.ops.load_library("cutlass.so")
10+
except Exception:
11+
torch.ops.load_library("//scripts/bertrand/tf32_gemm:cutlass")
12+
13+
torch.set_float32_matmul_precision("high")
14+
15+
configs = []
16+
for fp8_inputs in [False]:
17+
configs.append(
18+
triton.testing.Benchmark(
19+
x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot
20+
x_vals=[128 * i for i in range(2, 33)],
21+
line_arg="provider", # Argument name whose value corresponds to a different line in the plot
22+
# Possible values for `line_arg`
23+
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
24+
line_vals=["cublas", "triton", "cutlass", "precompiled"],
25+
line_names=["cublas", "triton", "cutlass", "precompiled"],
26+
ylabel="TFLOPS", # Label name for the y-axis
27+
plot_name="matmul-performance-fp32",
28+
args={"fp8_inputs": fp8_inputs},
29+
)
30+
)
31+
32+
33+
@triton.testing.perf_report(configs)
34+
def benchmark(M, N, K, provider, fp8_inputs):
35+
a = torch.zeros((M, K), device="cuda", dtype=torch.float32)
36+
b = torch.zeros((K, N), device="cuda", dtype=torch.float32)
37+
quantiles = [0.5, 0.2, 0.8]
38+
if provider == "cublas":
39+
ms, min_ms, max_ms = triton.testing.do_bench(
40+
lambda: torch.matmul(a, b), quantiles=quantiles
41+
)
42+
if provider == "triton":
43+
ms, min_ms, max_ms = triton.testing.do_bench(
44+
lambda: triton_matmul(a, b), quantiles=quantiles
45+
)
46+
# print(f"{N}: {matmul_kernel.best_config}")
47+
if provider == "precompiled":
48+
ms, min_ms, max_ms = triton.testing.do_bench(
49+
lambda: triton_matmul(a, b, precompiled=True), quantiles=quantiles
50+
)
51+
# print(f"{N}: {matmul_kernel.best_config}")
52+
if provider == "cutlass":
53+
ms, min_ms, max_ms = triton.testing.do_bench(
54+
lambda: torch.ops.cutlass.gemm(a, b), quantiles=quantiles
55+
)
56+
57+
def perf(ms):
58+
return 2 * M * N * K * 1e-12 / (ms * 1e-3)
59+
60+
return perf(ms), perf(max_ms), perf(min_ms)
61+
62+
63+
def main():
64+
benchmark.run(show_plots=True, print_data=True, save_path=".")
65+
66+
67+
if __name__ == "__main__":
68+
main()

cutlass.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#include "cutlass_kernel.h"
4+
5+
#include "ATen/ATen.h" // @manual
6+
#include "torch/extension.h" // @manual
7+
8+
at::Tensor gemm(at::Tensor a, at::Tensor b) {
9+
auto c = a.new_empty({a.size(0), b.size(1)});
10+
gemm_kernel(
11+
a.data_ptr<float>(),
12+
b.data_ptr<float>(),
13+
c.data_ptr<float>(),
14+
a.size(0),
15+
b.size(1),
16+
a.size(1));
17+
return c;
18+
}
19+
20+
TORCH_LIBRARY(cutlass, m) {
21+
m.def("gemm", &gemm);
22+
}

cutlass_kernel.cu

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#include "cutlass/cutlass.h"
4+
#include "cutlass/gemm/device/gemm.h"
5+
#include "cutlass/gemm/device/gemm_universal.h"
6+
7+
/**
8+
* Panic wrapper for unwinding CUTLASS errors
9+
*/
10+
#define CUTLASS_CHECK(status) \
11+
{ \
12+
cutlass::Status error = status; \
13+
if (error != cutlass::Status::kSuccess) { \
14+
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \
15+
<< " at: " << __LINE__ << std::endl; \
16+
exit(EXIT_FAILURE); \
17+
} \
18+
}
19+
20+
///////////////////////////////////////////////////////////////////////////////////////////////////
21+
22+
// The code section below describes datatype for input, output matrices and
23+
// computation between elements in input matrices.
24+
using ElementAccumulator = float; // <- data type of accumulator
25+
using ElementComputeEpilogue =
26+
ElementAccumulator; // <- data type of epilogue operations
27+
using ElementInputA = float; // <- data type of elements in input matrix A
28+
using ElementInputB = float; // <- data type of elements in input matrix B
29+
using ElementOutput = float; // <- data type of elements in output matrix D
30+
31+
// The code section below describes matrix layout of input and output matrices.
32+
// Column Major for Matrix A, Row Major for Matrix B and Row Major for Matrix C
33+
using LayoutInputA = cutlass::layout::RowMajor;
34+
using LayoutInputB = cutlass::layout::RowMajor;
35+
using LayoutOutput = cutlass::layout::RowMajor;
36+
37+
// This code section describes whether you want to use tensor cores or regular
38+
// SIMT cores on GPU SM
39+
using MMAOp = cutlass::arch::OpClassTensorOp;
40+
41+
// This code section describes CUDA SM architecture number
42+
using SmArch = cutlass::arch::Sm80;
43+
44+
// This code section describes the tile size a thread block will compute
45+
using ShapeMMAThreadBlock =
46+
cutlass::gemm::GemmShape<128, 256, 16>; // <- threadblock tile M = 128, N =
47+
// 128, K = 16
48+
// This code section describes tile size a warp will compute
49+
using ShapeMMAWarp =
50+
cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16
51+
// This code section describes the size of MMA op
52+
using ShapeMMAOp =
53+
cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
54+
55+
// This code section describes how threadblocks are scheduled on GPU
56+
using SwizzleThreadBlock =
57+
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
58+
59+
// This code section describes the epilogue part of the kernel
60+
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
61+
ElementOutput, // <- data type of output matrix
62+
128 /
63+
cutlass::sizeof_bits<
64+
ElementOutput>::value, // <- the number of elements per vectorized
65+
// memory access. For a byte, it's 16
66+
// elements. This becomes the vector width of
67+
// math instructions in the epilogue too
68+
ElementAccumulator, // <- data type of accumulator
69+
ElementComputeEpilogue>; // <- data type for alpha/beta in linear
70+
// combination function
71+
72+
// Number of pipelines you want to use
73+
constexpr int NumStages = 3;
74+
75+
using Gemm = cutlass::gemm::device::Gemm<
76+
ElementInputA,
77+
LayoutInputA,
78+
ElementInputB,
79+
LayoutInputB,
80+
ElementOutput,
81+
LayoutOutput,
82+
ElementAccumulator,
83+
MMAOp,
84+
SmArch,
85+
ShapeMMAThreadBlock,
86+
ShapeMMAWarp,
87+
ShapeMMAOp,
88+
EpilogueOp,
89+
SwizzleThreadBlock,
90+
NumStages>;
91+
92+
void gemm_kernel(float* a, float* b, float* c, int m, int n, int k) {
93+
cutlass::gemm::GemmCoord problem_size{m, n, k};
94+
cutlass::TensorRef tensor_a{a, LayoutInputA{k}};
95+
cutlass::TensorRef tensor_b{b, LayoutInputB{n}};
96+
cutlass::TensorRef tensor_c{c, LayoutOutput{n}};
97+
cutlass::TensorRef tensor_d{c, LayoutOutput{n}};
98+
99+
// Initialize alpha and beta for dot product computation
100+
ElementComputeEpilogue alpha = ElementComputeEpilogue(1.0f);
101+
ElementComputeEpilogue beta = ElementComputeEpilogue(0.0f);
102+
103+
// Split K dimension into 1 partitions
104+
int split_k_slices = 1;
105+
106+
// Create a tuple of gemm kernel arguments. This is later passed as arguments
107+
// to launch instantiated CUTLASS kernel
108+
typename Gemm::Arguments arguments{
109+
problem_size, // <- problem size of matrix multiplication
110+
tensor_a, // <- reference to matrix A on device
111+
tensor_b, // <- reference to matrix B on device
112+
tensor_c, // <- reference to matrix C on device
113+
tensor_d, // <- reference to matrix D on device
114+
{alpha, beta}, // <- tuple of alpha and beta
115+
split_k_slices}; // <- k-dimension split factor
116+
117+
// Using the arguments, query for extra workspace required for matrix
118+
// multiplication computation
119+
size_t workspace_size = Gemm::get_workspace_size(arguments);
120+
121+
// printf("workspace size: %d\n", workspace_size);
122+
if (workspace_size != 0) {
123+
exit(EXIT_FAILURE);
124+
}
125+
// Allocate workspace memory
126+
// cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
127+
128+
Gemm gemm_op;
129+
130+
// Instantiate CUTLASS kernel depending on templates
131+
cutlass::Status status = gemm_op.can_implement(arguments);
132+
CUTLASS_CHECK(status);
133+
134+
status = gemm_op.initialize(arguments, nullptr); // workspace.get());
135+
CUTLASS_CHECK(status);
136+
137+
status = gemm_op();
138+
CUTLASS_CHECK(status);
139+
}

cutlass_kernel.h

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#pragma once
4+
5+
void gemm_kernel(float* a, float* b, float* c, int m, int n, int k);

denoise-gpu.sh

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/bin/bash
2+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
3+
4+
# There's a whole presentation about stable benchmarking here:
5+
# https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9956-best-practices-when-benchmarking-cuda-applications_V2.pdf
6+
7+
# Lock GPU clocks
8+
sudo nvidia-smi -i 6 -pm 1 >&/dev/null # persistent mode
9+
sudo nvidia-smi --power-limit=330 -i 6 >& /dev/null # lock to 330 W
10+
sudo nvidia-smi -lgc 1140 -i 6 >& /dev/null # lock to 1410 MHz. The max on A100 is 1410 MHz
11+
12+
# TODO: On my devgpu, device 6 is apparently attached to NUMA node 3. How did
13+
# I discover this?
14+
#
15+
# `nvidia-smi -i 6 -pm 1` prints the PCI bus ID (00000000:C6:00.0)
16+
#
17+
# You can also get this from `nvidia-smi -x -q` and looking for minor_number
18+
# and pci_bus_id
19+
#
20+
# Then, `cat /sys/bus/pci/devices/0000:c6:00.0/numa_node` prints 3
21+
# is it always the case that device N is on numa node N/2? :shrug:
22+
#
23+
# Maybe automate this process or figure out if it always holds?
24+
#
25+
# ... Or you can just `nvidia-smi topo -mp` and it will just print out exactly
26+
# what you want, like this:
27+
28+
# GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 mlx5_0 mlx5_1 mlx5_2 mlx5_3 CPU Affinity NUMA Affinity
29+
# GPU0 X PXB SYS SYS SYS SYS SYS SYS NODE SYS SYS SYS 0-23,96-119 0
30+
# GPU6 SYS SYS SYS SYS SYS SYS X PXB SYS SYS SYS NODE 72-95,168-191 3
31+
32+
export CUDA_VISIBLE_DEVICES=6
33+
numactl -m 3 -c 3 "$@"
34+
35+
# Unlock GPU clock
36+
sudo nvidia-smi -rgc -i 6 >& /dev/null

0 commit comments

Comments
 (0)