Skip to content

Conversation

@kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Oct 22, 2025

Summary

To allow for arbitrary heterogeneous quantization schemes, this PR switches several helpers from AutoAWQ to the observer and QDQ logic. AWQ no longer constrains that the quantization config needs to have the same settings for group_size, symmetric, and num_bits for each config_group.

Resolves #1657

Prerequisites:

Test plan

  • When running llm-compressor/examples/awq/llama_example.py with this (with duo_scaling="both") and logging the best configuration of (ratio, duo_scaling), I see a good mix of Falses and Trues. i.e. a good percentage of best_scales were found with duo_scaling=False and a good percentage were found with duo_scaling=True. Generated model output looks good.
  • When using awq_one_shot.py (pasted below), Wikitext PPL is consistent for w4a16 and w4a16_asym on this branch when compared to main, and better than what was reported in a previous AWQ PR, but those might have been differently configured. For W4A16_ASYM, the results are both 13.41 for main and this branch. This is what we've been historically using to test regressions.
Scheme Wikitext PPL RTN AWQ main AWQ this branch
W4A16 13.784 13.477 13.426
W4A16_ASYM 13.606 13.346 13.377
  • I see a small regression in recovery when running CADENCE=weekly TEST_DATA_FILE=~/projects/llm-compressor/tests/lmeval/configs/w4a16_awq_sym.yaml pytest -s ~/projects/llm-compressor/tests/lmeval/test_lmeval.py on this branch, which causes the test to fail. This persists even when using pseudo_quantize_tensor instead of call_observer/forward_quantize, as shown in this diff. I get the same result in this diff, so at least that means quantization logic in CT is consistent with AutoAWQ
    Output:
<main>
2025-11-17T18:26:04.682699+0000 | _validate_recovery | INFO - ✓ exact_match,strict-match                 | Base: 0.7650 | Compressed: 0.7090 | Recovery: 92.68% ↑ | Threshold: ≥92.00%
2025-11-17T18:26:04.682811+0000 | _validate_recovery | INFO - ✓ exact_match,flexible-extract             | Base: 0.7630 | Compressed: 0.7100 | Recovery: 93.05% ↑ | Threshold: ≥93.00%
<this branch>
2025-11-17T17:55:00.648672+0000 | _validate_recovery | ERROR - ✗ exact_match,strict-match                 | Base: 0.7650 | Compressed: 0.6950 | Recovery: 90.85% ↑ | Threshold: ≥92.00%
2025-11-17T17:55:00.648967+0000 | _validate_recovery | ERROR - ✗ exact_match,flexible-extract             | Base: 0.7630 | Compressed: 0.6960 | Recovery: 91.22% ↑ | Threshold: ≥93.00%

This is already a pretty high drop in recovery, should we revisit this test?

