3
3
4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
+ import math
6
7
from dataclasses import dataclass
7
8
from typing import Optional
8
9
@@ -29,6 +30,7 @@ class ModelArgs:
29
30
head_dim : int = 64
30
31
rope_base : float = 10000
31
32
norm_eps : float = 1e-5
33
+ rope_scaling : Optional [dict ] = None
32
34
33
35
def __post_init__ (self ):
34
36
if self .n_local_heads == - 1 :
@@ -68,6 +70,9 @@ def from_name(cls, name: str):
68
70
69
71
"llama-3-8b" : dict (block_size = 8192 , n_layer = 32 , n_head = 32 , n_local_heads = 8 , dim = 4096 , intermediate_size = 14336 , vocab_size = 128256 , rope_base = 500000 ),
70
72
"llama-3-70b" : dict (block_size = 8192 , n_layer = 80 , n_head = 64 , n_local_heads = 8 , dim = 8192 , intermediate_size = 28672 , vocab_size = 128256 , rope_base = 500000 ),
73
+ "llama-3.1-405b" : dict (block_size = 131072 , n_layer = 126 , n_head = 128 , n_local_heads = 8 , dim = 16384 , intermediate_size = 53248 , vocab_size = 128256 , rope_base = 500000 ,
74
+ rope_scaling = dict (factor = 8.0 , low_freq_factor = 1.0 , high_freq_factor = 4.0 , original_max_position_embeddings = 8192 ),
75
+ ),
71
76
}
72
77
73
78
class KVCache (nn .Module ):
@@ -119,7 +124,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
119
124
for b in self .layers :
120
125
b .attention .kv_cache = KVCache (max_batch_size , max_seq_length , self .config .n_local_heads , head_dim , dtype )
121
126
122
- self .freqs_cis = precompute_freqs_cis (self .config .block_size , self .config .dim // self .config .n_head , self .config .rope_base , dtype )
127
+ self .freqs_cis = precompute_freqs_cis (self .config .block_size , self .config .dim // self .config .n_head , self .config .rope_base , dtype , self . config . rope_scaling )
123
128
self .causal_mask = torch .tril (torch .ones (self .max_seq_length , self .max_seq_length , dtype = torch .bool ))
124
129
125
130
def forward (self , idx : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
@@ -230,11 +235,36 @@ def forward(self, x: Tensor) -> Tensor:
230
235
return output * self .weight
231
236
232
237
238
+ def apply_rope_scaling (freqs : torch .Tensor , rope_scaling : Optional [dict ] = None ):
239
+ factor = rope_scaling ["factor" ]
240
+ low_freq_factor = rope_scaling ["low_freq_factor" ]
241
+ high_freq_factor = rope_scaling ["high_freq_factor" ]
242
+ old_context_len = rope_scaling ["original_max_position_embeddings" ]
243
+
244
+ low_freq_wavelen = old_context_len / low_freq_factor
245
+ high_freq_wavelen = old_context_len / high_freq_factor
246
+ new_freqs = []
247
+ for freq in freqs :
248
+ wavelen = 2 * math .pi / freq
249
+ if wavelen < high_freq_wavelen :
250
+ new_freqs .append (freq )
251
+ elif wavelen > low_freq_wavelen :
252
+ new_freqs .append (freq / factor )
253
+ else :
254
+ assert low_freq_wavelen != high_freq_wavelen
255
+ smooth = (old_context_len / wavelen - low_freq_factor ) / (high_freq_factor - low_freq_factor )
256
+ new_freqs .append ((1 - smooth ) * freq / factor + smooth * freq )
257
+ return torch .tensor (new_freqs , dtype = freqs .dtype , device = freqs .device )
258
+
259
+
233
260
def precompute_freqs_cis (
234
261
seq_len : int , n_elem : int , base : int = 10000 ,
235
- dtype : torch .dtype = torch .bfloat16
262
+ dtype : torch .dtype = torch .bfloat16 ,
263
+ rope_scaling : Optional [dict ] = None ,
236
264
) -> Tensor :
237
265
freqs = 1.0 / (base ** (torch .arange (0 , n_elem , 2 )[: (n_elem // 2 )].float () / n_elem ))
266
+ if rope_scaling is not None :
267
+ freqs = apply_rope_scaling (freqs , rope_scaling )
238
268
t = torch .arange (seq_len , device = freqs .device )
239
269
freqs = torch .outer (t , freqs )
240
270
freqs_cis = torch .polar (torch .ones_like (freqs ), freqs )
0 commit comments