Skip to content

[Feature]Implement Llama 3.1 Frequency-Aware RoPE Scaling #392

@rubik-hua

Description

@rubik-hua

Feature: Implement Llama 3.1 Frequency-Aware RoPE Scaling

Background & Motivation

With the release of Llama 3.1, the context length has been extended to 128k. To maintain model performance over long contexts without degrading local attention, Meta introduced a frequency-aware smooth interpolation mechanism, departing from traditional linear scaling.
When loading Llama 3.1 pretrained weights, the config.json specifies the following rope_scaling configuration:

"rope_scaling": {
  "factor": 8.0,
  "low_freq_factor": 1.0,
  "high_freq_factor": 4.0,
  "original_max_position_embeddings": 8192,
  "rope_type": "llama3"
}

Currently, InfiniLM lacks native support for rope_type: "llama3", resulting in an "Unsupported rope_scaling type" exception when attempting to load Llama 3.1 models. We need to implement this logic to enable proper inference for the Llama 3.1 family.

Principle of Llama 3 RoPE Scaling

In standard RoPE, position encoding is inherently tied to wavelengths: high-frequency components (short wavelengths) capture local relative positions, while low-frequency components (long wavelengths) capture global absolute positions. Applying a uniform scaling factor (e.g., linear scaling with factor=8) across all frequencies severely damages high-frequency information, impairing the model's ability to distinguish adjacent tokens.
Llama 3 addresses this by applying piece-wise scaling based on wavelength:

  1. High-frequency band (wavelen < high_freq_wavelen): No scaling is applied. The original frequency is preserved to maintain local attention precision.
  2. Low-frequency band (wavelen > low_freq_wavelen): Full scaling is applied. The frequency is divided by factor to extrapolate the context window.
  3. Mid-frequency band (in between): A smooth linear interpolation is applied to avoid hard discontinuities at the scaling boundaries.

Reference Implementations

SGLang Implementation

SGLang implements this logic intuitively using torch.where:

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
    inv_freqs = super()._compute_inv_freq(base)
    low_freq_wavelen = self.orig_max_position / self.low_freq_factor
    high_freq_wavelen = self.orig_max_position / self.high_freq_factor
    wave_len = 2 * math.pi / inv_freqs
    if self.low_freq_factor != self.high_freq_factor:
        smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / (
            self.high_freq_factor - self.low_freq_factor
        )
    else:
        smooth = 0
    new_freqs = torch.where(
        wave_len < high_freq_wavelen,
        inv_freqs,
        torch.where(
            wave_len > low_freq_wavelen,
            inv_freqs / self.scaling_factor,
            (1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs,
        ),
    )
    return new_freqs

HuggingFace Implementation

HuggingFace's implementation is mathematically equivalent to SGLang's, also utilizing torch.where for the three-band frequency computation.

Implementation Strategy for InfiniLM

To maximize the reuse of existing RoPE kernels (avoiding kernel modifications), we propose pre-computing the scaling factors during the config parsing phase and mapping them elegantly to the existing LongRopeConfig structure.
LongRopeConfig interprets the provided factor as a wavelength multiplier (i.e., new_freq = inv_freq / factor). Based on this, we can derive the required smooth_factor for Llama 3's three frequency bands:

  1. High-frequency branch: No scaling $\rightarrow$ freq_scale = 1.0 $\rightarrow$ smooth_factor = 1.0 / freq_scale = 1.0
  2. Low-frequency branch: Full scaling $\rightarrow$ freq_scale = 1.0 / factor $\rightarrow$ smooth_factor = 1.0 / freq_scale = factor
  3. Mid-frequency branch:
    From the interpolation formula: new_freq = (1 - smooth) * inv_freq / factor + smooth * inv_freq, we can extract the frequency scale:
    freq_scale = (1 - smooth) / factor + smooth
    Therefore, the required smooth factor to pass to LongRopeConfig is the inverse:
    smooth_factor = 1.0 / ((1 - smooth) / factor + smooth)

C++ Implementation Key Points

  1. Full Double-Precision Computation: Llama 3's rope_theta is typically very large (e.g., 500000.0). Intermediate computations involving pow and division are highly susceptible to precision truncation if calculated using float, leading to frequency curve misalignment with PyTorch. The entire computation loop must use double, only casting to float at the final step when storing into vector<float>.
  2. Bypassing Amplitude Scaling: LongRopeConfig inherently applies an amplitude scaling penalty of sqrt(log(...)) for long sequences, which Llama 3 does not use. To bypass this, the outer factor parameter in the LongRopeConfig constructor must be explicitly set to 1.0f.
  3. Uniform Short/Long Factors: Unlike native LongRoPE models, Llama 3 does not use separate scaling factors for short and long sequences. Thus, passing the identical smooth_factors vector for both short and long factors in LongRopeConfig is sufficient.

Acceptance Criteria

  • Correctly parse and validate model configs containing rope_type: "llama3" and its required parameters (factor, low_freq_factor, high_freq_factor, original_max_position_embeddings).
  • The pre-computed smooth_factors logic is mathematically strictly equivalent to the SGLang/HuggingFace implementations.
  • Inference of Llama 3.1 (8B / 70B) models produces coherent outputs without gibberish, aligning with SGLang inference results.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions