Skip to content

Commit feb153a

Browse files
committed
[Layers] update mode assignment for GDN family layers
1 parent 1a40446 commit feb153a

File tree

5 files changed

+5
-7
lines changed

5 files changed

+5
-7
lines changed

fla/layers/comba.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,9 @@ def forward(
222222

223223
batch_size, q_len, _ = hidden_states.shape
224224
# change to inference mode.
225-
mode = 'fused_recurrent' if q_len <= 64 else self.mode
225+
mode = 'fused_recurrent' if (q_len <= 64 and not self.training) else self.mode
226226
if self.training:
227227
assert mode == 'chunk', "Only chunk mode is supported in training."
228-
229228
last_state = None
230229
if past_key_values is not None and len(past_key_values) > self.layer_idx:
231230
last_state = past_key_values[self.layer_idx]

fla/layers/gated_deltanet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def forward(
217217

218218
batch_size, q_len, _ = hidden_states.shape
219219
# change to inference mode.
220-
mode = 'fused_recurrent' if q_len <= 64 else self.mode
220+
mode = 'fused_recurrent' if (q_len <= 64 and not self.training) else self.mode
221221
if self.training:
222222
assert mode == 'chunk', "Only chunk mode is supported in training."
223223

fla/layers/gated_deltaproduct.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,7 @@ def forward(
174174

175175
batch_size, q_len, _ = hidden_states.shape
176176
# change to inference mode.
177-
mode = 'fused_recurrent' if q_len <= 64 else self.mode
178-
177+
mode = 'fused_recurrent' if (q_len <= 64 and not self.training) else self.mode
179178
if self.training:
180179
assert mode == 'chunk', "Only chunk mode is supported in training."
181180

fla/layers/kda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def forward(
172172

173173
batch_size, q_len, _ = hidden_states.shape
174174
# change to inference mode.
175-
mode = 'fused_recurrent' if q_len <= 64 and not self.training else self.mode
175+
mode = 'fused_recurrent' if (q_len <= 64 and not self.training) else self.mode
176176
if self.training:
177177
assert mode == 'chunk', "Only chunk mode is supported in training."
178178

fla/layers/mom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def forward(
437437
if origin_cu_seqlens is not None:
438438
hidden_states, attention_mask = self.cu2pad(hidden_states, origin_cu_seqlens)
439439

440-
mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
440+
mode = 'fused_recurrent' if (hidden_states.shape[1] <= 64 and not self.training) else self.mode
441441
if self.training:
442442
assert mode == 'chunk', "Only chunk mode is supported in training."
443443

0 commit comments

Comments
 (0)