awq_oneshot.py script ```python import os

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

from llmcompressor import oneshot, active_session
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modifiers.awq import AWQModifier, AWQMapping
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationStrategy,
QuantizationType,
)

MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"

SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-asym"

Configure the quantization algorithm to run.

recipe = [
AWQModifier(
ignore=[
"lm_head",
"re:.*mlp.gate$",
"re:.mlp.shared_expert_gate$",
"re:visual.
",
],
scheme="W4A16_ASYM",
duo_scaling="both",
targets=["Linear"],
# offload_device=torch.device("cpu"),
),
]

Select calibration dataset.

DATASET_ID = "mit-han-lab/pile-val-backup"
DATASET_SPLIT = "validation"

Select number of samples. 256 samples is a good place to start.

Increasing the number of samples can improve accuracy.

NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512

def get_calib_dataset(tokenizer):
from datasets import load_dataset

ds = load_dataset(
    DATASET_ID,
    split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*10}]",
)

def preprocess(example):
    return {"input_ids": tokenizer.encode(example["text"].strip())}

ds = (
    ds.shuffle(seed=42)
    .map(preprocess, remove_columns=ds.column_names)
    .select(range(NUM_CALIBRATION_SAMPLES))
)

return ds

if name == "main":
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype="auto", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

###
### Apply algorithms.
###
oneshot(
    model=model,
    dataset=get_calib_dataset(tokenizer),
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    log_dir=None,
    trust_remote_code_model=True,
)

# Confirm generations of the quantized model look sane.
dispatch_for_generation(model)
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

##
### Apply algorithms.
##

## LM EVAL

active_session().reset()
del model
del tokenizer
torch.cuda.empty_cache()

import lm_eval
from lm_eval.utils import make_table

results = lm_eval.simple_evaluate(
    model="vllm",
    model_args={
        "pretrained": SAVE_DIR,
        "add_bos_token": True,
        "dtype": "bfloat16",
        "gpu_memory_utilization": 0.7,
        "max_model_len": 4096,
        # "max_num_batched_tokens": 128,
        # "max_num_seqs": 128,
    },
    tasks=["wikitext"],
    batch_size=128,
)
print(make_table(results))
</details>

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
@github-actions
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

Signed-off-by: Brian Dellabetta <[email protected]>
@brian-dellabetta brian-dellabetta changed the title [WIP] Generalize AWQ quantization [AWQ] Generalize AWQ quantization Nov 13, 2025
Copy link
Collaborator Author

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that so long as you feel confident that _compute_layer_means is going to work as expected for all the supported strategies, then I think this looks good to me!

module=balance_layer,
)
for balance_layer in mapping.balance_layers
if hasattr(balance_layer, "quantization_scheme")
Copy link
Collaborator Author

@kylesayrs kylesayrs Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this function error if balance layer doesn't have a quantization scheme?

Copy link
Collaborator

@brian-dellabetta brian-dellabetta Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this to skip balance_layers that don't have a quant schema, if that does arise. So this should be robust enough to still work when someone wants to update a mapping like input_layernorm -> q/k/v proj, but does NOT want to quantize all q/k/v proj layers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, that seems like the most robust solution.


return w, scales, zeros
for layer in layers:
if not hasattr(layer, "weight"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like a lot of this algorithm assumes that layers have weights? Should we be silently skipping here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To date, all the balance_layers are linear layers. I could just change the AWQMapping type to linear and avoid this checking logic. WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything's good with me, I prefer loud errors if assumptions are violated in this case.

weight = layer.weight
org_shape = weight.shape

# If group-wise, calculate abs max based on group
Copy link
Collaborator Author

@kylesayrs kylesayrs Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this generalize to block or tensor quant? I think I personally need to get a better understanding of what "group normalization" is supposed to do and how it applies to other quant strategies.

Writing this function with torch native vectorized ops might help, I think it might be reducible to frobenius norm and mean/sum ops, but that doesn't have to be done now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for tensor, i don't think using channel-wise means makes sense. We could validate that if need be, i.e. throw validation error if strategy==TENSOR and duo_scaling != False.

For block, we'd have to update this logic, yeah. I'll leave as a todo for now

Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Copy link
Collaborator Author

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve from my side

Copy link
Collaborator

@fynnsu fynnsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, added a couple comments below!

Comment on lines +706 to +709
if weight_total_sum is None:
weight_total_sum = weight_sum
else:
weight_total_sum += weight_sum
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a little strange to me. Can't we just initialize weight_total_sum = 0.0?

Comment on lines +609 to +614
history.append(loss)
if loss < best_error:
best_error = loss
best_duo_scaling = use_duo_scaling
best_ratio = ratio
best_scales = scales.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like history is currently just used for debugging when no best_ratio is found. I wonder if we could instead be recording saving the hyperparameter states and losses and printing them everytime (when in DEBUG logging mode.

e.g.

for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings):
    ratio = grid_idx / n_grid
    ...
    history.append({"ratio": ratio, "duo_scaling": use_duo_scaling, "error": loss})
...
logger.debug(history)

This might be useful in the future as we look into improving the hyperparameter search / to get a sense of what parameters are most often selected. I think including the ratio/duo_scaling in some way is important now that we've switched from a simple linear search to a grid search, so that's easy to tell which arguments are being set.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

W4fp8 AWQ

5 participants