Skip to content

Commit 8b91ce6

Browse files
committed
fix seq_lengths
1 parent 9ba2eac commit 8b91ce6

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

generate.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def causal_mask(b, h, q, kv):
6767

6868
def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
6969
# input_pos: [B, S]
70-
mask = create_block_mask(causal_mask, 1, 1, input_pos.shape[0], model.max_seq_length, device="cuda")
70+
mask = create_block_mask(causal_mask, 1, 1, input_pos.shape[0], model.max_seq_length, device=x.device)
7171
logits = model(mask, x, input_pos)
7272
return sample(logits, **sampling_kwargs)[0]
7373

@@ -77,11 +77,12 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
7777
block_index = input_pos // block_mask.BLOCK_SIZE[0]
7878
mask = block_mask[:, :, block_index]
7979
mask.mask_mod = block_mask.mask_mod
80+
mask.seq_lengths = (1, model.max_seq_length)
8081
logits = model(mask, x, input_pos)
8182
return sample(logits, **sampling_kwargs)
8283

8384
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
84-
block_mask = create_block_mask(causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device="cuda")
85+
block_mask = create_block_mask(causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device=cur_token.device)
8586
new_tokens, new_probs = [], []
8687
for i in range(num_new_tokens):
8788
next_token, next_prob = decode_one_token(

0 commit comments

Comments
 (0)