Skip to content

Commit e464d85

Browse files
committed
lint
1 parent 25860ce commit e464d85

File tree

2 files changed

+73
-77
lines changed

2 files changed

+73
-77
lines changed

src/fairseq2/models/transformer/_sdpa/_flex.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from typing import Callable, TypeAlias, final
1010

11-
import torch
1211
from torch import Tensor
1312
from torch.nn.attention.flex_attention import flex_attention
1413
from typing_extensions import override
Lines changed: 73 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,91 @@
1+
from unittest.mock import Mock, patch
2+
13
import pytest
24
import torch
3-
from unittest.mock import Mock, patch
45

5-
from fairseq2.models.transformer._attention_bias import IdentityBias
66
from fairseq2.device import Device
7-
7+
from fairseq2.models.transformer._attention_bias import IdentityBias
88
from fairseq2.models.transformer._block_mask import (
9+
BlockMaskCache,
10+
BlockMaskCacheKey,
911
_causal_mask_fn,
10-
_sliding_window_causal_mask_fn,
11-
_offsets_to_doc_ids_tensor,
12+
_create_composed_mask,
1213
_create_packed_mask_fn,
1314
_create_padding_mask_fn,
14-
_create_composed_mask,
15-
BlockMaskCacheKey,
16-
BlockMaskCache,
15+
_offsets_to_doc_ids_tensor,
16+
_sliding_window_causal_mask_fn,
1717
)
1818

1919

2020
class TestMaskFunctions:
2121
"""Test individual mask functions."""
2222

23-
def test_causal_mask_fn(self):
23+
def test_causal_mask_fn(self) -> None:
2424
"""Test causal mask function behavior."""
2525
q_lens = torch.tensor([3, 2])
2626
kv_lens = torch.tensor([3, 2])
2727
mask_fn = _causal_mask_fn(q_lens, kv_lens)
28-
28+
2929
# Test for batch 0
3030
b = torch.tensor(0)
3131
h = torch.tensor(0)
32-
32+
3333
# Test diagonal and upper triangular positions
3434
assert mask_fn(b, h, torch.tensor(0), torch.tensor(0)) == True
3535
assert mask_fn(b, h, torch.tensor(1), torch.tensor(0)) == True
3636
assert mask_fn(b, h, torch.tensor(1), torch.tensor(1)) == True
3737
assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == False
3838
assert mask_fn(b, h, torch.tensor(2), torch.tensor(1)) == True
3939

40-
def test_sliding_window_causal_mask_fn(self):
40+
def test_sliding_window_causal_mask_fn(self) -> None:
4141
"""Test sliding window causal mask function."""
4242
q_lens = torch.tensor([4])
4343
kv_lens = torch.tensor([4])
4444
window_size = 2
4545
mask_fn = _sliding_window_causal_mask_fn(window_size, q_lens, kv_lens)
46-
46+
4747
b = torch.tensor(0)
4848
h = torch.tensor(0)
49-
49+
5050
# Test window behavior
5151
assert mask_fn(b, h, torch.tensor(2), torch.tensor(1)) == True # Within window
5252
assert mask_fn(b, h, torch.tensor(2), torch.tensor(2)) == True # Diagonal
53-
assert mask_fn(b, h, torch.tensor(3), torch.tensor(1)) == False # Outside window
53+
assert (
54+
mask_fn(b, h, torch.tensor(3), torch.tensor(1)) == False
55+
) # Outside window
5456
assert mask_fn(b, h, torch.tensor(1), torch.tensor(2)) == False # Future token
5557

56-
def test_sliding_window_size_one(self):
58+
def test_sliding_window_size_one(self) -> None:
5759
"""Test sliding window with size 1 (diagonal only)."""
5860
q_lens = torch.tensor([3])
5961
kv_lens = torch.tensor([3])
6062
mask_fn = _sliding_window_causal_mask_fn(1, q_lens, kv_lens)
61-
63+
6264
b = torch.tensor(0)
6365
h = torch.tensor(0)
64-
66+
6567
# Only diagonal should be True
6668
assert mask_fn(b, h, torch.tensor(0), torch.tensor(0)) == True
6769
assert mask_fn(b, h, torch.tensor(1), torch.tensor(1)) == True
6870
assert mask_fn(b, h, torch.tensor(1), torch.tensor(0)) == False
6971
assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == False
7072

71-
def test_offsets_to_doc_ids_tensor(self):
73+
def test_offsets_to_doc_ids_tensor(self) -> None:
7274
"""Test conversion of offsets to document IDs."""
7375
offsets = torch.tensor([0, 3, 5, 8])
7476
doc_ids = _offsets_to_doc_ids_tensor(offsets)
7577
expected = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2], dtype=torch.int32)
7678
assert torch.equal(doc_ids, expected)
7779

