Reparameterization trick

Reparameterization trick

πŸ“Œ
πŸ“Œ
Sign up to Circuit of Knowledge blog for unlimited tutorials and content
πŸ“Œ
If it’s knowledge you’re after, join our growing Slack community!

July 5th 2024

The reparameterization trick is a key technique in variational inference that enables the optimization of variational autoencoders (VAEs) and other models with continuous latent variables. It addresses the challenge of backpropagating gradients through stochastic sampling operations, which are non-differentiable. By reformulating the sampling process as a deterministic function of the parameters and a separate source of randomness, the reparameterization trick allows gradients to flow through the sampling operation and enables end-to-end training of VAEs using standard gradient-based optimization methods.

In a VAE, the objective is to maximize the evidence lower bound (ELBO), which involves an expectation over the variational distribution qΟ•(z∣x)q_{\phi}(\mathbf{z}|\mathbf{x}). The ELBO can be written as:

L(Ο•,ΞΈ;x)=EqΟ•(z∣x)[log⁑pΞΈ(x,z)βˆ’log⁑qΟ•(z∣x)]\mathcal{L}(\phi, \theta; \mathbf{x}) = \mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x})}[\log p_{\theta}(\mathbf{x}, \mathbf{z}) - \log q_{\phi}(\mathbf{z}|\mathbf{x})]

To optimize the ELBO with respect to the variational parameters Ο•\phi, we need to compute the gradient:

βˆ‡Ο•L(Ο•,ΞΈ;x)=βˆ‡Ο•EqΟ•(z∣x)[log⁑pΞΈ(x,z)βˆ’log⁑qΟ•(z∣x)]\nabla_{\phi} \mathcal{L}(\phi, \theta; \mathbf{x}) = \nabla_{\phi} \mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x})}[\log p_{\theta}(\mathbf{x}, \mathbf{z}) - \log q_{\phi}(\mathbf{z}|\mathbf{x})]

However, the expectation is taken with respect to the variational distribution qΟ•(z∣x)q_{\phi}(\mathbf{z}|\mathbf{x}), which depends on the parameters Ο•\phi. This makes the gradient computation challenging because the sampling operation z∼qΟ•(z∣x)\mathbf{z} \sim q_{\phi}(\mathbf{z}|\mathbf{x}) is non-differentiable. Why? because the process of drawing samples from a probability distribution is a discrete operation that does not have a well-defined gradient with respect to the parameters of the distribution. So, in the context of variational autoencoders (VAEs), the variational distribution qΟ•(z∣x)q_{\phi}(\mathbf{z}|\mathbf{x}) is parameterized by Ο•\phi, which are typically the parameters of a neural network that outputs the mean and variance of a Gaussian distribution. The goal is to optimize these parameters to maximize the evidence lower bound (ELBO) and learn a meaningful latent representation. Thus, we introduce a parametrization trick that can optimize on the mean vector and (diagonal) covariance matrix and derive by these deterministic parameters and not by a random variable.

The reparameterization trick overcomes this challenge by expressing the sampling operation as a deterministic function of the variational parameters and a separate source of randomness. Instead of sampling directly from qΟ•(z∣x)q_{\phi}(\mathbf{z}|\mathbf{x}), we introduce a differentiable transformation gΟ•(Ο΅,x)g_{\phi}(\boldsymbol{\epsilon}, \mathbf{x}) and a noise variable Ο΅\boldsymbol{\epsilon} such that:

z=gΟ•(Ο΅,x)withϡ∼p(Ο΅)\mathbf{z} = g_{\phi}(\boldsymbol{\epsilon}, \mathbf{x}) \quad \text{with} \quad \boldsymbol{\epsilon} \sim p(\boldsymbol{\epsilon})

The transformation gΟ•g_{\phi} is chosen such that the resulting z\mathbf{z} has the desired distribution qΟ•(z∣x)q_{\phi}(\mathbf{z}|\mathbf{x}). For example, if qΟ•(z∣x)q_{\phi}(\mathbf{z}|\mathbf{x}) is a Gaussian distribution with mean ΞΌΟ•(x)\boldsymbol{\mu}_{\phi}(\mathbf{x}) and diagonal covariance σϕ2(x)\boldsymbol{\sigma}^2_{\phi}(\mathbf{x}), the reparameterization trick can be applied as follows:

z=ΞΌΟ•(x)+σϕ(x)βŠ™Ο΅withϡ∼N(0,I)\mathbf{z} = \boldsymbol{\mu}_{\phi}(\mathbf{x}) + \boldsymbol{\sigma}_{\phi}(\mathbf{x}) \odot \boldsymbol{\epsilon} \quad \text{with} \quad \boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})

Here, βŠ™\odot denotes element-wise multiplication. By expressing z\mathbf{z} in this way, the sampling operation becomes deterministic with respect to Ο•\phi, and the randomness is isolated in the noise variable Ο΅\boldsymbol{\epsilon}.

With the reparameterization trick, the gradient of the ELBO with respect to Ο•\phi can be computed as:

βˆ‡Ο•L(Ο•,ΞΈ;x)=Ep(Ο΅)[βˆ‡Ο•(log⁑pΞΈ(x,gΟ•(Ο΅,x))βˆ’log⁑qΟ•(gΟ•(Ο΅,x)∣x))]\nabla_{\phi} \mathcal{L}(\phi, \theta; \mathbf{x}) = \mathbb{E}_{p(\boldsymbol{\epsilon})}[\nabla{\phi} (\log p_{\theta}(\mathbf{x}, g_{\phi}(\boldsymbol{\epsilon}, \mathbf{x})) - \log q_{\phi}(g_{\phi}(\boldsymbol{\epsilon}, \mathbf{x})|\mathbf{x}))]

Now, the expectation is taken with respect to the distribution of the noise variable p(Ο΅)p(\boldsymbol{\epsilon}), which is independent of Ο•\phi. This allows the gradient to be estimated using Monte Carlo samples of Ο΅\boldsymbol{\epsilon} and backpropagated through the deterministic function gΟ•g_{\phi}.

Here's a simple Python code snippet that demonstrates the reparameterization trick for a Gaussian distribution:

import torch

def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std

In this code, mu represents the mean ΞΌΟ•(x)\boldsymbol{\mu}_{\phi}(\mathbf{x}) and log_var represents the logarithm of the variance log⁑σϕ2(x)\log \boldsymbol{\sigma}^2_{\phi}(\mathbf{x}). The reparameterize function takes these parameters and a random noise variable eps sampled from a standard Gaussian distribution, and returns the reparametrized latent variable z.

The reparameterization trick is crucial for enabling gradient-based optimization of VAEs and other models with continuous latent variables. By expressing the sampling operation as a deterministic function of the parameters and a separate source of randomness, it allows gradients to flow through the sampling process and enables end-to-end training using standard backpropagation. This trick has been widely adopted in variational inference and has contributed to the success of VAEs and related models in various applications, including image generation, representation learning, and unsupervised learning tasks.