Skip to content

Commit 6a420ed

Browse files
authored
Autotuning not on A100, instructions, single line API for model optimizations. (#67)
1 parent 7cd6ba3 commit 6a420ed

File tree

6 files changed

+117
-21
lines changed

6 files changed

+117
-21
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,19 @@ The package acts like a drop-in replacement for segment-anything.
3333

3434
So, for example, if you're currently doing `from segment_anything import sam_model_registry` you should be able to do `from segment_anything_fast import sam_model_registry`.
3535

36+
However, you're likely here because you want to try a fast, inference version. So we also created a `sam_model_fast_registry` that automatically applies
37+
- Sets `eval` mode
38+
- Uses `bfloat16`
39+
- Enables torch.compile with max-autotune
40+
- Uses a custom Triton kernel that implements SDPA for relative positional encodings for long sequence lengths
41+
42+
The custom Triton kernel in particular was written for A100. If you're not using an A100, we will try to rerun autotuning on your device and locally save the best configs.
43+
You might still run into performance issues, so you can disable the kernel by setting the environment variable `SEGMENT_ANYTHING_FAST_USE_FLASH_4=0`
44+
45+
Please also note that the first time you're running this model you'll likely need to wait a bit for it to compile.
46+
47+
If you'd like to see the details on how to reproduce all results, please see the README in the experiments folder above.
48+
3649
Please don't be shy to open a Github issue if you're missing functionality or find an issue. Thank you.
3750

3851
## Results

experiments/README.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,15 @@ These experiments were run on an Amazon p4d.24xlarge instance. See the Product
4646
### Installation instructions
4747

4848
```
49-
$ conda create -n nightly20231023py310
50-
$ conda activate nightly20231023py310
49+
$ conda create -n nightly20231117py310
50+
$ conda activate nightly20231117py310
5151
$ conda install python=3.10
52-
$ pip install https://download.pytorch.org/whl/nightly/cu121/torch-2.2.0.dev20231023%2Bcu121-cp310-cp310-linux_x86_64.whl
53-
$ pip install https://download.pytorch.org/whl/nightly/cu121/torchvision-0.17.0.dev20231023%2Bcu121-cp310-cp310-linux_x86_64.whl
54-
$ cd /scratch/cpuhrsch/dev
52+
$ pip install https://download.pytorch.org/whl/nightly/cu121/torch-2.2.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl
53+
$ pip install https://download.pytorch.org/whl/nightly/cu121/torchvision-0.17.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl
5554
$ git clone https://github.com/cpuhrsch/segment-anything.git
5655
$ cd segment-anything
5756
$ pip install -e .
58-
$ cd /scratch/cpuhrsch/dev
57+
$ cd ..
5958
$ git clone https://github.com/pytorch-labs/segment-anything-fast.git
6059
$ cd segment-anything-fast
6160
$ pip install -e .

experiments/run_experiments.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def run(batch_size,
144144
traces_dir=None,
145145
num_workers=32,
146146
print_header=True,
147-
capture_output=True):
147+
capture_output=True,
148+
local_fork_only=False):
148149

149150
assert model == "vit_b" or model == "vit_h"
150151

@@ -161,23 +162,38 @@ def run(batch_size,
161162
assert traces_dir is not None
162163
rt = functools.partial(run_traces_fn, traces_dir, pytorch_path, rexp)
163164

164-
rt("fp32", "default", print_header=print_header)
165-
rt("fp16", "codesign", use_half="bfloat16")
166-
rt("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
167-
rt("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
168-
rt("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune")
165+
if local_fork_only:
166+
rt("fp32", "local-fork", print_header=print_header)
167+
rt("fp16", "local-fork", use_half="bfloat16")
168+
rt("compile", "local-fork", use_half="bfloat16", use_compile="max-autotune")
169+
# The local fork already uses SDPA + Triton for all of the above experiments.
170+
# local_fork_only mainly exists to ablate the order in which we apply
171+
# techniques and cannot be used to reproduce the experimental results
172+
else:
173+
rt("fp32", "default", print_header=print_header)
174+
rt("fp16", "codesign", use_half="bfloat16")
175+
rt("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
176+
rt("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
177+
rt("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune")
169178
if batch_size > 1:
170179
rt("NT", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True)
171180
rt("int8", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True, compress="dynamic_quant")
172181
rt("sparse", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True, compress="sparse")
173182

174183
if run_experiments:
175-
rexp("fp32", "default", print_header=print_header)
176-
print_header = False
177-
rexp("bf16", "codesign", use_half="bfloat16")
178-
rexp("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
179-
rexp("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
180-
rexp("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune")
184+
if local_fork_only:
185+
rexp("fp32", "local-fork", print_header=print_header)
186+
rexp("bf16", "local-fork", use_half="bfloat16")
187+
rexp("compile", "local-fork", use_half="bfloat16", use_compile="max-autotune")
188+
# The local fork already uses SDPA + Triton for all of the above experiments.
189+
# local_fork_only mainly exists to ablate the order in which we apply
190+
# techniques and cannot be used to reproduce the experimental results
191+
else:
192+
rexp("fp32", "default", print_header=print_header)
193+
rexp("bf16", "codesign", use_half="bfloat16")
194+
rexp("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
195+
rexp("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
196+
rexp("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune")
181197
if batch_size > 1:
182198
rexp("NT", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(batch_size > 1))
183199
rexp("int8", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(batch_size > 1), compress="dynamic_quant")

segment_anything_fast/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
build_sam_vit_l,
1111
build_sam_vit_b,
1212
sam_model_registry,
13+
build_sam_fast,
14+
build_sam_fast_vit_h,
15+
build_sam_fast_vit_l,
16+
build_sam_fast_vit_b,
17+
sam_model_fast_registry,
1318
)
1419
from .predictor import SamPredictor
1520
from .automatic_mask_generator import SamAutomaticMaskGenerator

segment_anything_fast/build_sam.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,46 @@ def build_sam_vit_b(checkpoint=None):
5151
"vit_b": build_sam_vit_b,
5252
}
5353

54+
def _apply_eval_dtype_sam(model, dtype=None):
55+
56+
def prep_model(model, dtype):
57+
if dtype is not None:
58+
return model.eval().to(dtype)
59+
return model.eval()
60+
61+
model.image_encoder = prep_model(model.image_encoder, dtype)
62+
model.prompt_encoder = prep_model(model.prompt_encoder, dtype)
63+
model.mask_decoder = prep_model(model.mask_decoder, dtype)
64+
65+
return model
66+
67+
def build_sam_fast_vit_h(checkpoint=None):
68+
sam = build_sam_vit_h(checkpoint)
69+
sam = _apply_eval_dtype_sam(sam)
70+
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
71+
return sam
72+
73+
build_sam_fast = build_sam_fast_vit_h
74+
75+
def build_sam_fast_vit_l(checkpoint=None):
76+
sam = build_sam_vit_l(checkpoint)
77+
sam = _apply_eval_dtype_sam(sam)
78+
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
79+
return sam
80+
81+
def build_sam_fast_vit_b(checkpoint=None):
82+
sam = build_sam_vit_b(checkpoint)
83+
sam = _apply_eval_dtype_sam(sam)
84+
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
85+
return sam
86+
87+
sam_model_fast_registry = {
88+
"default": build_sam_fast_vit_h,
89+
"vit_h": build_sam_fast_vit_h,
90+
"vit_l": build_sam_fast_vit_l,
91+
"vit_b": build_sam_fast_vit_b,
92+
}
93+
5494

5595
def _build_sam(
5696
encoder_embed_dim,

segment_anything_fast/flash_4.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import triton
2424
import triton.language as tl
2525

26+
import os
27+
import pathlib
28+
2629

2730
@triton.jit
2831
def _fwd_kernel_aligned(
@@ -220,9 +223,18 @@ def _attention_rel_h_rel_w_kernel_aligned_device(q, k, v, rel_h_w, sm_scale, o,
220223

221224

222225
def _load_best_configs():
226+
device_name = torch.cuda.get_device_name()
227+
if not device_name.startswith('NVIDIA A100'):
228+
print("Warning: Custom flash attention kernels were written specifically for A100.")
223229
import importlib
224230
saved_configs = importlib.resources.files("segment_anything_fast")
225231
saved_configs = saved_configs / "configs" / "flash_4_configs_a100.p"
232+
if not device_name.startswith('NVIDIA A100'):
233+
cwd = pathlib.Path.cwd()
234+
saved_configs = cwd / "flash_4_configs.p"
235+
print(f"We will try to read previously created kernel configurations from {saved_configs}.")
236+
print("You can disable this kernel by setting SEGMENT_ANYTHING_FAST_USE_FLASH_4=0")
237+
return None
226238
if saved_configs.is_file():
227239
import pickle
228240
with open(saved_configs, 'rb') as f:
@@ -234,6 +246,11 @@ def _save_best_configs(best_configs):
234246
import importlib
235247
saved_configs = importlib.resources.files("segment_anything_fast")
236248
saved_configs = saved_configs / "configs" / "flash_4_configs_a100.p"
249+
device_name = torch.cuda.get_device_name()
250+
if not device_name.startswith('NVIDIA A100'):
251+
saved_configs = pathlib.Path.cwd() / "flash_4_configs.p"
252+
print("Warning: Custom flash attention kernels were written specifically for A100.")
253+
print(f"Storing configs for {device_name} locally under {saved_configs}")
237254
with open(saved_configs, 'wb') as f:
238255
import pickle
239256
print(f"Saving best configs to file {saved_configs}")
@@ -277,7 +294,7 @@ def _attention_rel_h_rel_w_kernel_aligned(q, k, v, rel_h_w, sm_scale):
277294
BEST_CONFIGS = _load_best_configs()
278295
key = _create_best_configs_key(q, k, v, rel_h_w, o)
279296
if key not in BEST_CONFIGS:
280-
print("key ", key, " not found. Running autotune")
297+
print("key ", key, " not found. Running autotune. This might take a while.")
281298
import functools
282299
import itertools
283300
configs = []
@@ -309,6 +326,9 @@ def _attention_rel_h_rel_w_kernel_aligned(q, k, v, rel_h_w, sm_scale):
309326
return o
310327

311328

329+
USE_CUSTOM_KERNEL = bool(int(os.environ.get('SEGMENT_ANYTHING_FAST_USE_FLASH_4', 1)))
330+
331+
312332
def _attention_rel_h_rel_w(q_, k_, v_, rel_h_, rel_w_):
313333
"""
314334
Writing this as a composite allows torch.compile to fuse
@@ -320,15 +340,18 @@ def _attention_rel_h_rel_w(q_, k_, v_, rel_h_, rel_w_):
320340
sm_scale = 1. / math.sqrt(q_.size(-1))
321341
# Check if second last dimension is multiple of 256
322342
q_size_2_padded = (((q_.size(-2) + 256 - 1) // 256) * 256) - q_.size(-2)
343+
344+
def kernel_guards(q_, k_, v_):
345+
return (q_.dtype == torch.bfloat16 or q_.dtype == torch.float16) and q_.dtype == k_.dtype and k_.dtype == v_.dtype and USE_CUSTOM_KERNEL
323346
# vit_b and vit_l
324-
if q_size_2_padded == 0 and q_.size(-1) == 64:
347+
if q_size_2_padded == 0 and q_.size(-1) == 64 and kernel_guards(q_, k_, v_):
325348
rel_h_w = torch.cat([rel_h_.squeeze(-1), rel_w_.squeeze(-2)], dim=-1)
326349
o = torch.ops.customflash.custom_flash_aligned(
327350
q_, k_, v_, rel_h_w, sm_scale)
328351
if o.numel() > 0:
329352
return o
330353
# vit_h
331-
if q_size_2_padded == 0 and q_.size(-1) == 80:
354+
if q_size_2_padded == 0 and q_.size(-1) == 80 and kernel_guards(q_, k_, v_):
332355
# Only support multiples of 64, so need to pad
333356
q = torch.nn.functional.pad(q_, (0, 128 - 80, 0, 0), "constant", 0)
334357
k = torch.nn.functional.pad(k_, (0, 128 - 80, 0, 0), "constant", 0)

0 commit comments

Comments
 (0)