Skip to content

About the query_mask #182

@bjuthjliu

Description

@bjuthjliu

Source Code:

    padding_num = -2 ** 32 + 1
    if type in ("k", "key", "keys"):
        key_masks = tf.to_float(key_masks)
        key_masks = tf.tile(key_masks, [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen)
        key_masks = tf.expand_dims(key_masks, 1)  # (h*N, 1, seqlen)
        outputs = inputs + key_masks * padding_num

I think the outputs should be:

    padding_num = -2 ** 32 + 1
    if type in ("k", "key", "keys"):
        key_masks = tf.to_float(key_masks) # (N, T_k)
        key_masks = tf.tile(key_masks, [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen)
        key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(key_masks)[1], 1]) # (h*N, T_q, seqlen)
        paddings = tf.ones_like(key_masks) * padding_num
        outputs = tf.where(tf.equal(key_masks, 0), paddings, inputs)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions