Skip to content

Commit 1700f8d

Browse files
committed
add input size assertions. fix kda doc
1 parent feb153a commit 1700f8d

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

fla/ops/gla/chunk.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,5 +1316,9 @@ def chunk_gla(
13161316
)
13171317
if scale is None:
13181318
scale = q.shape[-1] ** -0.5
1319+
if initial_state is not None:
1320+
assert initial_state.dtype == torch.float32, "initial_state must be in float32."
1321+
assert q.shape == k.shape == g.shape, "q, k, g must have the same shape."
1322+
assert v.shape == (q.shape[0], q.shape[1], q.shape[2], v.shape[-1]), "v must be of shape (batch size, seq len, num of head, head dim)."
13191323
o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens)
13201324
return o, final_state

fla/ops/kda/chunk.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,11 @@ def chunk_kda(
334334
f"The number of initial states is expected to be equal to the number of input sequences, "
335335
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.",
336336
)
337+
if initial_state is not None:
338+
assert initial_state.dtype == torch.float32, "initial_state must be in float32."
339+
assert q.shape == k.shape == g.shape, "q, k, g must have the same shape."
340+
assert beta.shape == (q.shape[0], q.shape[1], q.shape[2]), "beta must be of shape (batch size, num of head, seq len)."
341+
assert v.shape == (q.shape[0], q.shape[1], q.shape[2], v.shape[-1]), "v must be of shape (batch size, seq len, num of head, head dim)."
337342
if scale is None:
338343
scale = k.shape[-1] ** -0.5
339344
o, final_state = ChunkKDAFunction.apply(

0 commit comments

Comments
 (0)