diff --git a/greykite/framework/templates/pickle_utils.py b/greykite/framework/templates/pickle_utils.py index 9b652cf..8fce40b 100644 --- a/greykite/framework/templates/pickle_utils.py +++ b/greykite/framework/templates/pickle_utils.py @@ -263,6 +263,21 @@ def dump_obj( print(f"I Don't recognize type {type(obj)}") +def load_pickle(path, mode): + """Loads the pickled files and closes the file handle. + + Parameters + ---------- + path : `str` + The path to the pickled file. + mode : `str` + The mode of the open function. + """ + with open(path, mode) as file: + data = dill.load(file) + return data + + def load_obj( dir_name, obj=None, @@ -310,7 +325,7 @@ def load_obj( # Gets the type files if any. # Stores in a dictionary with key being the name and value being the loaded value. - obj_types = {file.split(".")[0]: dill.load(open(os.path.join(dir_name, file), "rb")) + obj_types = {file.split(".")[0]: load_pickle(os.path.join(dir_name, file), "rb") for file in files if ".type" in file} # Gets directories and pickled files. @@ -328,7 +343,7 @@ def load_obj( # The only 1 .pkl file case. if len(files) > 1: raise ValueError("Multiple elements found in top level.") - return dill.load(open(os.path.join(dir_name, files[0]), "rb")) + return load_pickle(os.path.join(dir_name, files[0]), "rb") else: # The .type + dir case. if len(obj_types) > 1: @@ -353,7 +368,7 @@ def load_obj( for element in elements: if ".pkl" in element: result.append( - dill.load(open(os.path.join(dir_name, element), "rb"))) + load_pickle(os.path.join(dir_name, element), "rb")) else: result.append( load_obj( @@ -373,8 +388,7 @@ def load_obj( # Iterates through keys and finds the corresponding values. for element in keys: if ".pkl" in element: - key = dill.load( - open(os.path.join(dir_name, element), "rb")) + key = load_pickle(os.path.join(dir_name, element), "rb") else: key = load_obj( os.path.join(dir_name, element), @@ -392,8 +406,7 @@ def load_obj( value_name = value_name if value_name in values else value_name_alt # Gets the value. if ".pkl" in value_name: - value = dill.load( - open(os.path.join(dir_name, value_name), "rb")) + value = load_pickle(os.path.join(dir_name, value_name), "rb") else: value = load_obj( os.path.join(dir_name, value_name), @@ -413,8 +426,7 @@ def load_obj( # Iterates through keys and finds the corresponding values. for element in keys: if ".pkl" in element: - key = dill.load( - open(os.path.join(dir_name, element), "rb")) + key = load_pickle(os.path.join(dir_name, element), "rb") else: key = load_obj( os.path.join(dir_name, element), @@ -432,8 +444,7 @@ def load_obj( value_name = value_name if value_name in values else value_name_alt # Gets the value. if ".pkl" in value_name: - value = dill.load( - open(os.path.join(dir_name, value_name), "rb")) + value = load_pickle(os.path.join(dir_name, value_name), "rb") else: value = load_obj( os.path.join(dir_name, value_name), @@ -453,8 +464,7 @@ def load_obj( values = {} for element in elements: if ".pkl" in element: - values[element.split(".")[0]] = dill.load( - open(os.path.join(dir_name, element), "rb")) + values[element.split(".")[0]] = load_pickle(os.path.join(dir_name, element), "rb") else: values[element] = load_obj( os.path.join(dir_name, element),