1313# Written by Ze Liu
1414# --------------------------------------------------------
1515import math
16- from typing import Callable , Optional , Tuple , Union
16+ from typing import Callable , Optional , Tuple , Union , Set , Dict
1717
1818import torch
1919import torch .nn as nn
3232_int_or_tuple_2_t = Union [int , Tuple [int , int ]]
3333
3434
35- def window_partition (x , window_size : Tuple [int , int ]):
35+ def window_partition (x : torch . Tensor , window_size : Tuple [int , int ]) -> torch . Tensor :
3636 """
3737 Args:
3838 x: (B, H, W, C)
@@ -48,7 +48,7 @@ def window_partition(x, window_size: Tuple[int, int]):
4848
4949
5050@register_notrace_function # reason: int argument is a Proxy
51- def window_reverse (windows , window_size : Tuple [int , int ], img_size : Tuple [int , int ]):
51+ def window_reverse (windows : torch . Tensor , window_size : Tuple [int , int ], img_size : Tuple [int , int ]) -> torch . Tensor :
5252 """
5353 Args:
5454 windows: (num_windows * B, window_size[0], window_size[1], C)
@@ -81,14 +81,14 @@ class WindowAttention(nn.Module):
8181
8282 def __init__ (
8383 self ,
84- dim ,
85- window_size ,
86- num_heads ,
87- qkv_bias = True ,
88- attn_drop = 0. ,
89- proj_drop = 0. ,
90- pretrained_window_size = [ 0 , 0 ] ,
91- ):
84+ dim : int ,
85+ window_size : Tuple [ int , int ] ,
86+ num_heads : int ,
87+ qkv_bias : bool = True ,
88+ attn_drop : float = 0. ,
89+ proj_drop : float = 0. ,
90+ pretrained_window_size : Tuple [ int , int ] = ( 0 , 0 ) ,
91+ ) -> None :
9292 super ().__init__ ()
9393 self .dim = dim
9494 self .window_size = window_size # Wh, Ww
@@ -149,7 +149,7 @@ def __init__(
149149 self .proj_drop = nn .Dropout (proj_drop )
150150 self .softmax = nn .Softmax (dim = - 1 )
151151
152- def forward (self , x , mask : Optional [torch .Tensor ] = None ):
152+ def forward (self , x : torch . Tensor , mask : Optional [torch .Tensor ] = None ) -> torch . Tensor :
153153 """
154154 Args:
155155 x: input features with shape of (num_windows*B, N, C)
@@ -197,20 +197,20 @@ class SwinTransformerV2Block(nn.Module):
197197
198198 def __init__ (
199199 self ,
200- dim ,
201- input_resolution ,
202- num_heads ,
203- window_size = 7 ,
204- shift_size = 0 ,
205- mlp_ratio = 4. ,
206- qkv_bias = True ,
207- proj_drop = 0. ,
208- attn_drop = 0. ,
209- drop_path = 0. ,
210- act_layer = nn .GELU ,
211- norm_layer = nn .LayerNorm ,
212- pretrained_window_size = 0 ,
213- ):
200+ dim : int ,
201+ input_resolution : _int_or_tuple_2_t ,
202+ num_heads : int ,
203+ window_size : _int_or_tuple_2_t = 7 ,
204+ shift_size : _int_or_tuple_2_t = 0 ,
205+ mlp_ratio : float = 4. ,
206+ qkv_bias : bool = True ,
207+ proj_drop : float = 0. ,
208+ attn_drop : float = 0. ,
209+ drop_path : float = 0. ,
210+ act_layer : nn . Module = nn .GELU ,
211+ norm_layer : nn . Module = nn .LayerNorm ,
212+ pretrained_window_size : _int_or_tuple_2_t = 0 ,
213+ ) -> None :
214214 """
215215 Args:
216216 dim: Number of input channels.
@@ -282,14 +282,16 @@ def __init__(
282282
283283 self .register_buffer ("attn_mask" , attn_mask , persistent = False )
284284
285- def _calc_window_shift (self , target_window_size , target_shift_size ) -> Tuple [Tuple [int , int ], Tuple [int , int ]]:
285+ def _calc_window_shift (self ,
286+ target_window_size : _int_or_tuple_2_t ,
287+ target_shift_size : _int_or_tuple_2_t ) -> Tuple [Tuple [int , int ], Tuple [int , int ]]:
286288 target_window_size = to_2tuple (target_window_size )
287289 target_shift_size = to_2tuple (target_shift_size )
288290 window_size = [r if r <= w else w for r , w in zip (self .input_resolution , target_window_size )]
289291 shift_size = [0 if r <= w else s for r , w , s in zip (self .input_resolution , window_size , target_shift_size )]
290292 return tuple (window_size ), tuple (shift_size )
291293
292- def _attn (self , x ) :
294+ def _attn (self , x : torch . Tensor ) -> torch . Tensor :
293295 B , H , W , C = x .shape
294296
295297 # cyclic shift
@@ -317,7 +319,7 @@ def _attn(self, x):
317319 x = shifted_x
318320 return x
319321
320- def forward (self , x ) :
322+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
321323 B , H , W , C = x .shape
322324 x = x + self .drop_path1 (self .norm1 (self ._attn (x )))
323325 x = x .reshape (B , - 1 , C )
@@ -330,7 +332,7 @@ class PatchMerging(nn.Module):
330332 """ Patch Merging Layer.
331333 """
332334
333- def __init__ (self , dim , out_dim = None , norm_layer = nn .LayerNorm ):
335+ def __init__ (self , dim : int , out_dim : Optional [ int ] = None , norm_layer : nn . Module = nn .LayerNorm ) -> None :
334336 """
335337 Args:
336338 dim (int): Number of input channels.
@@ -343,7 +345,7 @@ def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm):
343345 self .reduction = nn .Linear (4 * dim , self .out_dim , bias = False )
344346 self .norm = norm_layer (self .out_dim )
345347
346- def forward (self , x ) :
348+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
347349 B , H , W , C = x .shape
348350 _assert (H % 2 == 0 , f"x height ({ H } ) is not even." )
349351 _assert (W % 2 == 0 , f"x width ({ W } ) is not even." )
@@ -359,22 +361,22 @@ class SwinTransformerV2Stage(nn.Module):
359361
360362 def __init__ (
361363 self ,
362- dim ,
363- out_dim ,
364- input_resolution ,
365- depth ,
366- num_heads ,
367- window_size ,
368- downsample = False ,
369- mlp_ratio = 4. ,
370- qkv_bias = True ,
371- proj_drop = 0. ,
372- attn_drop = 0. ,
373- drop_path = 0. ,
374- norm_layer = nn .LayerNorm ,
375- pretrained_window_size = 0 ,
376- output_nchw = False ,
377- ):
364+ dim : int ,
365+ out_dim : int ,
366+ input_resolution : _int_or_tuple_2_t ,
367+ depth : int ,
368+ num_heads : int ,
369+ window_size : _int_or_tuple_2_t ,
370+ downsample : bool = False ,
371+ mlp_ratio : float = 4. ,
372+ qkv_bias : bool = True ,
373+ proj_drop : float = 0. ,
374+ attn_drop : float = 0. ,
375+ drop_path : float = 0. ,
376+ norm_layer : nn . Module = nn .LayerNorm ,
377+ pretrained_window_size : _int_or_tuple_2_t = 0 ,
378+ output_nchw : bool = False ,
379+ ) -> None :
378380 """
379381 Args:
380382 dim: Number of input channels.
@@ -428,7 +430,7 @@ def __init__(
428430 )
429431 for i in range (depth )])
430432
431- def forward (self , x ) :
433+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
432434 x = self .downsample (x )
433435
434436 for blk in self .blocks :
@@ -438,7 +440,7 @@ def forward(self, x):
438440 x = blk (x )
439441 return x
440442
441- def _init_respostnorm (self ):
443+ def _init_respostnorm (self ) -> None :
442444 for blk in self .blocks :
443445 nn .init .constant_ (blk .norm1 .bias , 0 )
444446 nn .init .constant_ (blk .norm1 .weight , 0 )
0 commit comments