1+ from unittest .mock import Mock , patch
2+
13import pytest
24import torch
3- from unittest .mock import Mock , patch
45
5- from fairseq2 .models .transformer ._attention_bias import IdentityBias
66from fairseq2 .device import Device
7-
7+ from fairseq2 . models . transformer . _attention_bias import IdentityBias
88from fairseq2 .models .transformer ._block_mask import (
9+ BlockMaskCache ,
10+ BlockMaskCacheKey ,
911 _causal_mask_fn ,
10- _sliding_window_causal_mask_fn ,
11- _offsets_to_doc_ids_tensor ,
12+ _create_composed_mask ,
1213 _create_packed_mask_fn ,
1314 _create_padding_mask_fn ,
14- _create_composed_mask ,
15- BlockMaskCacheKey ,
16- BlockMaskCache ,
15+ _offsets_to_doc_ids_tensor ,
16+ _sliding_window_causal_mask_fn ,
1717)
1818
1919
2020class TestMaskFunctions :
2121 """Test individual mask functions."""
2222
23- def test_causal_mask_fn (self ):
23+ def test_causal_mask_fn (self ) -> None :
2424 """Test causal mask function behavior."""
2525 q_lens = torch .tensor ([3 , 2 ])
2626 kv_lens = torch .tensor ([3 , 2 ])
2727 mask_fn = _causal_mask_fn (q_lens , kv_lens )
28-
28+
2929 # Test for batch 0
3030 b = torch .tensor (0 )
3131 h = torch .tensor (0 )
32-
32+
3333 # Test diagonal and upper triangular positions
3434 assert mask_fn (b , h , torch .tensor (0 ), torch .tensor (0 )) == True
3535 assert mask_fn (b , h , torch .tensor (1 ), torch .tensor (0 )) == True
3636 assert mask_fn (b , h , torch .tensor (1 ), torch .tensor (1 )) == True
3737 assert mask_fn (b , h , torch .tensor (0 ), torch .tensor (1 )) == False
3838 assert mask_fn (b , h , torch .tensor (2 ), torch .tensor (1 )) == True
3939
40- def test_sliding_window_causal_mask_fn (self ):
40+ def test_sliding_window_causal_mask_fn (self ) -> None :
4141 """Test sliding window causal mask function."""
4242 q_lens = torch .tensor ([4 ])
4343 kv_lens = torch .tensor ([4 ])
4444 window_size = 2
4545 mask_fn = _sliding_window_causal_mask_fn (window_size , q_lens , kv_lens )
46-
46+
4747 b = torch .tensor (0 )
4848 h = torch .tensor (0 )
49-
49+
5050 # Test window behavior
5151 assert mask_fn (b , h , torch .tensor (2 ), torch .tensor (1 )) == True # Within window
5252 assert mask_fn (b , h , torch .tensor (2 ), torch .tensor (2 )) == True # Diagonal
53- assert mask_fn (b , h , torch .tensor (3 ), torch .tensor (1 )) == False # Outside window
53+ assert (
54+ mask_fn (b , h , torch .tensor (3 ), torch .tensor (1 )) == False
55+ ) # Outside window
5456 assert mask_fn (b , h , torch .tensor (1 ), torch .tensor (2 )) == False # Future token
5557
56- def test_sliding_window_size_one (self ):
58+ def test_sliding_window_size_one (self ) -> None :
5759 """Test sliding window with size 1 (diagonal only)."""
5860 q_lens = torch .tensor ([3 ])
5961 kv_lens = torch .tensor ([3 ])
6062 mask_fn = _sliding_window_causal_mask_fn (1 , q_lens , kv_lens )
61-
63+
6264 b = torch .tensor (0 )
6365 h = torch .tensor (0 )
64-
66+
6567 # Only diagonal should be True
6668 assert mask_fn (b , h , torch .tensor (0 ), torch .tensor (0 )) == True
6769 assert mask_fn (b , h , torch .tensor (1 ), torch .tensor (1 )) == True
6870 assert mask_fn (b , h , torch .tensor (1 ), torch .tensor (0 )) == False
6971 assert mask_fn (b , h , torch .tensor (0 ), torch .tensor (1 )) == False
7072
71- def test_offsets_to_doc_ids_tensor (self ):
73+ def test_offsets_to_doc_ids_tensor (self ) -> None :
7274 """Test conversion of offsets to document IDs."""
7375 offsets = torch .tensor ([0 , 3 , 5 , 8 ])
7476 doc_ids = _offsets_to_doc_ids_tensor (offsets )
7577 expected = torch .tensor ([0 , 0 , 0 , 1 , 1 , 2 , 2 , 2 ], dtype = torch .int32 )
7678 assert torch .equal (doc_ids , expected )
7779
78- def test_padding_mask_fn (self ):
80+ def test_padding_mask_fn (self ) -> None :
7981 """Test padding mask function."""
8082 q_lens = torch .tensor ([2 , 3 ])
8183 kv_lens = torch .tensor ([3 , 2 ])
8284 mask_fn = _create_padding_mask_fn (q_lens , kv_lens )
83-
85+
8486 b = torch .tensor (0 )
8587 h = torch .tensor (0 )
86-
88+
8789 # Valid positions
8890 assert mask_fn (b , h , torch .tensor (0 ), torch .tensor (0 )) == True
8991 assert mask_fn (b , h , torch .tensor (1 ), torch .tensor (2 )) == True
@@ -95,38 +97,38 @@ def test_padding_mask_fn(self):
9597class TestPackedMaskFunction :
9698 """Test packed sequence mask function."""
9799
98- def test_create_packed_mask_fn_basic (self ):
100+ def test_create_packed_mask_fn_basic (self ) -> None :
99101 """Test basic packed mask functionality."""
100102 seq_begin_indices = torch .tensor ([0 , 3 , 5 ])
101103 keys_begin_indices = torch .tensor ([0 , 3 , 5 ])
102-
104+
103105 mask_fn = _create_packed_mask_fn (seq_begin_indices , keys_begin_indices )
104-
106+
105107 b = torch .tensor (0 )
106108 h = torch .tensor (0 )
107-
109+
108110 # Same document
109111 assert mask_fn (b , h , torch .tensor (0 ), torch .tensor (1 )) == True
110112 assert mask_fn (b , h , torch .tensor (3 ), torch .tensor (4 )) == True
111113 # Different documents
112114 assert mask_fn (b , h , torch .tensor (0 ), torch .tensor (3 )) == False
113115 assert mask_fn (b , h , torch .tensor (1 ), torch .tensor (4 )) == False
114116
115- def test_create_packed_mask_fn_with_base_mask (self ):
117+ def test_create_packed_mask_fn_with_base_mask (self ) -> None :
116118 """Test packed mask with base causal mask."""
117119 seq_begin_indices = torch .tensor ([0 , 2 , 4 ])
118120 keys_begin_indices = torch .tensor ([0 , 2 , 4 ])
119121 q_lens = torch .tensor ([2 , 2 ])
120122 kv_lens = torch .tensor ([2 , 2 ])
121-
123+
122124 base_mask_fn = _causal_mask_fn (q_lens , kv_lens )
123125 mask_fn = _create_packed_mask_fn (
124126 seq_begin_indices , keys_begin_indices , base_mask_fn
125127 )
126-
128+
127129 b = torch .tensor (0 )
128130 h = torch .tensor (0 )
129-
131+
130132 # Same document, causal valid
131133 assert mask_fn (b , h , torch .tensor (1 ), torch .tensor (0 )) == True
132134 # Same document, causal invalid
@@ -138,165 +140,160 @@ def test_create_packed_mask_fn_with_base_mask(self):
138140class TestBlockMaskCache :
139141 """Test block mask caching functionality."""
140142
141- def test_cache_key_creation (self ):
143+ def test_cache_key_creation (self ) -> None :
142144 """Test cache key creation for different layouts."""
143145 cache = BlockMaskCache ()
144-
146+
145147 # Mock BatchLayout for non-packed sequences
146148 seqs_layout = Mock ()
147149 seqs_layout .packed = False
148150 seqs_layout .seq_lens = [3 , 4 , 2 ]
149151 seqs_layout .max_seq_len = 4
150-
152+
151153 keys_layout = Mock ()
152154 keys_layout .packed = False
153155 keys_layout .seq_lens = [3 , 4 , 2 ]
154156 keys_layout .max_seq_len = 4
155-
157+
156158 key = cache ._create_cache_key (seqs_layout , keys_layout )
157159 assert key .batch_size == 3
158160 assert key .seqs_len == 4
159161 assert key .keys_len == 4
160162
161- def test_cache_key_creation_packed (self ):
163+ def test_cache_key_creation_packed (self ) -> None :
162164 """Test cache key creation for packed sequences."""
163165 cache = BlockMaskCache ()
164-
166+
165167 # Mock BatchLayout for packed sequences
166168 seqs_layout = Mock ()
167169 seqs_layout .packed = True
168170 seqs_layout .seq_begin_indices = [0 , 3 , 7 ]
169-
171+
170172 keys_layout = Mock ()
171173 keys_layout .packed = True
172174 keys_layout .seq_begin_indices = [0 , 3 , 7 ]
173-
175+
174176 key = cache ._create_cache_key (seqs_layout , keys_layout )
175177 assert key .batch_size == 1
176178 assert key .seqs_len == 7
177179 assert key .keys_len == 7
178180
179- def test_cache_key_hash (self ):
181+ def test_cache_key_hash (self ) -> None :
180182 """Test that cache keys are hashable."""
181183 key1 = BlockMaskCacheKey (batch_size = 2 , seqs_len = 10 , keys_len = 10 )
182184 key2 = BlockMaskCacheKey (batch_size = 2 , seqs_len = 10 , keys_len = 10 )
183185 key3 = BlockMaskCacheKey (batch_size = 3 , seqs_len = 10 , keys_len = 10 )
184-
186+
185187 assert hash (key1 ) == hash (key2 )
186188 assert hash (key1 ) != hash (key3 )
187189 assert key1 == key2
188190 assert key1 != key3
189191
190- @patch (' fairseq2.models.transformer._block_mask._create_composed_mask' )
191- def test_cache_hit_and_miss (self , mock_create_mask ) :
192+ @patch (" fairseq2.models.transformer._block_mask._create_composed_mask" )
193+ def test_cache_hit_and_miss (self , mock_create_mask : Mock ) -> None :
192194 """Test cache hit and miss behavior."""
193195 cache = BlockMaskCache ()
194196 mock_mask = Mock ()
195197 mock_create_mask .return_value = mock_mask
196-
198+
197199 # Mock inputs
198200 bias = Mock (spec = IdentityBias )
199201 seqs_layout = Mock ()
200202 seqs_layout .packed = False
201203 seqs_layout .seq_lens = [3 , 4 ]
202204 seqs_layout .max_seq_len = 4
203-
205+
204206 keys_layout = Mock ()
205207 keys_layout .packed = False
206208 keys_layout .seq_lens = [3 , 4 ]
207209 keys_layout .max_seq_len = 4
208-
210+
209211 device = Mock (spec = Device )
210-
212+
211213 # First call - cache miss
212214 result1 = cache .get_or_create_mask (bias , seqs_layout , keys_layout , device )
213215 assert result1 == mock_mask
214216 assert mock_create_mask .call_count == 1
215-
217+
216218 # Second call - cache hit
217219 result2 = cache .get_or_create_mask (bias , seqs_layout , keys_layout , device )
218220 assert result2 == mock_mask
219221 assert mock_create_mask .call_count == 1 # Should not increase
220222
221- def test_cache_clear (self ):
222- """Test cache clearing."""
223- cache = BlockMaskCache ()
224- cache ._cache ["test" ] = "value"
225- assert len (cache ._cache ) == 1
226-
227- cache .clear ()
228- assert len (cache ._cache ) == 0
229-
230223
231224class TestCreateComposedMask :
232225 """Test the main composed mask creation function."""
233226
234- @patch ('fairseq2.models.transformer._block_mask.create_block_mask' )
235- def test_create_composed_mask_identity_bias (self , mock_create_block_mask ):
227+ @patch ("fairseq2.models.transformer._block_mask.create_block_mask" )
228+ def test_create_composed_mask_identity_bias (
229+ self , mock_create_block_mask : Mock
230+ ) -> None :
236231 """Test composed mask creation with identity bias."""
237232 mock_block_mask = Mock ()
238233 mock_create_block_mask .return_value = mock_block_mask
239-
234+
240235 bias = Mock (spec = IdentityBias )
241-
236+
242237 # Mock BatchLayout
243238 seqs_layout = Mock ()
244239 seqs_layout .packed = False
245240 seqs_layout .padded = True
246241 seqs_layout .seq_lens = [3 , 4 ]
247242 seqs_layout .max_seq_len = 4
248243 seqs_layout .seq_lens_pt = torch .tensor ([3 , 4 ])
249-
244+
250245 keys_layout = Mock ()
251246 keys_layout .packed = False
252247 keys_layout .padded = True
253248 keys_layout .seq_lens = [3 , 4 ]
254249 keys_layout .max_seq_len = 4
255250 keys_layout .seq_lens_pt = torch .tensor ([3 , 4 ])
256-
251+
257252 device = Mock (spec = Device )
258-
253+
259254 result = _create_composed_mask (bias , seqs_layout , keys_layout , device )
260-
255+
261256 # Should create block mask with padding mask only
262257 mock_create_block_mask .assert_called_once ()
263258 assert result == mock_block_mask
264259
265- @patch ('fairseq2.models.transformer._block_mask.create_block_mask' )
266- def test_create_composed_mask_no_masks_needed (self , mock_create_block_mask ):
260+ @patch ("fairseq2.models.transformer._block_mask.create_block_mask" )
261+ def test_create_composed_mask_no_masks_needed (
262+ self , mock_create_block_mask : Mock
263+ ) -> None :
267264 """Test when no masks are needed."""
268265 bias = Mock (spec = IdentityBias )
269-
266+
270267 # Mock BatchLayout with no padding
271268 seqs_layout = Mock ()
272269 seqs_layout .packed = False
273270 seqs_layout .padded = False
274-
271+
275272 keys_layout = Mock ()
276273 keys_layout .packed = False
277274 keys_layout .padded = False
278-
275+
279276 device = Mock (spec = Device )
280-
277+
281278 result = _create_composed_mask (bias , seqs_layout , keys_layout , device )
282-
279+
283280 # Should return None when no masks are needed
284281 assert result is None
285282 mock_create_block_mask .assert_not_called ()
286283
287- def test_unsupported_bias_type (self ):
284+ def test_unsupported_bias_type (self ) -> None :
288285 """Test that unsupported bias types raise an error."""
289286 bias = Mock () # Unknown bias type
290-
287+
291288 seqs_layout = Mock ()
292289 seqs_layout .packed = False
293290 seqs_layout .padded = False
294-
291+
295292 keys_layout = Mock ()
296293 keys_layout .packed = False
297294 keys_layout .padded = False
298-
295+
299296 device = Mock (spec = Device )
300-
297+
301298 with pytest .raises (Exception ): # Should raise NotSupportedError
302299 _create_composed_mask (bias , seqs_layout , keys_layout , device )
0 commit comments