Skip to content

Commit 7fc96a9

Browse files
committed
set weights_only parameter of torch.load to False
- #48
1 parent e17a9c0 commit 7fc96a9

16 files changed

+68
-29
lines changed

chebai/models/electra.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ def __init__(
256256
# Load pretrained checkpoint if provided
257257
if pretrained_checkpoint:
258258
with open(pretrained_checkpoint, "rb") as fin:
259-
model_dict = torch.load(fin, map_location=self.device)
259+
model_dict = torch.load(
260+
fin, map_location=self.device, weights_only=False
261+
)
260262
if load_prefix:
261263
state_dict = filter_dict(model_dict["state_dict"], load_prefix)
262264
else:
@@ -414,7 +416,9 @@ def __init__(self, cone_dimensions=20, **kwargs):
414416
model_prefix = kwargs.get("load_prefix", None)
415417
if pretrained_checkpoint:
416418
with open(pretrained_checkpoint, "rb") as fin:
417-
model_dict = torch.load(fin, map_location=self.device)
419+
model_dict = torch.load(
420+
fin, map_location=self.device, weights_only=False
421+
)
418422
if model_prefix:
419423
state_dict = {
420424
str(k)[len(model_prefix) :]: v

chebai/preprocessing/datasets/base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ def load_processed_data(
200200
filename = self.processed_file_names_dict[kind]
201201
except NotImplementedError:
202202
filename = f"{kind}.pt"
203-
return torch.load(os.path.join(self.processed_dir, filename))
203+
return torch.load(
204+
os.path.join(self.processed_dir, filename), weights_only=False
205+
)
204206

205207
def dataloader(self, kind: str, **kwargs) -> DataLoader:
206208
"""
@@ -519,7 +521,7 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:
519521
DataLoader: DataLoader object for the specified subset.
520522
"""
521523
subdatasets = [
522-
torch.load(os.path.join(s.processed_dir, f"{kind}.pt"))
524+
torch.load(os.path.join(s.processed_dir, f"{kind}.pt"), weights_only=False)
523525
for s in self.subsets
524526
]
525527
dataset = [
@@ -1022,7 +1024,9 @@ def _retrieve_splits_from_csv(self) -> None:
10221024
splits_df = pd.read_csv(self.splits_file_path)
10231025

10241026
filename = self.processed_file_names_dict["data"]
1025-
data = torch.load(os.path.join(self.processed_dir, filename))
1027+
data = torch.load(
1028+
os.path.join(self.processed_dir, filename), weights_only=False
1029+
)
10261030
df_data = pd.DataFrame(data)
10271031

10281032
train_ids = splits_df[splits_df["split"] == "train"]["id"]
@@ -1081,7 +1085,9 @@ def load_processed_data(
10811085

10821086
# If filename is provided
10831087
try:
1084-
return torch.load(os.path.join(self.processed_dir, filename))
1088+
return torch.load(
1089+
os.path.join(self.processed_dir, filename), weights_only=False
1090+
)
10851091
except FileNotFoundError:
10861092
raise FileNotFoundError(f"File {filename} doesn't exist")
10871093

chebai/preprocessing/datasets/chebi.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,9 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
407407
"""
408408
try:
409409
filename = self.processed_file_names_dict["data"]
410-
data_chebi_version = torch.load(os.path.join(self.processed_dir, filename))
410+
data_chebi_version = torch.load(
411+
os.path.join(self.processed_dir, filename), weights_only=False
412+
)
411413
except FileNotFoundError:
412414
raise FileNotFoundError(
413415
f"File data.pt doesn't exists. "
@@ -428,7 +430,8 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
428430
data_chebi_train_version = torch.load(
429431
os.path.join(
430432
self._chebi_version_train_obj.processed_dir, filename_train
431-
)
433+
),
434+
weights_only=False,
432435
)
433436
except FileNotFoundError:
434437
raise FileNotFoundError(

chebai/preprocessing/datasets/go_uniprot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,9 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
508508
"""
509509
try:
510510
filename = self.processed_file_names_dict["data"]
511-
data_go = torch.load(os.path.join(self.processed_dir, filename))
511+
data_go = torch.load(
512+
os.path.join(self.processed_dir, filename), weights_only=False
513+
)
512514
except FileNotFoundError:
513515
raise FileNotFoundError(
514516
f"File data.pt doesn't exists. "

chebai/preprocessing/datasets/pubchem.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -891,10 +891,10 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:
891891
DataLoader: DataLoader instance.
892892
"""
893893
labeled_data = torch.load(
894-
os.path.join(self.labeled.processed_dir, f"{kind}.pt")
894+
os.path.join(self.labeled.processed_dir, f"{kind}.pt"), weights_only=False
895895
)
896896
unlabeled_data = torch.load(
897-
os.path.join(self.unlabeled.processed_dir, f"{kind}.pt")
897+
os.path.join(self.unlabeled.processed_dir, f"{kind}.pt"), weights_only=False
898898
)
899899
if self.data_limit is not None:
900900
labeled_data = labeled_data[: self.data_limit]

chebai/preprocessing/migration/chebi_data_migration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def _combine_pt_splits(
168168
df_list: List[pd.DataFrame] = []
169169
for split, file_name in old_splits_file_names.items():
170170
file_path = os.path.join(old_dir, file_name)
171-
file_df = pd.DataFrame(torch.load(file_path))
171+
file_df = pd.DataFrame(torch.load(file_path, weights_only=False))
172172
df_list.append(file_df)
173173

174174
return pd.concat(df_list, ignore_index=True)

chebai/result/analyse_sem.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,9 @@ def run_all(
427427
os.path.join(buffer_dir_smoothed, "preds000.pt")
428428
):
429429
preds = torch.load(
430-
os.path.join(buffer_dir_smoothed, "preds000.pt"), DEVICE
430+
os.path.join(buffer_dir_smoothed, "preds000.pt"),
431+
DEVICE,
432+
weights_only=False,
431433
)
432434
labels = None
433435
else:

chebai/result/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _generate_predictions(self, data_path, raw=False, **kwargs):
5454
else:
5555
data_tuples = [
5656
(x.get("raw_features", x["ident"]), x["ident"], x)
57-
for x in torch.load(data_path)
57+
for x in torch.load(data_path, weights_only=False)
5858
]
5959

6060
for raw_features, ident, row in tqdm.tqdm(data_tuples):

chebai/result/pretraining.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def evaluate_model(logs_base_path, model_filename, data_module):
3434
collate = data_module.reader.COLLATOR()
3535
test_file = "test.pt"
3636
data_path = os.path.join(data_module.processed_dir, test_file)
37-
data_list = torch.load(data_path)
37+
data_list = torch.load(data_path, weights_only=False)
3838
preds_list = []
3939
labels_list = []
4040

chebai/result/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def load_results_from_buffer(
182182
torch.load(
183183
os.path.join(buffer_dir, filename),
184184
map_location=torch.device(device),
185+
weights_only=False,
185186
)
186187
)
187188
i += 1
@@ -194,6 +195,7 @@ def load_results_from_buffer(
194195
torch.load(
195196
os.path.join(buffer_dir, filename),
196197
map_location=torch.device(device),
198+
weights_only=False,
197199
)
198200
)
199201
i += 1

0 commit comments

Comments
 (0)