Variational Inference and Reparameterization-- step-by-step

The reparameterization trick was first introduced when I first learned about the Variational Autoencoder(VAE). It shamefully took me multiple years from when I first learned about it to when I fully internalize why it is needed, and how I can turn any inference problem into one solved by variational inference. Here, I write about my general approaches to inference, with a focus on Variational Inference, and detailed explanation for why the reparameterization trick is needed.

Refreshers on some well-known models

In a general framework for inference, we observe some data $X$, which we assume to be generated by some latent variables and/or parameters $\theta$, such that we can write down a formula for $P(\mathbf{X}|\theta)$. $theta$ itself is generated from certain distribution, $P(\theta)$. Whether $\theta$ involve some latent variable or parameters, it is usually the goal of our inference to find the distribution or the values of $\theta$ given the observed data $P(\theta|X)$. Example cases of this framework are:

Four general approaches to inference

What these models share in common is that it involves observed data $X$ and one or more layers of latent variables $\theta$. What I learned, over the years of connecting the dots (sometimes, very inefficiently, as part of my interdisciplinary training) is that we will almost will always use one of the following approaches for inference (i.e. finding $P(\theta|X)$).

Note that though the general approaches can act like recipes, our job as practitioners of the discipline is to refine these approaches based on the data at hand, hence requires the arts of implementation, trial and error. Identifying when to use which approach, unfortunately, has to be handled on a case-by-case basis. An important observation is that at each iteration, all of these approaches try to reframe the inference problem into one that can be solved by calculus, i.e. designing a loss function where the model’s latent variables and parameters $\theta$ are variables, taking derivatives, setting to $0$ and solve for $\theta$.

Variational Inference

Let’s try to understand Variational Inference (VI) and where reparameterization trick comes in through a good old example:

As mentioned above, there are genrally 4 approaches to this problem and in this case, it can be solved by all 4 approaches. Here, we focus on using Variational Inference (VI).

Step 1: Identifying observed data, latent variables and write down the posterior.

\(\begin{array}{|c|c|c|} \hline & \textbf{General framework} & \textbf{Example} \\ \hline Obs.data & \mathbf{X} & \mathbf{A}, \mathbf{b} \\ Latent & \mathbf{\theta} & \mathbf{\beta} \\ Prior & P(\theta) & P(\beta) = Normal(\mathbf{0}, \sigma^2\cdot\mathbf{I_D}) \\ Likelihood & P(\mathbf{X}\|\theta) & P(\mathbf{A}, \mathbf{b}\|\beta) = P(\mathbf{b}-\mathbf{A}\beta) = Normal(\mathbf{0}, \sigma_0^2) \\ Posterior & P(\mathbf{\theta}\|\mathbf{X}) & P(\beta\|\mathbf{A}, \mathbf{b})=\frac{P(\mathbf{A}, \mathbf{b}\|\beta) P{\mathbf{\beta}}}{\int_{\mathbf{\beta}} P(\mathbf{A}, \mathbf{b}\|\beta) P(\mathbf{\beta}) d\mathbf{\beta}} = Normal(\mathbf{\mu}, \mathbf{\Sigma})\\ \hline \end{array}\)
In the table above, $\mathbf{\mu}$ and $\mathbf{\Sigma}$ has a closed-form solution (due to the fact that we designed both the likelihood and the prior to be Normal). However, most of the time, in the general cases, $P(\mathbf{\theta}|\mathbf{X})$ cannot be derived analysitically. This is when I say to myself: ‘Let's try Variational Inference’.

Step 2: Write down the ELBO for the general case

If this is your first time hearing about the ELBO, just think of it as a fancy word for the loss function for variational inference, which we want to minimize. I have to rederive the ELBO for the general case each time I plan to work with variational inference, just so I can remember what to do next.

Figure 1. Derivation of the ELBO.

Clearly, we want to find $\theta$ that maximizes $P(X)$. What we are saying, in variational inference, is: we cannot find the exact posterior $P(\theta|\mathbf{X})$, and also cannot find the optimal $\theta$ for $P(X)$ since it is not a tractable problem. Instead, we try to generate $\theta$ from a distribution of our design $q(\theta)$ (called the variational distribution), such that we can easily sample from (i.e. some known distributions that we can easily sample by a call to functions in numpy, scipy or pytorch). The distribution $q$ has some parameters (such as mean $\mu$ for a Normal distribution) that we will optimize so that the ELBO is maximized. The ELBO, as shown above, is a lower bound for the log-likelihood $log(P(X))$.

Another amazing advantage of variational infernece is that we potentially ignore the hierarchical dependency between latent variables. If our model is such that $\theta \rightarrow \mathbf{Z} \rightarrow \cdots \rightarrow \mathbf{X}$ where besides $\mathbf{X}$, all other variables are latent and one generates another, then mean-field variational inference allows us to assume that $\theta \sim q_{\theta}$, $Z \sim q_{Z}$, etc., and $q(\text{latent vars.}) = q(\theta)\cdot q(Z) \cdots$. We just have to find the optimal parameters for each $q$ to maximize the ELBO.

Step 3: Write down the ELBO for the case at hand

In the example case, we can choose $q(\mathbf{\beta})$ to be $Normal(\underbrace{\mathbf{\mu}}_{D\cdot 1}, \mathbf{\Sigma})$, where

\(\begin{aligned} \underbrace{\mathbf{\Sigma}}_{D \cdot D}= \begin{bmatrix} \sigma_1^2 & 0 & \cdots & 0 \\ 0 & \sigma_2^2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & \sigma_D^2 \end{bmatrix} \end{aligned}\)
The parameters for which we will optimize for, therefore, include $\mathbf{\mu}$ and $\sigma_1^2, \sigma_2^2, \cdots, \sigma_D^2$. The ELBO for this case, based on Fig. 1, is:
\(\begin{align} \text{ELBO} &= \mathbb{E}_{q(\mathbf{\beta})} \left[ \log P(\mathbf{A}, \mathbf{b} \mid \mathbf{\beta}) \right] - KL(q(\mathbf{\beta}) \| P(\mathbf{\beta})) \\ &= \sum_{t=1}^{T} \sum_{i=1}^{N} \underbrace{\log P(\mathbf{b}_i - \mathbf{A}_i \mathbf{\beta_t})}_{N(0,\sigma_0^2)} - \underbrace{KL(q(\mathbf{\beta}) \| P(\mathbf{\beta}))}_{KL(N(\mu, \Sigma), N(\mathbf{0}, \sigma^2\cdot \mathbf{I_D}))}\\ \end{align}\)

Step 4: Reparameterize the sampled latent variables from $q$

Note that the parameters for the variational inferences include $\mu$ and $\sigma_1^2, \sigma_2^2, \cdots, \sigma_D^2$. The goal of variational inference, again, is to find the optimal values of these parameters to maximize ELBO. There usually exists an analytical fomular for $KL(q(\mathbf{\beta}) | P(\mathbf{\beta}))$ (which, in our example is the KL divergence between two MultivariateNormal distributions with different means and diagonal covariance matrices). Therefore, we can easily calculate $\frac{\partial KL(q | P)}{\partial \mu_i}$ and $\frac{\partial KL(q | P)}{\partial \sigma_i}$ for each $i \in {1,…D}$. That is one part of $\frac{\partial \text{ELBO}}{\partial \mu_i}$ and $\frac{\partial \text{ELBO}}{\partial \sigma_i}$.

