diff --git a/.gitignore b/.gitignore index 166fdb12..c6a2d897 100644 --- a/.gitignore +++ b/.gitignore @@ -373,4 +373,4 @@ DerivedData/ # User created VERSION -version.py \ No newline at end of file +version.py diff --git a/examples/timeseries/tscwgan_example.py b/examples/timeseries/tscwgan_example.py new file mode 100644 index 00000000..cb1bf1a7 --- /dev/null +++ b/examples/timeseries/tscwgan_example.py @@ -0,0 +1,59 @@ +from numpy import reshape + +from ydata_synthetic.preprocessing.timeseries import processed_stock +from ydata_synthetic.synthesizers.timeseries import TSCWGAN +from ydata_synthetic.synthesizers import ModelParameters, TrainParameters +from ydata_synthetic.postprocessing.regular.inverse_preprocesser import inverse_transform + +model = TSCWGAN + +#Define the GAN and training parameters +noise_dim = 32 +dim = 128 +seq_len = 48 +cond_dim = 24 +batch_size = 128 + +log_step = 100 +epochs = 300+1 +learning_rate = 5e-4 +beta_1 = 0.5 +beta_2 = 0.9 +models_dir = './cache' +critic_iter = 5 + +# Get transformed data stock - Univariate +data, processed_data, scaler = processed_stock(path='./data/stock_data.csv', seq_len=seq_len, cols = ['Open']) +data_sample = processed_data[0] + +model_parameters = ModelParameters(batch_size=batch_size, + lr=learning_rate, + betas=(beta_1, beta_2), + noise_dim=noise_dim, + n_cols=seq_len, + layers_dim=dim, + condition = cond_dim) + +train_args = TrainParameters(epochs=epochs, + sample_interval=log_step, + critic_iter=critic_iter) + +#Training the TSCWGAN model +synthesizer = model(model_parameters, gradient_penalty_weight=10) +synthesizer.train(processed_data, train_args) + +#Saving the synthesizer to later generate new events +synthesizer.save(path='./tscwgan_stock.pkl') + +#Loading the synthesizer +synth = model.load(path='./tscwgan_stock.pkl') + +#Sampling the data +#Note that the data returned is not inverse processed. +cond_index = 100 # Arbitrary sequence for conditioning +cond_array = reshape(processed_data[cond_index][:cond_dim], (1,-1)) + +data_sample = synth.sample(cond_array, 1000, 100) + +# Inverting the scaling of the synthetic samples +inv_data_sample = inverse_transform(data_sample, scaler) diff --git a/src/ydata_synthetic/postprocessing/regular/inverse_preprocesser.py b/src/ydata_synthetic/postprocessing/regular/inverse_preprocesser.py index 9b9a0b50..b99f4bc3 100644 --- a/src/ydata_synthetic/postprocessing/regular/inverse_preprocesser.py +++ b/src/ydata_synthetic/postprocessing/regular/inverse_preprocesser.py @@ -1,45 +1,46 @@ # Inverts all preprocessing pipelines provided in the preprocessing examples from typing import Union -import pandas as pd +from pandas import DataFrame, concat from sklearn.pipeline import Pipeline from sklearn.compose import ColumnTransformer -from sklearn.preprocessing import PowerTransformer, OneHotEncoder, StandardScaler +from sklearn.preprocessing import PowerTransformer, OneHotEncoder, StandardScaler, MinMaxScaler -def inverse_transform(data: pd.DataFrame, processor: Union[Pipeline, ColumnTransformer, PowerTransformer, OneHotEncoder, StandardScaler]) -> pd.DataFrame: +def inverse_transform(data: DataFrame, processor: Union[Pipeline, ColumnTransformer, PowerTransformer, + OneHotEncoder, StandardScaler, MinMaxScaler]) -> DataFrame: """Inverts data transformations taking place in a standard sklearn processor. Supported processes are sklearn pipelines, column transformers or base estimators like standard scalers. Args: - data (pd.DataFrame): The data object that needs inversion of preprocessing + data (DataFrame): The data object that needs inversion of preprocessing processor (Union[Pipeline, ColumnTransformer, BaseEstimator]): The processor applied on the original data Returns: - inv_data (pd.DataFrame): The data object after inverting preprocessing""" + inv_data (DataFrame): The data object after inverting preprocessing""" inv_data = data.copy() - if isinstance(processor, (PowerTransformer, OneHotEncoder, StandardScaler, Pipeline)): - inv_data = pd.DataFrame(processor.inverse_transform(data), columns=processor.feature_names_in_) + if isinstance(processor, (PowerTransformer, OneHotEncoder, StandardScaler, MinMaxScaler, Pipeline)): + inv_data = DataFrame(processor.inverse_transform(data), columns=processor.feature_names_in_ if hasattr(processor, "feature_names_in") else None) elif isinstance(processor, ColumnTransformer): output_indices = processor.output_indices_ - assert isinstance(data, pd.DataFrame), "The data to be inverted from a ColumnTransformer has to be a Pandas DataFrame." + assert isinstance(data, DataFrame), "The data to be inverted from a ColumnTransformer has to be a Pandas DataFrame." for t_name, t, t_cols in processor.transformers_[::-1]: slice_ = output_indices[t_name] t_indices = list(range(slice_.start, slice_.stop, 1 if slice_.step is None else slice_.step)) if t == 'drop': continue elif t == 'passthrough': - inv_cols = pd.DataFrame(data.iloc[:,t_indices].values, columns = t_cols, index = data.index) + inv_cols = DataFrame(data.iloc[:,t_indices].values, columns = t_cols, index = data.index) inv_col_names = inv_cols.columns else: - inv_cols = pd.DataFrame(t.inverse_transform(data.iloc[:,t_indices].values), columns = t_cols, index = data.index) + inv_cols = DataFrame(t.inverse_transform(data.iloc[:,t_indices].values), columns = t_cols, index = data.index) inv_col_names = inv_cols.columns if set(inv_col_names).issubset(set(inv_data.columns)): inv_data[inv_col_names] = inv_cols[inv_col_names] else: - inv_data = pd.concat([inv_data, inv_cols], axis=1) + inv_data = concat([inv_data, inv_cols], axis=1) else: print('The provided data processor is not supported and cannot be inverted with this method.') return None - return inv_data[processor.feature_names_in_] + return inv_data[processor.feature_names_in_] if hasattr(processor, "feature_names_in") else inv_data diff --git a/src/ydata_synthetic/preprocessing/timeseries/stock.py b/src/ydata_synthetic/preprocessing/timeseries/stock.py index f10367cc..26866adb 100644 --- a/src/ydata_synthetic/preprocessing/timeseries/stock.py +++ b/src/ydata_synthetic/preprocessing/timeseries/stock.py @@ -2,17 +2,30 @@ Get the stock data from Yahoo finance data Data from the period 01 January 2017 - 24 January 2021 """ +from typing import Optional, List + import pandas as pd +from typeguard import typechecked from ydata_synthetic.preprocessing.timeseries.utils import real_data_loading -def transformations(path, seq_len: int): - stock_df = pd.read_csv(path) +@typechecked +def transformations(path, seq_len: int, cols: Optional[List] = None): + """Apply min max scaling and roll windows of a temporal dataset. + + Args: + path(str): path to a csv temporal dataframe + seq_len(int): length of the rolled sequences + cols (Union[str, List]): Column or list of columns to be used""" + if isinstance(cols, list): + stock_df = pd.read_csv(path)[cols] + else: + stock_df = pd.read_csv(path) try: stock_df = stock_df.set_index('Date').sort_index() except: stock_df=stock_df #Data transformations to be applied prior to be used with the synthesizer model - processed_data = real_data_loading(stock_df.values, seq_len=seq_len) + data, processed_data, scaler = real_data_loading(stock_df.values, seq_len=seq_len) - return processed_data + return data, processed_data, scaler diff --git a/src/ydata_synthetic/preprocessing/timeseries/utils.py b/src/ydata_synthetic/preprocessing/timeseries/utils.py index c77c67b2..c8404899 100644 --- a/src/ydata_synthetic/preprocessing/timeseries/utils.py +++ b/src/ydata_synthetic/preprocessing/timeseries/utils.py @@ -4,7 +4,7 @@ import numpy as np from sklearn.preprocessing import MinMaxScaler -# Method implemented here: https://github.com/jsyoon0823/TimeGAN/blob/master/data_loading.py +# Method adapted from here: https://github.com/jsyoon0823/TimeGAN/blob/master/data_loading.py # Originally used in TimeGAN research def real_data_loading(data: np.array, seq_len): """Load and preprocess real-world datasets. @@ -30,7 +30,7 @@ def real_data_loading(data: np.array, seq_len): # Mix the datasets (to make it similar to i.i.d) idx = np.random.permutation(len(temp_data)) - data = [] + processed_data = [] for i in range(len(temp_data)): - data.append(temp_data[idx[i]]) - return data + processed_data.append(temp_data[idx[i]]) + return data, processed_data, scaler diff --git a/src/ydata_synthetic/synthesizers/gan.py b/src/ydata_synthetic/synthesizers/gan.py index 6f7d5684..d1f658cc 100644 --- a/src/ydata_synthetic/synthesizers/gan.py +++ b/src/ydata_synthetic/synthesizers/gan.py @@ -21,10 +21,10 @@ _model_parameters_df = [128, 1e-4, (None, None), 128, 264, None, None, None, 1, None] -_train_parameters = ['cache_prefix', 'label_dim', 'epochs', 'sample_interval', 'labels'] +_train_parameters = ['cache_prefix', 'label_dim', 'epochs', 'sample_interval', 'labels', 'critic_iter'] ModelParameters = namedtuple('ModelParameters', _model_parameters, defaults=_model_parameters_df) -TrainParameters = namedtuple('TrainParameters', _train_parameters, defaults=('', None, 300, 50, None)) +TrainParameters = namedtuple('TrainParameters', _train_parameters, defaults=('', None, 300, 50, None, None)) # pylint: disable=R0902 diff --git a/src/ydata_synthetic/synthesizers/timeseries/__init__.py b/src/ydata_synthetic/synthesizers/timeseries/__init__.py index a3523536..3984a68b 100644 --- a/src/ydata_synthetic/synthesizers/timeseries/__init__.py +++ b/src/ydata_synthetic/synthesizers/timeseries/__init__.py @@ -1,5 +1,7 @@ from ydata_synthetic.synthesizers.timeseries.timegan.model import TimeGAN +from ydata_synthetic.synthesizers.timeseries.tscwgan.model import TSCWGAN __all__ = [ 'TimeGAN', + 'TSCWGAN', ] diff --git a/src/ydata_synthetic/synthesizers/timeseries/tscwgan/__init__.py b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py new file mode 100644 index 00000000..ee90c927 --- /dev/null +++ b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py @@ -0,0 +1,257 @@ +""" +Conditional time-series Wasserstein GAN. +Based on: https://www.naun.org/main/NAUN/neural/2020/a082016-004(2020).pdf +And on: https://github.com/CasperHogenboom/WGAN_financial_time-series +""" +from tqdm import trange +from numpy import array, vstack, hstack +from numpy.random import normal + +from tensorflow import concat, float32, convert_to_tensor, GradientTape, reduce_mean, tile, squeeze +from tensorflow import data as tfdata +from tensorflow.keras import Model, Sequential +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.layers import Input, Conv1D, Dense, LeakyReLU, Flatten, Add + +from ydata_synthetic.synthesizers.gan import BaseModel +from ydata_synthetic.synthesizers import TrainParameters +from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty + +class TSCWGAN(BaseModel): + + __MODEL__='TSCWGAN' + + def __init__(self, model_parameters, gradient_penalty_weight=10): + """Create a base TSCWGAN.""" + self.gradient_penalty_weight = gradient_penalty_weight + self.cond_dim = model_parameters.condition + super().__init__(model_parameters) + self.data_dim = model_parameters.n_cols + + def define_gan(self): + self.generator = Generator(self.batch_size). \ + build_model(input_shape=(self.noise_dim + self.cond_dim, 1), dim=self.layers_dim, data_dim=self.data_dim) + self.critic = Critic(self.batch_size). \ + build_model(input_shape=(self.data_dim + self.cond_dim, 1), dim=self.layers_dim) + + self.g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2) + self.c_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2) + + # The generator takes noise as input and generates records + noise = Input(shape=self.noise_dim, batch_size=self.batch_size) + cond = Input(shape=self.cond_dim, batch_size=self.batch_size) + gen = concat([cond, noise], axis=1) + gen = self.generator(gen) + score = concat([cond, gen], axis=1) + score = self.critic(score) + + def train(self, data, train_arguments: TrainParameters): + self.define_gan() + real_batches = self.get_batch_data(data) + noise_batches = self.get_batch_noise() + + for epoch in trange(train_arguments.epochs): + for i in range(train_arguments.critic_iter): + real_batch = next(real_batches) + noise_batch = next(noise_batches)[:len(real_batch)] # Truncate the noise tensor in the shape of the real data tensor + + c_loss = self.update_critic(real_batch, noise_batch) + + real_batch = next(real_batches) + noise_batch = next(noise_batches)[:len(real_batch)] + + g_loss = self.update_generator(real_batch, noise_batch) + + print(f"Epoch: {epoch} | critic_loss: {c_loss} | gen_loss: {g_loss}") + + self.g_optimizer = self.g_optimizer.get_config() + self.c_optimizer = self.c_optimizer.get_config() + + def update_critic(self, real_batch, noise_batch): + with GradientTape() as c_tape: + fake_batch, cond_batch = self._make_fake_batch(real_batch, noise_batch) + + # Real and fake records with conditions + real_batch_ = concat([cond_batch, real_batch], axis=1) + fake_batch_ = concat([cond_batch, fake_batch], axis=1) + + c_loss = self.c_lossfn(real_batch_, fake_batch_) + + c_gradient = c_tape.gradient(c_loss, self.critic.trainable_variables) + + # Update the weights of the critic using the optimizer + self.c_optimizer.apply_gradients( + zip(c_gradient, self.critic.trainable_variables) + ) + return c_loss + + def update_generator(self, real_batch, noise_batch): + with GradientTape() as g_tape: + fake_batch, cond_batch = self._make_fake_batch(real_batch, noise_batch) + + # Fake records with conditions + fake_batch_ = concat([cond_batch, fake_batch], axis=1) + + g_loss = self.g_lossfn(fake_batch_) + + g_gradient = g_tape.gradient(g_loss, self.generator.trainable_variables) + + # Update the weights of the generator using the optimizer + self.g_optimizer.apply_gradients( + zip(g_gradient, self.generator.trainable_variables) + ) + return g_loss + + def c_lossfn(self, real_batch_, fake_batch_): + score_fake = self.critic(fake_batch_) + score_real = self.critic(real_batch_) + grad_penalty = self.gradient_penalty(real_batch_, fake_batch_) + c_loss = reduce_mean(score_fake) - reduce_mean(score_real) + grad_penalty + return c_loss + + def g_lossfn(self, fake_batch_): + score_fake = self.critic(fake_batch_) + g_loss = - reduce_mean(score_fake) + return g_loss + + def _make_fake_batch(self, real_batch, noise_batch): + """Generate a batch of fake records and return it with the batch of used conditions. + Conditions are the first elements of records in the real batch.""" + cond_batch = real_batch[:, :self.cond_dim] + gen_input = concat([cond_batch, noise_batch], axis=1) + return self.generator(gen_input, training=True), cond_batch + + def gradient_penalty(self, real, fake): + gp = gradient_penalty(self.critic, real, fake, mode=Mode.DRAGAN) + return gp + + def _generate_noise(self): + "Gaussian noise for the generator input." + while True: + yield normal(size=self.noise_dim) + + def get_batch_noise(self): + "Create a batch iterator for the generator gaussian noise input." + return iter(tfdata.Dataset.from_generator(self._generate_noise, output_types=float32) + .batch(self.batch_size) + .repeat()) + + def get_batch_data(self, data, n_windows= None): + if not n_windows: + n_windows = len(data) + data = squeeze(convert_to_tensor(data, dtype=float32)) + return iter(tfdata.Dataset.from_tensor_slices(data) + .shuffle(buffer_size=n_windows) + .batch(self.batch_size).repeat()) + + def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24): + """For a given condition, produce n_samples of length seq_len. + + Args: + condition (numpy.array): Condition for the generated samples, must have the same length. + n_samples (int): Minimum number of generated samples (returns always a multiple of batch_size). + seq_len (int): Length of the generated samples. + + Returns: + data (numpy.array): An array of data of shape [n_samples, seq_len]""" + assert len(condition.shape) == 2, "Condition array should be two-dimensional." + assert condition.shape[1] == self.cond_dim, \ + f"The condition sequence should have {self.cond_dim} length." + batches = n_samples // self.batch_size + 1 + ar_steps = seq_len // self.data_dim + 1 + data = [] + z_dist = self.get_batch_noise() + for batch in trange(batches, desc=f'Synthetic data generation'): + data_ = [] + cond_seq = convert_to_tensor(condition, float32) + gen_input = concat([tile(cond_seq, multiples=[self.batch_size, 1]), next(z_dist)], axis=1) + for step in range(ar_steps): + records = self.generator(gen_input, training=False) + gen_input = concat([records[:, -self.cond_dim:], next(z_dist)], axis=1) + data_.append(records) + data_ = hstack(data_)[:, :seq_len] + data.append(data_) + return array(vstack(data)) + + +class Generator(Model): + """Conditional generator with skip connections.""" + def __init__(self, batch_size): + self.batch_size = batch_size + + def build_model(self, input_shape, dim, data_dim): + # Define input - Expected input shape is (batch_size, seq_len, noise_dim). noise_dim = Z + cond + noise_input = Input(shape = input_shape, batch_size = self.batch_size) + + # Compose model + proc_input = Sequential(layers=[ + Conv1D(filters=dim, kernel_size=1, input_shape = input_shape), + LeakyReLU(), + Conv1D(dim, kernel_size=5, dilation_rate=2, padding="same"), + LeakyReLU() + ], name='input_to_latent')(noise_input) + + block_cnn = Sequential(layers=[ + Conv1D(filters=dim, kernel_size=3, dilation_rate=2, padding="same"), + LeakyReLU() + ], name='block_cnn') + for i in range(3): + if i == 0: + cnn_block_i = proc_input + cnn_block_o = block_cnn(proc_input) + else: + cnn_block_o = block_cnn(cnn_block_i) + cnn_block_i = Add()([cnn_block_i, cnn_block_o]) + + shift = Sequential(layers=[ + Conv1D(filters=10, kernel_size=3, dilation_rate=2, padding="same"), + LeakyReLU(), + Flatten(), + Dense(dim*2), + LeakyReLU() + ], name='block_shift')(cnn_block_i) + + block = Sequential(layers=[ + Dense(dim*2), + LeakyReLU() + ], name='block') + for i in range(3): + if i == 0: + block_i = shift + block_o = block(shift) + else: + block_o = block(block_i) + block_i = Add()([block_i, block_o]) + + output = Dense(data_dim, name='latent_to_ouput')(block_i) + return Model(inputs = noise_input, outputs = output, name='SkipConnectionGenerator') + +class Critic(Model): + """Conditional Wasserstein Critic with skip connections.""" + def __init__(self, batch_size): + self.batch_size = batch_size + + def build_model(self, input_shape, dim): + # Define input - Expected input shape is X + condition + record_input = Input(shape = input_shape, batch_size = self.batch_size) + + # Compose model + proc_record = Sequential(layers=[ + Dense(dim*2,), + LeakyReLU() + ], name='ts_to_latent')(record_input) + + block = Sequential(layers=[ + Dense(dim*2), + LeakyReLU() + ], name='block') + for i in range(7): + if i == 0: + block_i = proc_record + block_o = block(proc_record) + else: + block_o = block(block_i) + block_i = Add()([block_i, block_o]) + + output = Dense(1, name = 'latent_to_score')(block_i) + return Model(inputs=record_input, outputs=output, name='SkipConnectionCritic')