Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions greykite/framework/templates/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# original author: Albert Chen
# updated: McKenzie Quinn August 2021
"""Main entry point to create a forecast.
Generates a forecast from input data and config and stores the result.
"""
Expand All @@ -45,7 +46,8 @@
from greykite.framework.templates.simple_silverkite_template import SimpleSilverkiteTemplate
from greykite.framework.templates.template_interface import TemplateInterface
from greykite.sklearn.estimator.one_by_one_estimator import OneByOneEstimator

from greykite.framework.templates.gcp_utils import dump_obj_cloud
from greykite.framework.templates.gcp_utils import load_obj_cloud

class Forecaster:
"""The main entry point to creates a forecast.
Expand Down Expand Up @@ -362,6 +364,7 @@ def run_forecast_json(
def dump_forecast_result(
self,
destination_dir,
bucket_name=None,
object_name="object",
dump_design_info=True,
overwrite_exist_dir=False):
Expand All @@ -387,17 +390,29 @@ def dump_forecast_result(
"""
if self.forecast_result is None:
raise ValueError("self.forecast_result is None, nothing to dump.")
dump_obj(
obj=self.forecast_result,
dir_name=destination_dir,
obj_name=object_name,
dump_design_info=dump_design_info,
overwrite_exist_dir=overwrite_exist_dir
)
if not bucket_name:
dump_obj(
obj=self.forecast_result,
dir_name=destination_dir,
obj_name=object_name,
dump_design_info=dump_design_info,
overwrite_exist_dir=overwrite_exist_dir
)
else:
dump_obj_cloud(
obj=self.forecast_result,
dir_name=destination_dir,
bucket_name = bucket_name,
obj_name=object_name,
dump_design_info=dump_design_info,
overwrite_exist_dir=overwrite_exist_dir)



def load_forecast_result(
self,
source_dir,
bucket_name=None,
load_design_info=True):
"""Loads ``self.forecast_result`` from local files created by ``self.dump_result``.

Expand All @@ -412,8 +427,16 @@ def load_forecast_result(
"""
if self.forecast_result is not None:
raise ValueError("self.forecast_result is not None, please create a new instance.")
self.forecast_result = load_obj(
dir_name=source_dir,
obj=None,
load_design_info=load_design_info
)
if not bukcet_name:
self.forecast_result = load_obj(
dir_name=source_dir,
obj=None,
load_design_info=load_design_info
)
else:
self.forecast_result = load_obj_cloud(
dir_name=source_dir,
bucket_name = bucket_name,
obj=None,
load_design_info=load_design_info
)
Loading