forked from karpathy/llama2.c
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcompile.py
42 lines (38 loc) · 1.38 KB
/
compile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch._export
import argparse
import os
from model import ModelArgs, Transformer
def load_checkpoint(checkpoint):
# load the provided model checkpoint
checkpoint_dict = torch.load(checkpoint, map_location='cpu')
gptconf = ModelArgs(**checkpoint_dict['model_args'])
model = Transformer(gptconf)
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)
model.eval()
return model, gptconf
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"filepath", type=str, default="llama2.so", help="the output filepath"
)
parser.add_argument("--checkpoint", type=str, help="checkpoint .pt")
args = parser.parse_args()
model, config = load_checkpoint(args.checkpoint)
x = torch.randint(0, config.vocab_size, (1, config.max_seq_len // 2))
constraints = [
torch._export.dynamic_dim(x, 1),
torch._export.dynamic_dim(x, 1) <= config.max_seq_len,
torch._export.dynamic_dim(x, 1) >= 1,
]
so_path = torch._export.aot_compile(
model,
(x,),
constraints=constraints,
options={"aot_inductor.output_path": args.filepath},
)