Variational Inference and the Evidence Lower Bound (ELBO)

Variational Inference and the Evidence Lower Bound (ELBO)

📌
Sign up to Circuit of Knowledge blog for unlimited tutorials and content
📌
If it’s knowledge you’re after, join our growing Slack community!

July 5th 2024

This is a detailed tutorial on Variational Inference and the Evidence Lower Bound (ELBO), including mathematical derivations, intuitive explanations, and documented Python code to illustrate the key concepts.

Introduction

Variational Inference (VI) is a powerful technique in Bayesian machine learning used to approximate intractable posterior distributions. The main idea behind VI is to transform the complex problem of posterior inference into an optimization problem by introducing a family of simpler, tractable distributions and finding the one that best approximates the true posterior. This is achieved by maximizing a lower bound on the log evidence, known as the Evidence Lower Bound (ELBO).

Variational Inference Setup

Let's consider a generative model with latent variables z\mathbf{z} and observed variables x\mathbf{x}. The joint probability distribution can be factorized as:

p(x,z)=p(xz)p(z)p(\mathbf{x}, \mathbf{z}) = p(\mathbf{x} | \mathbf{z})p(\mathbf{z})

Our goal is to infer the posterior distribution p(zx)p(\mathbf{z} | \mathbf{x}), which represents the probability of the latent variables given the observed data. However, computing the exact posterior is often intractable due to the normalization constant:

p(zx)=p(x,z)p(x)=p(x,z)p(x,z)dzp(\mathbf{z} | \mathbf{x}) = \frac{p(\mathbf{x}, \mathbf{z})}{p(\mathbf{x})} = \frac{p(\mathbf{x}, \mathbf{z})}{\int p(\mathbf{x}, \mathbf{z}) d\mathbf{z}}

Variational Inference introduces a variational distribution qϕ(zx)q_{\phi}(\mathbf{z}|\mathbf{x}) to approximate the true posterior. The goal is to find the optimal variational distribution qϕ(zx)q_{\phi}^*(\mathbf{z}|\mathbf{x}) that minimizes the Kullback-Leibler (KL) divergence between qϕ(zx)q_{\phi}(\mathbf{z}|\mathbf{x}) and p(zx)p(\mathbf{z} | \mathbf{x}):

qϕ(zx)=argminqϕ(zx)KL(qϕ(zx)p(zx))q_{\phi}^*(\mathbf{z}|\mathbf{x}) = \arg\min_{q_{\phi}(\mathbf{z}|\mathbf{x})} \mathrm{KL}(q_{\phi}(\mathbf{z}|\mathbf{x}) || p(\mathbf{z} | \mathbf{x}))

Evidence Lower Bound (ELBO)

The KL divergence between qϕ(zx)q_{\phi}(\mathbf{z}|\mathbf{x}) and p(zx)p(\mathbf{z} | \mathbf{x}) can be written as:

KL(qϕ(zx)p(zx))=Eqϕ(zx)[logqϕ(zx)logp(zx)]\mathrm{KL}(q_{\phi}(\mathbf{z}|\mathbf{x}) || p(\mathbf{z} | \mathbf{x})) = \mathbb{E}{q{\phi}(\mathbf{z}|\mathbf{x})}[\log q_{\phi}(\mathbf{z}|\mathbf{x}) - \log p(\mathbf{z} | \mathbf{x})]

However, this expression is still intractable due to the presence of the log evidence term logp(x)\log p(\mathbf{x}). To circumvent this issue, we can derive the Evidence Lower Bound (ELBO) by applying Jensen's inequality:

logp(x)=logp(x,z)dz=logp(x,z)qϕ(zx)qϕ(zx)dz\log p(\mathbf{x}) = \log \int p(\mathbf{x}, \mathbf{z}) d\mathbf{z} = \log \int \frac{p(\mathbf{x}, \mathbf{z})}{q_{\phi}(\mathbf{z}|\mathbf{x})} q_{\phi}(\mathbf{z}|\mathbf{x}) d\mathbf{z} qϕ(zx)logp(x,z)qϕ(zx)dz=Eqϕ(zx)[logp(x,z)logqϕ(zx)]=ELBO(qϕ)\geq \int q_{\phi}(\mathbf{z}|\mathbf{x}) \log \frac{p(\mathbf{x}, \mathbf{z})}{q_{\phi}(\mathbf{z}|\mathbf{x})} d\mathbf{z} = \mathbb{E}{q{\phi}(\mathbf{z}|\mathbf{x})}[\log p(\mathbf{x}, \mathbf{z}) - \log q_{\phi}(\mathbf{z}|\mathbf{x})] = \mathrm{ELBO}(q_{\phi})

