Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pytorch_forecasting/data/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def preprocess(
else:
# convert first to tensor, then transform and then convert to numpy array
if isinstance(y, (pd.Series, pd.DataFrame)):
y = y.to_numpy()
# PyTorch wants writeable arrays
y = y.to_numpy(copy=True)
y = torch.as_tensor(y)
y = self.get_transform(self.transformation)["forward"](y)
y = np.asarray(y)
Expand Down Expand Up @@ -713,7 +714,8 @@ def transform(
if isinstance(y, (pd.Series)):
index = y.index
pandas_dtype = y.dtype
y = y.values
# PyTorch wants writeable arrays
y = y.to_numpy(copy=True)
y_was = "pandas"
y = torch.as_tensor(y)
elif isinstance(y, np.ndarray):
Expand Down
20 changes: 16 additions & 4 deletions pytorch_forecasting/data/timeseries/_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,13 +1456,25 @@ def _to_tensor(cols, long=True, real=False) -> torch.Tensor:
else:
dtypekind = data.dtypes[cols].kind
if real:
return torch.tensor(data[cols].to_numpy(np.float64), dtype=torch.float)
# PyTorch wants writeable arrays
return torch.tensor(
data[cols].to_numpy(np.float64, copy=True), dtype=torch.float
)
elif not long:
return torch.tensor(data[cols].to_numpy(np.int64), dtype=torch.int64)
# PyTorch wants writeable arrays
return torch.tensor(
data[cols].to_numpy(np.int64, copy=True), dtype=torch.int64
)
elif dtypekind in "bi":
return torch.tensor(data[cols].to_numpy(np.int64), dtype=torch.long)
# PyTorch wants writeable arrays
return torch.tensor(
data[cols].to_numpy(np.int64, copy=True), dtype=torch.long
)
else:
return torch.tensor(data[cols].to_numpy(np.float64), dtype=torch.float)
# PyTorch wants writeable arrays
return torch.tensor(
data[cols].to_numpy(np.float64, copy=True), dtype=torch.float
)

index = _to_tensor(self._group_ids, long=False)
time = _to_tensor("__time_idx__", long=False)
Expand Down
27 changes: 19 additions & 8 deletions pytorch_forecasting/data/timeseries/_timeseries_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
if col not in [self.time] + self._group + [self.weight] + self._target
]
if self._group:
self._groups = self.data.groupby(self._group).groups

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.10)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.10)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.10)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.10)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.10)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.10)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.10)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.10)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.11)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.11)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.11)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.11)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.11)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.11)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.11)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.11)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.10)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.10)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.10)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.10)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.11)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.11)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.11)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.11)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.13)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.13)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.13)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 134 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.13)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
self._group_ids = list(self._groups.keys())
else:
self._groups = {"_single_group": self.data.index}
Expand Down Expand Up @@ -236,16 +236,20 @@

cutoff_time = data[time].max()

data_vals = data[time].values
data_tgt_vals = data[_target].values
data_feat_vals = data[feature_cols].values
# PyTorch wants writeable arrays
data_vals = data[time].to_numpy(copy=True)
data_tgt_vals = data[_target].to_numpy(copy=True)
data_feat_vals = data[feature_cols].to_numpy(copy=True)

result = {
"t": data_vals,
"y": torch.tensor(data_tgt_vals),
"x": torch.tensor(data_feat_vals),
"group": torch.tensor([hash(str(group_id))]),
"st": torch.tensor(data[_static].iloc[0].values if _static else []),
# PyTorch wants writeable arrays
"st": torch.tensor(
data[_static].iloc[0].to_numpy(copy=True) if _static else []
),
"cutoff_time": cutoff_time,
}

Expand Down Expand Up @@ -278,7 +282,10 @@
for j, col in enumerate(_known):
if col in feature_cols:
feature_idx = feature_cols.index(col)
x_merged[idx, feature_idx] = future_data[col].values[i]
# PyTorch wants writeable arrays
x_merged[idx, feature_idx] = future_data[col].to_numpy(
copy=True
)[i]

result.update(
{
Expand All @@ -293,17 +300,21 @@
weights_merged = np.full(num_timepoints, np.nan)
for i, t in enumerate(data_vals):
idx = current_time_indices[t]
weights_merged[idx] = data[weight].values[i]
# PyTorch wants writeable arrays
weights_merged[idx] = data[weight].to_numpy(copy=True)[i]

for i, t in enumerate(data_fut_vals):
if t in current_time_indices and self.weight in future_data.columns:
idx = current_time_indices[t]
weights_merged[idx] = future_data[weight].values[i]
# PyTorch wants writeable arrays
weights_merged[idx] = future_data[weight].to_numpy(copy=True)[i]

result["weights"] = torch.tensor(weights_merged, dtype=torch.float32)
else:
result["weights"] = torch.tensor(
data[self.weight].values, dtype=torch.float32
# PyTorch wants writeable arrays
data[self.weight].to_numpy(copy=True),
dtype=torch.float32,
)

return result
Expand Down
54 changes: 54 additions & 0 deletions tests/test_data/test_timeseries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy
import pickle
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -716,3 +717,56 @@ def test_correct_dtype_inference():
x, y = next(iter(dataloader))
# real features must be real
assert x["encoder_cont"].dtype is torch.float


def test_pytorch_unwriteable_data():
"""
-- Ensures that PyTorch doesn't throw a warning on non-writeable
arrays extracted from pandas objects.
This is a weak test, since the warning is only issued once and might
already have been issued.
"""
# save current mode
copy_on_write = pd.options.mode.copy_on_write
pd.options.mode.copy_on_write = True

# Create a small dataset
data = pd.DataFrame(
{
"time_idx": np.arange(30),
"value": np.sin(np.arange(30) / 5) + np.random.normal(scale=0.1, size=30),
"feature": np.cos(np.arange(30) / 5) + np.random.normal(scale=0.1, size=30),
"group": ["A"] * 30,
}
)

with warnings.catch_warnings(record=True) as w:
# catch all warnings
warnings.simplefilter("always")

# Define the dataset
dataset = TimeSeriesDataSet(
data,
time_idx="time_idx",
target="value",
group_ids=["group"],
static_categoricals=["group"],
max_encoder_length=4,
max_prediction_length=2,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["value", "feature"],
target_normalizer=None,
scalers={"feature": StandardScaler()},
)

next(iter(dataset))

# reset original mode
pd.options.mode.copy_on_write = copy_on_write

# Check if the specific warning was triggered
to_catch = "The given NumPy array is not writable, and PyTorch"
to_catch += " does not support non-writable tensors."
for warning in w:
if to_catch in str(warning.message):
assert False, "Non-writable NumPy array passed to torch.as_tensor"
Loading