-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsampling.py
65 lines (48 loc) · 1.94 KB
/
sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import logging
import torch
import tiktoken
log = logging.getLogger(__name__)
TOKENIZER = tiktoken.get_encoding("r50k_base")
def sample(model, ddp, prompt, tParams):
'''
Sample single 'prompt' multiple times.
'''
model.eval()
_sample(model, ddp, prompt, tParams)
model.train()
def multi_sample(model, ddp, prompts, tParams):
'''
Sample multiple 'prompts' multiple times.
'''
model.eval()
prompt_shard = [val for val in prompts[ddp.local_rank::ddp.world_size]]
for prompt in prompt_shard:
_sample(model, ddp, prompt, tParams)
model.train()
def _sample(model, ddp, prompt, tParams):
'''
Sample 'model' to text-complete 'prompt'. This will be done 'batch_size' number
of times. Not going to search for EOT token, just stop at 'sampling_tokens' number of tokens.
Using top-k sampling.
#TODO: Add temperature.
#TODO: Add top-p sampling following top-k.
'''
sampling_tokens = tParams.sampling_tokens
batch_size = tParams.sampling_batch
top_k = tParams.sampling_top_k
input_ids = TOKENIZER.encode(prompt)
sequence = torch.tensor([input_ids] * batch_size, device=ddp.assigned_device)
with torch.no_grad():
for _ in range(sampling_tokens):
with torch.autocast(device_type=ddp.device_type, dtype=torch.bfloat16):
last_tokens_logits = model(sequence)[0][:, -1, :]
# Sample only from top-k probabilities
top_k_logits, top_k_indices = torch.topk(last_tokens_logits, top_k, dim=-1)
probs = torch.nn.functional.softmax(top_k_logits, dim=-1)
sampled_indices = torch.multinomial(probs, num_samples=1)
next_tokens = top_k_indices.gather(-1, sampled_indices)
sequence = torch.cat([sequence, next_tokens], dim=-1)
# Decode and log generated text
for output in sequence:
decoded_text = TOKENIZER.decode(output.tolist())
log.info(decoded_text)