-
Notifications
You must be signed in to change notification settings - Fork 609
Add QAT single-device recipe #2716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2716
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a63a306 with merge base cbc1456 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
b3370bb
to
09388f9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good! I'll keep checking in as you go through your checklist.
Definitely don't need to focus on docstrings and typing in this recipe lol, I personally think we go a little overboard there.
The focus should be clear, concise code and functionality. If you're confused on which features should compose, just lmk.
bff9537
to
c505593
Compare
Closes #2696 Changelog
Notes
Full Output Log on Qwen2.5-1.5B
> tune run qat_single_device --config recipes/configs/qwen2_5/1.5B_qat_single_device.yaml 11:40:35 AM
INFO:torchtune.utils._logging:Running QATRecipeSingleDevice with resolved config:
batch_size: 1
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2.5-1.5B-Instruct
checkpoint_files:
- model.safetensors
model_type: QWEN2
output_dir: /tmp/torchtune/qwen2_5_1_5B/qat_single_device
recipe_checkpoint: null
clip_grad_norm: null
compile: false
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: false
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: true
epochs: 2
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: false
loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: 5
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /tmp/torchtune/qwen2_5_1_5B/qat_single_device/logs
model:
_component_: torchtune.models.qwen2_5.qwen2_5_1_5b_instruct
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 5.0e-06
optimizer_in_bwd: true
output_dir: /tmp/torchtune/qwen2_5_1_5B/qat_single_device
profiler:
_component_: torchtune.training.setup_torch_profiler
active_steps: 2
cpu: true
cuda: true
enabled: false
num_cycles: 1
output_dir: /tmp/torchtune/qwen2_5_1_5B/qat_single_device/profiling_outputs
profile_memory: false
record_shapes: true
wait_steps: 5
warmup_steps: 3
with_flops: false
with_stack: false
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 64
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
_component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
max_seq_len: 128
merges_file: /tmp/Qwen2.5-1.5B-Instruct/merges.txt
path: /tmp/Qwen2.5-1.5B-Instruct/vocab.json
Writing logs to /tmp/torchtune/qwen2_5_1_5B/qat_single_device/logs/log_1747150878.txt
INFO:torchtune.utils._logging:Instantiating model and loading checkpoint ...
INFO:torchtune.utils._logging:QAT Model (quantizer applied) is initialized with compute precision torch.bfloat16.
INFO:torchtune.utils._logging:Memory stats after model init:
GPU peak memory active: 2.93 GiB
GPU peak memory alloc: 2.93 GiB
GPU peak memory reserved: 3.13 GiB
INFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|5|Loss: 3.084172248840332: 100%|███| 5/5 [00:15<00:00, 3.06s/it]INFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Model checkpoint of size 2.88 GiB saved to /tmp/torchtune/qwen2_5_1_5B/qat_single_device/epoch_0/model-00001-of-00001.safetensors
INFO:torchtune.utils._logging:Recipe checkpoint of size 11.50 GiB saved to /tmp/torchtune/qwen2_5_1_5B/qat_single_device/recipe_state/recipe_state.pt
INFO:torchtune.utils._logging:Saving checkpoint took 8.20 secs
1|5|Loss: 3.084172248840332: 100%|███| 5/5 [00:23<00:00, 4.73s/it]
0%| | 0/5 [00:00<?, ?it/sINFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...██████| 5/5 [00:14<00:00, 3.02s/it]
INFO:torchtune.utils._logging:Model checkpoint of size 2.88 GiB saved to /tmp/torchtune/qwen2_5_1_5B/qat_single_device/epoch_1/model-00001-of-00001.safetensors
INFO:torchtune.utils._logging:Saving final epoch checkpoint.
INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
INFO:torchtune.utils._logging:Saving checkpoint took 1.55 secs
2|10|Loss: 2.724879026412964: 100%|█████ 5/5 [00:16<00:00, 3.27s/it] Full Output Log on Recipe Test
> pytest -s tests/recipes/test_qat_single_device.py --with-integration
Expected artifacts for test run are:
small-ckpt-tune-03082024.pt
small-ckpt-meta-03082024.pt
small-ckpt-hf-03082024.pt
small-ckpt-tune-llama3-05052024.pt
small-ckpt-hf-reward-07122024.pt
small-ckpt-meta-vision-10172024.pt
small-ckpt-hf-vision-10172024.pt
llama3-hf-04232025/config.json
llama3-hf-04232025/generation_config.json
llama3-hf-04232025/model.safetensors
llama3-hf-04232025/model.safetensors.index.json
llama3-hf-04232025/special_tokens_map.json
llama3-hf-04232025/tokenizer.json
llama3-hf-04232025/tokenizer_config.json
tokenizer.model
tokenizer_llama3.model
File already exists locally: /tmp/test-artifacts/small-ckpt-tune-03082024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-meta-03082024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-hf-03082024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-tune-llama3-05052024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-hf-reward-07122024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-meta-vision-10172024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-hf-vision-10172024.pt
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/config.json
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/generation_config.json
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/model.safetensors
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/model.safetensors.index.json
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/special_tokens_map.json
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/tokenizer.json
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/tokenizer_config.json
File already exists locally: /tmp/test-artifacts/tokenizer.model
File already exists locally: /tmp/test-artifacts/tokenizer_llama3.model
======================================= test session starts ========================================
platform linux -- Python 3.12.10, pytest-7.4.0, pluggy-1.5.0
rootdir: /home/adheep/code/repos/torchtune
configfile: pyproject.toml
plugins: anyio-4.9.0, mock-3.14.0, cov-6.1.1, integration-0.2.3
collected 1 item
tests/recipes/test_qat_single_device.py INFO:torchtune.utils._logging:Running QATRecipeSingleDevice with resolved config:
batch_size: 1
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/test-artifacts
checkpoint_files:
- /tmp/test-artifacts/small-ckpt-hf-03082024.pt
model_type: LLAMA2
output_dir: /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0
recipe_checkpoint: null
clip_grad_norm: null
compile: false
dataset:
_component_: torchtune.datasets.alpaca_dataset
data_files: /home/adheep/code/repos/torchtune/tests/assets/alpaca_tiny.json
packed: false
source: json
split: train
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 2
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: false
loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: 2
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
filename: /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0tmppytest-of-adheeppytest-26test_loss_llama2_1B_qat_single0.txt
log_dir: /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0/logs
model:
_component_: torchtune.models.llama2.llama2
attn_dropout: 0.0
embed_dim: 256
intermediate_dim: 768
max_seq_len: 2048
norm_eps: 1.0e-05
num_heads: 16
num_kv_heads: 8
num_layers: 4
vocab_size: 32000
optimizer:
_component_: torch.optim.AdamW
lr: 2.0e-05
optimizer_in_bwd: true
output_dir: /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0
profiler:
_component_: torchtune.training.setup_torch_profiler
active_steps: 2
cpu: true
cuda: true
enabled: false
num_cycles: 1
output_dir: /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0/profiling_outputs
profile_memory: false
record_shapes: true
wait_steps: 5
warmup_steps: 3
with_flops: false
with_stack: false
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
resume_from_checkpoint: false
seed: 9
shuffle: true
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
max_seq_len: 2048
path: /tmp/test-artifacts/tokenizer.model
prompt_template: null
INFO:torchtune.utils._logging:Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. Enabling activation offloading should reduce memory further.
Writing logs to /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0tmppytest-of-adheeppytest-26test_loss_llama2_1B_qat_single0.txt
INFO:torchtune.utils._logging:Instantiating model and loading checkpoint ...
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Memory stats after model init:
GPU peak memory active: 0.05 GiB
GPU peak memory alloc: 0.05 GiB
GPU peak memory reserved: 0.05 GiB
INFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|2|Loss: 10.715113639831543: 100%|███████████████████████████████████| 2/2 [00:00<00:00, 3.22it/s]INFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Model checkpoint of size 0.04 GiB saved to /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0/epoch_0/model-00001-of-00001.safetensors
INFO:torchtune.utils._logging:Recipe checkpoint of size 0.07 GiB saved to /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0/recipe_state/recipe_state.pt
INFO:torchtune.utils._logging:Saving checkpoint took 0.07 secs
1|2|Loss: 10.715113639831543: 100%|███████████████████████████████████| 2/2 [00:00<00:00, 4.38it/s]
0%| | 0/2 [00:00<?, ?it/sINFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Model checkpoint of size 0.04 GiB saved to /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0/epoch_1/model-00001-of-00001.safetensors
INFO:torchtune.utils._logging:Saving final epoch checkpoint.
INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
INFO:torchtune.utils._logging:Saving checkpoint took 0.02 secs
2|4|Loss: 10.497347831726074: 100%|███████████████████████████████████| 2/2 [00:00<00:00, 11.49it/s]
.
========================================= warnings summary =========================================
tests/recipes/test_qat_single_device.py::TestQATSingleDeviceRecipe::test_loss[llama2/1B_qat_single_device-llama2-hf-1-1]
/home/adheep/code/repos/torchtune/torchtune/datasets/_alpaca.py:78: DeprecationWarning: train_on_input is deprecated and will be removed in a future release. Please use masking_strategy instead.You should replace train_on_input=True with masking_strategy='train_on_all', and train_on_input=False with masking_strategy='train_on_assistant'.For backwards compatibility, if you pass both train_on_input and masking_strategy, the value of masking_strategy will be ignored until torchtune 0.7.
message_transform = AlpacaToMessages(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=================================== 1 passed, 1 warning in 1.38s ===================================
log screenshots |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Last thing you need to do if update the README with a green checkmark showing that we now support QAT for single device :)
I'll get to that asap! |
@joecummings all set with the readme change. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent attention to detail - made this very easy to review. Thanks!
Context
This PR adds a new quantization-aware training (QAT) recipe for single-device setups. It is currently unfinished but posted as a draft for visibility and early feedback.
Addresses #2696
Changelog
QATRecipeSingleDevice
class scaffoldedTODO
setup()
to wire up model, dataloader, loss, optimizer, and profilersave_checkpoint
andload_checkpoint
)Test plan
UX
cc @joecummings let me know if you have any thoughts so far!