|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import torch |
| 4 | +import torch.distributed as dist |
| 5 | +import torch.nn as nn |
| 6 | +from tokenizers import Tokenizer, ByteLevelBPETokenizer |
| 7 | +from typing import Any, List, Optional |
| 8 | + |
| 9 | +try: |
| 10 | + torch.classes.load_library(os.environ.get("FT_PATH")) |
| 11 | +except Exception: |
| 12 | + raise ImportError( |
| 13 | + "Please install FasterTransformer and provide a path to the binary" |
| 14 | + "`libth_transformer.so` via the environment variable `FT_PATH`." |
| 15 | + ) |
| 16 | + |
| 17 | +model = None |
| 18 | +tokenizer = None |
| 19 | +device = None |
| 20 | + |
| 21 | +BOS_TOKEN = 0 |
| 22 | +PAD_TOKEN = 1 |
| 23 | +EOS_TOKEN = 2 |
| 24 | +UNK_TOKEN = 3 |
| 25 | + |
| 26 | + |
| 27 | +@torch.inference_mode() |
| 28 | +def generate( |
| 29 | + inputs: List[List[int]], |
| 30 | + output_length: int, |
| 31 | + beam_width: int = 1, |
| 32 | + top_k: Optional[int] = 0, |
| 33 | + top_p: Optional[float] = 1.0, |
| 34 | + diversity_rate: Optional[float] = None, |
| 35 | + temperature: Optional[float] = 1.0, |
| 36 | + len_penalty: Optional[float] = None, |
| 37 | + repetition_penalty: Optional[float] = 1.0, |
| 38 | + presence_penalty: Optional[float] = None, |
| 39 | + random_seed: Optional[int] = 0, |
| 40 | + min_length: Optional[int] = None, |
| 41 | + bad_words_list: Optional[torch.Tensor] = None, |
| 42 | + return_cum_log_probs: Optional[int] = 0, |
| 43 | +) -> List[Any]: |
| 44 | + inputs = [[EOS_TOKEN] + toks for toks in inputs] |
| 45 | + inputs = [torch.tensor(toks, dtype=torch.int32, device=device) for toks in inputs] |
| 46 | + lengths = torch.tensor([len(t) for t in inputs], dtype=torch.int32, device=device) |
| 47 | + inputs = nn.utils.rnn.pad_sequence(inputs, True, padding_value=PAD_TOKEN) |
| 48 | + |
| 49 | + if top_k is not None: |
| 50 | + top_k = torch.tensor([top_k], dtype=torch.int32) |
| 51 | + if top_p is not None: |
| 52 | + top_p = torch.tensor([top_p], dtype=torch.float32) |
| 53 | + if diversity_rate is not None: |
| 54 | + diversity_rate = torch.tensor([diversity_rate], dtype=torch.float32) |
| 55 | + if temperature is not None: |
| 56 | + temperature = torch.tensor([temperature], dtype=torch.float32) |
| 57 | + if len_penalty is not None: |
| 58 | + len_penalty = torch.tensor([len_penalty], dtype=torch.float32) |
| 59 | + if repetition_penalty is not None: |
| 60 | + repetition_penalty = torch.tensor([repetition_penalty], dtype=torch.float32) |
| 61 | + if presence_penalty is not None: |
| 62 | + presence_penalty = torch.tensor([presence_penalty], dtype=torch.float32) |
| 63 | + if random_seed is not None: |
| 64 | + random_seed = torch.tensor([random_seed], dtype=torch.int64) |
| 65 | + if min_length is not None: |
| 66 | + min_length = torch.tensor([min_length], dtype=torch.int64) |
| 67 | + |
| 68 | + outputs, output_lengths = model.forward( |
| 69 | + inputs, |
| 70 | + lengths, |
| 71 | + output_length, |
| 72 | + beam_width, |
| 73 | + top_k, |
| 74 | + top_p, |
| 75 | + diversity_rate, |
| 76 | + temperature, |
| 77 | + len_penalty, |
| 78 | + repetition_penalty, |
| 79 | + presence_penalty, |
| 80 | + min_length, |
| 81 | + random_seed, |
| 82 | + bad_words_list, |
| 83 | + return_cum_log_probs, |
| 84 | + ) |
| 85 | + |
| 86 | + results = [] |
| 87 | + beam_idx = 0 |
| 88 | + special = outputs.new_tensor([BOS_TOKEN, PAD_TOKEN, EOS_TOKEN, UNK_TOKEN]) |
| 89 | + for output, output_len in zip(outputs, output_lengths): |
| 90 | + mask = ~torch.isin(output[beam_idx], special) |
| 91 | + mask[1:] = mask[1:].cummin(dim=0)[0] |
| 92 | + |
| 93 | + tokens = output[beam_idx][1 : output_len[beam_idx]] |
| 94 | + tokens = tokens[mask[1 : output_len[beam_idx]]] |
| 95 | + results.append({"text": tokenizer.decode(tokens.tolist())}) |
| 96 | + return [results] |
| 97 | + |
| 98 | + |
| 99 | +def main(args: argparse.Namespace) -> None: |
| 100 | + global model, tokenizer, device |
| 101 | + dist.init_process_group(backend="mpi") |
| 102 | + world_size = dist.get_world_size() |
| 103 | + rank = dist.get_rank() % world_size |
| 104 | + device = torch.device(f"cuda:{dist.get_rank() % torch.cuda.device_count()}") |
| 105 | + torch.cuda.set_device(device) |
| 106 | + |
| 107 | + if args.tokenizer_file is not None: |
| 108 | + tokenizer = Tokenizer.from_file(args.tokenizer_file) |
| 109 | + else: |
| 110 | + tokenizer = ByteLevelBPETokenizer(args.vocab_file, args.merges_file) |
| 111 | + |
| 112 | + torch_dtypes = {"fp16": torch.half, "bf16": torch.bfloat16, "fp32": torch.float} |
| 113 | + dtype = torch_dtypes[args.dtype] |
| 114 | + |
| 115 | + state_dict = torch.load(f"{args.weight_path}/part-{rank}.pt") |
| 116 | + weights = [w.to(device, dtype) for w in state_dict["weights"]] |
| 117 | + int8_weights, int8_scales = [], [] |
| 118 | + if args.int8_mode != 0 and {"int8_weights", "int8_scales"} <= state_dict.keys(): |
| 119 | + int8_weights = [w.to(device=device) for w in state_dict["int8_weights"]] |
| 120 | + int8_scales = [w.to(device=device) for w in state_dict["int8_scales"]] |
| 121 | + |
| 122 | + kwargs = { |
| 123 | + "head_num": args.num_heads, |
| 124 | + "size_per_head": args.embed_size // args.num_heads, |
| 125 | + "inter_size": 4 * args.embed_size, |
| 126 | + "layer_num": args.num_layers, |
| 127 | + "expert_num": 0, |
| 128 | + "moe_k": 0, |
| 129 | + "moe_layer_index": [], |
| 130 | + "vocab_size": args.vocab_size, |
| 131 | + "start_id": 2, |
| 132 | + "end_id": 2, |
| 133 | + "tensor_para_size": world_size, |
| 134 | + "pipeline_para_size": 1, |
| 135 | + "int8_mode": args.int8_mode, |
| 136 | + "layernorm_eps": 1e-5, |
| 137 | + "layernorm_type": "pre_layernorm", |
| 138 | + "activation_type": "Relu", |
| 139 | + "has_positional_encoding": True, |
| 140 | + "has_pre_decoder_layernorm": False, |
| 141 | + "has_post_decoder_layernorm": True, |
| 142 | + "has_adapters": False, |
| 143 | + "adapter_inter_size": 0, |
| 144 | + "use_attention_linear_bias": False, |
| 145 | + "weights": weights, |
| 146 | + "int8_weights": int8_weights, |
| 147 | + "scale": int8_scales, |
| 148 | + "shared_contexts_ratio": 1.0, |
| 149 | + } |
| 150 | + model = torch.classes.FasterTransformer.ParallelGptOp(*kwargs.values()) |
| 151 | + |
| 152 | + object = [None] |
| 153 | + while True: |
| 154 | + if torch.distributed.get_rank() == 0: |
| 155 | + prompt = input("\033[32mPrompt: \033[0;1m").rstrip() |
| 156 | + if not prompt: |
| 157 | + continue |
| 158 | + object = [[tokenizer.encode(prompt).ids]] |
| 159 | + |
| 160 | + dist.broadcast_object_list(object, src=0) |
| 161 | + output = generate( |
| 162 | + object[0], |
| 163 | + output_length=args.output_length, |
| 164 | + beam_width=args.beam_width, |
| 165 | + top_k=args.top_k, |
| 166 | + top_p=args.top_p, |
| 167 | + diversity_rate=args.diversity_rate, |
| 168 | + temperature=args.temperature, |
| 169 | + len_penalty=args.len_penalty, |
| 170 | + repetition_penalty=args.repetition_penalty, |
| 171 | + random_seed=0, |
| 172 | + ) |
| 173 | + if torch.distributed.get_rank() == 0: |
| 174 | + print(f"Output: {output[0][0]['text']}") |
| 175 | + |
| 176 | + |
| 177 | +def measure_time(func, *args, **kwargs): |
| 178 | + start = torch.cuda.Event(enable_timing=True) |
| 179 | + end = torch.cuda.Event(enable_timing=True) |
| 180 | + start.record() |
| 181 | + func(*args, **kwargs) |
| 182 | + end.record() |
| 183 | + torch.cuda.synchronize() |
| 184 | + return start.elapsed_time(end) |
| 185 | + |
| 186 | + |
| 187 | +def get_args() -> argparse.Namespace: |
| 188 | + parser = argparse.ArgumentParser() |
| 189 | + parser.add_argument("--num-layers", type=int, default=12) |
| 190 | + parser.add_argument("--num-heads", type=int, default=12) |
| 191 | + parser.add_argument("--embed-size", type=int, default=768) |
| 192 | + parser.add_argument("--vocab-size", type=int, default=50272) |
| 193 | + |
| 194 | + parser.add_argument("--vocab-file", type=str) |
| 195 | + parser.add_argument("--merges-file", type=str) |
| 196 | + parser.add_argument("--tokenizer-file", type=str, default=None) |
| 197 | + parser.add_argument("--weight-path", type=str) |
| 198 | + parser.add_argument("--dtype", choices=["fp32", "fp16", "bf16"], default="fp16") |
| 199 | + parser.add_argument("--int8-mode", type=int, default=0) |
| 200 | + |
| 201 | + parser.add_argument("--batch-size", type=int, default=1) |
| 202 | + parser.add_argument("--output-length", type=int, default=256) |
| 203 | + parser.add_argument("--beam-width", type=int, default=1) |
| 204 | + parser.add_argument("--top-k", type=int, default=20) |
| 205 | + parser.add_argument("--top-p", type=float, default=0.95) |
| 206 | + parser.add_argument("--temperature", type=float, default=0.7) |
| 207 | + parser.add_argument("--len-penalty", type=float, default=0.0) |
| 208 | + parser.add_argument("--diversity-rate", type=float, default=0.0) |
| 209 | + parser.add_argument("--repetition-penalty", type=float, default=1.2) |
| 210 | + return parser.parse_args() |
| 211 | + |
| 212 | + |
| 213 | +if __name__ == "__main__": |
| 214 | + args = get_args() |
| 215 | + main(args) |
0 commit comments