Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit 9aa46d2

Browse files
committed
formatting
1 parent 523f4e1 commit 9aa46d2

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def forward(
195195
op=xf_op[0],
196196
)
197197
.transpose(0, 1)
198-
.reshape(seq_len, bsz, num_heads*head_dim)
198+
.reshape(seq_len, bsz, num_heads * head_dim)
199199
)
200200
# TODO: Reshape q/k/v back to original?
201201
else:
@@ -413,7 +413,11 @@ def backward(ctx, grad_output):
413413
op=xf_op[0],
414414
)
415415
out = attn
416-
attn = attn.transpose(0, 1).reshape(seq_len, bsz, num_heads*head_dim).contiguous()
416+
attn = (
417+
attn.transpose(0, 1)
418+
.reshape(seq_len, bsz, num_heads * head_dim)
419+
.contiguous()
420+
)
417421
else:
418422
attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha(
419423
q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype

0 commit comments

Comments
 (0)