@@ -57,7 +57,7 @@ def __init__(
5757 # TODO: Find a better solution.
5858 self ._preprocessors .append (self ._config .transformer .rotary .build (self ._tensor_space ))
5959
60- if not self ._config .transformer .diffusion :
60+ if self ._config .transformer .diffusion is None :
6161 if self ._use_flash_attention :
6262 self ._preprocessors .append (FlashAttnVarlenPreprocessor (self ._config .transformer , self ._tensor_space ))
6363 else :
@@ -355,12 +355,21 @@ def preprocess(
355355
356356 batch_size , seq_len = batch .token_ids .shape
357357 seq_len -= 1 # last token is dropped inputs
358+ # attention_mask = torch.ones(
359+ # (batch_size, 1, seq_len, seq_len),
360+ # dtype=torch.bool,
361+ # device=self._tensor_space.distributed.device,
362+ # )
363+ # kwargs[TransformerKwargs.attention_mask] = attention_mask.unsqueeze(1).unsqueeze(1)
358364 attention_mask = torch .ones (
359- (batch_size , 1 , seq_len , seq_len ),
365+ (seq_len , seq_len ),
360366 dtype = torch .bool ,
361367 device = self ._tensor_space .distributed .device ,
362368 )
363- kwargs [TransformerKwargs .attention_mask ] = attention_mask .unsqueeze (1 ).unsqueeze (1 )
369+ kwargs [TransformerKwargs .attention_mask ] = attention_mask [
370+ None , None , 0 :seq_len , None , :seq_len
371+ ]
372+ print (f"attention_mask: { kwargs [TransformerKwargs .attention_mask ]} " )
364373 # # kwargs[TransformerKwargs.attention_mask_value] = torch.tensor(
365374 # # -10000.0, device=self._tensor_space.distributed.device
366375 # # )
0 commit comments