Skip to content

Commit e1836d5

Browse files
authored
More robust detection of time series granularity. (#135)
* More robust detection of time series granularity. Previously, we would detect the granularity of a time series as the GCD of all timedeltas found in the time series (assuming pandas couldn't infer the granularity on its own). However, this behavior fails for time series with missing data that are sampled at granularities that aren't an exact number of seconds, e.g. monthly time series would be resampled to a daily granularity because months are of inconsistent length. This commit uses the most commonly observed timedelta, and it also checks whether a k-month granularity is a better fit for the time series than a n-day granularity. * Simplify code. * Introduce more flexible granularities to models. * Retain old behavior when offsets aren't used. * More careful handling of resampling w/ offsets. * Fix off-by-one error. * Add test coverage for resampling. * Fix typo. * Update version matrix on old docs pages.
1 parent f34583c commit e1836d5

File tree

17 files changed

+289
-173
lines changed

17 files changed

+289
-173
lines changed

.github/workflows/tests.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ jobs:
3838
env:
3939
PYTHON_VERSION: ${{ matrix.python-version }}
4040
with:
41-
max_attempts: 3
41+
max_attempts: 1
4242
timeout_minutes: 60
43-
retry-on: error
43+
retry_on: error
4444
command: |
4545
set -euxo pipefail
4646
# Get a comma-separated list of the directories of all python source files

benchmark_forecast.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from merlion.models.forecast.base import ForecasterBase
2828
from merlion.transform.resample import TemporalResample, granularity_str_to_seconds
2929
from merlion.utils import TimeSeries, UnivariateTimeSeries
30-
from merlion.utils.resample import get_gcd_timedelta
30+
from merlion.utils.resample import infer_granularity, to_pd_datetime
3131

3232
from ts_datasets.base import BaseDataset
3333
from ts_datasets.forecast import *
@@ -265,12 +265,7 @@ def train_model(
265265
df = df.resample(dt, closed="right", label="right").mean().interpolate()
266266

267267
vals = TimeSeries.from_pd(df)
268-
# Get time-delta
269-
if not is_multivariate_data:
270-
dt = df.index[1] - df.index[0]
271-
else:
272-
dt = get_gcd_timedelta(vals.time_stamps)
273-
dt = pd.to_timedelta(dt, unit="s")
268+
dt = infer_granularity(vals.time_stamps)
274269

275270
# Get the train/val split
276271
t = trainval.index[np.argmax(~trainval)].value // 1e9
@@ -304,7 +299,11 @@ def train_model(
304299
# loop over horizon conditions
305300
for horizon in horizons:
306301
horizon = granularity_str_to_seconds(horizon)
307-
max_forecast_steps = math.ceil(horizon / dt.total_seconds())
302+
try:
303+
max_forecast_steps = int(math.ceil(horizon / dt.total_seconds()))
304+
except:
305+
window = TimeSeries.from_pd(test_vals.to_pd()[: to_pd_datetime(train_end_timestamp + horizon)])
306+
max_forecast_steps = len(TemporalResample(granularity=dt)(window))
308307
logger.debug(f"horizon is {pd.Timedelta(seconds=horizon)} and max_forecast_steps is {max_forecast_steps}")
309308
if retrain_type == "without_retrain":
310309
retrain_freq = None

docs/build_docs.sh

+11-15
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,6 @@ set -euo pipefail
55
DIRNAME=$(cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
66
cd "${DIRNAME}/.."
77

8-
# Set up virtual environment
9-
pip3 install --upgrade pip setuptools wheel virtualenv
10-
if [ ! -d venv ]; then
11-
rm -f venv
12-
virtualenv venv
13-
fi
14-
source venv/bin/activate
15-
168
# Get current git head & stash unsaved changes
179
GIT_HEAD=$(git rev-parse HEAD)
1810
GIT_BRANCH=$(git branch --show-current)
@@ -37,19 +29,23 @@ function exit_handler {
3729
}
3830
trap exit_handler EXIT
3931

40-
# Install Sphinx requirements. Get old Merlion docs from gh-pages branch, but only keep the version-tagged ones.
32+
# Set up virtual environment & install Sphinx requirements.
33+
pip3 install --upgrade pip setuptools wheel virtualenv
34+
if [ ! -d venv ]; then
35+
rm -f venv
36+
virtualenv venv
37+
fi
38+
source venv/bin/activate
4139
pip3 install -r "${DIRNAME}/requirements.txt"
40+
41+
# Get old Merlion docs from gh-pages branch. Only keep version-tagged ones, and update the version matrix as needed.
4242
git checkout gh-pages && git pull && git checkout --force "${GIT_HEAD}"
4343
sphinx-build -M clean "${DIRNAME}/source" "${DIRNAME}/build"
4444
mkdir -p "${DIRNAME}/build" "${DIRNAME}/build/html"
4545
git --work-tree "${DIRNAME}/build/html" checkout gh-pages . && git reset --hard
46-
python -c \
47-
"import re; import os; import shutil;
48-
for f in [os.path.join('${DIRNAME}/build/html', f) for f in os.listdir('${DIRNAME}/build/html')]:
49-
if not (os.path.isdir(f) and re.search('v([0-9].)+[0-9]$', f)):
50-
shutil.rmtree(f) if os.path.isdir(f) else os.remove(f)"
46+
python3 "${DIRNAME}/process_old_docs.py"
5147

52-
# Install all released versions of Merlion/ts_datasets and use them to build the appropriate API docs.
48+
# Install all released versions of Merlion/ts_datasets _not_ on gh-pages and use them to build the appropriate API docs.
5349
# Uninstall after we're done with each one.
5450
versions=("latest")
5551
for v in $(git tag --list 'v[0-9]*'); do

docs/process_old_docs.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#
2+
# Copyright (c) 2022 salesforce.com, inc.
3+
# All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
#
7+
"""
8+
Script which removes redirects from the HTML API docs & updates the version matrix on old files.
9+
"""
10+
import os
11+
import re
12+
import shutil
13+
14+
from bs4 import BeautifulSoup as bs
15+
from git import Repo
16+
17+
18+
def create_version_dl(soup, prefix, current_version, all_versions):
19+
dl = soup.new_tag("dl")
20+
dt = soup.new_tag("dt")
21+
dt.string = "Versions"
22+
dl.append(dt)
23+
for version in all_versions:
24+
# Create the href for this version & bold it if it's the current version
25+
href = soup.new_tag("a", href=f"{prefix}/{version}/index.html")
26+
href.string = version
27+
if version == current_version:
28+
strong = soup.new_tag("strong")
29+
strong.append(href)
30+
href = strong
31+
# Create a list item & add it to the dl
32+
dd = soup.new_tag("dd")
33+
dd.append(href)
34+
dl.append(dd)
35+
return dl
36+
37+
38+
def main():
39+
# Get all the versions
40+
repo = Repo(search_parent_directories=True)
41+
versions = sorted([tag.name for tag in repo.tags if re.match("v[0-9].*", tag.name)], reverse=True)
42+
versions = ["latest", *versions]
43+
44+
dirname = os.path.join(os.path.dirname(os.path.abspath(__file__)), "build", "html")
45+
for version in os.listdir(dirname):
46+
# If this isn't a directory containing a numbered version's API docs, delete it
47+
version_root = os.path.join(dirname, version)
48+
if version == "latest" or version not in versions:
49+
shutil.rmtree(version_root) if os.path.isdir(version_root) else os.remove(version_root)
50+
continue
51+
52+
# Update version matrix in HTML source versioned files
53+
for subdir, _, files in os.walk(version_root):
54+
html_files = [os.path.join(subdir, f) for f in files if f.endswith(".html")]
55+
56+
# Determine how far the version root is from the files in this directory
57+
prefix = ".."
58+
while subdir and subdir != version_root:
59+
subdir = os.path.dirname(subdir)
60+
prefix += "/.."
61+
62+
# Create the new description list for the version & write the new file
63+
for file in html_files:
64+
with open(file) as f:
65+
soup = bs(f, "html.parser")
66+
version_dl = [dl for dl in soup.find_all("dl") if dl.find("dt", text="Versions")]
67+
if len(version_dl) == 0:
68+
continue
69+
version_dl[0].replace_with(create_version_dl(soup, prefix, version, versions))
70+
with open(file, "w", encoding="utf-8") as f:
71+
f.write(str(soup))
72+
73+
74+
if __name__ == "__main__":
75+
main()

docs/requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
GitPython
2+
beautifulsoup4
23
ipykernel
3-
nbsphinx==0.8.7
4+
nbsphinx
45
pandoc
56
sphinx
67
sphinx_autodoc_typehints

docs/source/conf.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636
# Add any Sphinx extension module names here, as strings. They can be
3737
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
3838
# ones.
39-
extensions = ["nbsphinx", "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx_autodoc_typehints"]
39+
extensions = [
40+
"nbsphinx",
41+
"IPython.sphinxext.ipython_console_highlighting",
42+
"sphinx.ext.autodoc",
43+
"sphinx.ext.autosummary",
44+
"sphinx_autodoc_typehints",
45+
]
4046

4147
autoclass_content = "both" # include both class docstring and __init__
4248
autodoc_default_options = {
@@ -91,3 +97,4 @@
9197
exclude_patterns = ["examples"]
9298
else:
9399
exclude_patterns = ["tutorials"]
100+
exclude_patterns += ["**.ipynb_checkpoints"]

merlion/models/base.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from merlion.transform.factory import TransformFactory
2525
from merlion.transform.normalize import Rescale, MeanVarNormalize
2626
from merlion.transform.sequence import TransformSequence
27-
from merlion.utils.time_series import assert_equal_timedeltas, to_pd_datetime, TimeSeries
27+
from merlion.utils.time_series import assert_equal_timedeltas, to_pd_datetime, infer_granularity, TimeSeries
2828
from merlion.utils.misc import AutodocABCMeta, ModelConfigMeta
2929

3030
logger = logging.getLogger(__name__)
@@ -169,6 +169,7 @@ def __init__(self, config: Config):
169169
self.config = copy.copy(config)
170170
self.last_train_time = None
171171
self.timedelta = None
172+
self.timedelta_offset = pd.to_timedelta(0)
172173
self.train_data = None
173174

174175
def reset(self):
@@ -304,12 +305,10 @@ def train_pre_process(self, train_data: TimeSeries) -> TimeSeries:
304305

305306
# Make sure timestamps are equally spaced if needed (e.g. for ARIMA)
306307
t = train_data.time_stamps
308+
self.timedelta, self.timedelta_offset = infer_granularity(t, return_offset=True)
307309
if self.require_even_sampling:
308-
assert_equal_timedeltas(train_data.univariates[train_data.names[0]])
310+
assert_equal_timedeltas(train_data.univariates[train_data.names[0]], self.timedelta, self.timedelta_offset)
309311
assert train_data.is_aligned
310-
self.timedelta = pd.infer_freq(to_pd_datetime(t))
311-
else:
312-
self.timedelta = t[1] - t[0]
313312
self.last_train_time = t[-1]
314313
return train_data.align() if self.auto_align else train_data
315314

merlion/models/forecast/base.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def resample_time_stamps(self, time_stamps: Union[int, List[int]], time_series_p
114114
)
115115

116116
# Determine timedelta & initial time of forecast
117-
dt = self.timedelta
117+
dt, offset = self.timedelta, self.timedelta_offset
118118
if time_series_prev is not None and not time_series_prev.is_empty():
119119
t0 = to_pd_datetime(time_series_prev.tf)
120120
else:
@@ -124,34 +124,33 @@ def resample_time_stamps(self, time_stamps: Union[int, List[int]], time_series_p
124124
if isinstance(time_stamps, (int, float)):
125125
n = int(time_stamps)
126126
assert self.max_forecast_steps is None or n <= self.max_forecast_steps
127-
resampled = pd.date_range(start=t0, periods=n + 1, freq=dt)[1:]
128-
tf = resampled[-1]
127+
resampled = pd.date_range(start=t0, periods=n + 1, freq=dt) + offset
128+
resampled = resampled[1:] if resampled[0] == t0 else resampled[:-1]
129129
time_stamps = to_timestamp(resampled)
130130

131131
elif not self.require_even_sampling:
132132
resampled = to_pd_datetime(time_stamps)
133-
tf = resampled[-1]
134133

135134
# Handle the cases where we don't have a max_forecast_steps
136135
elif self.max_forecast_steps is None:
137136
tf = to_pd_datetime(time_stamps[-1])
138-
resampled = pd.date_range(start=t0, end=tf, freq=dt)[1:]
139-
if resampled[-1] < tf:
140-
extra = pd.date_range(start=resampled[-1], periods=2, freq=dt)[1:]
141-
resampled = resampled.union(extra)
137+
resampled = pd.date_range(start=t0, end=tf + 2 * dt, freq=dt) + offset
138+
if resampled[0] == t0:
139+
resampled = resampled[1:]
140+
if len(resampled) > 1 and resampled[-2] >= tf:
141+
resampled = resampled[:-1]
142142

143143
# Handle the case where we do have a max_forecast_steps
144144
else:
145-
resampled = pd.date_range(start=t0, periods=self.max_forecast_steps + 1, freq=dt)[1:]
146-
tf = resampled[-1]
147-
n = sum(t < to_pd_datetime(time_stamps[-1]) for t in resampled)
148-
resampled = resampled[: n + 1]
145+
resampled = pd.date_range(start=t0, periods=self.max_forecast_steps + 1, freq=dt) + offset
146+
resampled = resampled[1:] if resampled[0] == t0 else resampled[:-1]
147+
resampled = resampled[: 1 + sum(resampled < to_pd_datetime(time_stamps[-1]))]
149148

149+
tf = resampled[-1]
150150
assert to_pd_datetime(time_stamps[0]) >= t0 and to_pd_datetime(time_stamps[-1]) <= tf, (
151151
f"Expected `time_stamps` to be between {t0} and {tf}, but `time_stamps` ranges "
152152
f"from {to_pd_datetime(time_stamps[0])} to {to_pd_datetime(time_stamps[-1])}"
153153
)
154-
155154
return to_timestamp(resampled).tolist()
156155

157156
def train_pre_process(

merlion/models/forecast/ets.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(self, config: ETSConfig):
8181
super().__init__(config)
8282
self.model = None
8383
self._last_val = None
84+
self._n_train = None
8485

8586
@property
8687
def require_even_sampling(self) -> bool:
@@ -136,12 +137,13 @@ def _train(self, train_data: pd.DataFrame, train_config=None):
136137
name = self.target_name
137138
train_data = train_data[name]
138139
times = train_data.index
139-
self.model = self._instantiate_model(train_data).fit(disp=False)
140+
self.model = self._instantiate_model(pd.Series(train_data.values)).fit(disp=False)
140141

141142
# get forecast for the training data
142143
self._last_val = train_data[-1]
143-
yhat = pd.DataFrame(self.model.fittedvalues.values.tolist(), index=times, columns=[name])
144-
err = pd.DataFrame(self.model.standardized_forecasts_error.tolist(), index=times, columns=[f"{name}_err"])
144+
self._n_train = len(train_data)
145+
yhat = pd.DataFrame(self.model.fittedvalues.values, index=times, columns=[name])
146+
err = pd.DataFrame(self.model.standardized_forecasts_error, index=times, columns=[f"{name}_err"])
145147
return yhat, err
146148

147149
def _forecast(
@@ -152,10 +154,12 @@ def _forecast(
152154
if time_series_prev is None:
153155
last_val = self._last_val
154156
model = self.model
157+
start = self._n_train
155158
else:
156159
time_series_prev = time_series_prev.iloc[:, self.target_seq_index]
157-
val_prev = time_series_prev[-self._max_lookback :]
158-
last_val = val_prev[-1]
160+
val_prev = pd.Series(time_series_prev[-self._max_lookback :].values)
161+
last_val = val_prev.iloc[-1]
162+
start = len(val_prev)
159163

160164
# the default setting of refit=False is fast and conducts exponential smoothing with given parameters,
161165
# while the setting of refit=True is slow and refits the model on time_series_prev.
@@ -165,16 +169,10 @@ def _forecast(
165169
else:
166170
model = model.smooth(params=self.model.params)
167171

168-
# Run forecasting. Some variants of ETS model does not support prediction interval.
169-
# In this case we use point forecasting and set prediction_interval as None.
170-
try:
171-
forecast_result = model.get_prediction(start=time_stamps[0], end=time_stamps[-1])
172-
forecast = np.asarray(forecast_result.predicted_mean)
173-
err = np.sqrt(np.asarray(forecast_result.var_pred_mean))
174-
except (NotImplementedError, AttributeError):
175-
forecast_result = model.predict(start=time_stamps[0], end=time_stamps[-1])
176-
forecast = np.asarray(forecast_result)
177-
err = None
172+
# Run forecasting.
173+
forecast_result = model.get_prediction(start=start, end=start + len(time_stamps) - 1)
174+
forecast = np.asarray(forecast_result.predicted_mean)
175+
err = np.sqrt(np.asarray(forecast_result.var_pred_mean))
178176

179177
# If return_prev is True, return the forecast and error of last train window instead of time_series_prev
180178
if time_series_prev is not None and return_prev:

merlion/models/forecast/prophet.py

-6
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,6 @@ def set_seasonality(self, theta, train_data: UnivariateTimeSeries):
189189
logger.debug(f"Add seasonality {str(p)} ({p * dt})")
190190
self.model.add_seasonality(name=f"extra_season_{p}", period=period, fourier_order=p)
191191

192-
def resample_time_stamps(self, time_stamps: Union[int, List[int]], time_series_prev: TimeSeries = None):
193-
if isinstance(time_stamps, (int, float)):
194-
times = pd.date_range(start=self.last_train_time, freq=self.timedelta, periods=int(time_stamps + 1))[1:]
195-
time_stamps = to_timestamp(times)
196-
return time_stamps
197-
198192
def _add_exog_data(self, data: pd.DataFrame, exog_data: pd.DataFrame):
199193
df = pd.DataFrame(data[self.target_name].rename("y"))
200194
if exog_data is not None:

merlion/models/forecast/smoother.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def update(
293293
)
294294

295295
new_data = TimeSeries.from_pd(new_data).univariates[name]
296-
assert_equal_timedeltas(new_data, self.timedelta)
296+
assert_equal_timedeltas(new_data, self.timedelta, self.timedelta_offset)
297297
next_train_time = self.last_train_time + self.timedelta
298298
if to_pd_datetime(new_data.t0) > next_train_time:
299299
logger.warning(

0 commit comments

Comments
 (0)