Skip to content

Issues: Dynamic Batching #250

Open
Open
@srush

Description

@srush

(Posting a couple issues to get features upstreamed from OpenNMT-py, cc @da03)

For the transformer model, we need some improvements to dynamic batching. In particular the batch_size_fn interface has some issues. Let me give an example.

We need batches of 4096 tokens (including padding).

  • We can't really do this with with batch_size_fn because while it lets us count the total number of tokens, it doesn't let us account for padding (max size in the batch). One bad example either causes tons of padding, or a huge batch and an OOM.

Our current terrible hack:

        global max_src_in_batch, max_tgt_in_batch

        def batch_size_fn(new, count, sofar):
            global max_src_in_batch, max_tgt_in_batch
            if count == 1:
                max_src_in_batch = 0
                max_tgt_in_batch = 0
            max_src_in_batch = max(max_src_in_batch,  len(new.src) + 2)
            max_tgt_in_batch = max(max_tgt_in_batch,  len(new.tgt) + 1)
            src_elements = count * max_src_in_batch
            tgt_elements = count * max_tgt_in_batch
            return max(src_elements, tgt_elements)
  • Iterator uses this line to buffer data for batching:

for p in batch(data, batch_size * 100, batch_size_fn): https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L271

Unfortunately even though there is a 100 here, it doesn't help because if we are counting padding in batch_size_fn then on long example will make every other sentence take a ton of space. Think we need control.

Our current hack (don't use batch_size_fn for buffering):

for p in torchtext.data.batch(data, self.batch_size * 100):

  • Minor: Batching use sort for two different purposes. One to find the batches themselves, and the other for the order in which the batch is created. I would like to be able to have a batch_construction_sort to find sentences of the same length and then an batch_sort for each in batch. For example: in MT I would like to sort by a weighted src x tgt len in batch_construction (to minimize padding), but then have the batch itself sorted by src len to make cudnn work.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions