Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_forecasting/data/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2513,7 +2513,7 @@ def to_dataloader(

Parameters
----------
train : bool, optional, default=Trze
train : bool, optional, default=True
whether dataloader is used for training (True) or prediction (False).
Will shuffle and drop last batch if True. Defaults to True.
batch_size : int, optional, default=64
Expand Down
28 changes: 27 additions & 1 deletion pytorch_forecasting/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,33 @@ def concat_sequences(
if isinstance(sequences[0], rnn.PackedSequence):
return rnn.pack_sequence(sequences, enforce_sorted=False)
elif isinstance(sequences[0], torch.Tensor):
return torch.cat(sequences, dim=1)
if sequences[0].ndim > 1:
first_lens = [xi.shape[1] for xi in sequences]
max_first_len = max(first_lens)
if max_first_len > min(first_lens):
sequences = [
(
xi
if xi.shape[1] == max_first_len
else torch.cat(
[
xi,
torch.full(
(
xi.shape[0],
max_first_len - xi.shape[1],
*xi.shape[2:],
),
float("nan"),
device=xi.device,
),
],
dim=1,
)
)
for xi in sequences
]
return torch.cat(sequences, dim=0)
elif isinstance(sequences[0], (tuple, list)):
return tuple(
concat_sequences([sequences[ii][i] for ii in range(len(sequences))])
Expand Down
Loading