Skip to content

Commit 6d6f18d

Browse files
[WIP] pingpong gemm v1
1 parent 375eba4 commit 6d6f18d

File tree

11 files changed

+2050
-319
lines changed

11 files changed

+2050
-319
lines changed

Diff for: cutlass.py/fast_math.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
def round_up(a, b):
2+
return (a + b - 1) // b * b
3+
4+
5+
def integer_log2(x):
6+
n = 0
7+
x >>= 1
8+
while x:
9+
x = x >> 1
10+
n += 1
11+
return n
12+
13+
14+
class FastDivmodU64Pow2:
15+
def __init__(self, divisor=0) -> None:
16+
self.divisor = divisor
17+
self.shift_right = integer_log2(divisor)
18+
19+
def divide(self, dividend):
20+
return dividend >> self.shift_right
21+
22+
def modulus(self, dividend):
23+
return dividend & (self.divisor - 1)
24+
25+
def divmod(self, dividend):
26+
quotient = self.divide(dividend)
27+
remainder = self.modulus(dividend)
28+
return quotient, remainder
29+
30+
def __call__(self, dividend):
31+
return self.divmod(dividend)
32+
33+
34+
class FastDivmodU64:
35+
def __init__(self, divisor):
36+
self.divisor = divisor
37+
38+
def divmod(self, x):
39+
return (x // self.divisor, x % self.divisor)
40+
41+
def __call__(self, x):
42+
return self.divmod(x)

Diff for: cutlass.py/hw_info.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from dataclasses import dataclass
2+
import subprocess
3+
import os
4+
5+
6+
@dataclass
7+
class dim3:
8+
x: int
9+
y: int
10+
z: int
11+
12+
13+
@dataclass
14+
class DeviceCoord:
15+
gridDim: dim3
16+
blockDim: dim3
17+
clusterDim: dim3
18+
blockIdx: dim3 = dim3(0, 0, 0)
19+
threadIdx: dim3 = dim3(0, 0, 0)
20+
blockIdx_in_cluster: dim3 = dim3(0, 0, 0)
21+
22+
def block_id_in_cluster(self):
23+
return (
24+
self.blockIdx_in_cluster.x,
25+
self.blockIdx_in_cluster.y,
26+
self.blockIdx_in_cluster.z,
27+
)
28+
29+
def set_blockIdx(self, x, y, z):
30+
self.blockIdx = dim3(x, y, z)
31+
self.blockIdx_in_cluster = dim3(x % self.clusterDim.x, y % self.clusterDim.y, z % self.clusterDim.z)
32+
33+
def set_threadIdx(self, x, y, z):
34+
self.threadIdx = dim3(x, y, z)
35+
36+
37+
@dataclass
38+
class KernelHardwareInfo:
39+
device_id: int = 0
40+
sm_count: int = 0
41+
42+
@staticmethod
43+
def query_device_multiprocessor_count(device_id: int = 0, arch: str = "90a"):
44+
cuda_header_code = f"""
45+
#include <cuda_runtime.h>
46+
#include <iostream>
47+
static constexpr int device_id = {device_id};
48+
"""
49+
cuda_code = """
50+
int main() {
51+
cudaError_t result = cudaSetDevice(device_id);
52+
if (result != cudaSuccess) {
53+
std::cerr << "cudaSetDevice() returned error "
54+
<< cudaGetErrorString(result) << std::endl;
55+
return 1;
56+
}
57+
int multiprocessor_count;
58+
result = cudaDeviceGetAttribute(&multiprocessor_count,
59+
cudaDevAttrMultiProcessorCount, device_id);
60+
if (result != cudaSuccess) {
61+
std::cerr << "cudaDeviceGetAttribute() returned error "
62+
<< cudaGetErrorString(result) << std::endl;
63+
return 1;
64+
}
65+
std::cout << multiprocessor_count << std::endl;
66+
return 0;
67+
}
68+
"""
69+
# Combine the header and main CUDA code
70+
full_cuda_code = cuda_header_code + cuda_code
71+
72+
# Write the CUDA code to a temporary file
73+
with open("temp_query_device.cu", "w") as file:
74+
file.write(full_cuda_code)
75+
76+
# Compile the CUDA code using nvcc
77+
compile_command = (
78+
f"nvcc -arch=sm_{arch} temp_query_device.cu -o temp_query_device"
79+
)
80+
try:
81+
subprocess.run(
82+
compile_command,
83+
check=True,
84+
shell=True,
85+
text=True,
86+
stderr=subprocess.PIPE,
87+
)
88+
except subprocess.CalledProcessError as e:
89+
print(f"Compilation failed: {e.stderr}")
90+
return -1
91+
92+
# Run the compiled binary and capture the output
93+
try:
94+
result = subprocess.run(
95+
"./temp_query_device", capture_output=True, text=True, check=True
96+
)
97+
return int(result.stdout.strip())
98+
except subprocess.CalledProcessError as e:
99+
print(f"Execution failed: {e.stderr}")
100+
return -1
101+
finally:
102+
# Cleanup the temporary files
103+
os.remove("temp_query_device.cu")
104+
os.remove("temp_query_device")
105+
106+
107+
if __name__ == "__main__":
108+
print(KernelHardwareInfo.query_device_multiprocessor_count())

Diff for: cutlass.py/mapping.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,27 @@
1-
class Mapping:
2-
pass
1+
from typing import List, Optional
2+
from dataclasses import dataclass
3+
from tiling import HyperCube, HyperPoint
4+
from functools import reduce
35

4-
class
6+
class Function:
7+
def forward(self, *args):
8+
raise NotImplementedError()
9+
10+
def backward(self, *args):
11+
raise NotImplementedError()
12+
13+
class Mapping(Function):
14+
def __init__(self, functions: Optional[List[Function]] = None) -> None:
15+
self.functions = functions if functions is not None else []
16+
for func in self.functions:
17+
assert isinstance(func, Function), "Should put Function type in Mapping."
18+
19+
@dataclass
20+
class Layout(Function):
21+
shape: HyperCube
22+
stride: Optional[HyperPoint] = None
23+
24+
def __post_init__(self):
25+
self.ndim = len(self.shape)
26+
if self.stride is None:
27+
self.stride = Hyperreduce(lambda a, b: a + [a[-1] * b], reversed(self.shape[:-1]), [1])

Diff for: cutlass.py/mma.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(self, mma_op: MMA_OP) -> None:
6969
self.B_frag_type = SmemDesc(B_major)
7070

7171
self.MNK_shape = HyperCube(3, [mma_op.M_tile, mma_op.N_tile, mma_op.K_tile])
72-
self.
72+
self.thread_id =
7373

7474

7575
def gmma_selector(

Diff for: cutlass.py/swizzle.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
class Swizzle:
2+
def __init__(self, num_bits: int, num_base: int, num_shft: int):
3+
self.num_bits = num_bits
4+
self.num_base = num_base
5+
self.num_shft = num_shft
6+
7+
assert self.num_bits >= 0, "MBase must be positive."
8+
assert self.num_bits >= 0, "BBits must be positive."
9+
assert (
10+
abs(self.num_shft) >= self.num_bits
11+
), "abs(SShift) must be more than BBits."
12+
13+
self.bit_msk = (1 << self.num_bits) - 1
14+
self.yyy_msk = self.bit_msk << (self.num_base + max(0, self.num_shft))
15+
self.zzz_msk = self.bit_msk << (self.num_base - min(0, self.num_shft))
16+
self.msk_sft = self.num_shft
17+
18+
self.swizzle_code = self.yyy_msk | self.zzz_msk
19+
20+
def apply(self, offset):
21+
if self.msk_sft >= 0:
22+
return offset ^ ((offset & self.yyy_msk) >> self.msk_sft)
23+
else:
24+
return offset ^ ((offset & self.yyy_msk) << -self.msk_sft)
25+
26+
def __call__(self, offset):
27+
return self.apply(offset)
28+
29+
30+
def test_swizzle():
31+
32+
def get_ind_matrix(rows, cols):
33+
return [[(x, y) for y in range(cols)] for x in range(rows)]
34+
35+
def get_row_major_ind(x, y, rows, cols):
36+
return x * cols + y
37+
38+
def get_row_major_tuple(xy, rows, cols):
39+
return (xy // cols, xy % cols)
40+
41+
def get_col_major_ind(x, y, rows, cols):
42+
return x + y * rows
43+
44+
def get_col_major_tuple(xy, rows, cols):
45+
return (xy % rows, xy // rows)
46+
47+
def print_matrix(mtx, rows, cols, func=lambda x: x, prompt=""):
48+
print(prompt)
49+
for x in range(rows):
50+
for y in range(cols):
51+
item = mtx[x][y]
52+
item = func(item)
53+
print(item, end=" ")
54+
print()
55+
56+
# Swizzle<3, 4, 4>
57+
print("Swizzle<3,4,3>")
58+
rows = 128
59+
cols = 64
60+
mtx = get_ind_matrix(rows, cols)
61+
print_matrix(mtx, rows, cols, prompt="Original")
62+
print()
63+
swizzle = Swizzle(3, 4, 3)
64+
print_matrix(
65+
mtx,
66+
rows,
67+
cols,
68+
lambda tp: get_row_major_tuple(
69+
swizzle(get_row_major_ind(tp[0], tp[1], rows, cols)), rows, cols
70+
),
71+
prompt="After swizzle",
72+
)
73+
print()
74+
75+
# Swizzle<2, 0, -2>
76+
print("Swizzle<2,0,-2>")
77+
rows = 4
78+
cols = 4
79+
mtx = get_ind_matrix(rows, cols)
80+
print_matrix(mtx, rows, cols, prompt="Original")
81+
print()
82+
swizzle = Swizzle(2, 0, -2)
83+
print_matrix(
84+
mtx,
85+
rows,
86+
cols,
87+
lambda tp: get_row_major_tuple(
88+
swizzle(get_row_major_ind(tp[0], tp[1], rows, cols)), rows, cols
89+
),
90+
prompt="After swizzle",
91+
)
92+
print()
93+
94+
95+
if __name__ == "__main__":
96+
pass

0 commit comments

Comments
 (0)