Hi. Maybe it's me understanding it incorrectly. In code line 178~180 from run_pplm.py, where a window mask for choosing only a recent past of the hidden states to update is constructed:
window_mask = torch.cat(
(ones_mask, torch.zeros(zeros_key_val_shape)),
dim=-2
Should we actually concatenate in the order of (zeros; ones) instead since we aim to mask out the recent latents rather than the very beginning?
Any response to this would be greatly appreciated!
Hi. Maybe it's me understanding it incorrectly. In code line 178~180 from run_pplm.py, where a window mask for choosing only a recent past of the hidden states to update is constructed:
Should we actually concatenate in the order of (zeros; ones) instead since we aim to mask out the recent latents rather than the very beginning?
Any response to this would be greatly appreciated!