Demystifying Neural Networks: Variational AutoEncoders

Dagang Wei on 2024-03-27

Source

This article is part of the series Demystifying Neural Networks.

Introduction

In the realm of deep learning, autoencoders are well known for their ability to compress and reconstruct data. Their goal is to learn efficient representations of input data, usually serving purposes like dimensionality reduction or feature extraction. Enter the Variational Autoencoder (VAE), a close relative of the traditional autoencoder, but one infused with a touch of probabilistic magic. This special twist makes VAEs incredibly powerful generative models. Let’s dive in!

What is a Variational Autoencoder (VAE)?

A variational autoencoder is a type of generative neural network architecture. At its heart, a VAE still has the same structural components as a traditional autoencoder: an encoder and a decoder.

Why Use a VAE?

VAE vs. Autoencoder: Differences

Variational Autoencoder (VAE)

Autoencoder (AE)

How Does a VAE Work?

  1. The Distribution Trick: Instead of a point in the latent space, the encoder of a VAE outputs the parameters that define a probability distribution (usually mean and variance). During training, we sample a point from this distribution to feed into the decoder.
  2. The Reparameterization Trick: This is the clever part. Directly backpropagating gradients through random sampling is tricky. The reparameterization trick lets us express the sampled point in the latent space as a deterministic function of the distribution parameters and an external random variable. This allows for proper training.
  3. Loss Function: Beyond Reconstruction: The VAE’s loss function has two parts:

The Reparameterization Trick

The core difficulty lies in the concept of backpropagation, a training technique in neural networks. Backpropagation allows us to adjust the network’s internal parameters based on the difference between the predicted output and the actual output. However, it struggles when dealing with random sampling, which is inherent in VAEs.

The Traditional (Broken) Approach

Imagine the encoder outputs the mean (μ) and standard deviation (σ) of a Gaussian distribution. We then sample a random variable from the distribution as:

z = random_sample(μ, σ)

This seems straightforward, but there’s a catch. Backpropagation requires calculating the gradients with respect to the model’s parameters (μ and σ in this case). However, calculating the gradient with respect to z (a randomly sampled variable) is mathematically undefined. This throws a wrench in the training process.

The reparameterization trick offers an elegant solution. Instead of directly sampling z from its distribution N(μ,σ^2), we decompose z into a deterministic component and a stochastic component that is independent of the parameters we want to optimize.

Here’s the magic:

  1. Introduce a new random rariable: We introduce another independent random variable, typically another standard normal variable (ε). This new variable (ε) serves as a source of randomness entirely separate from the sampling process within the VAE.
  2. Reparameterize the latent variable: Instead of directly adding noise, we create a scaled version of the new random variable (ε) using the standard deviation (σ) predicted by the encoder. We achieve this with element-wise multiplication:

z = μ + σ * ε

This equation might seem very similar to the original equation, but there’s a crucial difference. Now, both μ and σ are deterministic outputs from the encoder, and ε is a separate source of randomness we can control.

Why Does This Work?

Even though new equation introduces a new variable (ε), it still captures the essence of the original Gaussian distribution. The scaled ε effectively injects randomness while maintaining the relationship between the mean (μ) and standard deviation (σ) defined by the encoder.

Benefits:

The reparameterization trick is a brilliant illustration of how a seemingly simple mathematical transformation can have a profound impact on the training process of complex neural network architectures like VAEs.

Example

The following is an example implementation of VAE on the MNIST dataset. The code is available in this colab notebook.

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras import layers, models, losses
from keras.datasets import mnist

# Load and prepare the MNIST dataset
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255. 
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

# VAE Architecture 
latent_dim = 2  # Dimension of the latent space

# Encoder 
encoder_inputs = layers.Input(shape=(784,)) 
x = layers.Dense(256, activation='relu')(encoder_inputs)
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x) 

# Reparameterization Trick
def sampling(args):
    z_mean, z_log_var = args
    epsilon = tf.random.normal(shape=tf.shape(z_mean), mean=0., stddev=1.)
    return z_mean + tf.exp(z_log_var / 2) * epsilon

z = layers.Lambda(sampling)([z_mean, z_log_var]) 
encoder = models.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')

# Decoder 
latent_inputs = layers.Input(shape=(latent_dim,))
x = layers.Dense(256, activation='relu')(latent_inputs)
decoder_outputs = layers.Dense(784, activation='sigmoid')(x)  
decoder = models.Model(latent_inputs, decoder_outputs, name='decoder')

# VAE model (combining encoder and decoder)
vae_outputs = decoder(encoder(encoder_inputs)[2])  
vae = models.Model(encoder_inputs, vae_outputs, name='vae')

# Loss function (reconstruction loss + KL divergence)
reconstruction_loss = losses.mse(encoder_inputs, vae_outputs) 
reconstruction_loss *= 784  # Rescale due to input shape
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
vae_loss = tf.reduce_mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)

# Training 
vae.compile(optimizer='adam') 
vae.fit(x_train, x_train, epochs=30, batch_size=128)  

# Test and Visualization
n_to_visualize = 3
digit_size = 28
figure = np.zeros((digit_size * 2, digit_size * n_to_visualize))  # Two rows for original and reconstructed

# Choose specific test images
test_image_indices = [1, 3, 7]  # Feel free to change these indices
images = x_test[test_image_indices]

# Generate reconstructions
_, _, encoded_images = encoder.predict(images)  
decoded_images = decoder.predict(encoded_images) 

for i, idx in enumerate(test_image_indices):
    # Original image
    figure[0:digit_size, i * digit_size:(i + 1) * digit_size] = images[i].reshape(28, 28)

    # Reconstructed image
    figure[digit_size:, i * digit_size:(i + 1) * digit_size] = decoded_images[i].reshape(28, 28)

plt.figure(figsize=(10, 4))  # Adjust figure size
plt.suptitle('Original vs Reconstructed')  # Add a title
for i in range(n_to_visualize):
    plt.subplot(2, n_to_visualize, i + 1) 
    plt.imshow(figure[0:digit_size, i*digit_size :(i+1)*digit_size], cmap='gray')
    plt.xticks([])  # Remove ticks
    plt.yticks([])

    plt.subplot(2, n_to_visualize, n_to_visualize + i + 1)  # Subplot for reconstruction
    plt.imshow(figure[digit_size:, i*digit_size :(i+1)*digit_size], cmap='gray')
    plt.xticks([])
    plt.yticks([])

plt.show()

Conclusion

Variational autoencoders provide a fascinating blend of autoencoders and probabilistic modeling. Their ability to learn organized latent spaces and their generative prowess makes them a remarkably versatile tool in the vast landscape of deep learning.

References