The Gumbel-Max Trick for Discrete Distributions

Ryan AdamsComputation, Probability

It often comes up in neural networks, generalized linear models, topic models and many other probabilistic models that one wishes to parameterize a discrete distribution in terms of an unconstrained vector of numbers, i.e., a vector that is not confined to the simplex, might be negative, etc. A very common way to address this is to use the “softmax” transformation:

    \begin{align*} \pi_k &= \frac{\exp\{x_k\}}{\sum_{k'=1}^K\exp\{x_{k'}\}} \end{align*}

where the x_k are unconstrained in \mathbb{R}, but the \pi_k live on the simplex, i.e., \pi_k \geq 0 and \sum_{k}\pi_k=1. The x_k parameterize a discrete distribution (not uniquely) and we can generate data by performing the softmax transformation and then doing the usual thing to draw from a discrete distribution. Interestingly, it turns out that there is an alternative way to arrive at such discrete samples, that doesn’t actually require constructing the discrete distribution.

It turns out that the following trick is equivalent to the softmax-discrete procedure: add Gumbel noise to each x_k and then take the argmax. That is, add independent noise to each one and then do a max. This doesn’t change the asymptotic complexity of the algorithm, but opens the door to some interesting implementation possibilities. How does this work? The Gumbel distribution with unit scale and location parameter \mu has the following PDF:

    \begin{align*} f(z\,;\,\mu) &= \exp\{-(z-\mu) - \exp\{-(z-\mu)\}\}. \end{align*}

The CDF of the Gumbel is

    \begin{align*} F(z\,;\,\mu) &= \exp\{-\exp\{-(z-\mu)\}\} \end{align*}

Now, imagine that the kth of our Gumbels, with location x_k, resulted in an outcome z_k. The probability that all of the other z_{k'\neq k} are less than this is

    \begin{align*} \Pr(z_k \text{ is largest}\,|\, z_k, \{x_{k'}\}^K_{k'=1}) &= \prod_{k'\neq k}\exp\{-\exp\{-(z_k-x_{k'})\}\} \end{align*}

We know the marginal distribution over z_k and we need to integrate it out to find the overall probability:

    \begin{align*} \Pr(\text{$k$ is largest}\,|\,\{x_{k'}\}) = \int \exp\{-(z_k-x_k)-\exp\{-(z_k-x_k)\}\}\\ \times \prod_{k'\neq k}\exp\{-\exp\{-(z_k-x_{k'})\}\} \,\mathrm{d}z_k\quad \end{align*}

With a little bit of algebra, we get:

    \begin{align*} \Pr(\text{$k$ is largest}\,|\,\{x_{k'}\}) =& \int \exp\{-z_k + x_k \\ &\qquad -\exp\{-z_k\} \sum_{k'=1}^K \exp\{x_k\}\}\,\mathrm{d}z_k\quad \end{align*}

It turns out that this integral has closed form:

    \begin{align*} \Pr(\text{$k$ is largest}\,|\,\{x_{k'}\}) = \frac{\exp\{x_k\}}{\sum_{k'=1}^K\exp\{x_{k'}\}} \end{align*}

We can see that this is exactly the softmax probability!