The ELBO is a lower bound on the log evidence and can be written as:

ELBO(qϕ)=Eqϕ(zx)[logp(x,z)logqϕ(zx)]\mathrm{ELBO}(q_{\phi}) = \mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x})}[\log p(\mathbf{x}, \mathbf{z}) - \log q_{\phi}(\mathbf{z}|\mathbf{x})]

By maximizing the ELBO with respect to the variational distribution qϕ(zx)q_{\phi}(\mathbf{z}|\mathbf{x}), we minimize the KL divergence between qϕ(zx)q_{\phi}(\mathbf{z}|\mathbf{x}) and p(zx)p(\mathbf{z} | \mathbf{x}), thus finding the best approximation to the true posterior.

In a Variational Autoencoder (VAE), the ELBO can be decomposed into a reconstruction term and a prior matching term:

ELBO(qϕ)=Eqϕ(zx)[logpθ(xz)]DKL(qϕ(zx)p(z))\mathrm{ELBO}(q_{\phi}) = \mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x})}[\log p_{\theta}(\mathbf{x}|\mathbf{z})] - D_{\mathrm{KL}}(q_{\phi}(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}))

The reconstruction term ensures that the learned latent variables can effectively reconstruct the original data, while the prior matching term encourages the variational distribution to stay close to a prior distribution, typically chosen to be a standard Gaussian.

Intuition and Interpretation

The ELBO consists of two terms:

  1. The expected log likelihood of the data under the variational distribution: Eqϕ(zx)[logpθ(xz)]\mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})]
    • This term encourages the variational distribution to place probability mass on latent variable configurations that explain the observed data well.
  2. The negative KL divergence between the variational distribution and the prior: DKL(qϕ(zx)p(z))-D_{\mathrm{KL}}(q_{\phi}(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}))
    • This term acts as a regularizer, preventing the variational distribution from deviating too far from the prior distribution.

Maximizing the ELBO balances these two terms, finding a variational distribution that explains the data well while staying close to the prior.

For a more detailed explanation of the math behind the ELBO, VAEs, and HVAEs, please refer to the paper "Understanding Diffusion Models: A Unified Perspective" by Calvin Luo (pages 2 to 5).

Python Implementation

Here's a simple Python implementation illustrating the key concepts of Variational Inference and the ELBO using a Variational Autoencoder (VAE) as an example:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        # Encoder network
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2 * latent_dim)
        )
        # Decoder network
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        # Reparameterization trick to sample from latent space
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # Encode input to latent space
        h = self.encoder(x)
        mu, logvar = h.chunk(2, dim=1)
        # Sample from latent space using reparameterization trick
        z = self.reparameterize(mu, logvar)
        # Decode latent vector to reconstruct input
        recon_x = self.decoder(z)
        return recon_x, mu, logvar

    def elbo(self, recon_x, x, mu, logvar):
        # Reconstruction loss (binary cross-entropy)
        recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
        # KL divergence between variational distribution and prior (standard Gaussian)
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        # Evidence Lower Bound (ELBO)
        elbo = -recon_loss - kl_div
        return elbo

def train(model, optimizer, train_loader, epochs):
    for epoch in range(epochs):
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.view(data.size(0), -1)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            elbo = model.elbo(recon_batch, data, mu, logvar)
            loss = -elbo  # Negative ELBO as loss
            loss.backward()
            optimizer.step()

        print(f"Epoch [{epoch+1}/{epochs}], ELBO: {elbo.item():.4f}")

        # Generate and display images after specific epochs
        if epoch + 1 in [1, 5, 10, 15, 20]:
            generated_images = generate_images(model)
            fig, axes = plt.subplots(2, 4, figsize=(8, 4))
            fig.suptitle(f"Generated Images - Epoch {epoch + 1}")
            for i, ax in enumerate(axes.flat):
                ax.imshow(generated_images[i].numpy(), cmap='gray')
                ax.axis('off')
            plt.tight_layout()
            plt.show()

def generate_images(model, num_images=8):
    with torch.no_grad():
        z = torch.randn(num_images, latent_dim)
        samples = model.decoder(z)
        samples = samples.view(num_images, 28, 28)
        return samples