78-
def test_padding_mask_fn(self):
80+
def test_padding_mask_fn(self) -> None:
7981
"""Test padding mask function."""
8082
q_lens = torch.tensor([2, 3])
8183
kv_lens = torch.tensor([3, 2])
8284
mask_fn = _create_padding_mask_fn(q_lens, kv_lens)
83-
85+
8486
b = torch.tensor(0)
8587
h = torch.tensor(0)
86-
88+
8789
# Valid positions
8890
assert mask_fn(b, h, torch.tensor(0), torch.tensor(0)) == True
8991
assert mask_fn(b, h, torch.tensor(1), torch.tensor(2)) == True
@@ -95,38 +97,38 @@ def test_padding_mask_fn(self):
9597
class TestPackedMaskFunction:
9698
"""Test packed sequence mask function."""
9799

98-
def test_create_packed_mask_fn_basic(self):
100+
def test_create_packed_mask_fn_basic(self) -> None:
99101
"""Test basic packed mask functionality."""
100102
seq_begin_indices = torch.tensor([0, 3, 5])
101103
keys_begin_indices = torch.tensor([0, 3, 5])
102-
104+
103105
mask_fn = _create_packed_mask_fn(seq_begin_indices, keys_begin_indices)
104-
106+
105107
b = torch.tensor(0)
106108
h = torch.tensor(0)
107-
109+
108110
# Same document
109111
assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == True
110112
assert mask_fn(b, h, torch.tensor(3), torch.tensor(4)) == True
111113
# Different documents
112114
assert mask_fn(b, h, torch.tensor(0), torch.tensor(3)) == False
113115
assert mask_fn(b, h, torch.tensor(1), torch.tensor(4)) == False
114116

115-
def test_create_packed_mask_fn_with_base_mask(self):
117+
def test_create_packed_mask_fn_with_base_mask(self) -> None:
116118
"""Test packed mask with base causal mask."""
117119
seq_begin_indices = torch.tensor([0, 2, 4])
118120
keys_begin_indices = torch.tensor([0, 2, 4])
119121
q_lens = torch.tensor([2, 2])
120122
kv_lens = torch.tensor([2, 2])
121-
123+
122124
base_mask_fn = _causal_mask_fn(q_lens, kv_lens)
123125
mask_fn = _create_packed_mask_fn(
124126
seq_begin_indices, keys_begin_indices, base_mask_fn
125127
)
126-
128+
127129
b = torch.tensor(0)
128130
h = torch.tensor(0)
129-
131+
130132
# Same document, causal valid
131133
assert mask_fn(b, h, torch.tensor(1), torch.tensor(0)) == True
132134
# Same document, causal invalid
@@ -138,165 +140,160 @@ def test_create_packed_mask_fn_with_base_mask(self):
138140
class TestBlockMaskCache:
139141
"""Test block mask caching functionality."""
140142

141-
def test_cache_key_creation(self):
143+
def test_cache_key_creation(self) -> None:
142144
"""Test cache key creation for different layouts."""
143145
cache = BlockMaskCache()
144-
146+
145147
# Mock BatchLayout for non-packed sequences
146148
seqs_layout = Mock()
147149
seqs_layout.packed = False
148150
seqs_layout.seq_lens = [3, 4, 2]
149151
seqs_layout.max_seq_len = 4
150-
152+
151153
keys_layout = Mock()
152154
keys_layout.packed = False
153155
keys_layout.seq_lens = [3, 4, 2]
154156
keys_layout.max_seq_len = 4
155-
157+
156158
key = cache._create_cache_key(seqs_layout, keys_layout)
157159
assert key.batch_size == 3
158160
assert key.seqs_len == 4
159161
assert key.keys_len == 4
160162

161-
def test_cache_key_creation_packed(self):
163+
def test_cache_key_creation_packed(self) -> None:
162164
"""Test cache key creation for packed sequences."""
163165
cache = BlockMaskCache()
164-
166+
165167
# Mock BatchLayout for packed sequences
166168
seqs_layout = Mock()
167169
seqs_layout.packed = True
168170
seqs_layout.seq_begin_indices = [0, 3, 7]
169-
171+
170172
keys_layout = Mock()
171173
keys_layout.packed = True
172174
keys_layout.seq_begin_indices = [0, 3, 7]
173-
175+
174176
key = cache._create_cache_key(seqs_layout, keys_layout)
175177
assert key.batch_size == 1
176178
assert key.seqs_len == 7
177179
assert key.keys_len == 7
178180

179-
def test_cache_key_hash(self):
181+
def test_cache_key_hash(self) -> None:
180182
"""Test that cache keys are hashable."""
181183
key1 = BlockMaskCacheKey(batch_size=2, seqs_len=10, keys_len=10)
182184
key2 = BlockMaskCacheKey(batch_size=2, seqs_len=10, keys_len=10)
183185
key3 = BlockMaskCacheKey(batch_size=3, seqs_len=10, keys_len=10)
184-
186+
185187
assert hash(key1) == hash(key2)
186188
assert hash(key1) != hash(key3)
187189
assert key1 == key2
188190
assert key1 != key3
189191

