Diffusion Models

Original paper: Denoising Diffusion Probabilistic Models
Survey: Understanding Diffusion Models A Unified Perspective

Intuition

A diffusion model is trained through a diffusion process that progressively adds noise to the original data. The model then learns how to reconstruct the original data from this noisy input.

graph LR;
A[Original Data] -->|Forward Diffusion| B[Noisy Data];
B -->|Reverse Denoising| A

Once the model has sufficiently learned to reconstruct the data distribution from a typically Gaussian noise distribution, it gains the capability to generate new, novel data.

Forward Process

At each time step t, the transition probability from input xt1 to output xt is defined as a Gaussian:

q(xtxt1)=N(1βtxt1,βtI),

where βt(0,1).

Reparameterization

We can generate xt using the reparameterization trick:

xt=1βtxt1+βtϵt,ϵtN(0,I).

Furthermore, we can generate xt from x0 with one reparameterization step:

xt=α¯tx0+1α¯tϵ,

where

α¯t:=i=1tαi,αi=1βi,ϵN(0,I).

Asymptotics

Since αi(0,1), as t

α¯t0.

Therefore, regardless of x0,

xN(0,I).

Reverse Process

The posterior distribution, q(xt1xt), is generally intractable:

q(xt1xt)=q(xtxt1)q(xt1)q(xt),

where

q(xt)=x0Xq(xtx0)q(x0)dx0.

Therefore, we learn a network parameterized by θ to approximate q(xt1xt) with a Gaussian distribution:

pθ(xt1xt):=N(μθ(xt,t),Σθ(xt,t)).

We approximate the log-likelihood using the ELBO:

logpθ(x0)=Eq(x0)[logpθ(x0)]Eq[logpθ(x0:T)q(x1:Tx0)]=Eq[DKL(q(xTx0)∣∣pθ(xT))LT+t>1TDKL(q(xt1xt,x0)∣∣pθ(xt1xt))L1:T1logpθ(x0x1)L0].

The LT Term

Since xT is directly sampled from a standard Gaussian distribution, LT has no learnable parameters for fixed β1:T, thus ignored from the training objective.

The L1:T1 Term

The posterior conditioned on x0, q(xt1xt,x0), is a tractable Gaussian:

q(xt1xt,x0)=N(μ~t(xt,x0),β~tI),

where

μ~t(xt,x0):=1αtxt1αt1α¯tαtϵ;β~t:=1α¯t11α¯tβt.

For simplicity, Σθ(xt,t) is set to untrained time-dependent constants σt2I, where σt2=β~t. Since q(xt1xt,x0) and pθ(xt1xt) are both tractable Gaussians, the KL divergence has a closed form solution (see KL divergence between Gaussians):

Lt1=12σt2μ~t(xt,x0)μθ(xt,t)2+C.

The L0 Term

To obtain discrete log-likelihood, this term is set to an independent discrete decoder derived from Gaussian N(x0;μθ(x1,1),σ12I). For discrete xt of dimension D,

pθ(x0x1)=i=1DF(x)=(x0)iN(x;(μθ)i,σ12)dx,

where F(x) is a discretization function that maps xR into discrete (x0)i.

Training

The formulation of μ~t(xt,x0) indicates that instead of directly predicting a mean μθ(xt,t), it is better to predict a ϵθ(xt,t), such that

Lt1Cϵϵθ(xt,t)2.

Combining with the reparameterization step, we have the training objective:

ϵϵθ(α¯tx0+1α¯tϵ)2.

We thus have the training algorithm:

\begin{algorithm}
\caption{Training}
\begin{algorithmic}
\Repeat
    \State $\mathbf{x}_0 \sim q(\mathbf{x}_0)$
    \State $t \sim \text{Uniform}(\{1, \ldots, T\})$
    \State $\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
    \State Take gradient descent step on $\nabla_\theta \| \boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta (\sqrt{\bar\alpha_t} \mathbf{x}_0 + \sqrt{1-\bar\alpha_t} \boldsymbol{\epsilon}, t) \|^2$
\Until{converged}
\end{algorithmic}
\end{algorithm}

Sampling

The sampling algorithm also uses the reparameterization trick to sample xt1 given xt:

xt1=1αt(xt1αt1α¯tϵθ(xt,t))+σtz,

where zN(0,I) if t>1, and z is set to zero when t=1 to generate deterministic results.


\begin{algorithm}
\caption{Sampling}
\begin{algorithmic}
\State $\mathbf{x}_T \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
\For{$t = T, \cdots, 1$}
    \State $\mathbf{z} \sim \begin{cases} 
    \mathcal{N}(\mathbf{0}, \mathbf{I}) & \text{if } t > 1 \\
    \mathbf{0} & \text{otherwise}
    \end{cases}$
    \State $\mathbf{x}_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left(\mathbf{x}_t - \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}} \boldsymbol\epsilon_\theta(\mathbf{x}_t, t)\right) + \sigma_t \mathbf{z}$
\EndFor
\State \Return $\mathbf{x}_0$
\end{algorithmic}
\end{algorithm}

Network Architecture

The network for estimating ϵθ(xt,t) is typically a U-Net architecture, where parameters are shared across time steps, and the time step t is taken as input.

Extensions and Applications

Conditional Generation

Guidance Methods

Latent Diffusion Models