# Hyperparameters
input_dim = 784
latent_dim = 20
epochs = 20
batch_size = 128
lr = 1e-3

# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Instantiate the model and optimizer
model = VAE(input_dim, latent_dim)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Train the model
train(model, optimizer, train_loader, epochs)

Explanation:

The code implements a Variational Autoencoder (VAE) using PyTorch. The VAE consists of an encoder network that maps the input data to a latent space and a decoder network that reconstructs the input data from the latent space.

The VAE class defines the architecture of the encoder and decoder networks. The encoder network takes the input data (input_dim) and maps it to a latent space of dimension latent_dim. The decoder network takes a latent vector from the latent space and reconstructs the input data.

The reparameterize function implements the reparameterization trick, which allows for backpropagation through the sampling process. It takes the mean (mu) and log-variance (logvar) of the variational distribution and samples a latent vector z from this distribution.

The forward function performs the encoding and decoding process. It encodes the input data to the latent space, samples a latent vector using the reparameterization trick, and then decodes the latent vector to reconstruct the input data.

The elbo function computes the Evidence Lower Bound (ELBO), which consists of two terms: the reconstruction loss and the KL divergence between the variational distribution and the prior distribution (assumed to be a standard Gaussian). The reconstruction loss is calculated using binary cross-entropy, measuring the difference between the reconstructed data and the original input. The KL divergence acts as a regularization term, encouraging the variational distribution to be close to the prior distribution.

The train function performs the training process. It iterates over the training data in batches, computes the ELBO for each batch, and updates the model parameters using stochastic gradient ascent. The negative ELBO is used as the loss function, which is equivalent to maximizing the ELBO. After specific epochs (1, 5, 10, 15, 20), the function generates and displays a grid of images using the generate_images function to visualize the progress of the VAE.

The generate_images function generates new images using the trained VAE. It samples latent vectors from a standard Gaussian distribution and passes them through the decoder network to generate images.

The code loads the MNIST dataset, which consists of handwritten digit images, using PyTorch's datasets and transforms modules. The dataset is divided into batches using a DataLoader.

After training the VAE for the specified number of epochs, the code generates and displays a grid of images after epochs 1, 5, 10, 15, and 20. These images showcase the VAE's ability to generate new samples that resemble the training data.

The negativity of the ELBO in the code (loss = -elbo) is due to the optimization process. In variational inference, the goal is to maximize the ELBO, which is equivalent to minimizing the negative ELBO. By minimizing the negative ELBO, we are effectively maximizing the ELBO and finding the best approximation to the true posterior distribution.

The code connects to the math in the tutorial by implementing the key concepts of variational inference and the ELBO. The encoder network represents the variational distribution q(z)q(\mathbf{z}), which approximates the true posterior distribution p(zx)p(\mathbf{z} | \mathbf{x}). The decoder network represents the likelihood function p(xz)p(\mathbf{x} | \mathbf{z}). The ELBO is computed as the sum of the reconstruction loss (expected log-likelihood) and the negative KL divergence between the variational distribution and the prior distribution. By maximizing the ELBO, the VAE learns to generate samples that resemble the training data while keeping the variational distribution close to the prior distribution.

Detailed Explanation