However, due to the fact that we sample $\beta_t \sim N(\mu, \Sigma)$ ($t$ is index of $\beta$ samples) to construct the first part of the ELBO, we cannot easily calculate $\frac{\partial \log P(\mathbf{A}, \mathbf{b} \mid \mathbf{\beta_t}) }{\partial \mu_i}$ and $\frac{\partial \log P(\mathbf{A}, \mathbf{b} \mid \mathbf{\beta_t}) }{\partial \sigma_i}$ for each $i \in {1,…D}$. This is where the reparameterization trick comes in: Instead of directly sample $\beta_t \sim N(\mu, \Sigma)$, we sample $\epsilon_t \sim N(\mathbf{0}, \mathbf{I})$ and calculate $\beta_t = \mu + \Sigma^{1/2} \zeta_t$ where $\zeta_t \sim N(\mathbf{0}, \mathbf{I})$. This way, we can easily calculate $\frac{\partial \log P(\mathbf{A}, \mathbf{b} \mid \mathbf{\beta_t}) }{\partial \mu_i}$ and $\frac{\partial \log P(\mathbf{A}, \mathbf{b} \mid \mathbf{\beta_t}) }{\partial \sigma_i}$ because the randomness in sampling $beta_t$ got attributed to $\zeta_t$, $mu_i$ and $\sigma_i$ are included as an deterministic part of the ELBO.

Step 5: Optimize the ELBO

By this step, we have turned the problem into an optimization problem, and in turn tossing the ball into the court of optimization researchers. We, of course, need to know all the criteria for framing a solvable optimization task (such as making sure all the functions are differentiable, and the loss function should be convex if possible, etc.), but that is a topic for another post. Here, I want to provide a snapshot code for constructing the model for the example case, and the training loop.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dist
class model(nn.Module): 
    '''
    The model module specify the generative process to get to the predicted value of b given input A. 
    Required functions: __init__ and forward
    '''
    def __init__(self, D):
        super(model, self).__init__()
        self.D = D
        # set up the parameters for the variational distribution. 
        # use nn.Parameter to declare model parameters, so pytorch can keep track of the gradients
        self.mu = nn.Parameter(torch.zeros(D, 1))  # vector of mu_1, ..., mu_D
        self.sigma = nn.Parameter(torch.ones(D, 1))  # vector of sigma_1, ..., sigma_D

    @staticmethod
    def reparameterize(mu, sigma):
       eps = torch.randn_like(sigma)
       return mu + sigma * eps
    
    def forward(self, A):
        beta = self.reparameterize(self.mu, self.sigma)
        return torch.matmul(A, beta)

#### Generate data ####
N = 1000
D = 2
true_beta = torch.tensor([3, 6]).reshape(D, 1)
A = torch.randn(N, D)
b = torch.matmul(A, true_beta) + torch.randn(N, 1)

#### Loss function ####
def neg_elbo(b_pred, b, model, mu_prior=0, sigma_prior=1):
    epsilon = dist.Normal(1, 1)
    log_llh = epsilon.log_prob(b-b_pred).sum()
    model.eval()
    kl = dist.kl_divergence(dist.Normal(model.mu, model.sigma), dist.Normal(mu_prior, sigma_prior)).sum()
    return -log_llh + kl  # minimize the negative of ELBO

#### Training loop ######
model = model(D)
optimizer = optim.Adam(model.parameters(), lr=0.01)  
# the optimize registers parameters in the model for which it needs to keep track of the gradients
for epoch in range(1000):
    optimizer.zero_grad()
    b_pred = model(A)
    loss = neg_elbo(b_pred, b, model)
    loss.backward()
    optimizer.step()

#### Print the estimated beta ####
print(model.mu, model.sigma)  # model.mu can be used as the estimated beta,
# model.mu and model.sigma combined defines the estimated posterior P(beta|A, b)

Generalizing the model for other cases

So far, I have outlined the steps that I usually take when I try to apply variational inference, with an example in which the variational distribution $q$ is picked to be Normal (and hence, a reparameterization trick for normal distribution). In many cases, due to the nature of the data (categorical, continuous, non-genative, etc.), the choice of the variational distribution $q$ needs to be adapted. For example:

Conclusion

  1. Jang, E. et al. (2016). Categorical Reparameterization with Gumbel-Softmax. ArXiv Preprint ArXiv:1611.01144.