I will dedicate the next few posts to variational inference methods as a way to organize my own understanding - this first one will be pretty basic.

The goal of variational inference is to approximate an intractable probability distribution, $p$, with a tractable one, $q$, in a way that makes them as 'close' as possible. Let's unpack that statement a bit.

- Intractable $p$: a motivating example is the posterior distribution of a Bayesian model, i.e. given some observations $x = (x_1, x_2, \dots, x_n)$ and some model $p(x | \theta)$ parameterized by $\theta = (\theta_1, \dots, \theta_d)$, we often want to evaluate the distribution over parameters

\begin{align*}

p(\theta | x) = \frac{ p(x | \theta) p(\theta) }{ \int p(x | \theta) p(\theta) d \theta }

\end{align*}

For a lot of interesting models this distribution is intractable to deal with because of the integral in the denominator. We can evaluate the posterior up to a constant, but we can't compute the normalization constant. Applying variational inference to posterior distributions is sometimes called variational Bayes. </p> - A tractable posterior distribution is one for which we can evaluate the integral (and therefore take expectations with it). One way to achieve this is by making each $\theta_i$ independent:

\begin{align*}

q(\theta | \lambda) = \prod_{i=1}^d q_i(\theta_i | \lambda_i)

\end{align*}

for some marginal distributions $q_i$ parameterized by $\lambda_i$. This is called the 'mean field' approximation. </p> - We can make two distributions 'close' in the Kullback-Leibler sense:

\begin{align*}

KL(q || p) &\equiv \int q(\theta) \log \frac{ q(\theta) }{ p(\theta | x) } d\theta \\

&= E_q \left(\log \frac{ q(\theta) }{ p(\theta | x) } \right)

\end{align*}

Notice that we chose $KL(q||p)$ and not $KL(p||q)$ because of the intractability - we're assuming you cannot evaluate $E_p(\cdot)$.

In order to minimize the KL divergence $KL(q||p)$, note the following decomposition

\begin{align*}

\log p(x) &= \log \frac{ p(\theta | x) p(x) }{ p(\theta | x)} \\

&= \log \frac{ q(\theta) }{ p(\theta | x) } + \log \frac{p(x, \theta)}{q(\theta)} \\

E_q(\log p(x)) &= E_q \left( \log \frac{ q(\theta) }{ p(\theta | x) } + \log \frac{p(x, \theta)}{q(\theta)} \right) \\

\log p(x) &= KL(q || p) + E_q \left( \log \frac{p(x, \theta)}{q(\theta)} \right)

\end{align*}

where $\log p(x)$ falls out of the expectation with respect to $q(\theta)$. This is convenient because $\log p(x)$ (the log evidence) is going to be constant with respect to the variational distribution $q$ and the model parameters $\theta$.

And because KL divergence is strictly nonnegative, the second term $E_q \left( \log \frac{p(x, \theta)}{q(\theta)} \right)$ is a lower bound for $\log p(x)$, also known as the evidence lower bound (ELBO). In order to minimize the first term, $KL(q || p)$, it suffices to maximize the second.

One way to optimize over the choice of $q(\theta) = \prod_i q_i(\theta_i)$ is to consider the ELBO with respect some $q_j(\theta_j)$, separating it from the expectation with respect to all other variables, $E_{-q_j}(\cdot)$ (note that $\int_{-\theta_j}$ is the integral over all $\theta_i$ with $i \neq j$):

\begin{align*}

ELBO(q_j) &= \int \prod_i q_i(\theta_i) \left( \log p(x, \theta) - \sum_k \log q_k(\theta_k) \right)\\

&= \int_{\theta_j} q_j(\theta_j) \int_{-\theta_j} \prod_{i \neq j} q_i(\theta_i) \times \\

&~~ \left( \log p(x, \theta) - \sum_k \log q_k(\theta_k) \right) \\

&= \int_{\theta_j} q_j(\theta_j) E_{-q_j} [ \log p(x, \theta) ] \\

&~~ - \int_{\theta_j} q_j(\theta_j) \log q_j(\theta_j) + const. \\

&= -KL(q_j || \tilde{q}_j) + const.

\end{align*}

where

\begin{align*}

\tilde{q_j} \propto \exp\left( E_{-q_j}[\log p(x,\theta)] \right)

\end{align*}

This motivates a particular updating scheme: iterate over the marginals $q_j$, maximizing the ELBO at each step with $\tilde{q}_j$, and repeat.

These statements remain pretty general. We haven't specified the functional form of $q_j(\theta_j)$, but it will fall out of the $E_{-q_j}[\log p(x,\theta)]$ for specific models (a good, simple example is the normal-gamma model from Murphy's MLAPP or Bishop's PRML). This form will then define the variational parameters $\lambda_i$, and the iterative algorithm will provide a way to compute them.

Some fun properties of the mean field approximation:

- The optimization function is not guaranteed to be convex (wainwright and jordan)
- The optimization procedure is pretty much out of the box for exponential family models (maybe a future blog post will be dedicated to exploring this)
- This variational approximation underestimates uncertainty (a consequence that pops out as a result of the KL divergence ordering, $KL(q || p)$ as opposed to $KL(p || q)$).

This begs the question, can you do better with richer $q$ distributions? Yes! In future posts, I will take a look at some more complicated variational distributions and their performance.

references:

- Kevin Murphy. Machine learning: A probabilistic perspective
- Wainwright and Jordan. Graphical Models, Exponential Families, and Variational Inference