Skip to content

Commit 8354eba

Browse files
authored
Support Llama-3.1-405B (#199)
* Support Llama 3.1 405B * Update readme
1 parent 61c193d commit 8354eba

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ mistralai/Mistral-7B-v0.1
7373
mistralai/Mistral-7B-Instruct-v0.1
7474
mistralai/Mistral-7B-Instruct-v0.2
7575
meta-llama/Meta-Llama-3-8B
76+
meta-llama/Meta-Llama-3.1-405B
7677
```
7778

7879
For example, to convert Llama-2-7b-chat-hf
@@ -120,6 +121,7 @@ Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh
120121
| Llama-2-70B | Base | 62.50 | 1135.29 |
121122
| | 8-bit | 80.44 | 752.04 |
122123
| | 4-bit (G=32) | 90.77 | 548.10 |
124+
| Llama-3.1-405B | 8-bit | 15.60 | 815.87 |
123125

124126
### AMD
125127
Benchmarks run on one GCD of a MI-250x.

model.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import math
67
from dataclasses import dataclass
78
from typing import Optional
89

@@ -29,6 +30,7 @@ class ModelArgs:
2930
head_dim: int = 64
3031
rope_base: float = 10000
3132
norm_eps: float = 1e-5
33+
rope_scaling: Optional[dict] = None
3234

3335
def __post_init__(self):
3436
if self.n_local_heads == -1:
@@ -68,6 +70,9 @@ def from_name(cls, name: str):
6870

6971
"llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
7072
"llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000),
73+
"llama-3.1-405b": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000,
74+
rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192),
75+
),
7176
}
7277

7378
class KVCache(nn.Module):
@@ -119,7 +124,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
119124
for b in self.layers:
120125
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
121126

122-
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
127+
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype, self.config.rope_scaling)
123128
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
124129

125130
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
@@ -230,11 +235,36 @@ def forward(self, x: Tensor) -> Tensor:
230235
return output * self.weight
231236

232237

238+
def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Optional[dict] = None):
239+
factor = rope_scaling["factor"]
240+
low_freq_factor = rope_scaling["low_freq_factor"]
241+
high_freq_factor = rope_scaling["high_freq_factor"]
242+
old_context_len = rope_scaling["original_max_position_embeddings"]
243+
244+
low_freq_wavelen = old_context_len / low_freq_factor
245+
high_freq_wavelen = old_context_len / high_freq_factor
246+
new_freqs = []
247+
for freq in freqs:
248+
wavelen = 2 * math.pi / freq
249+
if wavelen < high_freq_wavelen:
250+
new_freqs.append(freq)
251+
elif wavelen > low_freq_wavelen:
252+
new_freqs.append(freq / factor)
253+
else:
254+
assert low_freq_wavelen != high_freq_wavelen
255+
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
256+
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
257+
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
258+
259+
233260
def precompute_freqs_cis(
234261
seq_len: int, n_elem: int, base: int = 10000,
235-
dtype: torch.dtype = torch.bfloat16
262+
dtype: torch.dtype = torch.bfloat16,
263+
rope_scaling: Optional[dict] = None,
236264
) -> Tensor:
237265
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
266+
if rope_scaling is not None:
267+
freqs = apply_rope_scaling(freqs, rope_scaling)
238268
t = torch.arange(seq_len, device=freqs.device)
239269
freqs = torch.outer(t, freqs)
240270
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

scripts/convert_hf_checkpoint.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,10 @@ def permute(w, n_head):
116116
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
117117
torch.save(final_result, checkpoint_dir / "model.pth")
118118
if 'llama-3' in model_name.lower():
119-
original_dir = checkpoint_dir / "original"
119+
if 'llama-3.1' in model_name.lower():
120+
original_dir = checkpoint_dir / "original" / "mp16"
121+
else:
122+
original_dir = checkpoint_dir / "original"
120123
tokenizer_model = original_dir / "tokenizer.model"
121124
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
122125
print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")

0 commit comments

Comments
 (0)