Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions greykite/framework/templates/pickle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,21 @@ def dump_obj(
print(f"I Don't recognize type {type(obj)}")


def load_pickle(path, mode):
"""Loads the pickled files and closes the file handle.

Parameters
----------
path : `str`
The path to the pickled file.
mode : `str`
The mode of the open function.
"""
with open(path, mode) as file:
data = dill.load(file)
return data


def load_obj(
dir_name,
obj=None,
Expand Down Expand Up @@ -310,7 +325,7 @@ def load_obj(

# Gets the type files if any.
# Stores in a dictionary with key being the name and value being the loaded value.
obj_types = {file.split(".")[0]: dill.load(open(os.path.join(dir_name, file), "rb"))
obj_types = {file.split(".")[0]: load_pickle(os.path.join(dir_name, file), "rb")
for file in files if ".type" in file}

# Gets directories and pickled files.
Expand All @@ -328,7 +343,7 @@ def load_obj(
# The only 1 .pkl file case.
if len(files) > 1:
raise ValueError("Multiple elements found in top level.")
return dill.load(open(os.path.join(dir_name, files[0]), "rb"))
return load_pickle(os.path.join(dir_name, files[0]), "rb")
else:
# The .type + dir case.
if len(obj_types) > 1:
Expand All @@ -353,7 +368,7 @@ def load_obj(
for element in elements:
if ".pkl" in element:
result.append(
dill.load(open(os.path.join(dir_name, element), "rb")))
load_pickle(os.path.join(dir_name, element), "rb"))
else:
result.append(
load_obj(
Expand All @@ -373,8 +388,7 @@ def load_obj(
# Iterates through keys and finds the corresponding values.
for element in keys:
if ".pkl" in element:
key = dill.load(
open(os.path.join(dir_name, element), "rb"))
key = load_pickle(os.path.join(dir_name, element), "rb")
else:
key = load_obj(
os.path.join(dir_name, element),
Expand All @@ -392,8 +406,7 @@ def load_obj(
value_name = value_name if value_name in values else value_name_alt
# Gets the value.
if ".pkl" in value_name:
value = dill.load(
open(os.path.join(dir_name, value_name), "rb"))
value = load_pickle(os.path.join(dir_name, value_name), "rb")
else:
value = load_obj(
os.path.join(dir_name, value_name),
Expand All @@ -413,8 +426,7 @@ def load_obj(
# Iterates through keys and finds the corresponding values.
for element in keys:
if ".pkl" in element:
key = dill.load(
open(os.path.join(dir_name, element), "rb"))
key = load_pickle(os.path.join(dir_name, element), "rb")
else:
key = load_obj(
os.path.join(dir_name, element),
Expand All @@ -432,8 +444,7 @@ def load_obj(
value_name = value_name if value_name in values else value_name_alt
# Gets the value.
if ".pkl" in value_name:
value = dill.load(
open(os.path.join(dir_name, value_name), "rb"))
value = load_pickle(os.path.join(dir_name, value_name), "rb")
else:
value = load_obj(
os.path.join(dir_name, value_name),
Expand All @@ -453,8 +464,7 @@ def load_obj(
values = {}
for element in elements:
if ".pkl" in element:
values[element.split(".")[0]] = dill.load(
open(os.path.join(dir_name, element), "rb"))
values[element.split(".")[0]] = load_pickle(os.path.join(dir_name, element), "rb")
else:
values[element] = load_obj(
os.path.join(dir_name, element),
Expand Down