class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        # Encoder network
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2 * latent_dim)
        )
        # Decoder network
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim),
            nn.Sigmoid()
        )
  • The VAE class defines the architecture of the Variational Autoencoder (VAE). The encoder network maps the input data to the parameters of the variational distribution qϕ(zx)q_{\phi}(\mathbf{z}|\mathbf{x}), while the decoder network maps the latent variables z\mathbf{z} to the reconstructed input data.
    def reparameterize(self, mu, logvar):
        # Reparameterization trick to sample from latent space
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
  • The reparameterize function implements the reparameterization trick, which allows for backpropagation through the sampling process. It takes the mean mu and log-variance logvar of the variational distribution and samples a latent vector z\mathbf{z} from this distribution.
    def forward(self, x):
        # Encode input to latent space
        h = self.encoder(x)
        mu, logvar = h.chunk(2, dim=1)
        # Sample from latent space using reparameterization trick
        z = self.reparameterize(mu, logvar)
        # Decode latent vector to reconstruct input
        recon_x = self.decoder(z)
        return recon_x, mu, logvar
  • The forward function performs the encoding and decoding process. It encodes the input data x to the latent space using the encoder network, samples a latent vector z using the reparameterize function, and then decodes the latent vector to reconstruct the input data using the decoder network.
    def elbo(self, recon_x, x, mu, logvar):
        # Reconstruction loss (binary cross-entropy)
        recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
        # KL divergence between variational distribution and prior (standard Gaussian)
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        # Evidence Lower Bound (ELBO)
        elbo = -recon_loss - kl_div
        return elbo
  • The elbo function computes the Evidence Lower Bound (ELBO) for the VAE. It consists of two terms:
    1. The reconstruction loss, which is the binary cross-entropy between the reconstructed data recon_x and the original input x. This corresponds to the expected log-likelihood term Eqϕ(zx)[logpθ(xz)]\mathbb{E}{q{\phi}(\mathbf{z}|\mathbf{x})}[\log p_{\theta}(\mathbf{x}|\mathbf{z})] in the ELBO equation.
    2. The KL divergence between the variational distribution qϕ(zx)q_{\phi}(\mathbf{z}|\mathbf{x}) and the prior distribution p(z)p(\mathbf{z}), which is assumed to be a standard Gaussian. This corresponds to the negative KL divergence term DKL(qϕ(zx)p(z))-D_{\mathrm{KL}}(q_{\phi}(\mathbf{z}|\mathbf{x}) || p(\mathbf{z})) in the ELBO equation.
  • The ELBO is then computed as the negative sum of the reconstruction loss and the KL divergence.
def train(model, optimizer, train_loader, epochs):
    for epoch in range(epochs):
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.view(data.size(0), -1)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            elbo = model.elbo(recon_batch, data, mu, logvar)
            loss = -elbo  # Negative ELBO as loss
            loss.backward()
            optimizer.step()
  • The train function performs the training process. It iterates over the training data in batches, computes the ELBO for each batch using the elbo function, and updates the model parameters using stochastic gradient ascent. The negative ELBO is used as the loss function, which is equivalent to maximizing the ELBO.
def generate_images(model, num_images=8):
    with torch.no_grad():
        z = torch.randn(num_images, latent_dim)
        samples = model.decoder(z)
        samples = samples.view(num_images, 28, 28)
        return samples
  • The generate_images function generates new images using the trained VAE. It samples latent vectors z from a standard Gaussian distribution and passes them through the decoder network to generate images.

The rest of the code sets up the hyperparameters, loads the MNIST dataset, instantiates the VAE model and optimizer, and trains the model using the train function. It also generates and displays images at specific epochs to visualize the progress of the VAE.

Overall, the code implements the key equations and concepts discussed in the math section, including the ELBO, the reconstruction term, the KL divergence term, and the reparameterization trick.

Output:

Epoch [1/20], ELBO: -17162.4336
image
image
Epoch [2/20], ELBO: -14177.7568
Epoch [3/20], ELBO: -12621.7002
Epoch [4/20], ELBO: -11669.5605
Epoch [5/20], ELBO: -12150.1875
image
image
Epoch [6/20], ELBO: -11669.9355
Epoch [7/20], ELBO: -11127.0879
Epoch [8/20], ELBO: -11047.6523
Epoch [9/20], ELBO: -11274.8857
Epoch [10/20], ELBO: -11104.8076
image
image
Epoch [11/20], ELBO: -11278.0664
Epoch [12/20], ELBO: -11670.6182
Epoch [13/20], ELBO: -11034.0752
Epoch [14/20], ELBO: -10838.0742
Epoch [15/20], ELBO: -10617.1680
image
image
Epoch [16/20], ELBO: -10899.2812
Epoch [17/20], ELBO: -10723.3271
Epoch [18/20], ELBO: -10925.9639
Epoch [19/20], ELBO: -10370.6025
Epoch [20/20], ELBO: -10978.4980
image
image

Conclusion

Variational Inference and the Evidence Lower Bound (ELBO) provide a principled framework for approximating intractable posterior distributions in Bayesian models. By introducing a variational distribution and maximizing the ELBO, we can find the best approximation to the true posterior while balancing the trade-off between explaining the observed data and staying close to the prior distribution.

The provided Python code demonstrates a simple implementation of Variational Inference using a Variational Autoencoder as an example. This code can be adapted to other models and datasets to perform variational inference in various settings.

I hope this tutorial has provided you with a solid understanding of Variational Inference and the ELBO, both from a mathematical and intuitive perspective, along with practical code examples to illustrate the concepts. Let me know if you have any further questions!