Optimize sparse 2:4 compression performance #358
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Optimize Sparse 2:4 Compression Performance (2.64x Speedup)
Summary
This PR optimizes the sparse 2:4 compression pipeline, achieving a 2.64x speedup (30.18s → 11.42s) for Llama-3-8B sparse models by eliminating unnecessary CPU transfers and implementing vectorized GPU-accelerated bit packing operations.
Motivation
Profiling revealed that sparse compression was significantly slower than expected due to:
.cpu()
calls before compressionChanges
1. Vectorized Bit Packing (
src/compressed_tensors/utils/helpers.py
)pack_bitmasks()
that efficiently packs bits using tensor operations2. Deferred CPU Transfer (
src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py
).cpu()
call inSparse24BitMaskTensor.from_dense()
3. Optimized Top-k Selection
sorted=False
parameter totopk()
operationPerformance Results
Real Model Benchmark (Llama-3-8B-Instruct 2:4 Sparse)
Testing with
/home/rahul/llm-compressor/Meta-Llama-3-8B-Instruct2of4-sparse
on NVIDIA A100-80GB:Model details:
Bit Packing Microbenchmark
Performance of the optimized
pack_bitmasks
function:Verification
Quick Verification Script
Benchmark Script
Run Tests
Technical Details
Vectorized Algorithm
The new implementation uses vectorized operations instead of Python loops:
Memory Efficiency
Backward Compatibility
Code Quality
Refactoring
Test Coverage
Checklist