Skip to content

Conversation

@43758726
Copy link
Collaborator

@43758726 43758726 commented Jan 5, 2026

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Reworked calibration data loading: shared processor-based preprocessing, safer C4 loading, pileval flow adjusted, and new loaders for ultrachat_200k, gsm8k, neuralmagic_calibration, open-platypus, and openwebtext; removed ptb.

Modification

  • calib_dataloader.py: added process_dataset, new dataset loaders (ultrachat_200k, gsm8k, neuralmagic_calibration, open-platypus, openwebtext), reworked pileval, simplified C4 loading, dropped PTB path, and updated get_calib_loaders signature to take processor.
  • calibrate.py: default dataset wikitext2, broader allowed dataset list, load AutoProcessor, pass to get_calib_loaders.
  • auto_awq.py: default dataset wikitext2, docstring updated.
  • gptq.py: default dataset wikitext2, load AutoProcessor, pass processor to get_calib_loaders, add device param and use .to(device).
  • smooth_quant.py / lmdeploy/lite/quantization/*: minor updates aligning defaults/params.
  • utils.py: updated --calib-dataset default and help text listing supported datasets.
  • Docs w4a16.md: examples now use --calib-dataset wikitext2.

Use cases (Optional)

calibrate: lmdeploy lite calibrate internlm/internlm2_5-7b-chat
auto_awq: lmdeploy lite auto_awq internlm/internlm2_5-7b-chat
auto_gptq: lmdeploy lite auto_gptq internlm/internlm2_5-7b-chat
smooth_quant: lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat
help messages: lmdeploy lite calibrate -h

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

type=int,
default=128,
help='Group size for weight quantization statistics')
parser.add_argument('--device', type=str, default='cuda', help='Device for calibrate. (cpu, cuda:0,1,2...)')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to bring in this argument?

Comment on lines 17 to 19
:param ds: language dataset to preprocess and tokenize
:param tokenizer: tokenizer to be used for tokenization
:param max_seq_length: maximum sequence length of samples
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refine the docstring to comply with lmdeploy's standards.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the quantization calibration dataset infrastructure by introducing a unified preprocessing pipeline, adding support for multiple new datasets, and improving error handling. The default calibration dataset changes from PTB to wikitext2, and the PTB dataset loader is removed entirely.

  • Introduced a shared process_dataset helper function that handles dataset-specific preprocessing and tokenization
  • Added support for 5 new calibration datasets: ultrachat_200k, gsm8k, neuralmagic_calibration, open-platypus, and openwebtext
  • Updated API functions to load and pass AutoProcessor alongside AutoTokenizer for improved preprocessing capabilities

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 26 comments.

Show a summary per file
File Description
lmdeploy/lite/utils/calib_dataloader.py Core refactor: added process_dataset helper, implemented 5 new dataset loaders with chat template support, simplified C4 loader, reworked pileval logic, removed PTB loaders, and updated get_calib_loaders to accept processor parameter
lmdeploy/lite/apis/calibrate.py Changed default dataset to wikitext2, added AutoProcessor loading, expanded supported dataset list, and updated function calls to pass processor
lmdeploy/lite/apis/gptq.py Changed default dataset to wikitext2, added AutoProcessor import and loading, added device parameter, and updated model initialization to use device
lmdeploy/lite/apis/auto_awq.py Changed default calibration dataset from ptb to wikitext2 and updated docstring
lmdeploy/lite/apis/smooth_quant.py Changed default calibration dataset from ptb to wikitext2
lmdeploy/lite/quantization/awq.py Added device parameter to max_memory_allocated calls for better multi-GPU support
lmdeploy/lite/quantization/calibration.py Added device parameter to max_memory_allocated calls for better multi-GPU support
lmdeploy/cli/utils.py Changed default --calib-dataset to wikitext2 and updated help text to list all supported datasets
lmdeploy/cli/lite.py Added --device argument for auto_gptq and calibrate commands
docs/zh_cn/quantization/w4a16.md Updated example to use wikitext2 instead of ptb
docs/en/quantization/w4a16.md Updated example to use wikitext2 instead of ptb

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

"""Load openwebtext train and test datasets and tokenize.
Args:
processor: Processor to apply chatplate encoding and encode text.
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says "chatplate" which appears to be a typo. This should likely be "chat template" to properly describe the functionality of applying a chat template for encoding.

Copilot uses AI. Check for mistakes.
lengths.append(len(ids))
if len(samples_encode) >= max_keep:
break

Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential division by zero error. If all samples in the dataset are filtered out (empty or too long), lengths will be empty and this will cause a ZeroDivisionError. Consider adding a check to ensure lengths is not empty before calculating the average.

Suggested change
if not lengths:
raise ValueError(
'No valid samples found in pileval dataset after filtering '
'(empty or >512 tokens). Please check the dataset or adjust '
'filtering parameters.'
)

Copilot uses AI. Check for mistakes.
# open-platypus samples have far fewer tokens than seqlen; recompute how many
# train items to select so it can still yield enough samples after concatenation.
lengths = torch.tensor([len(sample['input_ids']) for sample in train_data], dtype=torch.long)
avg_tokens = lengths.sum().item() / len(train_data)
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential division by zero error. If train_data is empty after filtering, this will cause a ZeroDivisionError. Consider adding a check to ensure train_data is not empty before calculating the average.

Copilot uses AI. Check for mistakes.
Comment on lines 243 to 244
avg_tokens = sum(lengths) / len(lengths)
needed_samples = (seqlen * nsamples) // avg_tokens
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential issue: if avg_tokens is greater than seqlen * nsamples, the integer division will result in needed_samples = 0, which means no samples will be collected. This could lead to an empty samples list and a subsequent error when trying to concatenate. Consider adding a check to ensure needed_samples is at least 1.

Suggested change
avg_tokens = sum(lengths) / len(lengths)
needed_samples = (seqlen * nsamples) // avg_tokens
if lengths:
avg_tokens = sum(lengths) / len(lengths)
needed_samples = max(1, int((seqlen * nsamples) // avg_tokens))
else:
# Fallback: if no valid lengths were collected, use the original nsamples.
needed_samples = max(1, int(nsamples))

Copilot uses AI. Check for mistakes.
# train items to select so it can still yield enough samples after concatenation.
lengths = torch.tensor([len(sample['input_ids']) for sample in train_data], dtype=torch.long)
avg_tokens = lengths.sum().item() / len(train_data)
needed_samples = (seqlen * nsamples) // avg_tokens
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential issue: if avg_tokens is greater than seqlen * nsamples, the integer division will result in needed_samples = 0, which means no samples will be collected. This could lead to an empty samples list and a subsequent error when trying to concatenate. Consider adding a check to ensure needed_samples is at least 1.

Suggested change
needed_samples = (seqlen * nsamples) // avg_tokens
if avg_tokens <= 0:
needed_samples = 1
else:
needed_samples = max(1, int((seqlen * nsamples) // avg_tokens))

Copilot uses AI. Check for mistakes.
presets.
:param ds: language dataset to preprocess and tokenize
:param tokenizer: tokenizer to be used for tokenization
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring parameter name is incorrect. The parameter is named 'processor' but the docstring refers to 'tokenizer'. This should be updated to match the actual parameter name.

Suggested change
:param tokenizer: tokenizer to be used for tokenization
:param processor: tokenizer to be used for tokenization

Copilot uses AI. Check for mistakes.
)

else:
raise NotImplementedError(f'Cannot preprocess dataset {ds.info.dataset_name}')
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message for unsupported datasets should include all newly supported datasets. Currently, it only mentions the dataset name from ds.info.dataset_name without listing what datasets are actually supported by process_dataset.

Copilot uses AI. Check for mistakes.
n_run += 1
if n_run == nsamples:
if n_run == needed_samples:
break
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential error when samples is empty. If all samples are filtered out or needed_samples is 0, torch.cat(samples, dim=1) will raise a RuntimeError. Consider adding a check to ensure samples is not empty before concatenation.

Suggested change
break
break
if not samples:
return [], None

Copilot uses AI. Check for mistakes.

def process(sample):
return processor(
processor.apply_chat_template(
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code assumes that all processors have an apply_chat_template method, but this method is typically available on tokenizers, not on all processors. If a processor doesn't have this method, it will raise an AttributeError. Consider checking if the processor has this method or falling back to the tokenizer's method.

Copilot uses AI. Check for mistakes.
'content': sample['output']
}]
return processor(
processor.apply_chat_template(
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code assumes that all processors have an apply_chat_template method, but this method is typically available on tokenizers, not on all processors. If a processor doesn't have this method, it will raise an AttributeError. Consider checking if the processor has this method or falling back to the tokenizer's method.

Copilot uses AI. Check for mistakes.
@lvhan028 lvhan028 merged commit 5dcbdc7 into InternLM:main Jan 8, 2026
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants