Skip to content

Add QAT single-device recipe #2716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 14, 2025
Merged

Conversation

adheep04
Copy link
Contributor

@adheep04 adheep04 commented May 10, 2025

Context

  • add a new feature

This PR adds a new quantization-aware training (QAT) recipe for single-device setups. It is currently unfinished but posted as a draft for visibility and early feedback.

Addresses #2696

Changelog

  • Core QATRecipeSingleDevice class scaffolded
  • Finished training loop

TODO

  • Implement setup() to wire up model, dataloader, loss, optimizer, and profiler
  • Implement checkpointing (save_checkpoint and load_checkpoint)
  • Add optional validation loop
  • Test functionality end-to-end
  • Polish code, docstrings, and typing

Test plan

  • Ran pre-commit hooks and linters

UX

  • I did not change any public API

cc @joecummings let me know if you have any thoughts so far!

Copy link

pytorch-bot bot commented May 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2716

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a63a306 with merge base cbc1456 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 10, 2025
@adheep04 adheep04 force-pushed the qat-single-device branch 2 times, most recently from b3370bb to 09388f9 Compare May 10, 2025 21:50
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good! I'll keep checking in as you go through your checklist.

Definitely don't need to focus on docstrings and typing in this recipe lol, I personally think we go a little overboard there.

The focus should be clear, concise code and functionality. If you're confused on which features should compose, just lmk.

@adheep04 adheep04 force-pushed the qat-single-device branch from bff9537 to c505593 Compare May 13, 2025 19:20
@adheep04
Copy link
Contributor Author

adheep04 commented May 13, 2025

Closes #2696
@joecummings Just finished the checklist and criteria for acceptance. The recipe works on Qwen2.5-1.5B and the test file passes for the llama2 small test model. This should be ready for another look when you have a chance. Let me know if I didn't compose the correct features. Thanks!

Changelog

  • Finished implementing entire qat-single-device recipe (including all functions from the interface and any helpers)
  • Wrote and ran test for the recipe (logs below) - the test only includes the small test llama2 (small-ckpt-hf-03082024.pt)
  • Added configs for Qwen2.5-1.5B, Qwen2.5-3B and llama2-1B config (for running test_qat_single_device.py)
  • Updated recipe registry qat-single-device for Qwen2.5-1.5B, Qwen2.5-3B, and Llama-2-1B
  • Jump to Full Output Log for Qwen2.5-1.5B
  • Jump to Full Output Log for Test

Notes

  • The usage of the llama-2 model was primarily for testing purposes and because qwen2.5 config dependencies weren't set up fully (hope this is okay!)
  • In test_qat_single_device.py, there's a line that overrides the config and adds "intermediate_dim=768" to the command since it wasn't defined in the test llama config. This can be removed and is only for the llama2 small test case.
  • Just saw your note about docstrings and typing, I didn't focus on them too much but I replicated the class docstring like in qat_distributed.py.

Full Output Log on Qwen2.5-1.5B

  • Changed quantization group size -> 64 and max_seq_len -> 128 for local memory constraints (not changed in actual config)
  • Recipe successfully runs with expected loss movement (epoch0: 3.084172, epoch1: 2.72487)
> tune run qat_single_device --config recipes/configs/qwen2_5/1.5B_qat_single_device.yaml                           11:40:35 AM
INFO:torchtune.utils._logging:Running QATRecipeSingleDevice with resolved config:

batch_size: 1
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Qwen2.5-1.5B-Instruct
  checkpoint_files:
  - model.safetensors
  model_type: QWEN2
  output_dir: /tmp/torchtune/qwen2_5_1_5B/qat_single_device
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: false
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: true
epochs: 2
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: false
loss:
  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: 5
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /tmp/torchtune/qwen2_5_1_5B/qat_single_device/logs
model:
  _component_: torchtune.models.qwen2_5.qwen2_5_1_5b_instruct
optimizer:
  _component_: bitsandbytes.optim.PagedAdamW
  lr: 5.0e-06
optimizer_in_bwd: true
output_dir: /tmp/torchtune/qwen2_5_1_5B/qat_single_device
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/torchtune/qwen2_5_1_5B/qat_single_device/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
  groupsize: 64
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
  max_seq_len: 128
  merges_file: /tmp/Qwen2.5-1.5B-Instruct/merges.txt
  path: /tmp/Qwen2.5-1.5B-Instruct/vocab.json

