Skip to content

Commit c9f683e

Browse files
authored
Add support for llama 3.1 8B/70B (#200)
* Add support for llama 3.1 8B/70B * Update 4 GPU perf numbers
1 parent 8354eba commit c9f683e

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

README.md

+16-6
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ mistralai/Mistral-7B-v0.1
7373
mistralai/Mistral-7B-Instruct-v0.1
7474
mistralai/Mistral-7B-Instruct-v0.2
7575
meta-llama/Meta-Llama-3-8B
76+
meta-llama/Meta-Llama-3.1-8B
77+
meta-llama/Meta-Llama-3.1-70B
7678
meta-llama/Meta-Llama-3.1-405B
7779
```
7880

@@ -93,8 +95,10 @@ Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh
9395
| Llama-2-70B | Base | OOM ||
9496
| | 8-bit | 19.13 | 1322.58 |
9597
| | 4-bit (G=32) | 25.25 | 1097.66 |
96-
| Llama-3-8B | Base | 94.25 | 1411.95 |
97-
| | 8-bit | 139.55 | 1047.23 |
98+
| Llama-3.1-8B | Base | 93.89 | 1410.76 |
99+
| | 8-bit | 137.64 | 1030.89 |
100+
| Llama-3.1-70B | Base | OOM ||
101+
| | 8-bit | 18.04 | 1253.78 |
98102

99103
### Speculative Sampling
100104
[Verifier: Llama-70B (int4), Draft: Llama-7B (int4)](./scripts/speculate_70B_int4.sh): 48.4 tok/s
@@ -110,17 +114,23 @@ Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh
110114
| | 2 | 21.32 | 1481.87 |
111115
| | 4 | 38.01 | 1340.76 |
112116
| | 8 | 62.50 | 1135.29 |
113-
| Llama-3-8B | 1 | 94.19 | 1411.76 |
114-
| | 2 | 150.48 | 1208.80 |
115-
| | 4 | 219.77 | 991.63 |
116-
| | 8 | 274.65 | 768.55 |
117+
| Llama-3.1-8B | 1 | 93.83 | 1408.37 |
118+
| | 2 | 149.10 | 1197.32 |
119+
| | 4 | 217.21 | 986.32 |
120+
| | 8 | 276.01 | 772.60 |
121+
| Llama-3.1-70B | 1 | OOM | |
122+
| | 2 | 16.03 | 1130.81 |
123+
| | 4 | 37.45 | 1360.53 |
124+
| | 8 | 58.78 | 1129.61 |
117125

118126
### Tensor Parallelism + Quantization
119127
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) |
120128
| -------- | ------- | ------ | ------ |
121129
| Llama-2-70B | Base | 62.50 | 1135.29 |
122130
| | 8-bit | 80.44 | 752.04 |
123131
| | 4-bit (G=32) | 90.77 | 548.10 |
132+
| Llama-3.1-70B | Base | 58.78 | 1129.61 |
133+
| | 8-bit | 75.58 | 726.57 |
124134
| Llama-3.1-405B | 8-bit | 15.60 | 815.87 |
125135

126136
### AMD

model.py

+6
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ def from_name(cls, name: str):
7070

7171
"llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
7272
"llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000),
73+
"llama-3.1-8b": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000,
74+
rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192),
75+
),
76+
"llama-3.1-70b": dict(block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000,
77+
rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192),
78+
),
7379
"llama-3.1-405b": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000,
7480
rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192),
7581
),

scripts/convert_hf_checkpoint.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def permute(w, n_head):
116116
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
117117
torch.save(final_result, checkpoint_dir / "model.pth")
118118
if 'llama-3' in model_name.lower():
119-
if 'llama-3.1' in model_name.lower():
119+
if 'llama-3.1-405b' in model_name.lower():
120120
original_dir = checkpoint_dir / "original" / "mp16"
121121
else:
122122
original_dir = checkpoint_dir / "original"

0 commit comments

Comments
 (0)