|
| 1 | +Autoencoders and their implementations in TensorFlow |
| 2 | +---------------------------------------------------- |
| 3 | + |
| 4 | +In this post, you will learn the concept behind Autoencoders as well how |
| 5 | +to implement an autoencoder in TensorFlow. |
| 6 | + |
| 7 | +Introduction |
| 8 | +------------ |
| 9 | + |
| 10 | +Autoencoders are a type of neural networks which copy its input to its |
| 11 | +output. They usually consist of two main parts, namely Encoder and |
| 12 | +Decoder. The encoder map the input into a hidden layer space which we |
| 13 | +refer to as a code. The decoder then reconstructs the input from the |
| 14 | +code. There are different types of Autoencoders: |
| 15 | + |
| 16 | +- **Undercomplete Autoencoders:** An autoencoder whose code |
| 17 | + dimension is less than the input dimension. Learning such an |
| 18 | + autoencoder forces it to capture the most salient features. |
| 19 | + However, using a big encoder and decoder in the lack of enough |
| 20 | + training data allows the network to memorized the task and omits |
| 21 | + learning useful features. In case of having linear decoder, it can |
| 22 | + act as PCA. However, adding nonlinear activation functions to the |
| 23 | + network makes it a nonlinear generalization of PCA. |
| 24 | +- **Regularized Autoencoders:** Rather than limiting the size of |
| 25 | + autoencoder and the code dimension for the sake of feature |
| 26 | + learning, we can add a loss function to prevent it memorizing the |
| 27 | + task and the training data. |
| 28 | + - **Sparse Autoencoders:** An autoencoder which has a sparsity |
| 29 | + penalty in the training loss in addition to the |
| 30 | + reconstruction error. They usually being used for the |
| 31 | + porpuse of other tasks such as classification. The loss is |
| 32 | + not as straightforward as other regularizers, and we will |
| 33 | + discuss it in another post later. |
| 34 | + - **Denoising Autoencoders (DAE):** The input of a DAE is a |
| 35 | + corrupted copy of the real input which is supposed to be |
| 36 | + reconstructed. Therefore, a DAE has to undo the corruption |
| 37 | + (noise) as well as reconstruction. |
| 38 | + - **Contractive Autoencoders (CAE):** The main idea behind |
| 39 | + these type of autoencoders is to learn a representation of |
| 40 | + the data which is robust to small changes in the input. |
| 41 | +- **Variational Autoencoders:** They maximize the probability of the |
| 42 | + training data instead of copying the input to the output and |
| 43 | + therefore does not need regularization to capture useful |
| 44 | + information. |
| 45 | + |
| 46 | +In this post, we are going to create a simple Undercomplete Autoencoder |
| 47 | +in TensorFlow to learn a low dimension representation (code) of the |
| 48 | +MNIST dataset. |
| 49 | + |
| 50 | +Create an Undercomplete Autoencoder |
| 51 | +----------------------------------- |
| 52 | + |
| 53 | +We are going to create an autoencoder with a 3-layer encoder and 3-layer |
| 54 | +decoder. Each layer of encoder downsamples its input along the spatial |
| 55 | +dimensions (width, height) by a factor of two using a stride 2. |
| 56 | +Consequently, the dimension of the code is 2(width) X 2(height) X |
| 57 | +8(depth) = 32 (for an image of 32X32). Similarly, each layer of the |
| 58 | +decoder upsamples its input by a factor of two (using transpose |
| 59 | +convolution with stride 2). |
| 60 | + |
| 61 | +.. code-block:: python |
| 62 | +
|
| 63 | + import tensorflow.contrib.layers as lays |
| 64 | +
|
| 65 | + def autoencoder(inputs): |
| 66 | + # encoder |
| 67 | + # 32 file code blockx 32 x 1 -> 16 x 16 x 32 |
| 68 | + # 16 x 16 x 32 -> 8 x 8 x 16 |
| 69 | + # 8 x 8 x 16 -> 2 x 2 x 8 |
| 70 | + net = lays.conv2d(inputs, 32, [5, 5], stride=2, padding='SAME') |
| 71 | + net = lays.conv2d(net, 16, [5, 5], stride=2, padding='SAME') |
| 72 | + net = lays.conv2d(net, 8, [5, 5], stride=4, padding='SAME') |
| 73 | + # decoder |
| 74 | + # 2 x 2 x 8 -> 8 x 8 x 16 |
| 75 | + # 8 x 8 x 16 -> 16 x 16 x 32 |
| 76 | + # 16 x 16 x 32 -> 32 x 32 x 1 |
| 77 | + net = lays.conv2d_transpose(net, 16, [5, 5], stride=4, padding='SAME') |
| 78 | + net = lays.conv2d_transpose(net, 32, [5, 5], stride=2, padding='SAME') |
| 79 | + net = lays.conv2d_transpose(net, 1, [5, 5], stride=2, padding='SAME', activation_fn=tf.nn.tanh) |
| 80 | + return net |
| 81 | +
|
| 82 | +.. figure:: _img/ae.png |
| 83 | + :scale: 50 % |
| 84 | + :align: center |
| 85 | + |
| 86 | + **Figure 1:** Autoencoder |
| 87 | + |
| 88 | +The MNIST dataset contains vectorized images of 28X28. Therefore we |
| 89 | +define a new function to reshape each batch of MNIST images to 28X28 and |
| 90 | +then resize to 32X32. The reason of resizing to 32X32 is to make it a |
| 91 | +power of two and therefore we can easily use the stride of 2 for |
| 92 | +downsampling and upsampling. |
| 93 | + |
| 94 | +.. code-block:: python |
| 95 | +
|
| 96 | + import numpy as np |
| 97 | + from skimage import transform |
| 98 | +
|
| 99 | + def resize_batch(imgs): |
| 100 | + # A function to resize a batch of MNIST images to (32, 32) |
| 101 | + # Args: |
| 102 | + # imgs: a numpy array of size [batch_size, 28 X 28]. |
| 103 | + # Returns: |
| 104 | + # a numpy array of size [batch_size, 32, 32]. |
| 105 | + imgs = imgs.reshape((-1, 28, 28, 1)) |
| 106 | + resized_imgs = np.zeros((imgs.shape[0], 32, 32, 1)) |
| 107 | + for i in range(imgs.shape[0]): |
| 108 | + resized_imgs[i, ..., 0] = transform.resize(imgs[i, ..., 0], (32, 32)) |
| 109 | + return resized_imgs |
| 110 | +
|
| 111 | +Now we create an autoencoder, define a square error loss and an |
| 112 | +optimizer. |
| 113 | + |
| 114 | + |
| 115 | +.. code-block:: python |
| 116 | +
|
| 117 | + import tensorflow as tf |
| 118 | +
|
| 119 | + ae_inputs = tf.placeholder(tf.float32, (None, 32, 32, 1)) # input to the network (MNIST images) |
| 120 | + ae_outputs = autoencoder(ae_inputs) # create the Autoencoder network |
| 121 | +
|
| 122 | + # calculate the loss and optimize the network |
| 123 | + loss = tf.reduce_mean(tf.square(ae_outputs - ae_inputs)) # claculate the mean square error loss |
| 124 | + train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss) |
| 125 | +
|
| 126 | + # initialize the network |
| 127 | + init = tf.global_variables_initializer() |
| 128 | +
|
| 129 | +Now we can read the batches, train the network and finally test the |
| 130 | +network by reconstructing a batch of test images. |
| 131 | + |
| 132 | + |
| 133 | +.. code-block:: python |
| 134 | +
|
| 135 | + from tensorflow.examples.tutorials.mnist import input_data |
| 136 | +
|
| 137 | + batch_size = 500 # Number of samples in each batch |
| 138 | + epoch_num = 5 # Number of epochs to train the network |
| 139 | + lr = 0.001 # Learning rate |
| 140 | +
|
| 141 | + # read MNIST dataset |
| 142 | + mnist = input_data.read_data_sets("MNIST_data", one_hot=True) |
| 143 | +
|
| 144 | + # calculate the number of batches per epoch |
| 145 | + batch_per_ep = mnist.train.num_examples // batch_size |
| 146 | +
|
| 147 | + with tf.Session() as sess: |
| 148 | + sess.run(init) |
| 149 | + for ep in range(epoch_num): # epochs loop |
| 150 | + for batch_n in range(batch_per_ep): # batches loop |
| 151 | + batch_img, batch_label = mnist.train.next_batch(batch_size) # read a batch |
| 152 | + batch_img = batch_img.reshape((-1, 28, 28, 1)) # reshape each sample to an (28, 28) image |
| 153 | + batch_img = resize_batch(batch_img) # reshape the images to (32, 32) |
| 154 | + _, c = sess.run([train_op, loss], feed_dict={ae_inputs: batch_img}) |
| 155 | + print('Epoch: {} - cost= {:.5f}'.format((ep + 1), c)) |
| 156 | +
|
| 157 | + # test the trained network |
| 158 | + batch_img, batch_label = mnist.test.next_batch(50) |
| 159 | + batch_img = resize_batch(batch_img) |
| 160 | + recon_img = sess.run([ae_outputs], feed_dict={ae_inputs: batch_img})[0] |
| 161 | +
|
| 162 | + # plot the reconstructed images and their ground truths (inputs) |
| 163 | + plt.figure(1) |
| 164 | + plt.title('Reconstructed Images') |
| 165 | + for i in range(50): |
| 166 | + plt.subplot(5, 10, i+1) |
| 167 | + plt.imshow(recon_img[i, ..., 0], cmap='gray') |
| 168 | + plt.figure(2) |
| 169 | + plt.title('Input Images') |
| 170 | + for i in range(50): |
| 171 | + plt.subplot(5, 10, i+1) |
| 172 | + plt.imshow(batch_img[i, ..., 0], cmap='gray') |
| 173 | + plt.show() |
0 commit comments