Skip to content

Commit ac935c7

Browse files
assumption of torch.initial_seed function accepting seed arg in DeepSpeedAccelerator abstract class is incorrect (#5569)
pytorch API reference - https://pytorch.org/docs/stable/generated/torch.initial_seed.html fix return value of manual_seed api for hpu --------- Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent b6e24ad commit ac935c7

7 files changed

+13
-13
lines changed

accelerator/abstract_accelerator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def manual_seed_all(self, seed):
8181
...
8282

8383
@abc.abstractmethod
84-
def initial_seed(self, seed):
84+
def initial_seed(self):
8585
...
8686

8787
@abc.abstractmethod

accelerator/cpu_accelerator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def manual_seed(self, seed):
100100
def manual_seed_all(self, seed):
101101
return torch.manual_seed(seed)
102102

103-
def initial_seed(self, seed):
104-
return torch.initial_seed(seed)
103+
def initial_seed(self):
104+
return torch.initial_seed()
105105

106106
def default_generator(self, device_index):
107107
return torch.default_generator

accelerator/cuda_accelerator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def manual_seed(self, seed):
9999
def manual_seed_all(self, seed):
100100
return torch.cuda.manual_seed_all(seed)
101101

102-
def initial_seed(self, seed):
103-
return torch.cuda.initial_seed(seed)
102+
def initial_seed(self):
103+
return torch.cuda.initial_seed()
104104

105105
def default_generator(self, device_index):
106106
return torch.cuda.default_generators[device_index]

accelerator/hpu_accelerator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ def get_rng_state(self, device_index=None):
7474
return self.hpu.random.get_rng_state()
7575

7676
def manual_seed(self, seed):
77-
self.hpu.random.manual_seed(seed)
77+
return self.hpu.random.manual_seed(seed)
7878

7979
def manual_seed_all(self, seed):
8080
self.hpu.random.manual_seed_all(seed)
8181

82-
def initial_seed(self, seed):
83-
self.hpu.random.initial_seed(seed)
82+
def initial_seed(self):
83+
return self.hpu.random.initial_seed()
8484

8585
def default_generator(self, device_index):
8686
return self.hpu.random.default_generators[device_index]

accelerator/mps_accelerator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def manual_seed_all(self, seed):
7777
def seed(self):
7878
return torch.mps.seed()
7979

80-
def initial_seed(self, seed):
80+
def initial_seed(self):
8181
return
8282

8383
def default_generator(self, device_index):

accelerator/npu_accelerator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def manual_seed(self, seed):
8484
def manual_seed_all(self, seed):
8585
return torch.npu.manual_seed_all(seed)
8686

87-
def initial_seed(self, seed):
88-
return torch.npu.initial_seed(seed)
87+
def initial_seed(self):
88+
return torch.npu.initial_seed()
8989

9090
def default_generator(self, device_index):
9191
return torch.npu.default_generators[device_index]

accelerator/xpu_accelerator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def manual_seed(self, seed):
7474
def manual_seed_all(self, seed):
7575
return torch.xpu.manual_seed_all(seed)
7676

77-
def initial_seed(self, seed):
78-
return torch.xpu.initial_seed(seed)
77+
def initial_seed(self):
78+
return torch.xpu.initial_seed()
7979

8080
def default_generator(self, device_index):
8181
return torch.xpu.default_generators[device_index]

0 commit comments

Comments
 (0)