Skip to content

Commit f84f1d4

Browse files
author
Hoang Phan
committed
Add more gpu tuning for other gpus
1 parent c046dd6 commit f84f1d4

File tree

7 files changed

+393
-39
lines changed

7 files changed

+393
-39
lines changed

lasp/gpu_config.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
"""
2+
GPU configuration utility for tuning block sizes based on architecture-specific shared memory limits.
3+
"""
4+
import torch
5+
6+
7+
# Shared memory limits per thread block (in bytes) by compute capability
8+
# Based on NVIDIA documentation:
9+
# - Compute Capability 6.x (Pascal): 48 KB per thread block
10+
# - Compute Capability 7.0 (Volta): 48 KB per thread block
11+
# - Compute Capability 7.5 (Turing): 48 KB per thread block
12+
# - Compute Capability 8.x (Ampere): 163 KB per thread block (static: 48 KB, dynamic: up to 163 KB)
13+
# - Compute Capability 8.9 (Ada Lovelace/RTX 4090): ~99 KB per thread block (varies by model)
14+
# - Compute Capability 9.0 (Hopper): 227 KB per thread block (static: 48 KB, dynamic: up to 227 KB)
15+
16+
SMEM_LIMITS = {
17+
# Compute capability 6.x (Pascal)
18+
6.0: 48 * 1024,
19+
6.1: 48 * 1024,
20+
6.2: 48 * 1024,
21+
# Compute capability 7.0 (Volta)
22+
7.0: 48 * 1024,
23+
# Compute capability 7.5 (Turing)
24+
7.5: 48 * 1024,
25+
# Compute capability 8.0 (Ampere A100)
26+
8.0: 163 * 1024,
27+
# Compute capability 8.6 (Ampere consumer, RTX 3090, etc.)
28+
8.6: 163 * 1024,
29+
# Compute capability 8.9 (Ada Lovelace, RTX 4090)
30+
# Note: RTX 4090 typically has ~99 KB limit per thread block
31+
8.9: 99 * 1024,
32+
# Compute capability 9.0 (Hopper)
33+
9.0: 227 * 1024,
34+
}
35+
36+
# Default to conservative 48 KB if architecture not found
37+
DEFAULT_SMEM_LIMIT = 48 * 1024
38+
39+
40+
def get_compute_capability(device=None):
41+
"""Get the compute capability of the current or specified GPU."""
42+
if device is None:
43+
device = torch.cuda.current_device()
44+
45+
props = torch.cuda.get_device_properties(device)
46+
major = props.major
47+
minor = props.minor
48+
compute_cap = float(f"{major}.{minor}")
49+
50+
return compute_cap
51+
52+
53+
def get_shared_memory_limit(device=None):
54+
"""Get the shared memory limit per thread block for the current GPU."""
55+
compute_cap = get_compute_capability(device)
56+
57+
# Try exact match first
58+
if compute_cap in SMEM_LIMITS:
59+
return SMEM_LIMITS[compute_cap]
60+
61+
# Try matching by major version
62+
major_version = int(compute_cap)
63+
for cap, limit in SMEM_LIMITS.items():
64+
if int(cap) == major_version:
65+
return limit
66+
67+
# Fall back to default
68+
return DEFAULT_SMEM_LIMIT
69+
70+
71+
def get_optimal_block_sizes(n, d, e, device=None):
72+
"""
73+
Calculate optimal BLOCK and BLOCK_MODEL sizes based on shared memory constraints.
74+
75+
Args:
76+
n: Sequence length
77+
d: Query/key dimension
78+
e: Value dimension
79+
device: CUDA device (optional)
80+
81+
Returns:
82+
tuple: (BLOCK, BLOCK_MODEL) sizes
83+
"""
84+
smem_limit = get_shared_memory_limit(device)
85+
86+
# Estimate shared memory usage per block
87+
# For forward kernel:
88+
# - q: BLOCK * d * 4 bytes (float32)
89+
# - k_trans: BLOCK * d * 4 bytes
90+
# - v: BLOCK * BLOCK_MODEL * 4 bytes
91+
# - kv: d * BLOCK_MODEL * 4 bytes
92+
# - Various temporary arrays: ~BLOCK^2 * 4 bytes for diag_decay
93+
# Total approximation: ~(2 * BLOCK * d + BLOCK * BLOCK_MODEL + d * BLOCK_MODEL + BLOCK^2) * 4
94+
95+
# Start with conservative values
96+
BLOCK = 32
97+
BLOCK_MODEL = 16
98+
99+
# Try to increase block sizes while staying within limit
100+
for block_size in [64, 128, 256]:
101+
for block_model in [16, 32, 64]:
102+
if block_model > e:
103+
continue
104+
105+
# Rough estimate of shared memory usage
106+
qk_mem = 2 * block_size * d * 4 # q and k_trans
107+
v_mem = block_size * block_model * 4
108+
kv_mem = d * block_model * 4
109+
diag_mem = block_size * block_size * 4 # diag_decay matrix
110+
temp_mem = block_size * block_model * 4 # o_intra, o_inter
111+
112+
total_mem = qk_mem + v_mem + kv_mem + diag_mem + temp_mem
113+
114+
# Add 20% overhead for safety
115+
if total_mem * 1.2 <= smem_limit:
116+
BLOCK = block_size
117+
BLOCK_MODEL = block_model
118+
else:
119+
break
120+
if total_mem * 1.2 > smem_limit:
121+
break
122+
123+
# Ensure BLOCK_MODEL doesn't exceed e and is power of 2
124+
try:
125+
import triton
126+
BLOCK_MODEL = min(BLOCK_MODEL, triton.next_power_of_2(e), 64)
127+
except ImportError:
128+
# Fallback: round down to nearest power of 2
129+
import math
130+
max_pow2 = 2 ** int(math.log2(min(BLOCK_MODEL, e, 64)))
131+
BLOCK_MODEL = max_pow2
132+
133+
# Cap BLOCK at reasonable values
134+
BLOCK = min(BLOCK, 128)
135+
136+
return BLOCK, BLOCK_MODEL
137+
138+
139+
def get_optimal_cblock_size(BLOCK, device=None):
140+
"""
141+
Calculate optimal CBLOCK size for backward kernels.
142+
143+
Args:
144+
BLOCK: Main block size
145+
device: CUDA device (optional)
146+
147+
Returns:
148+
int: CBLOCK size
149+
"""
150+
smem_limit = get_shared_memory_limit(device)
151+
152+
# CBLOCK is typically BLOCK // 2 or BLOCK // 4
153+
# For backward kernels, shared memory usage is similar to forward
154+
# but with CBLOCK instead of BLOCK for some operations
155+
156+
# Start conservative
157+
CBLOCK = 16
158+
159+
# Try increasing CBLOCK
160+
for cblock_size in [32, 64]:
161+
if cblock_size <= BLOCK and BLOCK % cblock_size == 0:
162+
# Estimate shared memory (conservative)
163+
# Similar to forward but with CBLOCK
164+
estimated_mem = 4 * cblock_size * cblock_size * 4 # Rough estimate
165+
if estimated_mem * 1.2 <= smem_limit:
166+
CBLOCK = cblock_size
167+
else:
168+
break
169+
170+
return min(CBLOCK, BLOCK // 2)
171+
172+
173+
# Fixed configurations pre-computed for common GPU architectures and dimensions
174+
# Format: (compute_capability, kernel_type, n_range, d_range, e_range): {BLOCK, BLOCK_MODEL, CBLOCK}
175+
# Ranges are (min, max) inclusive
176+
FIXED_CONFIGS = {
177+
# RTX 4090 / Ada Lovelace (8.9) - 99KB shared memory limit
178+
(8.9, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
179+
(8.9, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
180+
(8.9, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
181+
(8.9, 'lasp_fuse', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
182+
(8.9, 'lasp_fuse', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
183+
(8.9, 'lasp_fuse_parallel', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
184+
(8.9, 'lasp_fuse_parallel', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
185+
(8.9, 'lasp_blelloch', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
186+
(8.9, 'lasp_blelloch', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
187+
188+
# Ampere A100 / RTX 3090 (8.0, 8.6) - 163KB shared memory limit
189+
(8.0, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32},
190+
(8.6, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32},
191+
(8.0, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32},
192+
(8.6, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32},
193+
(8.0, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32},
194+
(8.6, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32},
195+
(8.0, 'lasp_fuse', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32},
196+
(8.6, 'lasp_fuse', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32},
197+
(8.0, 'lasp_fuse', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32},
198+
(8.6, 'lasp_fuse', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32},
199+
(8.0, 'lasp_fuse_parallel', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32},
200+
(8.6, 'lasp_fuse_parallel', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32},
201+
(8.0, 'lasp_fuse_parallel', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32},
202+
(8.6, 'lasp_fuse_parallel', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32},
203+
(8.0, 'lasp_blelloch', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32},
204+
(8.6, 'lasp_blelloch', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32},
205+
(8.0, 'lasp_blelloch', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32},
206+
(8.6, 'lasp_blelloch', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32},
207+
208+
# Hopper H100 (9.0) - 227KB shared memory limit
209+
(9.0, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32},
210+
(9.0, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32},
211+
(9.0, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32},
212+
(9.0, 'lasp_fuse', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32},
213+
(9.0, 'lasp_fuse', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32},
214+
(9.0, 'lasp_fuse_parallel', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32},
215+
(9.0, 'lasp_fuse_parallel', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32},
216+
(9.0, 'lasp_blelloch', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32},
217+
(9.0, 'lasp_blelloch', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32},
218+
219+
# Pascal/Turing/Volta (6.x, 7.0, 7.5) - 48KB shared memory limit (conservative)
220+
(6.0, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
221+
(6.1, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
222+
(6.2, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
223+
(7.0, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
224+
(7.5, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
225+
(6.0, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
226+
(6.1, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
227+
(6.2, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
228+
(7.0, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
229+
(7.5, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
230+
(6.0, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
231+
(6.1, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
232+
(6.2, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
233+
(7.0, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
234+
(7.5, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16},
235+
(6.0, 'lasp_fuse', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
236+
(6.1, 'lasp_fuse', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
237+
(6.2, 'lasp_fuse', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
238+
(7.0, 'lasp_fuse', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
239+
(7.5, 'lasp_fuse', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
240+
(6.0, 'lasp_fuse_parallel', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
241+
(6.1, 'lasp_fuse_parallel', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
242+
(6.2, 'lasp_fuse_parallel', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
243+
(7.0, 'lasp_fuse_parallel', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
244+
(7.5, 'lasp_fuse_parallel', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
245+
(6.0, 'lasp_blelloch', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
246+
(6.1, 'lasp_blelloch', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
247+
(6.2, 'lasp_blelloch', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
248+
(7.0, 'lasp_blelloch', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
249+
(7.5, 'lasp_blelloch', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16},
250+
}
251+
252+
253+
def _match_fixed_config(compute_cap, kernel_type, n, d, e):
254+
"""Match dimensions to fixed configuration ranges."""
255+
# Try exact compute capability match first
256+
for (cap, ktype, (n_min, n_max), (d_min, d_max), (e_min, e_max)), config in FIXED_CONFIGS.items():
257+
if cap == compute_cap and ktype == kernel_type:
258+
if n_min <= n <= n_max and d_min <= d <= d_max and e_min <= e <= e_max:
259+
return config
260+
261+
# Try matching by major version
262+
major_version = int(compute_cap)
263+
for (cap, ktype, (n_min, n_max), (d_min, d_max), (e_min, e_max)), config in FIXED_CONFIGS.items():
264+
if int(cap) == major_version and ktype == kernel_type:
265+
if n_min <= n <= n_max and d_min <= d <= d_max and e_min <= e <= e_max:
266+
return config
267+
268+
return None
269+
270+
271+
# Cache for performance (only caches lookup results, not computation)
272+
_smem_cache = {}
273+
274+
275+
def get_config_for_kernel(kernel_type, n, d, e, device=None):
276+
"""
277+
Get configuration for a specific kernel type using fixed lookup table.
278+
Falls back to dynamic computation if no match found.
279+
280+
Args:
281+
kernel_type: 'lightning', 'lasp_naive', 'lasp_cache', 'lasp_fuse', etc.
282+
n: Sequence length
283+
d: Query/key dimension
284+
e: Value dimension
285+
device: CUDA device (optional)
286+
287+
Returns:
288+
dict: Configuration with BLOCK, BLOCK_MODEL, CBLOCK, etc.
289+
"""
290+
if device is None:
291+
device = torch.cuda.current_device()
292+
293+
cache_key = (kernel_type, device, n, d, e)
294+
if cache_key in _smem_cache:
295+
return _smem_cache[cache_key]
296+
297+
compute_cap = get_compute_capability(device)
298+
299+
# Try fixed configuration first (fast lookup)
300+
config = _match_fixed_config(compute_cap, kernel_type, n, d, e)
301+
302+
if config is not None:
303+
_smem_cache[cache_key] = config
304+
return config
305+
306+
# Fall back to dynamic computation for edge cases
307+
smem_limit = get_shared_memory_limit(device)
308+
309+
if kernel_type == 'lightning':
310+
BLOCK, BLOCK_MODEL = get_optimal_block_sizes(n, d, e, device)
311+
config = {
312+
'BLOCK': BLOCK,
313+
'BLOCK_MODEL': BLOCK_MODEL,
314+
'CBLOCK': get_optimal_cblock_size(BLOCK, device),
315+
}
316+
elif kernel_type in ['lasp_naive', 'lasp_cache']:
317+
BLOCK, BLOCK_MODEL = get_optimal_block_sizes(n, d, e, device)
318+
config = {
319+
'BLOCK': BLOCK,
320+
'BLOCK_MODEL': BLOCK_MODEL,
321+
'CBLOCK': get_optimal_cblock_size(BLOCK, device),
322+
}
323+
elif kernel_type in ['lasp_fuse', 'lasp_fuse_parallel', 'lasp_blelloch']:
324+
if n > 128:
325+
if smem_limit <= 99 * 1024:
326+
BLOCK = 32
327+
CBLOCK = 16
328+
else:
329+
BLOCK = 128
330+
CBLOCK = 32
331+
else:
332+
BLOCK = min(n, 32)
333+
CBLOCK = min(n, 16)
334+
config = {
335+
'BLOCK': BLOCK,
336+
'CBLOCK': CBLOCK,
337+
}
338+
339+
_smem_cache[cache_key] = config
340+
return config
341+

lasp/lasp_blelloch.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.distributed as dist
1414
import triton
1515

16+
from .gpu_config import get_config_for_kernel
1617
from .lasp_fuse_parallel import (
1718
_fwd_diag_kernel,
1819
_fwd_kv_parallel,
@@ -72,13 +73,10 @@ def forward(ctx, q, k, v, s, KV, DKV):
7273
rank = get_sequence_parallel_rank()
7374
world_size = get_sequence_parallel_world_size()
7475

75-
# Determine block sizes (same logic as lasp_fuse_parallel)
76-
if n > 128:
77-
BLOCK = 256
78-
CBLOCK = 64
79-
else:
80-
BLOCK = min(n, 128)
81-
CBLOCK = min(n, 64)
76+
# Determine block sizes based on GPU architecture
77+
config = get_config_for_kernel('lasp_blelloch', n, d, e, q.device)
78+
BLOCK = config['BLOCK']
79+
CBLOCK = config['CBLOCK']
8280

8381
NUM_BLOCK = n // BLOCK
8482
NUM_CBLOCK = BLOCK // CBLOCK

0 commit comments

Comments
 (0)