|
4 | 4 | And on: https://github.com/CasperHogenboom/WGAN_financial_time-series |
5 | 5 | """ |
6 | 6 | from tqdm import trange |
7 | | -from numpy import array, vstack |
| 7 | +from numpy import array, vstack, hstack |
8 | 8 | from numpy.random import normal |
9 | 9 |
|
10 | | -from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, expand_dims |
| 10 | +from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, tile |
11 | 11 | from tensorflow import data as tfdata |
12 | 12 | from tensorflow.keras import Model, Sequential |
13 | 13 | from tensorflow.keras.optimizers import Adam |
@@ -142,20 +142,33 @@ def get_batch_data(self, data, n_windows= None): |
142 | 142 | .shuffle(buffer_size=n_windows) |
143 | 143 | .batch(self.batch_size).repeat()) |
144 | 144 |
|
145 | | - def sample(self, cond_array, n_samples): |
146 | | - """Provided that cond_array is passed, produce n_samples for each condition vector in cond_array.""" |
147 | | - assert len(cond_array.shape) == 1, "Condition array should be one-dimensional." |
148 | | - assert cond_array.shape[0] == self.cond_dim, \ |
149 | | - f"The condition sequence should have a {self.cond_dim} length." |
150 | | - steps = n_samples // self.batch_size + 1 |
| 145 | + def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24): |
| 146 | + """For a given condition, produce n_samples of length seq_len. |
| 147 | +
|
| 148 | + Args: |
| 149 | + condition (numpy.array): Condition for the generated samples, must have the same length. |
| 150 | + n_samples (int): Minimum number of generated samples (returns always a multiple of batch_size). |
| 151 | + seq_len (int): Length of the generated samples. |
| 152 | +
|
| 153 | + Returns: |
| 154 | + data (numpy.array): An array of data of shape [n_samples, seq_len]""" |
| 155 | + assert len(condition.shape) == 2, "Condition array should be two-dimensional." |
| 156 | + assert condition.shape[1] == self.cond_dim, \ |
| 157 | + f"The condition sequence should have {self.cond_dim} length." |
| 158 | + batches = n_samples // self.batch_size + 1 |
| 159 | + ar_steps = seq_len // self.data_dim + 1 |
151 | 160 | data = [] |
152 | 161 | z_dist = self.get_batch_noise() |
153 | | - cond_seq = expand_dims(convert_to_tensor(cond_array, float32), axis=0) |
154 | | - cond_seq = tile(cond_seq, multiples=[self.batch_size, 1]) |
155 | | - for step in trange(steps, desc=f'Synthetic data generation'): |
156 | | - gen_input = concat([cond_seq, next(z_dist)], axis=1) |
157 | | - records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False))) |
158 | | - data.append(records) |
| 162 | + for batch in trange(batches, desc=f'Synthetic data generation'): |
| 163 | + data_ = [] |
| 164 | + cond_seq = convert_to_tensor(condition, float32) |
| 165 | + gen_input = concat([tile(cond_seq, multiples=[self.batch_size, 1]), next(z_dist)], axis=1) |
| 166 | + for step in range(ar_steps): |
| 167 | + records = self.generator(gen_input, training=False) |
| 168 | + gen_input = concat([records[:, -self.cond_dim:], next(z_dist)], axis=1) |
| 169 | + data_.append(records) |
| 170 | + data_ = hstack(data_)[:, :seq_len] |
| 171 | + data.append(data_) |
159 | 172 | return array(vstack(data)) |
160 | 173 |
|
161 | 174 |
|
|
0 commit comments