From 4908690b836c2dbd2a43502c62ddb227aae90887 Mon Sep 17 00:00:00 2001 From: James Hou Date: Sun, 27 Apr 2025 23:49:03 -0700 Subject: [PATCH] added support for batch eval of llama --- tinychat/benchmark.py | 21 ++++++++++++++++----- tinychat/demo.py | 2 +- tinychat/models/llama.py | 2 ++ tinychat/modules/fused_attn.py | 2 ++ 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/tinychat/benchmark.py b/tinychat/benchmark.py index e397f10c..f0c90c27 100644 --- a/tinychat/benchmark.py +++ b/tinychat/benchmark.py @@ -48,6 +48,12 @@ def main(): parser.add_argument( "--max_batch_size", type=int, default=1, help="maximum batch size for kv cache" ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="batch size for inference", + ) parser.add_argument( "--flash_attn", action="store_true", @@ -99,6 +105,9 @@ def main(): from tinychat.models import FalconForCausalLM, LlamaForCausalLM, MPTForCausalLM from tinychat.models.vila_llama import VilaLlamaForCausalLM + assert args.batch_size == args.max_batch_size + assert (args.max_batch_size == 1) or ("llama" in args.model_type.lower()), "We only support batch eval for Llama for now" + modeling_utils._init_weights = False torch.nn.init.kaiming_uniform_ = skip torch.nn.init.kaiming_normal_ = skip @@ -266,7 +275,7 @@ def main(): # warming up input_ids = [1 for _ in range(2048)] - inputs = torch.as_tensor([input_ids], device=device) + inputs = torch.as_tensor([input_ids for _ in range(args.batch_size)], device=device) out = model( inputs, start_pos=0, @@ -286,7 +295,9 @@ def main(): start_pos = 0 torch.cuda.synchronize() t_st = time.time() - inputs = torch.as_tensor([input_ids], device=device) + inputs = torch.as_tensor( + [input_ids for _ in range(args.batch_size)], + device=device) out = model( inputs, start_pos=start_pos, @@ -296,7 +307,7 @@ def main(): start_pos += inputs.shape[1] torch.cuda.synchronize() t_ed = time.time() - token = torch.argmax(out, keepdim=True)[0] + token = torch.argmax(out, -1, keepdim=True)[:, :, 0] time_lis.append(t_ed - t_st) if args.verbose: print(i, t_ed - t_st) @@ -314,12 +325,12 @@ def main(): quant=args.precision in ["W4A16"], ) start_pos += 1 - token = torch.argmax(token, keepdim=True)[0] + token = torch.argmax(token, -1, keepdim=True)[:, :, 0] torch.cuda.synchronize() t_ed = time.time() time_lis.append(t_ed - t_st) print( - f"Decoding throughput: {token_num/sum(time_lis):.5f} token/s." + f"Decoding throughput: {token_num * args.batch_size / sum(time_lis):.5f} token/s." ) print("-" * 80) else: diff --git a/tinychat/demo.py b/tinychat/demo.py index 318b7ea5..dda4359d 100644 --- a/tinychat/demo.py +++ b/tinychat/demo.py @@ -240,7 +240,7 @@ def skip(*args, **kwargs): make_quant_attn(model, args.device) make_quant_norm(model) model( - torch.randint(0, 1000, (1, 4096), dtype=torch.int, device="cuda:0"), + torch.randint(0, 1000, (1, 2048), dtype=torch.int, device="cuda:0"), 0, quant=args.precision == "W4A16", ) diff --git a/tinychat/models/llama.py b/tinychat/models/llama.py index dff3fc11..6165635a 100644 --- a/tinychat/models/llama.py +++ b/tinychat/models/llama.py @@ -124,6 +124,8 @@ def __init__(self, args): bias=False, ) + max_batch_size = tinychat.utils.constants.max_batch_size + # following fastertransformer definition self.cache_v = ( torch.zeros( diff --git a/tinychat/modules/fused_attn.py b/tinychat/modules/fused_attn.py index 47f387dd..74d34c74 100644 --- a/tinychat/modules/fused_attn.py +++ b/tinychat/modules/fused_attn.py @@ -190,6 +190,7 @@ def __init__( self.o_proj = o_proj self.kv_max_seq_len = kv_max_seq_len + max_batch_size = tinychat.utils.constants.max_batch_size # following fastertransformer definition self.cache_v = ( @@ -350,6 +351,7 @@ def __init__( self.o_proj = o_proj self.kv_max_seq_len = kv_max_seq_len + max_batch_size = tinychat.utils.constants.max_batch_size # following fastertransformer definition # For short seqlence, we use fused kernel to accelerate decoding. if self.kv_max_seq_len <= 8192: