diff --git a/generate.py b/generate.py index 8446d115..5b63939d 100644 --- a/generate.py +++ b/generate.py @@ -57,12 +57,12 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): return idx_next, probs def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: - # input_pos: [B, S] + # input_pos: [S] logits = model(x, input_pos) return sample(logits, **sampling_kwargs)[0] def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [B, 1] + # input_pos: [1] assert input_pos.shape[-1] == 1 logits = model(x, input_pos) return sample(logits, **sampling_kwargs)