File tree Expand file tree Collapse file tree 3 files changed +9
-15
lines changed
distributed_shampoo/utils Expand file tree Collapse file tree 3 files changed +9
-15
lines changed Original file line number Diff line number Diff line change @@ -109,10 +109,6 @@ def _construct_composable_block_ids(
109
109
"""
110
110
return (param_index , f"block_{ block_index } " )
111
111
112
- @overload
113
- @torch .no_grad ()
114
- def _get_params_or_grads (self ) -> Iterable [Tensor ]: ...
115
-
116
112
@overload
117
113
@torch .no_grad ()
118
114
def _get_params_or_grads (
@@ -121,7 +117,9 @@ def _get_params_or_grads(
121
117
122
118
@overload
123
119
@torch .no_grad ()
124
- def _get_params_or_grads (self , get_grad : Literal [False ]) -> Iterable [Tensor ]: ...
120
+ def _get_params_or_grads (
121
+ self , get_grad : Literal [False ] = False
122
+ ) -> Iterable [Tensor ]: ...
125
123
126
124
@torch .no_grad ()
127
125
def _get_params_or_grads (self , get_grad : bool = False ) -> Iterable [Tensor | None ]:
Original file line number Diff line number Diff line change @@ -29,10 +29,6 @@ class FullyShardDistributor(Distributor):
29
29
30
30
"""
31
31
32
- @overload
33
- @torch .no_grad ()
34
- def _get_params_or_grads (self ) -> Iterable [Tensor ]: ...
35
-
36
32
@overload
37
33
@torch .no_grad ()
38
34
def _get_params_or_grads (
@@ -41,7 +37,9 @@ def _get_params_or_grads(
41
37
42
38
@overload
43
39
@torch .no_grad ()
44
- def _get_params_or_grads (self , get_grad : Literal [False ]) -> Iterable [Tensor ]: ...
40
+ def _get_params_or_grads (
41
+ self , get_grad : Literal [False ] = False
42
+ ) -> Iterable [Tensor ]: ...
45
43
46
44
@torch .no_grad ()
47
45
def _get_params_or_grads (self , get_grad : bool = False ) -> Iterable [Tensor | None ]:
Original file line number Diff line number Diff line change @@ -203,10 +203,6 @@ def __init__(
203
203
comms_group_rank = comms_group_rank ,
204
204
)
205
205
206
- @overload
207
- @torch .no_grad ()
208
- def _get_params_or_grads (self ) -> Iterable [Tensor ]: ...
209
-
210
206
@overload
211
207
@torch .no_grad ()
212
208
def _get_params_or_grads (
@@ -215,7 +211,9 @@ def _get_params_or_grads(
215
211
216
212
@overload
217
213
@torch .no_grad ()
218
- def _get_params_or_grads (self , get_grad : Literal [False ]) -> Iterable [Tensor ]: ...
214
+ def _get_params_or_grads (
215
+ self , get_grad : Literal [False ] = False
216
+ ) -> Iterable [Tensor ]: ...
219
217
220
218
@torch .no_grad ()
221
219
def _get_params_or_grads (self , get_grad : bool = False ) -> Iterable [Tensor | None ]:
You can’t perform that action at this time.
0 commit comments