fix(cache): compute codebooks on CPU at fp64 for MPS compatibility#5
Open
synapticode-ai wants to merge 1 commit into
Open
fix(cache): compute codebooks on CPU at fp64 for MPS compatibility#5synapticode-ai wants to merge 1 commit into
synapticode-ai wants to merge 1 commit into
Conversation
MPS framework doesn't support float64 dtype. compute_lloyd_max_codebook
(line 284) and compute_online_codebook (line 326) both hardcode
dtype=torch.float64 for their optimization grids, failing with TypeError
when callers pass device=torch.device('mps').
Fix: force internal computation onto CPU at fp64 in both functions
(preserving the algorithms' literature-standard precision for codebook
centroid optimization), then move the final centroids/boundaries to the
caller's target device when constructing the returned Codebook dataclass.
This fits the existing Codebook architecture: the dataclass's quantize/
dequantize methods (lines 214-223) already handle device migration at
usage time via .to(device=x.device, dtype=x.dtype). The fix sits at the
natural device-firewall: build on CPU, store on caller's device, use on
operand's device. _beta_pdf and _solve_lloyd_max inherit CPU automatically
through tensor argument propagation; no edits needed there.
Discovered during gamma-seeds tern-core R-track MPS validation
(2026-05-16). Both make_b_mse_hook and make_b_mse_hook_uniform factories
invoke TurboQuantConfig with device='mps' for KV-cache compression hooks;
both fail at TurboQuantConfig.__init__ before any hook is applied.
Verified end-to-end with tern-core's R7-B v1.2 harness on TinyLlama-1.1B
FP16 MPS: β1a hook now produces finite PPL output.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
5 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
compute_lloyd_max_codebookandcompute_online_codebookboth hardcodedtype=torch.float64for their optimization grids. MPS framework doesn't support float64, so both functions fail withTypeError: Cannot convert a MPS Tensor to float64 dtypewhen callers passdevice=torch.device('mps').This PR fixes both functions by forcing the internal computation onto CPU at fp64 (preserving the algorithms' literature-standard precision for codebook centroid optimization), then moving the final
centroids/boundariesto the caller's target device when constructing the returnedCodebookdataclass.Repro
Fix architecture
The
Codebookdataclass is already designed to be device-portable: itsquantize/dequantizemethods (lines 214–223) call.to(device=x.device, dtype=x.dtype)on the stored tensors at usage time. The fix sits at the natural device-firewall:.to(device)atCodebookconstruction).to(...)calls in quantize/dequantize)This means
_beta_pdfand_solve_lloyd_maxneed no edits — they inherit CPU automatically through tensor argument propagation once the entry-point functions force CPU on their grid construction.Verification
Verified end-to-end with a downstream consumer's bench harness:
compute_lloyd_max_codebook(d=64, b=4, device='mps')viaTurboQuantConfig)PPL=557.28for N=1, L=64 at b_mse=4; sanity result, not a quality measurement)compute_lloyd_max_codebook(d=64, b=4, device='cpu')continues to return CPU tensors at fp32Discovery context
Discovered during gamma-seeds tern-core R-track MPS validation (2026-05-16): both
make_b_mse_hookandmake_b_mse_hook_uniformfactories — KV-cache compression hooks used to evaluate TurboQuant under the R12 KV-cache-compression PPL diagnostic — invokeTurboQuantConfig.__init__withdevice='mps'. Both factories failed at construction before any hook could be applied. This fix unblocks downstream MPS-resident KV-cache compression benchmarking.Diff stats
src/cache.py: +31 / −8 (net +23 lines, mostly added comments + minimal.to(device)plumbing)🤖 Generated with Claude Code