Skip to content

Commit fd88633

Browse files
authored
Add QAT single-device recipe (#2716)
1 parent bdc888d commit fd88633

File tree

7 files changed

+1179
-2
lines changed

7 files changed

+1179
-2
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ You can also run e.g. ``tune ls full_dpo_distributed`` for a full list of availa
7575

7676
| Type of Weight Update | 1 Device | >1 Device | >1 Node |
7777
|-----------------------|:--------:|:---------:|:-------:|
78-
| [Full](https://pytorch.org/torchtune/stable/recipes/qat_distributed.html) | |||
78+
| [Full](https://pytorch.org/torchtune/stable/recipes/qat_distributed.html) | |||
7979
| LoRA/QLoRA ||||
8080

8181
Example: ``tune run qat_distributed --config llama3_1/8B_qat_lora`` <br />
82-
You can also run e.g. ``tune ls qat_distributed`` for a full list of available configs.
82+
You can also run e.g. ``tune ls qat_distributed`` or ``tune ls qat_single_device`` for a full list of available configs.
8383

8484
The above configs are just examples to get you started. The full list of recipes can be found [here](recipes/). If you'd like to work on one of the gaps you see, please submit a PR! If there's a entirely new post-training method you'd like to see implemented in torchtune, feel free to open an Issue.
8585

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Config for single device QAT finetuning in qat_single_device.py
2+
# using a Llama 1B/TinyLlama_v1.1 model.
3+
#
4+
# This config assumes that you've run the following command before launching:
5+
# tune download TinyLlama/TinyLlama_v1.1 --output-dir /tmp/TinyLlama_v1.1/
6+
#
7+
# To launch on a single device, run (from root):
8+
# tune run qat_single_device --config llama2/1B_full_qat_single_device
9+
10+
output_dir: /tmp/torchtune/llama1b/qat_single_device # /tmp may be deleted by your system. Adjust if needed.
11+
12+
# Tokenizer
13+
tokenizer:
14+
_component_: torchtune.models.llama2.llama2_tokenizer
15+
path: /tmp/TinyLlama_v1.1/tokenizer.model
16+
max_seq_len: 2048
17+
18+
# Dataset
19+
dataset:
20+
_component_: torchtune.datasets.alpaca_cleaned_dataset
21+
packed: False
22+
23+
seed: null
24+
shuffle: True
25+
26+
# Model Arguments
27+
model:
28+
_component_: torchtune.models.llama2.llama2
29+
vocab_size: 32000
30+
num_layers: 22
31+
num_heads: 32
32+
num_kv_heads: 4
33+
embed_dim: 2048
34+
max_seq_len: 2048
35+
intermediate_dim: 5632
36+
attn_dropout: 0.0
37+
norm_eps: 1e-5
38+
39+
checkpointer:
40+
_component_: torchtune.training.FullModelHFCheckpointer
41+
checkpoint_dir: /tmp/TinyLlama_v1.1/
42+
checkpoint_files: [pytorch_model.bin]
43+
recipe_checkpoint: null
44+
output_dir: ${output_dir}
45+
model_type: LLAMA2
46+
47+
resume_from_checkpoint: False
48+
49+
# Fine-tuning arguments
50+
batch_size: 1
51+
epochs: 1
52+
optimizer:
53+
_component_: bitsandbytes.optim.PagedAdamW
54+
lr: 5e-6
55+
optimizer_in_bwd: True # True saves memory, requires gradient_accumulation_steps=1
56+
loss:
57+
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
58+
max_steps_per_epoch: null
59+
gradient_accumulation_steps: 1 # Use to increase effective batch size
60+
clip_grad_norm: null
61+
compile: False # torch.compile the model+loss, can increase speed+decrease memory
62+
63+
# QAT arguments
64+
quantizer:
65+
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
66+
groupsize: 256
67+
68+
# Training environment
69+
device: cuda
70+
71+
# Memory management
72+
enable_activation_checkpointing: True # True reduces memory
73+
enable_activation_offloading: False # True reduces memory
74+
75+
# Reduced precision
76+
dtype: bf16
77+
78+
# Logging
79+
metric_logger:
80+
_component_: torchtune.training.metric_logging.DiskLogger
81+
log_dir: ${output_dir}/logs
82+
log_every_n_steps: 1
83+
log_peak_memory_stats: False
84+
log_level: INFO
85+
86+
# Profiler
87+
profiler:
88+
_component_: torchtune.training.setup_torch_profiler
89+
enabled: False
90+
output_dir: ${output_dir}/profiling_outputs
91+
cpu: True
92+
cuda: True
93+
profile_memory: False
94+
with_stack: False
95+
record_shapes: True
96+
with_flops: False
97+
wait_steps: 5
98+
warmup_steps: 3
99+
active_steps: 2
100+
num_cycles: 1
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Config for single device QAT finetuning in qat_single_device.py
2+
# using a Qwen2.5 1.5B
3+
#
4+
# This config assumes that you've run the following command before launching
5+
# this run:
6+
# tune download Qwen/Qwen2.5-1.5B-Instruct --output-dir /tmp/Qwen2.5-1.5B-Instruct
7+
#
8+
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
9+
# you can install it with
10+
# pip install bitsandbytes
11+
#
12+
# To launch on a single device, run the following command from root:
13+
# tune run qat_single_device --config qwen2_5/1.5B_qat_single_device
14+
#
15+
# You can add specific overrides through the command line. For example
16+
# to override the checkpointer directory while launching training
17+
# you can run:
18+
# tune run qat_single_device --config qwen2_5/1.5B_qat_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
19+
#
20+
# This config works only for training on single device.
21+
22+
output_dir: /tmp/torchtune/qwen2_5_1_5B/qat_single_device # /tmp may be deleted by your system. Change it to your preference.
23+
24+
# Tokenizer
25+
tokenizer:
26+
_component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
27+
path: /tmp/Qwen2.5-1.5B-Instruct/vocab.json
28+
merges_file: /tmp/Qwen2.5-1.5B-Instruct/merges.txt
29+
max_seq_len: null
30+
31+
# Dataset
32+
dataset:
33+
_component_: torchtune.datasets.alpaca_cleaned_dataset
34+
packed: False # True increases speed
35+
seed: null
36+
shuffle: True
37+
38+
# Model Arguments
39+
model:
40+
_component_: torchtune.models.qwen2_5.qwen2_5_1_5b_instruct
41+
42+
checkpointer:
43+
_component_: torchtune.training.FullModelHFCheckpointer
44+
checkpoint_dir: /tmp/Qwen2.5-1.5B-Instruct
45+
checkpoint_files: [model.safetensors]
46+
recipe_checkpoint: null
47+
output_dir: ${output_dir}
48+
model_type: QWEN2
49+
resume_from_checkpoint: False
50+
51+
# Fine-tuning arguments
52+
batch_size: 1
53+
epochs: 1
54+
optimizer:
55+
_component_: bitsandbytes.optim.PagedAdamW
56+
lr: 5e-6
57+
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
58+
loss:
59+
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
60+
max_steps_per_epoch: null
61+
gradient_accumulation_steps: 1 # Use to increase effective batch size
62+
clip_grad_norm: null
63+
compile: False # torch.compile the model + loss, True increases speed + decreases memory
64+
65+
# QAT arguments
66+
quantizer:
67+
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
68+
groupsize: 256
69+
70+
# Training environment
71+
device: cuda
72+
73+
# Memory management
74+
enable_activation_checkpointing: True # True reduces memory
75+
enable_activation_offloading: False # True reduces memory
76+
77+
# Reduced precision
78+
dtype: bf16
79+
80+
# Logging
81+
metric_logger:
82+
_component_: torchtune.training.metric_logging.DiskLogger
83+
log_dir: ${output_dir}/logs
84+
log_every_n_steps: 1
85+
log_peak_memory_stats: False
86+
log_level: INFO # DEBUG, WARN, etc.
87+
88+
89+
# Profiler (disabled)
90+
profiler:
91+
_component_: torchtune.training.setup_torch_profiler
92+
enabled: False
93+
94+
#Output directory of trace artifacts
95+
output_dir: ${output_dir}/profiling_outputs
96+
97+
#`torch.profiler.ProfilerActivity` types to trace
98+
cpu: True
99+
cuda: True
100+
101+
#trace options passed to `torch.profiler.profile`
102+
profile_memory: False
103+
with_stack: False
104+
record_shapes: True
105+
with_flops: False
106+
107+
# `torch.profiler.schedule` options:
108+
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
109+
wait_steps: 5
110+
warmup_steps: 3
111+
active_steps: 2
112+
num_cycles: 1
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Config for single device QAT finetuning in qat_single_device.py
2+
# using a Qwen2.5 3B
3+
#
4+
# This config assumes that you've run the following command before launching
5+
# this run:
6+
# tune download Qwen/Qwen2.5-3B-Instruct --output-dir /tmp/Qwen2.5-3B-Instruct
7+
#
8+
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
9+
# you can install it with
10+
# pip install bitsandbytes
11+
#
12+
# To launch on a single device, run the following command from root:
13+
# tune run qat_single_device --config qwen2_5/3B_qat_single_device
14+
#
15+
# You can add specific overrides through the command line. For example
16+
# to override the checkpointer directory while launching training
17+
# you can run:
18+
# tune run qat_single_device --config qwen2_5/3B_qat_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
19+
#
20+
# This config works only for training on single device.
21+
22+
output_dir: /tmp/torchtune/qwen2_5_3B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
23+
24+
# Tokenizer
25+
tokenizer:
26+
_component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
27+
path: /tmp/Qwen2.5-3B-Instruct/vocab.json
28+
merges_file: /tmp/Qwen2.5-3B-Instruct/merges.txt
29+
max_seq_len: null
30+
31+
# Dataset
32+
dataset:
33+
_component_: torchtune.datasets.alpaca_cleaned_dataset
34+
packed: False # True increases speed
35+
seed: null
36+
shuffle: True
37+
38+
# Model Arguments
39+
model:
40+
_component_: torchtune.models.qwen2_5.qwen2_5_3b
41+
42+
checkpointer:
43+
_component_: torchtune.training.FullModelHFCheckpointer
44+
checkpoint_dir: /tmp/Qwen2.5-3B-Instruct
45+
checkpoint_files: [
46+
model-00001-of-00002.safetensors,
47+
model-00002-of-00002.safetensors,
48+
]
49+
recipe_checkpoint: null
50+
output_dir: ${output_dir}
51+
model_type: QWEN2
52+
resume_from_checkpoint: False
53+
54+
# Fine-tuning arguments
55+
batch_size: 2
56+
epochs: 1
57+
optimizer:
58+
_component_: bitsandbytes.optim.PagedAdamW
59+
lr: 5e-6
60+
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
61+
loss:
62+
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
63+
max_steps_per_epoch: null
64+
gradient_accumulation_steps: 1 # Use to increase effective batch size
65+
clip_grad_norm: null
66+
compile: False # torch.compile the model + loss, True increases speed + decreases memory
67+
68+
# QAT arguments
69+
quantizer:
70+
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
71+
groupsize: 256
72+
73+
# Training environment
74+
device: cuda
75+
76+
# Memory management
77+
enable_activation_checkpointing: True # True reduces memory
78+
enable_activation_offloading: False # True reduces memory
79+
80+
# Reduced precision
81+
dtype: bf16
82+
83+
# Logging
84+
metric_logger:
85+
_component_: torchtune.training.metric_logging.DiskLogger
86+
log_dir: ${output_dir}/logs
87+
log_every_n_steps: 1
88+
log_peak_memory_stats: False
89+
log_level: INFO # DEBUG, WARN, etc.
90+
91+
92+
# Profiler (disabled)
93+
profiler:
94+
_component_: torchtune.training.setup_torch_profiler
95+
enabled: False
96+
97+
#Output directory of trace artifacts
98+
output_dir: ${output_dir}/profiling_outputs
99+
100+
#`torch.profiler.ProfilerActivity` types to trace
101+
cpu: True
102+
cuda: True
103+
104+
#trace options passed to `torch.profiler.profile`
105+
profile_memory: False
106+
with_stack: False
107+
record_shapes: True
108+
with_flops: False
109+
110+
# `torch.profiler.schedule` options:
111+
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
112+
wait_steps: 5
113+
warmup_steps: 3
114+
active_steps: 2
115+
num_cycles: 1

0 commit comments

Comments
 (0)