Variational Inference

May 12, 2024

KL Divergence

Use a distribution qq to approximate pp. A common measure of the "distance" or "divergence" between the two distribution is Kullback-Leibler (KL) Divergence.

KL(pq)=p(x)logp(x)q(x)dx\text{KL}(p \parallel q)=\int p(x)\log{\frac{p(x)}{q(x)}}dx

The formal definition is the information lost when pp is approximated with qq. Therefore, the goal would be to minimize the KL Divergence. Note that KL(pq)KL(qp)\text{KL}(p \parallel q) \neq \text{KL}(q \parallel p). It turns out that minimizing KL(ppθ)\text{KL}(p \parallel p_\theta) gives the MLE.

Reverse KL Divergnece

It is easier to minimize the reverse KL Divergence KL(qp)\text{KL}(q\parallel p) with respect to qq when approximating pp with qq.

KL(qp)=q(x)logq(x)dxq(x)logp(x)dx=q(x)logq(x)dxq(x)logp~(x)dx+q(x)logZdx=Exq(logq(x))Exq(logp~(x))+Exq(logq(Z))=H(q)Exq(logp~(x))+Exq(logq(Z))\begin{align*} \text{KL}(q\parallel p) & = \int q(x)\log{q(x)}dx - \int q(x) \log{p(x)}dx \\& = \int q(x)\log{q(x)}dx - \int q(x) \log{\tilde{p}(x)}dx + \int q(x) \log{Z}dx \\& = \underset{x\sim q}{E}(\log{q(x)}) - \underset{x\sim q}{E}(\log{\tilde{p}(x)}) + \underset{x\sim q}{E}(\log{q(Z)}) \\& = -H(q) - \underset{x\sim q}{E}(\log{\tilde{p}(x)}) + \underset{x\sim q}{E}(\log{q(Z)}) \end{align*}
argminq{KL(qp)}=argmaxq{Exq(logp~(x))+H(q)}\underset{q}{\arg\min} \{\text{KL}(q\parallel p)\} = \underset{q}{\arg\max} \left\{\underset{x\sim q}{E}(\log{\tilde{p}(x)}) + H(q)\right\}

It's hard to optimize over an expectation so use reparametrization trick.

Example: qq is a multi variate gaussian.

Maximizing over qq means maximizing over μ\boldsymbol{\mu} and Σ\boldsymbol{\Sigma} (the parameter of the distribution).

Then the entropy H(q)=12logΣ+d2log2πeH(q) = \frac{1}{2}\log{|\boldsymbol{\Sigma}|} + \frac{d}{2}\log{2\pi e}, so the maximization problem becomes

argmaxμ,Σ{ExN(μ,Σ)(logp~(x))+12logΣ}\underset{\boldsymbol{\mu},\boldsymbol{\Sigma}}{\arg\max}\left\{\underset{x\sim N(\boldsymbol{\mu},\boldsymbol{\Sigma})}{E}(\log{\tilde{p}(x)}) + \frac{1}{2}\log{|\boldsymbol{\Sigma}|}\right\}

Use Cholesky Decomposition to reparametize qq - Σ=LL\boldsymbol{\Sigma} = LL^\top. That is, given a standard normal distribution zN(0,I)z\sim N(0, \boldsymbol{I}), q=μ+Lzq = \boldsymbol{\mu} + Lz. Therefore, the maximization problem can be reparametized to

argmaxμ,Σ{EzN(0,I)(logp~(μ+Lz))+12logΣ}\underset{\boldsymbol{\mu},\boldsymbol{\Sigma}}{\arg\max}\left\{\underset{z\sim N(0,\boldsymbol{I})}{E}(\log{\tilde{p}(\boldsymbol{\mu} + Lz)}) + \frac{1}{2}\log{|\boldsymbol{\Sigma}|}\right\}

Note that the expectation can be easily estimated using Monte Carlo because zz is easy to sample.