diff --git a/pytorch_forecasting/data/encoders.py b/pytorch_forecasting/data/encoders.py index 20c7be098..6586547a7 100644 --- a/pytorch_forecasting/data/encoders.py +++ b/pytorch_forecasting/data/encoders.py @@ -227,7 +227,8 @@ def preprocess( else: # convert first to tensor, then transform and then convert to numpy array if isinstance(y, (pd.Series, pd.DataFrame)): - y = y.to_numpy() + # PyTorch wants writeable arrays + y = y.to_numpy(copy=True) y = torch.as_tensor(y) y = self.get_transform(self.transformation)["forward"](y) y = np.asarray(y) @@ -713,7 +714,8 @@ def transform( if isinstance(y, (pd.Series)): index = y.index pandas_dtype = y.dtype - y = y.values + # PyTorch wants writeable arrays + y = y.to_numpy(copy=True) y_was = "pandas" y = torch.as_tensor(y) elif isinstance(y, np.ndarray): diff --git a/pytorch_forecasting/data/timeseries/_timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py index 7739dcc2a..2270190ac 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -1456,13 +1456,25 @@ def _to_tensor(cols, long=True, real=False) -> torch.Tensor: else: dtypekind = data.dtypes[cols].kind if real: - return torch.tensor(data[cols].to_numpy(np.float64), dtype=torch.float) + # PyTorch wants writeable arrays + return torch.tensor( + data[cols].to_numpy(np.float64, copy=True), dtype=torch.float + ) elif not long: - return torch.tensor(data[cols].to_numpy(np.int64), dtype=torch.int64) + # PyTorch wants writeable arrays + return torch.tensor( + data[cols].to_numpy(np.int64, copy=True), dtype=torch.int64 + ) elif dtypekind in "bi": - return torch.tensor(data[cols].to_numpy(np.int64), dtype=torch.long) + # PyTorch wants writeable arrays + return torch.tensor( + data[cols].to_numpy(np.int64, copy=True), dtype=torch.long + ) else: - return torch.tensor(data[cols].to_numpy(np.float64), dtype=torch.float) + # PyTorch wants writeable arrays + return torch.tensor( + data[cols].to_numpy(np.float64, copy=True), dtype=torch.float + ) index = _to_tensor(self._group_ids, long=False) time = _to_tensor("__time_idx__", long=False) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 9a9e02c25..8853c10f5 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -236,16 +236,20 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]: cutoff_time = data[time].max() - data_vals = data[time].values - data_tgt_vals = data[_target].values - data_feat_vals = data[feature_cols].values + # PyTorch wants writeable arrays + data_vals = data[time].to_numpy(copy=True) + data_tgt_vals = data[_target].to_numpy(copy=True) + data_feat_vals = data[feature_cols].to_numpy(copy=True) result = { "t": data_vals, "y": torch.tensor(data_tgt_vals), "x": torch.tensor(data_feat_vals), "group": torch.tensor([hash(str(group_id))]), - "st": torch.tensor(data[_static].iloc[0].values if _static else []), + # PyTorch wants writeable arrays + "st": torch.tensor( + data[_static].iloc[0].to_numpy(copy=True) if _static else [] + ), "cutoff_time": cutoff_time, } @@ -278,7 +282,10 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]: for j, col in enumerate(_known): if col in feature_cols: feature_idx = feature_cols.index(col) - x_merged[idx, feature_idx] = future_data[col].values[i] + # PyTorch wants writeable arrays + x_merged[idx, feature_idx] = future_data[col].to_numpy( + copy=True + )[i] result.update( { @@ -293,17 +300,21 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]: weights_merged = np.full(num_timepoints, np.nan) for i, t in enumerate(data_vals): idx = current_time_indices[t] - weights_merged[idx] = data[weight].values[i] + # PyTorch wants writeable arrays + weights_merged[idx] = data[weight].to_numpy(copy=True)[i] for i, t in enumerate(data_fut_vals): if t in current_time_indices and self.weight in future_data.columns: idx = current_time_indices[t] - weights_merged[idx] = future_data[weight].values[i] + # PyTorch wants writeable arrays + weights_merged[idx] = future_data[weight].to_numpy(copy=True)[i] result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) else: result["weights"] = torch.tensor( - data[self.weight].values, dtype=torch.float32 + # PyTorch wants writeable arrays + data[self.weight].to_numpy(copy=True), + dtype=torch.float32, ) return result diff --git a/tests/test_data/test_timeseries.py b/tests/test_data/test_timeseries.py index f681c23f1..6d4dd2ff8 100644 --- a/tests/test_data/test_timeseries.py +++ b/tests/test_data/test_timeseries.py @@ -1,5 +1,6 @@ from copy import deepcopy import pickle +import warnings import numpy as np import pandas as pd @@ -716,3 +717,56 @@ def test_correct_dtype_inference(): x, y = next(iter(dataloader)) # real features must be real assert x["encoder_cont"].dtype is torch.float + + +def test_pytorch_unwriteable_data(): + """ + -- Ensures that PyTorch doesn't throw a warning on non-writeable + arrays extracted from pandas objects. + This is a weak test, since the warning is only issued once and might + already have been issued. + """ + # save current mode + copy_on_write = pd.options.mode.copy_on_write + pd.options.mode.copy_on_write = True + + # Create a small dataset + data = pd.DataFrame( + { + "time_idx": np.arange(30), + "value": np.sin(np.arange(30) / 5) + np.random.normal(scale=0.1, size=30), + "feature": np.cos(np.arange(30) / 5) + np.random.normal(scale=0.1, size=30), + "group": ["A"] * 30, + } + ) + + with warnings.catch_warnings(record=True) as w: + # catch all warnings + warnings.simplefilter("always") + + # Define the dataset + dataset = TimeSeriesDataSet( + data, + time_idx="time_idx", + target="value", + group_ids=["group"], + static_categoricals=["group"], + max_encoder_length=4, + max_prediction_length=2, + time_varying_known_reals=["time_idx"], + time_varying_unknown_reals=["value", "feature"], + target_normalizer=None, + scalers={"feature": StandardScaler()}, + ) + + next(iter(dataset)) + + # reset original mode + pd.options.mode.copy_on_write = copy_on_write + + # Check if the specific warning was triggered + to_catch = "The given NumPy array is not writable, and PyTorch" + to_catch += " does not support non-writable tensors." + for warning in w: + if to_catch in str(warning.message): + assert False, "Non-writable NumPy array passed to torch.as_tensor"