Variational Inference (part 1)

Andy MillerMachine Learning, Probability, Statistics, Uncategorized

I will dedicate the next few posts to variational inference methods as a way to organize my own understanding – this first one will be pretty basic.

The goal of variational inference is to approximate an intractable probability distribution, p, with a tractable one, q, in a way that makes them as ‘close’ as possible. Let’s unpack that statement a bit.

  1. Intractable p: a motivating example is the posterior distribution of a Bayesian model, i.e. given some observations x = (x_1, x_2, \dots, x_n) and some model p(x | \theta) parameterized by \theta = (\theta_1, \dots, \theta_d), we often want to evaluate the distribution over parameters

        \begin{align*}   p(\theta | x) = \frac{ p(x | \theta) p(\theta) }{ \int p(x | \theta) p(\theta) d \theta } \end{align*}

    For a lot of interesting models this distribution is intractable to deal with because of the integral in the denominator. We can evaluate the posterior up to a constant, but we can’t compute the normalization constant. Applying variational inference to posterior distributions is sometimes called variational Bayes.

  2. A tractable posterior distribution is one for which we can evaluate the integral (and therefore take expectations with it). One way to achieve this is by making each \theta_i independent:

        \begin{align*}   q(\theta | \lambda) = \prod_{i=1}^d q_i(\theta_i | \lambda_i)  \end{align*}

    for some marginal distributions q_i parameterized by \lambda_i. This is called the ‘mean field’ approximation.

  3. We can make two distributions ‘close’ in the Kullback-Leibler sense:

        \begin{align*}   KL(q || p) &\equiv \int q(\theta) \log \frac{ q(\theta) }{ p(\theta | x) } d\theta \\    	         &= E_q \left(\log \frac{ q(\theta) }{ p(\theta | x) } \right) \end{align*}

    Notice that we chose KL(q||p) and not KL(p||q) because of the intractability – we’re assuming you cannot evaluate E_p(\cdot).

In order to minimize the KL divergence KL(q||p), note the following decomposition

    \begin{align*}   \log p(x) &= \log \frac{ p(\theta | x) p(x) }{ p(\theta | x)} \\   		    &= \log \frac{ q(\theta) }{ p(\theta | x) } + \log \frac{p(x, \theta)}{q(\theta)} \\   E_q(\log p(x)) &= E_q \left( \log \frac{ q(\theta) }{ p(\theta | x) } + \log \frac{p(x, \theta)}{q(\theta)} \right) \\   \log p(x) &= KL(q || p) + E_q \left( \log \frac{p(x, \theta)}{q(\theta)} \right) \end{align*}

where \log p(x) falls out of the expectation with respect to q(\theta). This is convenient because \log p(x) (the log evidence) is going to be constant with respect to the variational distribution q and the model parameters \theta.

And because KL divergence is strictly nonnegative, the second term E_q \left( \log \frac{p(x, \theta)}{q(\theta)} \right) is a lower bound for \log p(x), also known as the evidence lower bound (ELBO). In order to minimize the first term, KL(q || p), it suffices to maximize the second.

One way to optimize over the choice of q(\theta) = \prod_i q_i(\theta_i) is to consider the ELBO with respect some q_j(\theta_j), separating it from the expectation with respect to all other variables, E_{-q_j}(\cdot) (note that \int_{-\theta_j} is the integral over all \theta_i with i \neq j):

    \begin{align*}    ELBO(q_j) &= \int \prod_i q_i(\theta_i) \left( \log p(x, \theta) - \sum_k \log q_k(\theta_k) \right)\\    	         &= \int_{\theta_j} q_j(\theta_j) \int_{-\theta_j} \prod_{i \neq j} q_i(\theta_i) \times \\                  &~~ \left( \log p(x, \theta) - \sum_k \log q_k(\theta_k) \right) \\ 	         &= \int_{\theta_j} q_j(\theta_j) E_{-q_j} [ \log p(x, \theta) ] \\                  &~~ - \int_{\theta_j} q_j(\theta_j) \log q_j(\theta_j) + const. \\ 	         &= -KL(q_j || \tilde{q}_j) + const. \end{align*}

where

    \begin{align*}   \tilde{q_j} \propto \exp\left( E_{-q_j}[\log p(x,\theta)] \right) \end{align*}

This motivates a particular updating scheme: iterate over the marginals q_j, maximizing the ELBO at each step with \tilde{q}_j, and repeat.

These statements remain pretty general. We haven’t specified the functional form of q_j(\theta_j), but it will fall out of the E_{-q_j}[\log p(x,\theta)] for specific models (a good, simple example is the normal-gamma model from Murphy’s MLAPP or Bishop’s PRML). This form will then define the variational parameters \lambda_i, and the iterative algorithm will provide a way to compute them.

Some fun properties of the mean field approximation:

  • The optimization function is not guaranteed to be convex (wainwright and jordan)
  • The optimization procedure is pretty much out of the box for exponential family models (maybe a future blog post will be dedicated to exploring this)
  • This variational approximation underestimates uncertainty (a consequence that pops out as a result of the KL divergence ordering, KL(q || p) as opposed to KL(p || q)).

This begs the question, can you do better with richer q distributions? Yes! In future posts, I will take a look at some more complicated variational distributions and their performance.

references:

  1. Kevin Murphy. Machine learning: A probabilistic perspective
  2. Wainwright and Jordan. Graphical Models, Exponential Families, and Variational Inference