July 5th 2024
This is a detailed tutorial on Variational Inference and the Evidence Lower Bound (ELBO), including mathematical derivations, intuitive explanations, and documented Python code to illustrate the key concepts.
Introduction
Variational Inference (VI) is a powerful technique in Bayesian machine learning used to approximate intractable posterior distributions. The main idea behind VI is to transform the complex problem of posterior inference into an optimization problem by introducing a family of simpler, tractable distributions and finding the one that best approximates the true posterior. This is achieved by maximizing a lower bound on the log evidence, known as the Evidence Lower Bound (ELBO).
Variational Inference Setup
Let's consider a generative model with latent variables and observed variables . The joint probability distribution can be factorized as:
Our goal is to infer the posterior distribution , which represents the probability of the latent variables given the observed data. However, computing the exact posterior is often intractable due to the normalization constant:
Variational Inference introduces a variational distribution to approximate the true posterior. The goal is to find the optimal variational distribution that minimizes the Kullback-Leibler (KL) divergence between and :
Evidence Lower Bound (ELBO)
The KL divergence between and can be written as:
However, this expression is still intractable due to the presence of the log evidence term . To circumvent this issue, we can derive the Evidence Lower Bound (ELBO) by applying Jensen's inequality:
The ELBO is a lower bound on the log evidence and can be written as:
By maximizing the ELBO with respect to the variational distribution , we minimize the KL divergence between and , thus finding the best approximation to the true posterior.
In a Variational Autoencoder (VAE), the ELBO can be decomposed into a reconstruction term and a prior matching term:
The reconstruction term ensures that the learned latent variables can effectively reconstruct the original data, while the prior matching term encourages the variational distribution to stay close to a prior distribution, typically chosen to be a standard Gaussian.
Intuition and Interpretation
The ELBO consists of two terms:
- The expected log likelihood of the data under the variational distribution:
- This term encourages the variational distribution to place probability mass on latent variable configurations that explain the observed data well.
- The negative KL divergence between the variational distribution and the prior:
- This term acts as a regularizer, preventing the variational distribution from deviating too far from the prior distribution.
Maximizing the ELBO balances these two terms, finding a variational distribution that explains the data well while staying close to the prior.
For a more detailed explanation of the math behind the ELBO, VAEs, and HVAEs, please refer to the paper "Understanding Diffusion Models: A Unified Perspective" by Calvin Luo (pages 2 to 5).
Python Implementation
Here's a simple Python implementation illustrating the key concepts of Variational Inference and the ELBO using a Variational Autoencoder (VAE) as an example:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
class VAE(nn.Module):
def __init__(self, input_dim, latent_dim):
super(VAE, self).__init__()
# Encoder network
self.encoder = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 2 * latent_dim)
)
# Decoder network
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, input_dim),
nn.Sigmoid()
)
def reparameterize(self, mu, logvar):
# Reparameterization trick to sample from latent space
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
# Encode input to latent space
h = self.encoder(x)
mu, logvar = h.chunk(2, dim=1)
# Sample from latent space using reparameterization trick
z = self.reparameterize(mu, logvar)
# Decode latent vector to reconstruct input
recon_x = self.decoder(z)
return recon_x, mu, logvar
def elbo(self, recon_x, x, mu, logvar):
# Reconstruction loss (binary cross-entropy)
recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
# KL divergence between variational distribution and prior (standard Gaussian)
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Evidence Lower Bound (ELBO)
elbo = -recon_loss - kl_div
return elbo
def train(model, optimizer, train_loader, epochs):
for epoch in range(epochs):
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(data.size(0), -1)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
elbo = model.elbo(recon_batch, data, mu, logvar)
loss = -elbo # Negative ELBO as loss
loss.backward()
optimizer.step()
print(f"Epoch [{epoch+1}/{epochs}], ELBO: {elbo.item():.4f}")
# Generate and display images after specific epochs
if epoch + 1 in [1, 5, 10, 15, 20]:
generated_images = generate_images(model)
fig, axes = plt.subplots(2, 4, figsize=(8, 4))
fig.suptitle(f"Generated Images - Epoch {epoch + 1}")
for i, ax in enumerate(axes.flat):
ax.imshow(generated_images[i].numpy(), cmap='gray')
ax.axis('off')
plt.tight_layout()
plt.show()
def generate_images(model, num_images=8):
with torch.no_grad():
z = torch.randn(num_images, latent_dim)
samples = model.decoder(z)
samples = samples.view(num_images, 28, 28)
return samples
# Hyperparameters
input_dim = 784
latent_dim = 20
epochs = 20
batch_size = 128
lr = 1e-3
# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Instantiate the model and optimizer
model = VAE(input_dim, latent_dim)
optimizer = optim.Adam(model.parameters(), lr=lr)
# Train the model
train(model, optimizer, train_loader, epochs)
Explanation:
The code implements a Variational Autoencoder (VAE) using PyTorch. The VAE consists of an encoder network that maps the input data to a latent space and a decoder network that reconstructs the input data from the latent space.
The VAE
class defines the architecture of the encoder and decoder networks. The encoder network takes the input data (input_dim
) and maps it to a latent space of dimension latent_dim
. The decoder network takes a latent vector from the latent space and reconstructs the input data.
The reparameterize
function implements the reparameterization trick, which allows for backpropagation through the sampling process. It takes the mean (mu
) and log-variance (logvar
) of the variational distribution and samples a latent vector z
from this distribution.
The forward
function performs the encoding and decoding process. It encodes the input data to the latent space, samples a latent vector using the reparameterization trick, and then decodes the latent vector to reconstruct the input data.
The elbo
function computes the Evidence Lower Bound (ELBO), which consists of two terms: the reconstruction loss and the KL divergence between the variational distribution and the prior distribution (assumed to be a standard Gaussian). The reconstruction loss is calculated using binary cross-entropy, measuring the difference between the reconstructed data and the original input. The KL divergence acts as a regularization term, encouraging the variational distribution to be close to the prior distribution.
The train
function performs the training process. It iterates over the training data in batches, computes the ELBO for each batch, and updates the model parameters using stochastic gradient ascent. The negative ELBO is used as the loss function, which is equivalent to maximizing the ELBO. After specific epochs (1, 5, 10, 15, 20), the function generates and displays a grid of images using the generate_images
function to visualize the progress of the VAE.
The generate_images
function generates new images using the trained VAE. It samples latent vectors from a standard Gaussian distribution and passes them through the decoder network to generate images.
The code loads the MNIST dataset, which consists of handwritten digit images, using PyTorch's datasets
and transforms
modules. The dataset is divided into batches using a DataLoader
.
After training the VAE for the specified number of epochs, the code generates and displays a grid of images after epochs 1, 5, 10, 15, and 20. These images showcase the VAE's ability to generate new samples that resemble the training data.
The negativity of the ELBO in the code (loss = -elbo
) is due to the optimization process. In variational inference, the goal is to maximize the ELBO, which is equivalent to minimizing the negative ELBO. By minimizing the negative ELBO, we are effectively maximizing the ELBO and finding the best approximation to the true posterior distribution.
The code connects to the math in the tutorial by implementing the key concepts of variational inference and the ELBO. The encoder network represents the variational distribution , which approximates the true posterior distribution . The decoder network represents the likelihood function . The ELBO is computed as the sum of the reconstruction loss (expected log-likelihood) and the negative KL divergence between the variational distribution and the prior distribution. By maximizing the ELBO, the VAE learns to generate samples that resemble the training data while keeping the variational distribution close to the prior distribution.
Detailed Explanation
class VAE(nn.Module):
def __init__(self, input_dim, latent_dim):
super(VAE, self).__init__()
# Encoder network
self.encoder = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 2 * latent_dim)
)
# Decoder network
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, input_dim),
nn.Sigmoid()
)
- The
VAE
class defines the architecture of the Variational Autoencoder (VAE). Theencoder
network maps the input data to the parameters of the variational distribution , while thedecoder
network maps the latent variables to the reconstructed input data.
def reparameterize(self, mu, logvar):
# Reparameterization trick to sample from latent space
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
- The
reparameterize
function implements the reparameterization trick, which allows for backpropagation through the sampling process. It takes the meanmu
and log-variancelogvar
of the variational distribution and samples a latent vector from this distribution.
def forward(self, x):
# Encode input to latent space
h = self.encoder(x)
mu, logvar = h.chunk(2, dim=1)
# Sample from latent space using reparameterization trick
z = self.reparameterize(mu, logvar)
# Decode latent vector to reconstruct input
recon_x = self.decoder(z)
return recon_x, mu, logvar
- The
forward
function performs the encoding and decoding process. It encodes the input datax
to the latent space using theencoder
network, samples a latent vectorz
using thereparameterize
function, and then decodes the latent vector to reconstruct the input data using thedecoder
network.
def elbo(self, recon_x, x, mu, logvar):
# Reconstruction loss (binary cross-entropy)
recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
# KL divergence between variational distribution and prior (standard Gaussian)
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Evidence Lower Bound (ELBO)
elbo = -recon_loss - kl_div
return elbo
- The
elbo
function computes the Evidence Lower Bound (ELBO) for the VAE. It consists of two terms: - The reconstruction loss, which is the binary cross-entropy between the reconstructed data
recon_x
and the original inputx
. This corresponds to the expected log-likelihood term in the ELBO equation. - The KL divergence between the variational distribution and the prior distribution , which is assumed to be a standard Gaussian. This corresponds to the negative KL divergence term in the ELBO equation.
- The ELBO is then computed as the negative sum of the reconstruction loss and the KL divergence.
def train(model, optimizer, train_loader, epochs):
for epoch in range(epochs):
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(data.size(0), -1)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
elbo = model.elbo(recon_batch, data, mu, logvar)
loss = -elbo # Negative ELBO as loss
loss.backward()
optimizer.step()
- The
train
function performs the training process. It iterates over the training data in batches, computes the ELBO for each batch using theelbo
function, and updates the model parameters using stochastic gradient ascent. The negative ELBO is used as the loss function, which is equivalent to maximizing the ELBO.
def generate_images(model, num_images=8):
with torch.no_grad():
z = torch.randn(num_images, latent_dim)
samples = model.decoder(z)
samples = samples.view(num_images, 28, 28)
return samples
- The
generate_images
function generates new images using the trained VAE. It samples latent vectorsz
from a standard Gaussian distribution and passes them through thedecoder
network to generate images.
The rest of the code sets up the hyperparameters, loads the MNIST dataset, instantiates the VAE model and optimizer, and trains the model using the train
function. It also generates and displays images at specific epochs to visualize the progress of the VAE.
Overall, the code implements the key equations and concepts discussed in the math section, including the ELBO, the reconstruction term, the KL divergence term, and the reparameterization trick.
Output:
Epoch [1/20], ELBO: -17162.4336
Epoch [2/20], ELBO: -14177.7568
Epoch [3/20], ELBO: -12621.7002
Epoch [4/20], ELBO: -11669.5605
Epoch [5/20], ELBO: -12150.1875
Epoch [6/20], ELBO: -11669.9355
Epoch [7/20], ELBO: -11127.0879
Epoch [8/20], ELBO: -11047.6523
Epoch [9/20], ELBO: -11274.8857
Epoch [10/20], ELBO: -11104.8076
Epoch [11/20], ELBO: -11278.0664
Epoch [12/20], ELBO: -11670.6182
Epoch [13/20], ELBO: -11034.0752
Epoch [14/20], ELBO: -10838.0742
Epoch [15/20], ELBO: -10617.1680
Epoch [16/20], ELBO: -10899.2812
Epoch [17/20], ELBO: -10723.3271
Epoch [18/20], ELBO: -10925.9639
Epoch [19/20], ELBO: -10370.6025
Epoch [20/20], ELBO: -10978.4980
Conclusion
Variational Inference and the Evidence Lower Bound (ELBO) provide a principled framework for approximating intractable posterior distributions in Bayesian models. By introducing a variational distribution and maximizing the ELBO, we can find the best approximation to the true posterior while balancing the trade-off between explaining the observed data and staying close to the prior distribution.
The provided Python code demonstrates a simple implementation of Variational Inference using a Variational Autoencoder as an example. This code can be adapted to other models and datasets to perform variational inference in various settings.
I hope this tutorial has provided you with a solid understanding of Variational Inference and the ELBO, both from a mathematical and intuitive perspective, along with practical code examples to illustrate the concepts. Let me know if you have any further questions!