Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit a13f5a3

Browse files
committed
Generation should stop after two new lines if that is the stop criteria
Summary: This addresses Issue 642. When the stop token is \n\n the generation should stop after generation two new lines.
1 parent 2c8fbd9 commit a13f5a3

File tree

3 files changed

+38
-4
lines changed

3 files changed

+38
-4
lines changed

metaseq/hub_utils.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ def decode(self, sentence: str) -> str:
8282
return self.tokenizer.decode(sentence)
8383

8484

85+
class RecurringPunctuation(object):
86+
"""Class for groping tokens of similar type. For example \n and \n\n"""
87+
88+
def __init__(self, single_token, multiple_token):
89+
super().__init__()
90+
self.single_token = single_token
91+
self.multiple_token = multiple_token
92+
93+
8594
class GeneratorInterface:
8695
"""
8796
PyTorch Hub interface for generating sequences from a pre-trained
@@ -323,14 +332,21 @@ def generate(
323332
self.cfg.generation,
324333
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
325334
)
326-
327335
# okay actually generate
328336
logger.info(f"Executing generation on input tensor size {src_tokens.shape}")
329337
if use_cuda:
330338
batch = utils.move_to_cuda(batch)
331339

332340
translate_start_time = time.time()
333-
translations = self.task.inference_step(generator, self.models, batch)
341+
recurring_punctuation = RecurringPunctuation(
342+
self.bpe.bpe.encode("\n").ids[0], self.bpe.bpe.encode("\n\n").ids[0]
343+
)
344+
translations = self.task.inference_step(
345+
generator,
346+
self.models,
347+
batch,
348+
recurring_punctuation=recurring_punctuation,
349+
)
334350
translate_time = time.time() - translate_start_time
335351
total_generation_time += translate_time
336352

metaseq/sequence_generator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from metaseq import utils
2020
from metaseq.data import data_utils
2121
from metaseq.models import BaseDecoder
22+
from metaseq.metaseq.hub_utils import RecurringPunctuation
2223

2324
logger = logging.getLogger(__name__)
2425

@@ -130,6 +131,7 @@ def forward(
130131
sample: Dict[str, Dict[str, Tensor]],
131132
prefix_tokens: Optional[Tensor] = None,
132133
bos_token: Optional[int] = None,
134+
recurring_punctuation: Optional[RecurringPunctuation] = None,
133135
):
134136
"""Generate a batch of translations."""
135137
return self._generate(sample, prefix_tokens, bos_token=bos_token)
@@ -144,6 +146,7 @@ def _generate(
144146
sample: Dict[str, Dict[str, Tensor]],
145147
prefix_tokens: Optional[Tensor] = None,
146148
bos_token: Optional[int] = None,
149+
recurring_punctuation: Optional[RecurringPunctuation] = None,
147150
):
148151
"""
149152
Args:
@@ -268,6 +271,7 @@ def _generate(
268271

269272
eos_mask = torch.zeros(lprobs.size(0), dtype=torch.bool, device=lprobs.device)
270273

274+
prev_token = None
271275
for step in range(start_step, max_len):
272276
if step < min_len:
273277
# minimum length constraint (does not apply if using prefix_tokens)
@@ -303,13 +307,20 @@ def _generate(
303307
all_lprobs[:, step] = lprobs
304308

305309
eos_mask |= next_toks == self.eos
310+
306311
for stop_token in self.stop:
307312
# if there are other early stopping tokens, allow those to trigger stop
308313
eos_mask |= next_toks == stop_token
314+
eos_mask |= (
315+
recurring_punctuation
316+
and recurring_punctuation.multiple_token == stop_token
317+
and recurring_punctuation.single_token == next_toks == prev_token
318+
)
309319

310320
if torch.all(eos_mask):
311321
break
312322

323+
prev_token = next_toks
313324
# forward through the next pass
314325
model_out = self.model.decoder(
315326
tokens[:, : step + 1],

metaseq/tasks/base_task.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,9 +420,16 @@ def build_dataset_for_inference(
420420
) -> torch.utils.data.Dataset:
421421
raise NotImplementedError
422422

423-
def inference_step(self, generator, models, sample, prefix_tokens=None):
423+
def inference_step(
424+
self, generator, models, sample, prefix_tokens=None, recurring_punctuation=None
425+
):
424426
with torch.no_grad():
425-
return generator.generate(models, sample, prefix_tokens=prefix_tokens)
427+
return generator.generate(
428+
models,
429+
sample,
430+
prefix_tokens=prefix_tokens,
431+
recurring_punctuation=recurring_punctuation,
432+
)
426433

427434
def begin_epoch(self, epoch, model):
428435
"""Hook function called before the start of each epoch."""

0 commit comments

Comments
 (0)