Proof of the Gumbel Max Trick

Statement

Assume that $\alpha_1, \alpha_2, \ldots, \alpha_n$ satisify $\sum_k\alpha_k=1$. Define

$$Z=\arg\max_k\{\log\alpha_k+G_k\}$$

where $G_1,\ldots,G_n \text{ i.i.d.}\sim Gumbel(0,1)$, whose PDF and CDF are defined as

$$\begin{align} f(x)&=e^{-(x+e^{-x})} \\ F(x)&=e^{-e^{-x}}\end{align}$$

. Then $\mathbb{P}(Z=k)=\alpha_k$.

Proof

Set $u_k=\log{\alpha_k}+G_k$. We prove by direct calculations.

$$\begin{align} \mathbb{P}(Z=k)&=\mathbb{P}(u_k \geq u_j,\forall j \neq k) \\ &=\int_{-\infty}^\infty \mathbb{P}(u_k \geq u_j, \forall j \neq k|u_k)\mathbb{P}(u_k) du_k \\ &=\int_{-\infty}^\infty \prod_{j\neq k}\mathbb{P}(u_k \geq u_j|u_k)\mathbb{P}(u_k) du_k \\ &=\int_{-\infty}^\infty \prod_{j\neq k}e^{-e^{-u_k+\log \alpha_j}} e^{-(u_k-\log\alpha_k+e^{-(u_k-\log\alpha_k)})} du_k \\ &=\int_{-\infty}^\infty e^{-\sum_{j\neq k}\alpha_je^{-u_k}} \alpha_k e^{-(u_k+\alpha_k e^{-u_k})} du_k \\ &=\alpha_k \int_{-\infty}^\infty e^{-u_k-(\alpha_k+\sum_{j\neq k}\alpha_j)e^{-u_k}} du_k \\ &= \alpha_k \end{align}$$.

Application

The trick is commonly used in DL to make sampling over a discrete distribution differentiable.

References


Author: hsfzxjy.
Link: .
License: CC BY-NC-ND 4.0.
All rights reserved by the author.
Commercial use of this post in any form is NOT permitted.
Non-commercial use of this post should be attributed with this block of text.

«西郊线

OOPS!

A comment box should be right here...But it was gone due to network issues :-(If you want to leave comments, make sure you have access to disqus.com.