Writing logs to /tmp/torchtune/qwen2_5_1_5B/qat_single_device/logs/log_1747150878.txt
INFO:torchtune.utils._logging:Instantiating model and loading checkpoint ...
INFO:torchtune.utils._logging:QAT Model (quantizer applied) is initialized with compute precision torch.bfloat16.
INFO:torchtune.utils._logging:Memory stats after model init:
	GPU peak memory active: 2.93 GiB
	GPU peak memory alloc: 2.93 GiB
	GPU peak memory reserved: 3.13 GiB
INFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|5|Loss: 3.084172248840332: 100%|███| 5/5 [00:15<00:00,  3.06s/it]INFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Model checkpoint of size 2.88 GiB saved to /tmp/torchtune/qwen2_5_1_5B/qat_single_device/epoch_0/model-00001-of-00001.safetensors
INFO:torchtune.utils._logging:Recipe checkpoint of size 11.50 GiB saved to /tmp/torchtune/qwen2_5_1_5B/qat_single_device/recipe_state/recipe_state.pt
INFO:torchtune.utils._logging:Saving checkpoint took 8.20 secs
1|5|Loss: 3.084172248840332: 100%|███| 5/5 [00:23<00:00,  4.73s/it]
  0%|                                                                                                                                                                                       | 0/5 [00:00<?, ?it/sINFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...██████| 5/5 [00:14<00:00,  3.02s/it]
INFO:torchtune.utils._logging:Model checkpoint of size 2.88 GiB saved to /tmp/torchtune/qwen2_5_1_5B/qat_single_device/epoch_1/model-00001-of-00001.safetensors
INFO:torchtune.utils._logging:Saving final epoch checkpoint.
INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
INFO:torchtune.utils._logging:Saving checkpoint took 1.55 secs
2|10|Loss: 2.724879026412964: 100%|█████ 5/5 [00:16<00:00,  3.27s/it]

Full Output Log on Recipe Test

  • Used small llama2
  • Main test passed (checks for reasonable loss range / change)
> pytest -s tests/recipes/test_qat_single_device.py --with-integration
Expected artifacts for test run are:
small-ckpt-tune-03082024.pt
small-ckpt-meta-03082024.pt
small-ckpt-hf-03082024.pt
small-ckpt-tune-llama3-05052024.pt
small-ckpt-hf-reward-07122024.pt
small-ckpt-meta-vision-10172024.pt
small-ckpt-hf-vision-10172024.pt
llama3-hf-04232025/config.json
llama3-hf-04232025/generation_config.json
llama3-hf-04232025/model.safetensors
llama3-hf-04232025/model.safetensors.index.json
llama3-hf-04232025/special_tokens_map.json
llama3-hf-04232025/tokenizer.json
llama3-hf-04232025/tokenizer_config.json
tokenizer.model
tokenizer_llama3.model
File already exists locally: /tmp/test-artifacts/small-ckpt-tune-03082024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-meta-03082024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-hf-03082024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-tune-llama3-05052024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-hf-reward-07122024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-meta-vision-10172024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-hf-vision-10172024.pt
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/config.json
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/generation_config.json
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/model.safetensors
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/model.safetensors.index.json
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/special_tokens_map.json
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/tokenizer.json
File already exists locally: /tmp/test-artifacts/llama3-hf-04232025/tokenizer_config.json
File already exists locally: /tmp/test-artifacts/tokenizer.model
File already exists locally: /tmp/test-artifacts/tokenizer_llama3.model
======================================= test session starts ========================================
platform linux -- Python 3.12.10, pytest-7.4.0, pluggy-1.5.0
rootdir: /home/adheep/code/repos/torchtune
configfile: pyproject.toml
plugins: anyio-4.9.0, mock-3.14.0, cov-6.1.1, integration-0.2.3
collected 1 item                                                                                   

tests/recipes/test_qat_single_device.py INFO:torchtune.utils._logging:Running QATRecipeSingleDevice with resolved config:

batch_size: 1
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/test-artifacts
  checkpoint_files:
  - /tmp/test-artifacts/small-ckpt-hf-03082024.pt
  model_type: LLAMA2
  output_dir: /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
dataset:
  _component_: torchtune.datasets.alpaca_dataset
  data_files: /home/adheep/code/repos/torchtune/tests/assets/alpaca_tiny.json
  packed: false
  source: json
  split: train
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 2
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: false
loss:
  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: 2
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  filename: /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0tmppytest-of-adheeppytest-26test_loss_llama2_1B_qat_single0.txt
  log_dir: /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0/logs
model:
  _component_: torchtune.models.llama2.llama2
  attn_dropout: 0.0
  embed_dim: 256
  intermediate_dim: 768
  max_seq_len: 2048
  norm_eps: 1.0e-05
  num_heads: 16
  num_kv_heads: 8
  num_layers: 4
  vocab_size: 32000
optimizer:
  _component_: torch.optim.AdamW
  lr: 2.0e-05
optimizer_in_bwd: true
output_dir: /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
  groupsize: 256
resume_from_checkpoint: false
seed: 9
shuffle: true
tokenizer:
  _component_: torchtune.models.llama2.llama2_tokenizer
  max_seq_len: 2048
  path: /tmp/test-artifacts/tokenizer.model
  prompt_template: null

INFO:torchtune.utils._logging:Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. Enabling activation offloading should reduce memory further.
Writing logs to /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0tmppytest-of-adheeppytest-26test_loss_llama2_1B_qat_single0.txt
INFO:torchtune.utils._logging:Instantiating model and loading checkpoint ...
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Memory stats after model init:
	GPU peak memory active: 0.05 GiB
	GPU peak memory alloc: 0.05 GiB
	GPU peak memory reserved: 0.05 GiB
INFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|2|Loss: 10.715113639831543: 100%|███████████████████████████████████| 2/2 [00:00<00:00,  3.22it/s]INFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Model checkpoint of size 0.04 GiB saved to /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0/epoch_0/model-00001-of-00001.safetensors
INFO:torchtune.utils._logging:Recipe checkpoint of size 0.07 GiB saved to /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0/recipe_state/recipe_state.pt
INFO:torchtune.utils._logging:Saving checkpoint took 0.07 secs
1|2|Loss: 10.715113639831543: 100%|███████████████████████████████████| 2/2 [00:00<00:00,  4.38it/s]
  0%|                                                                         | 0/2 [00:00<?, ?it/sINFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Model checkpoint of size 0.04 GiB saved to /tmp/pytest-of-adheep/pytest-26/test_loss_llama2_1B_qat_single0/epoch_1/model-00001-of-00001.safetensors
INFO:torchtune.utils._logging:Saving final epoch checkpoint.
INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
INFO:torchtune.utils._logging:Saving checkpoint took 0.02 secs
2|4|Loss: 10.497347831726074: 100%|███████████████████████████████████| 2/2 [00:00<00:00, 11.49it/s]
.

========================================= warnings summary =========================================
tests/recipes/test_qat_single_device.py::TestQATSingleDeviceRecipe::test_loss[llama2/1B_qat_single_device-llama2-hf-1-1]
  /home/adheep/code/repos/torchtune/torchtune/datasets/_alpaca.py:78: DeprecationWarning: train_on_input is deprecated and will be removed in a future release. Please use masking_strategy instead.You should replace train_on_input=True with masking_strategy='train_on_all', and train_on_input=False with masking_strategy='train_on_assistant'.For backwards compatibility, if you pass both train_on_input and masking_strategy, the value of masking_strategy will be ignored until torchtune 0.7. 
    message_transform = AlpacaToMessages(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=================================== 1 passed, 1 warning in 1.38s ===================================

log screenshots

image
image

@adheep04 adheep04 marked this pull request as ready for review May 13, 2025 20:45
@adheep04 adheep04 changed the title WIP: Add QAT single-device recipe Add QAT single-device recipe May 13, 2025
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Last thing you need to do if update the README with a green checkmark showing that we now support QAT for single device :)

@adheep04
Copy link
Contributor Author

I'll get to that asap!

@adheep04
Copy link
Contributor Author

@joecummings all set with the readme change.

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent attention to detail - made this very easy to review. Thanks!

@joecummings joecummings merged commit fd88633 into pytorch:main May 14, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants