Feature: Implement Llama 3.1 Frequency-Aware RoPE Scaling
Background & Motivation
With the release of Llama 3.1, the context length has been extended to 128k. To maintain model performance over long contexts without degrading local attention, Meta introduced a frequency-aware smooth interpolation mechanism, departing from traditional linear scaling.
When loading Llama 3.1 pretrained weights, the config.json specifies the following rope_scaling configuration:
"rope_scaling": {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
}
Currently, InfiniLM lacks native support for rope_type: "llama3", resulting in an "Unsupported rope_scaling type" exception when attempting to load Llama 3.1 models. We need to implement this logic to enable proper inference for the Llama 3.1 family.
Principle of Llama 3 RoPE Scaling
In standard RoPE, position encoding is inherently tied to wavelengths: high-frequency components (short wavelengths) capture local relative positions, while low-frequency components (long wavelengths) capture global absolute positions. Applying a uniform scaling factor (e.g., linear scaling with factor=8) across all frequencies severely damages high-frequency information, impairing the model's ability to distinguish adjacent tokens.
Llama 3 addresses this by applying piece-wise scaling based on wavelength:
- High-frequency band (
wavelen < high_freq_wavelen): No scaling is applied. The original frequency is preserved to maintain local attention precision.
- Low-frequency band (
wavelen > low_freq_wavelen): Full scaling is applied. The frequency is divided by factor to extrapolate the context window.
- Mid-frequency band (in between): A smooth linear interpolation is applied to avoid hard discontinuities at the scaling boundaries.
Reference Implementations
SGLang Implementation
SGLang implements this logic intuitively using torch.where:
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
low_freq_wavelen = self.orig_max_position / self.low_freq_factor
high_freq_wavelen = self.orig_max_position / self.high_freq_factor
wave_len = 2 * math.pi / inv_freqs
if self.low_freq_factor != self.high_freq_factor:
smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / (
self.high_freq_factor - self.low_freq_factor
)
else:
smooth = 0
new_freqs = torch.where(
wave_len < high_freq_wavelen,
inv_freqs,
torch.where(
wave_len > low_freq_wavelen,
inv_freqs / self.scaling_factor,
(1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs,
),
)
return new_freqs
HuggingFace Implementation
HuggingFace's implementation is mathematically equivalent to SGLang's, also utilizing torch.where for the three-band frequency computation.
Implementation Strategy for InfiniLM
To maximize the reuse of existing RoPE kernels (avoiding kernel modifications), we propose pre-computing the scaling factors during the config parsing phase and mapping them elegantly to the existing LongRopeConfig structure.
LongRopeConfig interprets the provided factor as a wavelength multiplier (i.e., new_freq = inv_freq / factor). Based on this, we can derive the required smooth_factor for Llama 3's three frequency bands:
-
High-frequency branch: No scaling $\rightarrow$
freq_scale = 1.0 $\rightarrow$ smooth_factor = 1.0 / freq_scale = 1.0
-
Low-frequency branch: Full scaling $\rightarrow$
freq_scale = 1.0 / factor $\rightarrow$ smooth_factor = 1.0 / freq_scale = factor
-
Mid-frequency branch:
From the interpolation formula: new_freq = (1 - smooth) * inv_freq / factor + smooth * inv_freq, we can extract the frequency scale:
freq_scale = (1 - smooth) / factor + smooth
Therefore, the required smooth factor to pass to LongRopeConfig is the inverse:
smooth_factor = 1.0 / ((1 - smooth) / factor + smooth)
C++ Implementation Key Points
- Full Double-Precision Computation: Llama 3's
rope_theta is typically very large (e.g., 500000.0). Intermediate computations involving pow and division are highly susceptible to precision truncation if calculated using float, leading to frequency curve misalignment with PyTorch. The entire computation loop must use double, only casting to float at the final step when storing into vector<float>.
- Bypassing Amplitude Scaling:
LongRopeConfig inherently applies an amplitude scaling penalty of sqrt(log(...)) for long sequences, which Llama 3 does not use. To bypass this, the outer factor parameter in the LongRopeConfig constructor must be explicitly set to 1.0f.
- Uniform Short/Long Factors: Unlike native LongRoPE models, Llama 3 does not use separate scaling factors for short and long sequences. Thus, passing the identical
smooth_factors vector for both short and long factors in LongRopeConfig is sufficient.
Acceptance Criteria
Feature: Implement Llama 3.1 Frequency-Aware RoPE Scaling
Background & Motivation
With the release of Llama 3.1, the context length has been extended to 128k. To maintain model performance over long contexts without degrading local attention, Meta introduced a frequency-aware smooth interpolation mechanism, departing from traditional linear scaling.
When loading Llama 3.1 pretrained weights, the
config.jsonspecifies the followingrope_scalingconfiguration:Currently, InfiniLM lacks native support for
rope_type: "llama3", resulting in an "Unsupported rope_scaling type" exception when attempting to load Llama 3.1 models. We need to implement this logic to enable proper inference for the Llama 3.1 family.Principle of Llama 3 RoPE Scaling
In standard RoPE, position encoding is inherently tied to wavelengths: high-frequency components (short wavelengths) capture local relative positions, while low-frequency components (long wavelengths) capture global absolute positions. Applying a uniform scaling factor (e.g., linear scaling with
factor=8) across all frequencies severely damages high-frequency information, impairing the model's ability to distinguish adjacent tokens.Llama 3 addresses this by applying piece-wise scaling based on wavelength:
wavelen < high_freq_wavelen): No scaling is applied. The original frequency is preserved to maintain local attention precision.wavelen > low_freq_wavelen): Full scaling is applied. The frequency is divided byfactorto extrapolate the context window.Reference Implementations
SGLang Implementation
SGLang implements this logic intuitively using
torch.where:HuggingFace Implementation
HuggingFace's implementation is mathematically equivalent to SGLang's, also utilizing
torch.wherefor the three-band frequency computation.Implementation Strategy for InfiniLM
To maximize the reuse of existing RoPE kernels (avoiding kernel modifications), we propose pre-computing the scaling factors during the config parsing phase and mapping them elegantly to the existing
LongRopeConfigstructure.LongRopeConfiginterprets the providedfactoras a wavelength multiplier (i.e.,new_freq = inv_freq / factor). Based on this, we can derive the requiredsmooth_factorfor Llama 3's three frequency bands:freq_scale = 1.0smooth_factor = 1.0 / freq_scale = 1.0freq_scale = 1.0 / factorsmooth_factor = 1.0 / freq_scale = factorFrom the interpolation formula:
new_freq = (1 - smooth) * inv_freq / factor + smooth * inv_freq, we can extract the frequency scale:freq_scale = (1 - smooth) / factor + smoothTherefore, the required smooth factor to pass to
LongRopeConfigis the inverse:smooth_factor = 1.0 / ((1 - smooth) / factor + smooth)C++ Implementation Key Points
rope_thetais typically very large (e.g., 500000.0). Intermediate computations involvingpowand division are highly susceptible to precision truncation if calculated usingfloat, leading to frequency curve misalignment with PyTorch. The entire computation loop must usedouble, only casting tofloatat the final step when storing intovector<float>.LongRopeConfiginherently applies an amplitude scaling penalty ofsqrt(log(...))for long sequences, which Llama 3 does not use. To bypass this, the outerfactorparameter in theLongRopeConfigconstructor must be explicitly set to1.0f.smooth_factorsvector for both short and long factors inLongRopeConfigis sufficient.Acceptance Criteria
rope_type: "llama3"and its required parameters (factor,low_freq_factor,high_freq_factor,original_max_position_embeddings).smooth_factorslogic is mathematically strictly equivalent to the SGLang/HuggingFace implementations.