Summary
Chronos2Pipeline.from_pretrained silently ships a model that is not in the state of its checkpoint. Two distinct mutations happen at load:
- Layer-norm weights are overwritten with
initializer_factor = 0.05 (the T5-style init value) instead of the checkpoint values. Predictions become nearly constant (std ~4e-4) — I'll call this degenerate.
- RoPE
inv_freq buffers are zeroed, which makes cos = 1, sin = 0 at every position, silently disabling rotary position encoding. The model still "works" because the input patch embedding carries explicit time encoding, but attention heads that were trained with RoPE receive positional-identity keys.
Neither bug raises an error. from_pretrained prints Loading weights: 100%|...| and returns.
Both reproduce against the official autogluon/chronos-2-small on HF Hub (and amazon/chronos-2) — not specific to any local checkpoint.
Environment
chronos-forecasting==2.2.2
torch==2.10.0, transformers==5.5.3
- Python 3.12, macOS arm64 / CPU,
device_map="cpu"
Minimal reproducer
import torch
from chronos import Chronos2Pipeline
from chronos.chronos2.layers import RoPE
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
pipe = Chronos2Pipeline.from_pretrained("autogluon/chronos-2-small", device_map="cpu")
model = pipe.model
# Bug 1: layer-norm weights are at initializer_factor default, not the checkpoint.
ln = model.encoder.block[0].layer[0].layer_norm.weight
print("layer_norm.weight[:3]:", ln.data[:3].tolist())
# -> [0.05, 0.05, 0.05] (initializer_factor * 1.0)
# Expected (from checkpoint safetensors):
# -> [0.11422..., 0.42890..., 0.27481...]
# Bug 2: RoPE inv_freq buffer is zeroed.
rope = model.encoder.block[0].layer[0].self_attention.rope_embed
print("rope.inv_freq[:4]:", rope.inv_freq[:4].tolist())
# -> [0.0, 0.0, 0.0, 0.0]
# Compare to a freshly constructed RoPE with the same config:
fresh = RoPE(dim=rope.dim, base=rope.base)
print("fresh.inv_freq[:4]:", fresh.inv_freq[:4].tolist())
# -> [1.0, 0.7498942, 0.5623413, 0.4216965] (correct)
Functional impact
On an ETTh1-style context (last-512 → 8-step forecast), median prediction values:
| Load path |
pred std |
first 3 steps |
Pure from_pretrained (layer-norm broken + RoPE off) |
4e-4 |
9.346, 9.346, 9.346 |
+ manual load_state_dict(strict=True) (RoPE still off) |
3e-2 |
9.117, 9.103, 9.098 |
| + both fixes (below) |
2e-1 |
measured 30%+ better MAE vs ground truth on multiple datasets |
The compounded effect is large enough that predictions without the fixes are visibly inferior on real-world benchmarks (ETTh1/ETTh2/ETTm1/Electricity).
Suspected root cause
- For the layer-norm bug,
Chronos2Model._init_weights unconditionally fills Chronos2LayerNorm.weight.data with initializer_factor * 1.0. If from_pretrained runs _init_weights on modules after materializing them from the checkpoint (contrary to HF's usual "init first, then load" order), that explains the overwrite. This could be tied to the meta-device init path.
- For the RoPE bug,
RoPE.__init__ computes inv_freq via torch.arange(...).float() / dim and registers it with persistent=False. Non-persistent buffers don't appear in state_dict, so load_state_dict cannot restore them. If the buffer is zero-initialized during meta-device materialization (rather than running __init__ again after materialization), the inv_freq from __init__ is lost and replaced with zeros.
Both fit an "apply is running in the wrong order / on a materialized module tree" pattern.
Workaround
Until fixed, applying both of the following immediately after from_pretrained produces a model whose predictions match ground truth well and reproduce across runs:
# Fix 1: re-apply the checkpoint so layer-norm values land correctly.
sf_path = hf_hub_download("autogluon/chronos-2-small", "model.safetensors")
model.load_state_dict(load_file(sf_path), strict=True) # reports missing=[], unexpected=[]
# Fix 2: recompute inv_freq on every RoPE (it is not in state_dict because persistent=False).
for m in model.modules():
if isinstance(m, RoPE):
m.inv_freq = 1.0 / (
m.base ** (torch.arange(0, m.dim, 2, dtype=torch.int64).float() / m.dim)
).to(m.inv_freq.device)
After both fixes, an independent C implementation of the Chronos-2 forward pass agrees with this PyTorch pipeline to max|Δ| ~5×10⁻⁶ across 5 datasets × 96-step forecasts — i.e., pure BF16 accumulation noise.
Diagnosed by a layer-by-layer trace: attention scores at position 0 matched bit-exactly, positions 1+ diverged, which pinpointed RoPE as the failing step. Inspecting inv_freq directly revealed the zeroed buffer.
Happy to open a PR if the maintainers want one.
Summary
Chronos2Pipeline.from_pretrainedsilently ships a model that is not in the state of its checkpoint. Two distinct mutations happen at load:initializer_factor = 0.05(the T5-style init value) instead of the checkpoint values. Predictions become nearly constant (std ~4e-4) — I'll call this degenerate.inv_freqbuffers are zeroed, which makescos = 1, sin = 0at every position, silently disabling rotary position encoding. The model still "works" because the input patch embedding carries explicit time encoding, but attention heads that were trained with RoPE receive positional-identity keys.Neither bug raises an error.
from_pretrainedprintsLoading weights: 100%|...|and returns.Both reproduce against the official
autogluon/chronos-2-smallon HF Hub (andamazon/chronos-2) — not specific to any local checkpoint.Environment
chronos-forecasting==2.2.2torch==2.10.0,transformers==5.5.3device_map="cpu"Minimal reproducer
Functional impact
On an ETTh1-style context (last-512 → 8-step forecast), median prediction values:
from_pretrained(layer-norm broken + RoPE off)load_state_dict(strict=True)(RoPE still off)The compounded effect is large enough that predictions without the fixes are visibly inferior on real-world benchmarks (ETTh1/ETTh2/ETTm1/Electricity).
Suspected root cause
Chronos2Model._init_weightsunconditionally fillsChronos2LayerNorm.weight.datawithinitializer_factor * 1.0. Iffrom_pretrainedruns_init_weightson modules after materializing them from the checkpoint (contrary to HF's usual "init first, then load" order), that explains the overwrite. This could be tied to the meta-device init path.RoPE.__init__computesinv_freqviatorch.arange(...).float() / dimand registers it withpersistent=False. Non-persistent buffers don't appear instate_dict, soload_state_dictcannot restore them. If the buffer is zero-initialized during meta-device materialization (rather than running__init__again after materialization), theinv_freqfrom__init__is lost and replaced with zeros.Both fit an "apply is running in the wrong order / on a materialized module tree" pattern.
Workaround
Until fixed, applying both of the following immediately after
from_pretrainedproduces a model whose predictions match ground truth well and reproduce across runs:After both fixes, an independent C implementation of the Chronos-2 forward pass agrees with this PyTorch pipeline to
max|Δ| ~5×10⁻⁶across 5 datasets × 96-step forecasts — i.e., pure BF16 accumulation noise.Diagnosed by a layer-by-layer trace: attention scores at position 0 matched bit-exactly, positions 1+ diverged, which pinpointed RoPE as the failing step. Inspecting
inv_freqdirectly revealed the zeroed buffer.Happy to open a PR if the maintainers want one.