Skip to content

Commit 77c22d7

Browse files
tsunghsienleefacebook-github-bot
authored andcommitted
Improve the type hints of _get_params_or_grads() (#170)
Summary: Pull Request resolved: #170 Figure out how to type hint when there is a default value of `get_grad=False` in `_get_params_or_grads()`. Reviewed By: gajjanag Differential Revision: D74560386 fbshipit-source-id: 691f4dd5a2fa4ed8875d1942801cdaa24afe669a
1 parent b7e3da9 commit 77c22d7

File tree

3 files changed

+9
-15
lines changed

3 files changed

+9
-15
lines changed

distributed_shampoo/utils/shampoo_distributor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,6 @@ def _construct_composable_block_ids(
109109
"""
110110
return (param_index, f"block_{block_index}")
111111

112-
@overload
113-
@torch.no_grad()
114-
def _get_params_or_grads(self) -> Iterable[Tensor]: ...
115-
116112
@overload
117113
@torch.no_grad()
118114
def _get_params_or_grads(
@@ -121,7 +117,9 @@ def _get_params_or_grads(
121117

122118
@overload
123119
@torch.no_grad()
124-
def _get_params_or_grads(self, get_grad: Literal[False]) -> Iterable[Tensor]: ...
120+
def _get_params_or_grads(
121+
self, get_grad: Literal[False] = False
122+
) -> Iterable[Tensor]: ...
125123

126124
@torch.no_grad()
127125
def _get_params_or_grads(self, get_grad: bool = False) -> Iterable[Tensor | None]:

distributed_shampoo/utils/shampoo_fully_shard_distributor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@ class FullyShardDistributor(Distributor):
2929
3030
"""
3131

32-
@overload
33-
@torch.no_grad()
34-
def _get_params_or_grads(self) -> Iterable[Tensor]: ...
35-
3632
@overload
3733
@torch.no_grad()
3834
def _get_params_or_grads(
@@ -41,7 +37,9 @@ def _get_params_or_grads(
4137

4238
@overload
4339
@torch.no_grad()
44-
def _get_params_or_grads(self, get_grad: Literal[False]) -> Iterable[Tensor]: ...
40+
def _get_params_or_grads(
41+
self, get_grad: Literal[False] = False
42+
) -> Iterable[Tensor]: ...
4543

4644
@torch.no_grad()
4745
def _get_params_or_grads(self, get_grad: bool = False) -> Iterable[Tensor | None]:

distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,6 @@ def __init__(
203203
comms_group_rank=comms_group_rank,
204204
)
205205

206-
@overload
207-
@torch.no_grad()
208-
def _get_params_or_grads(self) -> Iterable[Tensor]: ...
209-
210206
@overload
211207
@torch.no_grad()
212208
def _get_params_or_grads(
@@ -215,7 +211,9 @@ def _get_params_or_grads(
215211

216212
@overload
217213
@torch.no_grad()
218-
def _get_params_or_grads(self, get_grad: Literal[False]) -> Iterable[Tensor]: ...
214+
def _get_params_or_grads(
215+
self, get_grad: Literal[False] = False
216+
) -> Iterable[Tensor]: ...
219217

220218
@torch.no_grad()
221219
def _get_params_or_grads(self, get_grad: bool = False) -> Iterable[Tensor | None]:

0 commit comments

Comments
 (0)