Variational Inference (part 1)

[latexpage]

I will dedicate the next few posts to variational inference methods as a way to organize my own understanding – this first one will be pretty basic.

The goal of variational inference is to approximate an intractable probability distribution, $p$, with a tractable one, $q$, in a way that makes them as ‘close’ as possible. Let’s unpack that statement a bit.

  1. Intractable $p$: a motivating example is the posterior distribution of a Bayesian model, i.e. given some observations $x = (x_1, x_2, \dots, x_n)$ and some model $p(x | \theta)$ parameterized by $\theta = (\theta_1, \dots, \theta_d)$, we often want to evaluate the distribution over parameters
    \begin{align*}
    p(\theta | x) = \frac{ p(x | \theta) p(\theta) }{ \int p(x | \theta) p(\theta) d \theta }
    \end{align*}
    For a lot of interesting models this distribution is intractable to deal with because of the integral in the denominator. We can evaluate the posterior up to a constant, but we can’t compute the normalization constant. Applying variational inference to posterior distributions is sometimes called variational Bayes.

  2. A tractable posterior distribution is one for which we can evaluate the integral (and therefore take expectations with it). One way to achieve this is by making each $\theta_i$ independent:
    \begin{align*}
    q(\theta | \lambda) = \prod_{i=1}^d q_i(\theta_i | \lambda_i)
    \end{align*}
    for some marginal distributions $q_i$ parameterized by $\lambda_i$. This is called the ‘mean field’ approximation.

  3. We can make two distributions ‘close’ in the Kullback-Leibler sense:
    \begin{align*}
    KL(q || p) &\equiv \int q(\theta) \log \frac{ q(\theta) }{ p(\theta | x) } d\theta \\
    &= E_q \left(\log \frac{ q(\theta) }{ p(\theta | x) } \right)
    \end{align*}
    Notice that we chose $KL(q||p)$ and not $KL(p||q)$ because of the intractability – we’re assuming you cannot evaluate $E_p(\cdot)$.

In order to minimize the KL divergence $KL(q||p)$, note the following decomposition
\begin{align*}
\log p(x) &= \log \frac{ p(\theta | x) p(x) }{ p(\theta | x)} \\
&= \log \frac{ q(\theta) }{ p(\theta | x) } + \log \frac{p(x, \theta)}{q(\theta)} \\
E_q(\log p(x)) &= E_q \left( \log \frac{ q(\theta) }{ p(\theta | x) } + \log \frac{p(x, \theta)}{q(\theta)} \right) \\
\log p(x) &= KL(q || p) + E_q \left( \log \frac{p(x, \theta)}{q(\theta)} \right)
\end{align*}

where $\log p(x)$ falls out of the expectation with respect to $q(\theta)$. This is convenient because $\log p(x)$ (the log evidence) is going to be constant with respect to the variational distribution $q$ and the model parameters $\theta$.

And because KL divergence is strictly nonnegative, the second term $E_q \left( \log \frac{p(x, \theta)}{q(\theta)} \right)$ is a lower bound for $\log p(x)$, also known as the evidence lower bound (ELBO). In order to minimize the first term, $KL(q || p)$, it suffices to maximize the second.

One way to optimize over the choice of $q(\theta) = \prod_i q_i(\theta_i)$ is to consider the ELBO with respect some $q_j(\theta_j)$, separating it from the expectation with respect to all other variables, $E_{-q_j}(\cdot)$ (note that $\int_{-\theta_j}$ is the integral over all $\theta_i$ with $i \neq j$):
\begin{align*}
ELBO(q_j) &= \int \prod_i q_i(\theta_i) \left( \log p(x, \theta) – \sum_k \log q_k(\theta_k) \right)\\
&= \int_{\theta_j} q_j(\theta_j) \int_{-\theta_j} \prod_{i \neq j} q_i(\theta_i) \times \\
&~~ \left( \log p(x, \theta) – \sum_k \log q_k(\theta_k) \right) \\
&= \int_{\theta_j} q_j(\theta_j) E_{-q_j} [ \log p(x, \theta) ] \\
&~~ – \int_{\theta_j} q_j(\theta_j) \log q_j(\theta_j) + const. \\
&= -KL(q_j || \tilde{q}_j) + const.
\end{align*}
where
\begin{align*}
\tilde{q_j} \propto \exp\left( E_{-q_j}[\log p(x,\theta)] \right)
\end{align*}
This motivates a particular updating scheme: iterate over the marginals $q_j$, maximizing the ELBO at each step with $\tilde{q}_j$, and repeat.

These statements remain pretty general. We haven’t specified the functional form of $q_j(\theta_j)$, but it will fall out of the $E_{-q_j}[\log p(x,\theta)]$ for specific models (a good, simple example is the normal-gamma model from Murphy’s MLAPP or Bishop’s PRML). This form will then define the variational parameters $\lambda_i$, and the iterative algorithm will provide a way to compute them.

Some fun properties of the mean field approximation:

  • The optimization function is not guaranteed to be convex (wainwright and jordan)
  • The optimization procedure is pretty much out of the box for exponential family models (maybe a future blog post will be dedicated to exploring this)
  • This variational approximation underestimates uncertainty (a consequence that pops out as a result of the KL divergence ordering, $KL(q || p)$ as opposed to $KL(p || q)$).

This begs the question, can you do better with richer $q$ distributions? Yes! In future posts, I will take a look at some more complicated variational distributions and their performance.

references:

  1. Kevin Murphy. Machine learning: A probabilistic perspective
  2. Wainwright and Jordan. Graphical Models, Exponential Families, and Variational Inference