Skip to content

Commit a3c9b34

Browse files
author
Francisco Santos
committed
integrate TSDataProcessor, revise sample method
Auto regressive timeseries sampling method revert TS data processor integration
1 parent 2c2b720 commit a3c9b34

File tree

2 files changed

+30
-17
lines changed

2 files changed

+30
-17
lines changed

examples/timeseries/tscwgan_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from numpy import squeeze
1+
from numpy import reshape
22

33
from ydata_synthetic.preprocessing.timeseries import processed_stock
44
from ydata_synthetic.synthesizers.timeseries import TSCWGAN
@@ -51,9 +51,9 @@
5151
#Sampling the data
5252
#Note that the data returned is not inverse processed.
5353
cond_index = 100 # Arbitrary sequence for conditioning
54-
cond_array = squeeze(processed_data[cond_index][:cond_dim], axis=1)
54+
cond_array = reshape(processed_data[cond_index][:cond_dim], (1,-1))
5555

56-
data_sample = synth.sample(cond_array, 1000)
56+
data_sample = synth.sample(cond_array, 1000, 100)
5757

5858
# Inverting the scaling of the synthetic samples
5959
data_sample = inverse_transform(data_sample, scaler)

src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
And on: https://github.com/CasperHogenboom/WGAN_financial_time-series
55
"""
66
from tqdm import trange
7-
from numpy import array, vstack
7+
from numpy import array, vstack, hstack
88
from numpy.random import normal
99

10-
from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, expand_dims
10+
from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, tile
1111
from tensorflow import data as tfdata
1212
from tensorflow.keras import Model, Sequential
1313
from tensorflow.keras.optimizers import Adam
@@ -142,20 +142,33 @@ def get_batch_data(self, data, n_windows= None):
142142
.shuffle(buffer_size=n_windows)
143143
.batch(self.batch_size).repeat())
144144

145-
def sample(self, cond_array, n_samples):
146-
"""Provided that cond_array is passed, produce n_samples for each condition vector in cond_array."""
147-
assert len(cond_array.shape) == 1, "Condition array should be one-dimensional."
148-
assert cond_array.shape[0] == self.cond_dim, \
149-
f"The condition sequence should have a {self.cond_dim} length."
150-
steps = n_samples // self.batch_size + 1
145+
def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24):
146+
"""For a given condition, produce n_samples of length seq_len.
147+
148+
Args:
149+
condition (numpy.array): Condition for the generated samples, must have the same length.
150+
n_samples (int): Minimum number of generated samples (returns always a multiple of batch_size).
151+
seq_len (int): Length of the generated samples.
152+
153+
Returns:
154+
data (numpy.array): An array of data of shape [n_samples, seq_len]"""
155+
assert len(condition.shape) == 2, "Condition array should be two-dimensional."
156+
assert condition.shape[1] == self.cond_dim, \
157+
f"The condition sequence should have {self.cond_dim} length."
158+
batches = n_samples // self.batch_size + 1
159+
ar_steps = seq_len // self.data_dim + 1
151160
data = []
152161
z_dist = self.get_batch_noise()
153-
cond_seq = expand_dims(convert_to_tensor(cond_array, float32), axis=0)
154-
cond_seq = tile(cond_seq, multiples=[self.batch_size, 1])
155-
for step in trange(steps, desc=f'Synthetic data generation'):
156-
gen_input = concat([cond_seq, next(z_dist)], axis=1)
157-
records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False)))
158-
data.append(records)
162+
for batch in trange(batches, desc=f'Synthetic data generation'):
163+
data_ = []
164+
cond_seq = convert_to_tensor(condition, float32)
165+
gen_input = concat([tile(cond_seq, multiples=[self.batch_size, 1]), next(z_dist)], axis=1)
166+
for step in range(ar_steps):
167+
records = self.generator(gen_input, training=False)
168+
gen_input = concat([records[:, -self.cond_dim:], next(z_dist)], axis=1)
169+
data_.append(records)
170+
data_ = hstack(data_)[:, :seq_len]
171+
data.append(data_)
159172
return array(vstack(data))
160173

161174

0 commit comments

Comments
 (0)