|
| 1 | +""" |
| 2 | +GPU configuration utility for tuning block sizes based on architecture-specific shared memory limits. |
| 3 | +""" |
| 4 | +import torch |
| 5 | + |
| 6 | + |
| 7 | +# Shared memory limits per thread block (in bytes) by compute capability |
| 8 | +# Based on NVIDIA documentation: |
| 9 | +# - Compute Capability 6.x (Pascal): 48 KB per thread block |
| 10 | +# - Compute Capability 7.0 (Volta): 48 KB per thread block |
| 11 | +# - Compute Capability 7.5 (Turing): 48 KB per thread block |
| 12 | +# - Compute Capability 8.x (Ampere): 163 KB per thread block (static: 48 KB, dynamic: up to 163 KB) |
| 13 | +# - Compute Capability 8.9 (Ada Lovelace/RTX 4090): ~99 KB per thread block (varies by model) |
| 14 | +# - Compute Capability 9.0 (Hopper): 227 KB per thread block (static: 48 KB, dynamic: up to 227 KB) |
| 15 | + |
| 16 | +SMEM_LIMITS = { |
| 17 | + # Compute capability 6.x (Pascal) |
| 18 | + 6.0: 48 * 1024, |
| 19 | + 6.1: 48 * 1024, |
| 20 | + 6.2: 48 * 1024, |
| 21 | + # Compute capability 7.0 (Volta) |
| 22 | + 7.0: 48 * 1024, |
| 23 | + # Compute capability 7.5 (Turing) |
| 24 | + 7.5: 48 * 1024, |
| 25 | + # Compute capability 8.0 (Ampere A100) |
| 26 | + 8.0: 163 * 1024, |
| 27 | + # Compute capability 8.6 (Ampere consumer, RTX 3090, etc.) |
| 28 | + 8.6: 163 * 1024, |
| 29 | + # Compute capability 8.9 (Ada Lovelace, RTX 4090) |
| 30 | + # Note: RTX 4090 typically has ~99 KB limit per thread block |
| 31 | + 8.9: 99 * 1024, |
| 32 | + # Compute capability 9.0 (Hopper) |
| 33 | + 9.0: 227 * 1024, |
| 34 | +} |
| 35 | + |
| 36 | +# Default to conservative 48 KB if architecture not found |
| 37 | +DEFAULT_SMEM_LIMIT = 48 * 1024 |
| 38 | + |
| 39 | + |
| 40 | +def get_compute_capability(device=None): |
| 41 | + """Get the compute capability of the current or specified GPU.""" |
| 42 | + if device is None: |
| 43 | + device = torch.cuda.current_device() |
| 44 | + |
| 45 | + props = torch.cuda.get_device_properties(device) |
| 46 | + major = props.major |
| 47 | + minor = props.minor |
| 48 | + compute_cap = float(f"{major}.{minor}") |
| 49 | + |
| 50 | + return compute_cap |
| 51 | + |
| 52 | + |
| 53 | +def get_shared_memory_limit(device=None): |
| 54 | + """Get the shared memory limit per thread block for the current GPU.""" |
| 55 | + compute_cap = get_compute_capability(device) |
| 56 | + |
| 57 | + # Try exact match first |
| 58 | + if compute_cap in SMEM_LIMITS: |
| 59 | + return SMEM_LIMITS[compute_cap] |
| 60 | + |
| 61 | + # Try matching by major version |
| 62 | + major_version = int(compute_cap) |
| 63 | + for cap, limit in SMEM_LIMITS.items(): |
| 64 | + if int(cap) == major_version: |
| 65 | + return limit |
| 66 | + |
| 67 | + # Fall back to default |
| 68 | + return DEFAULT_SMEM_LIMIT |
| 69 | + |
| 70 | + |
| 71 | +def get_optimal_block_sizes(n, d, e, device=None): |
| 72 | + """ |
| 73 | + Calculate optimal BLOCK and BLOCK_MODEL sizes based on shared memory constraints. |
| 74 | +
|
| 75 | + Args: |
| 76 | + n: Sequence length |
| 77 | + d: Query/key dimension |
| 78 | + e: Value dimension |
| 79 | + device: CUDA device (optional) |
| 80 | +
|
| 81 | + Returns: |
| 82 | + tuple: (BLOCK, BLOCK_MODEL) sizes |
| 83 | + """ |
| 84 | + smem_limit = get_shared_memory_limit(device) |
| 85 | + |
| 86 | + # Estimate shared memory usage per block |
| 87 | + # For forward kernel: |
| 88 | + # - q: BLOCK * d * 4 bytes (float32) |
| 89 | + # - k_trans: BLOCK * d * 4 bytes |
| 90 | + # - v: BLOCK * BLOCK_MODEL * 4 bytes |
| 91 | + # - kv: d * BLOCK_MODEL * 4 bytes |
| 92 | + # - Various temporary arrays: ~BLOCK^2 * 4 bytes for diag_decay |
| 93 | + # Total approximation: ~(2 * BLOCK * d + BLOCK * BLOCK_MODEL + d * BLOCK_MODEL + BLOCK^2) * 4 |
| 94 | + |
| 95 | + # Start with conservative values |
| 96 | + BLOCK = 32 |
| 97 | + BLOCK_MODEL = 16 |
| 98 | + |
| 99 | + # Try to increase block sizes while staying within limit |
| 100 | + for block_size in [64, 128, 256]: |
| 101 | + for block_model in [16, 32, 64]: |
| 102 | + if block_model > e: |
| 103 | + continue |
| 104 | + |
| 105 | + # Rough estimate of shared memory usage |
| 106 | + qk_mem = 2 * block_size * d * 4 # q and k_trans |
| 107 | + v_mem = block_size * block_model * 4 |
| 108 | + kv_mem = d * block_model * 4 |
| 109 | + diag_mem = block_size * block_size * 4 # diag_decay matrix |
| 110 | + temp_mem = block_size * block_model * 4 # o_intra, o_inter |
| 111 | + |
| 112 | + total_mem = qk_mem + v_mem + kv_mem + diag_mem + temp_mem |
| 113 | + |
| 114 | + # Add 20% overhead for safety |
| 115 | + if total_mem * 1.2 <= smem_limit: |
| 116 | + BLOCK = block_size |
| 117 | + BLOCK_MODEL = block_model |
| 118 | + else: |
| 119 | + break |
| 120 | + if total_mem * 1.2 > smem_limit: |
| 121 | + break |
| 122 | + |
| 123 | + # Ensure BLOCK_MODEL doesn't exceed e and is power of 2 |
| 124 | + try: |
| 125 | + import triton |
| 126 | + BLOCK_MODEL = min(BLOCK_MODEL, triton.next_power_of_2(e), 64) |
| 127 | + except ImportError: |
| 128 | + # Fallback: round down to nearest power of 2 |
| 129 | + import math |
| 130 | + max_pow2 = 2 ** int(math.log2(min(BLOCK_MODEL, e, 64))) |
| 131 | + BLOCK_MODEL = max_pow2 |
| 132 | + |
| 133 | + # Cap BLOCK at reasonable values |
| 134 | + BLOCK = min(BLOCK, 128) |
| 135 | + |
| 136 | + return BLOCK, BLOCK_MODEL |
| 137 | + |
| 138 | + |
| 139 | +def get_optimal_cblock_size(BLOCK, device=None): |
| 140 | + """ |
| 141 | + Calculate optimal CBLOCK size for backward kernels. |
| 142 | +
|
| 143 | + Args: |
| 144 | + BLOCK: Main block size |
| 145 | + device: CUDA device (optional) |
| 146 | +
|
| 147 | + Returns: |
| 148 | + int: CBLOCK size |
| 149 | + """ |
| 150 | + smem_limit = get_shared_memory_limit(device) |
| 151 | + |
| 152 | + # CBLOCK is typically BLOCK // 2 or BLOCK // 4 |
| 153 | + # For backward kernels, shared memory usage is similar to forward |
| 154 | + # but with CBLOCK instead of BLOCK for some operations |
| 155 | + |
| 156 | + # Start conservative |
| 157 | + CBLOCK = 16 |
| 158 | + |
| 159 | + # Try increasing CBLOCK |
| 160 | + for cblock_size in [32, 64]: |
| 161 | + if cblock_size <= BLOCK and BLOCK % cblock_size == 0: |
| 162 | + # Estimate shared memory (conservative) |
| 163 | + # Similar to forward but with CBLOCK |
| 164 | + estimated_mem = 4 * cblock_size * cblock_size * 4 # Rough estimate |
| 165 | + if estimated_mem * 1.2 <= smem_limit: |
| 166 | + CBLOCK = cblock_size |
| 167 | + else: |
| 168 | + break |
| 169 | + |
| 170 | + return min(CBLOCK, BLOCK // 2) |
| 171 | + |
| 172 | + |
| 173 | +# Fixed configurations pre-computed for common GPU architectures and dimensions |
| 174 | +# Format: (compute_capability, kernel_type, n_range, d_range, e_range): {BLOCK, BLOCK_MODEL, CBLOCK} |
| 175 | +# Ranges are (min, max) inclusive |
| 176 | +FIXED_CONFIGS = { |
| 177 | + # RTX 4090 / Ada Lovelace (8.9) - 99KB shared memory limit |
| 178 | + (8.9, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 179 | + (8.9, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 180 | + (8.9, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 181 | + (8.9, 'lasp_fuse', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 182 | + (8.9, 'lasp_fuse', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 183 | + (8.9, 'lasp_fuse_parallel', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 184 | + (8.9, 'lasp_fuse_parallel', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 185 | + (8.9, 'lasp_blelloch', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 186 | + (8.9, 'lasp_blelloch', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 187 | + |
| 188 | + # Ampere A100 / RTX 3090 (8.0, 8.6) - 163KB shared memory limit |
| 189 | + (8.0, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32}, |
| 190 | + (8.6, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32}, |
| 191 | + (8.0, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32}, |
| 192 | + (8.6, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32}, |
| 193 | + (8.0, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32}, |
| 194 | + (8.6, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32}, |
| 195 | + (8.0, 'lasp_fuse', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32}, |
| 196 | + (8.6, 'lasp_fuse', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32}, |
| 197 | + (8.0, 'lasp_fuse', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32}, |
| 198 | + (8.6, 'lasp_fuse', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32}, |
| 199 | + (8.0, 'lasp_fuse_parallel', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32}, |
| 200 | + (8.6, 'lasp_fuse_parallel', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32}, |
| 201 | + (8.0, 'lasp_fuse_parallel', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32}, |
| 202 | + (8.6, 'lasp_fuse_parallel', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32}, |
| 203 | + (8.0, 'lasp_blelloch', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32}, |
| 204 | + (8.6, 'lasp_blelloch', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32}, |
| 205 | + (8.0, 'lasp_blelloch', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32}, |
| 206 | + (8.6, 'lasp_blelloch', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32}, |
| 207 | + |
| 208 | + # Hopper H100 (9.0) - 227KB shared memory limit |
| 209 | + (9.0, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32}, |
| 210 | + (9.0, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32}, |
| 211 | + (9.0, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 64, 'BLOCK_MODEL': 32, 'CBLOCK': 32}, |
| 212 | + (9.0, 'lasp_fuse', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32}, |
| 213 | + (9.0, 'lasp_fuse', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32}, |
| 214 | + (9.0, 'lasp_fuse_parallel', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32}, |
| 215 | + (9.0, 'lasp_fuse_parallel', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32}, |
| 216 | + (9.0, 'lasp_blelloch', (128, 2048), (64, 256), (32, 128)): {'BLOCK': 64, 'CBLOCK': 32}, |
| 217 | + (9.0, 'lasp_blelloch', (2049, 8192), (64, 256), (32, 128)): {'BLOCK': 128, 'CBLOCK': 32}, |
| 218 | + |
| 219 | + # Pascal/Turing/Volta (6.x, 7.0, 7.5) - 48KB shared memory limit (conservative) |
| 220 | + (6.0, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 221 | + (6.1, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 222 | + (6.2, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 223 | + (7.0, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 224 | + (7.5, 'lightning', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 225 | + (6.0, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 226 | + (6.1, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 227 | + (6.2, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 228 | + (7.0, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 229 | + (7.5, 'lasp_naive', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 230 | + (6.0, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 231 | + (6.1, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 232 | + (6.2, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 233 | + (7.0, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 234 | + (7.5, 'lasp_cache', (512, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'BLOCK_MODEL': 16, 'CBLOCK': 16}, |
| 235 | + (6.0, 'lasp_fuse', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 236 | + (6.1, 'lasp_fuse', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 237 | + (6.2, 'lasp_fuse', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 238 | + (7.0, 'lasp_fuse', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 239 | + (7.5, 'lasp_fuse', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 240 | + (6.0, 'lasp_fuse_parallel', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 241 | + (6.1, 'lasp_fuse_parallel', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 242 | + (6.2, 'lasp_fuse_parallel', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 243 | + (7.0, 'lasp_fuse_parallel', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 244 | + (7.5, 'lasp_fuse_parallel', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 245 | + (6.0, 'lasp_blelloch', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 246 | + (6.1, 'lasp_blelloch', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 247 | + (6.2, 'lasp_blelloch', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 248 | + (7.0, 'lasp_blelloch', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 249 | + (7.5, 'lasp_blelloch', (128, 8192), (64, 256), (32, 128)): {'BLOCK': 32, 'CBLOCK': 16}, |
| 250 | +} |
| 251 | + |
| 252 | + |
| 253 | +def _match_fixed_config(compute_cap, kernel_type, n, d, e): |
| 254 | + """Match dimensions to fixed configuration ranges.""" |
| 255 | + # Try exact compute capability match first |
| 256 | + for (cap, ktype, (n_min, n_max), (d_min, d_max), (e_min, e_max)), config in FIXED_CONFIGS.items(): |
| 257 | + if cap == compute_cap and ktype == kernel_type: |
| 258 | + if n_min <= n <= n_max and d_min <= d <= d_max and e_min <= e <= e_max: |
| 259 | + return config |
| 260 | + |
| 261 | + # Try matching by major version |
| 262 | + major_version = int(compute_cap) |
| 263 | + for (cap, ktype, (n_min, n_max), (d_min, d_max), (e_min, e_max)), config in FIXED_CONFIGS.items(): |
| 264 | + if int(cap) == major_version and ktype == kernel_type: |
| 265 | + if n_min <= n <= n_max and d_min <= d <= d_max and e_min <= e <= e_max: |
| 266 | + return config |
| 267 | + |
| 268 | + return None |
| 269 | + |
| 270 | + |
| 271 | +# Cache for performance (only caches lookup results, not computation) |
| 272 | +_smem_cache = {} |
| 273 | + |
| 274 | + |
| 275 | +def get_config_for_kernel(kernel_type, n, d, e, device=None): |
| 276 | + """ |
| 277 | + Get configuration for a specific kernel type using fixed lookup table. |
| 278 | + Falls back to dynamic computation if no match found. |
| 279 | +
|
| 280 | + Args: |
| 281 | + kernel_type: 'lightning', 'lasp_naive', 'lasp_cache', 'lasp_fuse', etc. |
| 282 | + n: Sequence length |
| 283 | + d: Query/key dimension |
| 284 | + e: Value dimension |
| 285 | + device: CUDA device (optional) |
| 286 | +
|
| 287 | + Returns: |
| 288 | + dict: Configuration with BLOCK, BLOCK_MODEL, CBLOCK, etc. |
| 289 | + """ |
| 290 | + if device is None: |
| 291 | + device = torch.cuda.current_device() |
| 292 | + |
| 293 | + cache_key = (kernel_type, device, n, d, e) |
| 294 | + if cache_key in _smem_cache: |
| 295 | + return _smem_cache[cache_key] |
| 296 | + |
| 297 | + compute_cap = get_compute_capability(device) |
| 298 | + |
| 299 | + # Try fixed configuration first (fast lookup) |
| 300 | + config = _match_fixed_config(compute_cap, kernel_type, n, d, e) |
| 301 | + |
| 302 | + if config is not None: |
| 303 | + _smem_cache[cache_key] = config |
| 304 | + return config |
| 305 | + |
| 306 | + # Fall back to dynamic computation for edge cases |
| 307 | + smem_limit = get_shared_memory_limit(device) |
| 308 | + |
| 309 | + if kernel_type == 'lightning': |
| 310 | + BLOCK, BLOCK_MODEL = get_optimal_block_sizes(n, d, e, device) |
| 311 | + config = { |
| 312 | + 'BLOCK': BLOCK, |
| 313 | + 'BLOCK_MODEL': BLOCK_MODEL, |
| 314 | + 'CBLOCK': get_optimal_cblock_size(BLOCK, device), |
| 315 | + } |
| 316 | + elif kernel_type in ['lasp_naive', 'lasp_cache']: |
| 317 | + BLOCK, BLOCK_MODEL = get_optimal_block_sizes(n, d, e, device) |
| 318 | + config = { |
| 319 | + 'BLOCK': BLOCK, |
| 320 | + 'BLOCK_MODEL': BLOCK_MODEL, |
| 321 | + 'CBLOCK': get_optimal_cblock_size(BLOCK, device), |
| 322 | + } |
| 323 | + elif kernel_type in ['lasp_fuse', 'lasp_fuse_parallel', 'lasp_blelloch']: |
| 324 | + if n > 128: |
| 325 | + if smem_limit <= 99 * 1024: |
| 326 | + BLOCK = 32 |
| 327 | + CBLOCK = 16 |
| 328 | + else: |
| 329 | + BLOCK = 128 |
| 330 | + CBLOCK = 32 |
| 331 | + else: |
| 332 | + BLOCK = min(n, 32) |
| 333 | + CBLOCK = min(n, 16) |
| 334 | + config = { |
| 335 | + 'BLOCK': BLOCK, |
| 336 | + 'CBLOCK': CBLOCK, |
| 337 | + } |
| 338 | + |
| 339 | + _smem_cache[cache_key] = config |
| 340 | + return config |
| 341 | + |
0 commit comments