Reparameterization trick

Reparameterization trick

πŸ“Œ
πŸ“Œ
Sign up to Circuit of Knowledge blog for unlimited tutorials and content
πŸ“Œ
If it’s knowledge you’re after, join our growing Slack community!

July 5th 2024

The reparameterization trick is a key technique in variational inference that enables the optimization of variational autoencoders (VAEs) and other models with continuous latent variables. It addresses the challenge of backpropagating gradients through stochastic sampling operations, which are non-differentiable. By reformulating the sampling process as a deterministic function of the parameters and a separate source of randomness, the reparameterization trick allows gradients to flow through the sampling operation and enables end-to-end training of VAEs using standard gradient-based optimization methods.

In a VAE, the objective is to maximize the evidence lower bound (ELBO), which involves an expectation over the variational distribution . The ELBO can be written as:

To optimize the ELBO with respect to the variational parameters , we need to compute the gradient:

However, the expectation is taken with respect to the variational distribution , which depends on the parameters . This makes the gradient computation challenging because the sampling operation is non-differentiable. Why? because the process of drawing samples from a probability distribution is a discrete operation that does not have a well-defined gradient with respect to the parameters of the distribution. So, in the context of variational autoencoders (VAEs), the variational distribution is parameterized by , which are typically the parameters of a neural network that outputs the mean and variance of a Gaussian distribution. The goal is to optimize these parameters to maximize the evidence lower bound (ELBO) and learn a meaningful latent representation. Thus, we introduce a parametrization trick that can optimize on the mean vector and (diagonal) covariance matrix and derive by these deterministic parameters and not by a random variable.

The reparameterization trick overcomes this challenge by expressing the sampling operation as a deterministic function of the variational parameters and a separate source of randomness. Instead of sampling directly from , we introduce a differentiable transformation and a noise variable such that:

The transformation is chosen such that the resulting has the desired distribution . For example, if is a Gaussian distribution with mean and diagonal covariance , the reparameterization trick can be applied as follows:

Here, denotes element-wise multiplication. By expressing in this way, the sampling operation becomes deterministic with respect to , and the randomness is isolated in the noise variable .

With the reparameterization trick, the gradient of the ELBO with respect to can be computed as:

Now, the expectation is taken with respect to the distribution of the noise variable , which is independent of . This allows the gradient to be estimated using Monte Carlo samples of and backpropagated through the deterministic function .

Here's a simple Python code snippet that demonstrates the reparameterization trick for a Gaussian distribution:

import torch

def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std

In this code, mu represents the mean and log_var represents the logarithm of the variance . The reparameterize function takes these parameters and a random noise variable eps sampled from a standard Gaussian distribution, and returns the reparametrized latent variable z.

The reparameterization trick is crucial for enabling gradient-based optimization of VAEs and other models with continuous latent variables. By expressing the sampling operation as a deterministic function of the parameters and a separate source of randomness, it allows gradients to flow through the sampling process and enables end-to-end training using standard backpropagation. This trick has been widely adopted in variational inference and has contributed to the success of VAEs and related models in various applications, including image generation, representation learning, and unsupervised learning tasks.