File tree Expand file tree Collapse file tree 5 files changed +5
-7
lines changed Expand file tree Collapse file tree 5 files changed +5
-7
lines changed Original file line number Diff line number Diff line change @@ -222,10 +222,9 @@ def forward(
222222
223223 batch_size , q_len , _ = hidden_states .shape
224224 # change to inference mode.
225- mode = 'fused_recurrent' if q_len <= 64 else self .mode
225+ mode = 'fused_recurrent' if ( q_len <= 64 and not self . training ) else self .mode
226226 if self .training :
227227 assert mode == 'chunk' , "Only chunk mode is supported in training."
228-
229228 last_state = None
230229 if past_key_values is not None and len (past_key_values ) > self .layer_idx :
231230 last_state = past_key_values [self .layer_idx ]
Original file line number Diff line number Diff line change @@ -217,7 +217,7 @@ def forward(
217217
218218 batch_size , q_len , _ = hidden_states .shape
219219 # change to inference mode.
220- mode = 'fused_recurrent' if q_len <= 64 else self .mode
220+ mode = 'fused_recurrent' if ( q_len <= 64 and not self . training ) else self .mode
221221 if self .training :
222222 assert mode == 'chunk' , "Only chunk mode is supported in training."
223223
Original file line number Diff line number Diff line change @@ -174,8 +174,7 @@ def forward(
174174
175175 batch_size , q_len , _ = hidden_states .shape
176176 # change to inference mode.
177- mode = 'fused_recurrent' if q_len <= 64 else self .mode
178-
177+ mode = 'fused_recurrent' if (q_len <= 64 and not self .training ) else self .mode
179178 if self .training :
180179 assert mode == 'chunk' , "Only chunk mode is supported in training."
181180
Original file line number Diff line number Diff line change @@ -172,7 +172,7 @@ def forward(
172172
173173 batch_size , q_len , _ = hidden_states .shape
174174 # change to inference mode.
175- mode = 'fused_recurrent' if q_len <= 64 and not self .training else self .mode
175+ mode = 'fused_recurrent' if ( q_len <= 64 and not self .training ) else self .mode
176176 if self .training :
177177 assert mode == 'chunk' , "Only chunk mode is supported in training."
178178
Original file line number Diff line number Diff line change @@ -437,7 +437,7 @@ def forward(
437437 if origin_cu_seqlens is not None :
438438 hidden_states , attention_mask = self .cu2pad (hidden_states , origin_cu_seqlens )
439439
440- mode = 'fused_recurrent' if hidden_states .shape [1 ] <= 64 else self .mode
440+ mode = 'fused_recurrent' if ( hidden_states .shape [1 ] <= 64 and not self . training ) else self .mode
441441 if self .training :
442442 assert mode == 'chunk' , "Only chunk mode is supported in training."
443443
You can’t perform that action at this time.
0 commit comments