190-
@patch('fairseq2.models.transformer._block_mask._create_composed_mask')
191-
def test_cache_hit_and_miss(self, mock_create_mask):
192+
@patch("fairseq2.models.transformer._block_mask._create_composed_mask")
193+
def test_cache_hit_and_miss(self, mock_create_mask: Mock) -> None:
192194
"""Test cache hit and miss behavior."""
193195
cache = BlockMaskCache()
194196
mock_mask = Mock()
195197
mock_create_mask.return_value = mock_mask
196-
198+
197199
# Mock inputs
198200
bias = Mock(spec=IdentityBias)
199201
seqs_layout = Mock()
200202
seqs_layout.packed = False
201203
seqs_layout.seq_lens = [3, 4]
202204
seqs_layout.max_seq_len = 4
203-
205+
204206
keys_layout = Mock()
205207
keys_layout.packed = False
206208
keys_layout.seq_lens = [3, 4]
207209
keys_layout.max_seq_len = 4
208-
210+
209211
device = Mock(spec=Device)
210-
212+
211213
# First call - cache miss
212214
result1 = cache.get_or_create_mask(bias, seqs_layout, keys_layout, device)
213215
assert result1 == mock_mask
214216
assert mock_create_mask.call_count == 1
215-
217+
216218
# Second call - cache hit
217219
result2 = cache.get_or_create_mask(bias, seqs_layout, keys_layout, device)
218220
assert result2 == mock_mask
219221
assert mock_create_mask.call_count == 1 # Should not increase
220222

221-
def test_cache_clear(self):
222-
"""Test cache clearing."""
223-
cache = BlockMaskCache()
224-
cache._cache["test"] = "value"
225-
assert len(cache._cache) == 1
226-
227-
cache.clear()
228-
assert len(cache._cache) == 0
229-
230223

231224
class TestCreateComposedMask:
232225
"""Test the main composed mask creation function."""
233226

234-
@patch('fairseq2.models.transformer._block_mask.create_block_mask')
235-
def test_create_composed_mask_identity_bias(self, mock_create_block_mask):
227+
@patch("fairseq2.models.transformer._block_mask.create_block_mask")
228+
def test_create_composed_mask_identity_bias(
229+
self, mock_create_block_mask: Mock
230+
) -> None:
236231
"""Test composed mask creation with identity bias."""
237232
mock_block_mask = Mock()
238233
mock_create_block_mask.return_value = mock_block_mask
239-
234+
240235
bias = Mock(spec=IdentityBias)
241-
236+
242237
# Mock BatchLayout
243238
seqs_layout = Mock()
244239
seqs_layout.packed = False
245240
seqs_layout.padded = True
246241
seqs_layout.seq_lens = [3, 4]
247242
seqs_layout.max_seq_len = 4
248243
seqs_layout.seq_lens_pt = torch.tensor([3, 4])
249-
244+
250245
keys_layout = Mock()
251246
keys_layout.packed = False
252247
keys_layout.padded = True
253248
keys_layout.seq_lens = [3, 4]
254249
keys_layout.max_seq_len = 4
255250
keys_layout.seq_lens_pt = torch.tensor([3, 4])
256-
251+
257252
device = Mock(spec=Device)
258-
253+
259254
result = _create_composed_mask(bias, seqs_layout, keys_layout, device)
260-
255+
261256
# Should create block mask with padding mask only
262257
mock_create_block_mask.assert_called_once()
263258
assert result == mock_block_mask
264259

265-
@patch('fairseq2.models.transformer._block_mask.create_block_mask')
266-
def test_create_composed_mask_no_masks_needed(self, mock_create_block_mask):
260+
@patch("fairseq2.models.transformer._block_mask.create_block_mask")
261+
def test_create_composed_mask_no_masks_needed(
262+
self, mock_create_block_mask: Mock
263+
) -> None:
267264
"""Test when no masks are needed."""
268265
bias = Mock(spec=IdentityBias)
269-
266+
270267
# Mock BatchLayout with no padding
271268
seqs_layout = Mock()
272269
seqs_layout.packed = False
273270
seqs_layout.padded = False
274-
271+
275272
keys_layout = Mock()
276273
keys_layout.packed = False
277274
keys_layout.padded = False
278-
275+
279276
device = Mock(spec=Device)
280-
277+
281278
result = _create_composed_mask(bias, seqs_layout, keys_layout, device)
282-
279+
283280
# Should return None when no masks are needed
284281
assert result is None
285282
mock_create_block_mask.assert_not_called()
286283

287-
def test_unsupported_bias_type(self):
284+
def test_unsupported_bias_type(self) -> None:
288285
"""Test that unsupported bias types raise an error."""
289286
bias = Mock() # Unknown bias type
290-
287+
291288
seqs_layout = Mock()
292289
seqs_layout.packed = False
293290
seqs_layout.padded = False
294-
291+
295292
keys_layout = Mock()
296293
keys_layout.packed = False
297294
keys_layout.padded = False
298-
295+
299296
device = Mock(spec=Device)
300-
297+
301298
with pytest.raises(Exception): # Should raise NotSupportedError
302299
_create_composed_mask(bias, seqs_layout, keys_layout, device)

0 commit comments

Comments
 (0)