Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit d0aa8b6

Browse files
committed
Clean up comments
1 parent 9aa46d2 commit d0aa8b6

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ def forward(
197197
.transpose(0, 1)
198198
.reshape(seq_len, bsz, num_heads * head_dim)
199199
)
200-
# TODO: Reshape q/k/v back to original?
201200
else:
202201
q = q.view(seq_len, -1, head_dim)
203202
k = k.view(seq_len, -1, head_dim)
@@ -396,11 +395,6 @@ def backward(ctx, grad_output):
396395

397396
# recalculate attention
398397
if xf_eff_attn:
399-
# TODO: reshape q/k/v?
400-
# q = q.view(seq_len, bsz, -1, head_dim).transpose(0, 1)
401-
# k = k.view(seq_len, bsz, -1, head_dim).transpose(0, 1)
402-
# v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1)
403-
404398
num_heads = embed_dim_per_partition // head_dim
405399

406400
attn, lse = xops.memory_efficient_attention_forward_requires_grad(

0 commit comments

Comments
 (0)