Skip to content

[BUG] Chronos2Pipeline.from_pretrained silently breaks layer-norm weights and RoPE inv_freq #491

@amiracoder

Description

@amiracoder

Summary

Chronos2Pipeline.from_pretrained silently ships a model that is not in the state of its checkpoint. Two distinct mutations happen at load:

  1. Layer-norm weights are overwritten with initializer_factor = 0.05 (the T5-style init value) instead of the checkpoint values. Predictions become nearly constant (std ~4e-4) — I'll call this degenerate.
  2. RoPE inv_freq buffers are zeroed, which makes cos = 1, sin = 0 at every position, silently disabling rotary position encoding. The model still "works" because the input patch embedding carries explicit time encoding, but attention heads that were trained with RoPE receive positional-identity keys.

Neither bug raises an error. from_pretrained prints Loading weights: 100%|...| and returns.

Both reproduce against the official autogluon/chronos-2-small on HF Hub (and amazon/chronos-2) — not specific to any local checkpoint.

Environment

  • chronos-forecasting==2.2.2
  • torch==2.10.0, transformers==5.5.3
  • Python 3.12, macOS arm64 / CPU, device_map="cpu"

Minimal reproducer

import torch
from chronos import Chronos2Pipeline
from chronos.chronos2.layers import RoPE
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

pipe = Chronos2Pipeline.from_pretrained("autogluon/chronos-2-small", device_map="cpu")
model = pipe.model

# Bug 1: layer-norm weights are at initializer_factor default, not the checkpoint.
ln = model.encoder.block[0].layer[0].layer_norm.weight
print("layer_norm.weight[:3]:", ln.data[:3].tolist())
# -> [0.05, 0.05, 0.05]           (initializer_factor * 1.0)
# Expected (from checkpoint safetensors):
# -> [0.11422..., 0.42890..., 0.27481...]

# Bug 2: RoPE inv_freq buffer is zeroed.
rope = model.encoder.block[0].layer[0].self_attention.rope_embed
print("rope.inv_freq[:4]:", rope.inv_freq[:4].tolist())
# -> [0.0, 0.0, 0.0, 0.0]

# Compare to a freshly constructed RoPE with the same config:
fresh = RoPE(dim=rope.dim, base=rope.base)
print("fresh.inv_freq[:4]:", fresh.inv_freq[:4].tolist())
# -> [1.0, 0.7498942, 0.5623413, 0.4216965]   (correct)

Functional impact

On an ETTh1-style context (last-512 → 8-step forecast), median prediction values:

Load path pred std first 3 steps
Pure from_pretrained (layer-norm broken + RoPE off) 4e-4 9.346, 9.346, 9.346
+ manual load_state_dict(strict=True) (RoPE still off) 3e-2 9.117, 9.103, 9.098
+ both fixes (below) 2e-1 measured 30%+ better MAE vs ground truth on multiple datasets

The compounded effect is large enough that predictions without the fixes are visibly inferior on real-world benchmarks (ETTh1/ETTh2/ETTm1/Electricity).

Suspected root cause

  1. For the layer-norm bug, Chronos2Model._init_weights unconditionally fills Chronos2LayerNorm.weight.data with initializer_factor * 1.0. If from_pretrained runs _init_weights on modules after materializing them from the checkpoint (contrary to HF's usual "init first, then load" order), that explains the overwrite. This could be tied to the meta-device init path.
  2. For the RoPE bug, RoPE.__init__ computes inv_freq via torch.arange(...).float() / dim and registers it with persistent=False. Non-persistent buffers don't appear in state_dict, so load_state_dict cannot restore them. If the buffer is zero-initialized during meta-device materialization (rather than running __init__ again after materialization), the inv_freq from __init__ is lost and replaced with zeros.

Both fit an "apply is running in the wrong order / on a materialized module tree" pattern.

Workaround

Until fixed, applying both of the following immediately after from_pretrained produces a model whose predictions match ground truth well and reproduce across runs:

# Fix 1: re-apply the checkpoint so layer-norm values land correctly.
sf_path = hf_hub_download("autogluon/chronos-2-small", "model.safetensors")
model.load_state_dict(load_file(sf_path), strict=True)  # reports missing=[], unexpected=[]

# Fix 2: recompute inv_freq on every RoPE (it is not in state_dict because persistent=False).
for m in model.modules():
    if isinstance(m, RoPE):
        m.inv_freq = 1.0 / (
            m.base ** (torch.arange(0, m.dim, 2, dtype=torch.int64).float() / m.dim)
        ).to(m.inv_freq.device)

After both fixes, an independent C implementation of the Chronos-2 forward pass agrees with this PyTorch pipeline to max|Δ| ~5×10⁻⁶ across 5 datasets × 96-step forecasts — i.e., pure BF16 accumulation noise.

Diagnosed by a layer-by-layer trace: attention scores at position 0 matched bit-exactly, positions 1+ diverged, which pinpointed RoPE as the failing step. Inspecting inv_freq directly revealed the zeroed buffer.

Happy to open a PR if the maintainers want one.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions