The Gumbel-Max Trick for Discrete Distributions

Ryan Adams · April 6, 2013

It often comes up in neural networks, generalized linear models, topic models and many other probabilistic models that one wishes to parameterize a discrete distribution in terms of an unconstrained vector of numbers, i.e., a vector that is not confined to the simplex, might be negative, etc. A very common way to address this is to use the "softmax" transformation:
\begin{align*}
\pi_k &= \frac{\exp\{x_k\}}{\sum_{k'=1}^K\exp\{x_{k'}\}}
\end{align*}
where the $x_k$ are unconstrained in $\mathbb{R}$, but the $\pi_k$ live on the simplex, i.e., $\pi_k \geq 0$ and $\sum_{k}\pi_k=1$. The $x_k$ parameterize a discrete distribution (not uniquely) and we can generate data by performing the softmax transformation and then doing the usual thing to draw from a discrete distribution. Interestingly, it turns out that there is an alternative way to arrive at such discrete samples, that doesn't actually require constructing the discrete distribution.

It turns out that the following trick is equivalent to the softmax-discrete procedure: add Gumbel noise to each $x_k$ and then take the argmax. That is, add independent noise to each one and then do a max. This doesn't change the asymptotic complexity of the algorithm, but opens the door to some interesting implementation possibilities. How does this work? The Gumbel distribution with unit scale and location parameter $\mu$ has the following PDF:
\begin{align*}
f(z\,;\,\mu) &= \exp\{-(z-\mu) - \exp\{-(z-\mu)\}\}.
\end{align*}
The CDF of the Gumbel is
\begin{align*}
F(z\,;\,\mu) &= \exp\{-\exp\{-(z-\mu)\}\}
\end{align*}
Now, imagine that the $k$th of our Gumbels, with location $x_k$, resulted in an outcome $z_k$. The probability that all of the other $z_{k'\neq k}$ are less than this is
\begin{align*}
\Pr(z_k \text{ is largest}\,|\, z_k, \{x_{k'}\}^K_{k'=1}) &= \prod_{k'\neq k}\exp\{-\exp\{-(z_k-x_{k'})\}\}
\end{align*}
We know the marginal distribution over $z_k$ and we need to integrate it out to find the overall probability:
\begin{align*}
\Pr(\text{$k$ is largest}\,|\,\{x_{k'}\}) =
\int \exp\{-(z_k-x_k)-\exp\{-(z_k-x_k)\}\}\\
\times \prod_{k'\neq k}\exp\{-\exp\{-(z_k-x_{k'})\}\} \,\mathrm{d}z_k\quad
\end{align*}
With a little bit of algebra, we get:
\begin{align*}
\Pr(\text{$k$ is largest}\,|\,\{x_{k'}\}) =&
\int \exp\{-z_k + x_k \\
&\qquad -\exp\{-z_k\} \sum_{k'=1}^K \exp\{x_k\}\}\,\mathrm{d}z_k\quad
\end{align*}
It turns out that this integral has closed form:
\begin{align*}
\Pr(\text{$k$ is largest}\,|\,\{x_{k'}\}) = \frac{\exp\{x_k\}}{\sum_{k'=1}^K\exp\{x_{k'}\}}
\end{align*}
We can see that this is exactly the softmax probability!

Twitter, Facebook