Diffusion Without Tears

Background context and a primer for our “History of Diffusion” interviews with Jascha Sohl-Dickstein, Yang Song and Sander Dieleman.

We’re excited to release our “History of Diffusion” interviews with Jascha Sohl-Dickstein, Yang Song, and Sander Dieleman. Those interviews have some surprising revelations about the research hunches, near misses, and other behind-the-scenes ups-and-downs that led to landmark diffusion papers.

The revelations are much more interesting with a little background knowledge about diffusion. When Yang Song says, for example, that the simplicity of the closed-form reverse-time SDE was a surprise to him and his coauthors, it helps to know what that equation is and why it was surprising.

That’s the purpose of this article: to give the background context that will help you get the most out of the interviews with Jascha, Yang Song, and Sander. The post has two parts: Part 1 is a self-contained explanation of diffusion models. It covers the SDE interpretation, and works out a 2D example “by hand” to give some geometric intuition. Part 2 is a brief history, focusing on Jascha’s original paper, DDPM, and SDEs.

To read this, it helps to remember a little vector calculus. Here’s a (open book) quiz. Let f(x,y)=(x2+y2,x+y)f(x,y) = (x^2 + y^2, x+y). What is yf\partial_y f What is the work integral of yf\partial_y f along the line segment from (0,0)(0,0) to (1,1)(1,1) Using Stoke’s/Green’s Theorem, what is Df\int_{\partial D} f, where DD is the unit disk? If you can answer these questions, you’re good to go.

Part 1: Diffusion Explained

In Part 1, we’ll explain diffusion three times:

  • Once at a high level,
  • Once at a lower level, quickly
  • Once at that same lower level, slowly, with a worked example in 2D.

The advantage of doing an example in 2D is that it builds some geometric intuition for what’s going on. Throughout all of this, we’ll talk mainly about image generation models, since that’s most intuitive and that’s what many of the original papers were about.

We’re not going to cover diffusion guidance at all. This a major omission, since classifier-free guidance is what most people have in mind when they think about Midjourney, Ideogram, or other image-generation tools (inputting “ultra realistic iPhone photo of a panda on a water slide” is the classifier-free text guidance you use with Midjourney, Ideogram, Dalle, etc.) The reason I’m not covering guidance is that you don’t need to understand it to appreciate the papers we cover in the interviews. However, guidance is an interesting and important subject, and I highly recommend Sander’s excellent post as a starting point; it should be very accessible after reading this post.

Once at a High Level

At a high level, image gen diffusion models work like this:

  1. You get a big dataset of images.
  2. You corrupt those images. Specifically, you iteratively add noise to those images over time. Let’s call x0x_0 your original image and xtx_t your image at time tt When tt is small, you’ve only added a little noise and xtx_t looks like a fuzzy / static-y version of x0x_0. When tt is large, xtx_t looks like “pure noise”.
  1. During training, you remove the noise. We train a neural network to take the input (xt,t)(x_t,t) and produce the output x0x_0. The network is rewarded for successfully removing the noise from xtx_t to reconstruct the original x0x_0. (This step is actually a little more complicated, and I’m fudging the explanation, as you’ll see when you keep reading; but this paragraph is still approximately correct.)
  2. During inference, you input pure noise into your trained model. Your model will think this is (xT,T)(x_T,T) for some very large time TT, and it’ll try to output the corresponding image x0x_0. And if you trained your model well, x0x_0 will look like the original images from your data set. Another way to think about this is that the model will “hallucinate, on purpose” a clean image x0x_0 when given pure noise xTx_T.

That’s it in a nutshell, and it’s useful to keep the high-level picture in mind. However, this explanation isn’t quite correct (some details have been over-simplified), and it leaves a lot unspecified. Let’s go deeper.

Once at a Lower Level, Briefly

There are several different ways to interpret diffusion models. I’m going to explain the stochastic differential equation interpretation, developed by Song et al. This method works by adding noise through a stochastic differential equation (SDE), and removing noise through a different SDE.

There are a few moving parts in this diffusion interpretation, and I think the best way to understand them is to see them twice:

  • Once briefly, so that you can see all the parts at once and keep the big picture in mind; and
  • Once slowly, so that you can build intuition for each part individually.

This section is the fast overview. I recommend reading it once, then reading the next section (fully-worked 2D example), and then coming back to this section again. For structure, we’ll mirror the high-level overview (the section above) by going deeper into data, adding noise, removing noise, and inference. We’ll do the same thing in the next section.

Our Data

To start, we imagine that our images x0x_0 are drawn from some unknown “data distribution” with mass function p0p_0. We’ll never actually calculate these probabilities, but we’re imaging that p0p_0 assigns a probability density to every point in R512×512×3\mathbb{R}^{512\times512\times3}, or whatever the dimension of our image data is.

Adding Noise

The way we add noise to our images is through a stochastic differential equation:

dx=f(x,t)dt+g(t)dwdx = f(x,t)dt + g(t)dw

where w is Brownian motion. If you’ve never seen an SDE before, you can think of this as roughly equivalent to a sequence x0,x1,x2,...x_0, x_1, x_2, ... where:

xn+1=xn+f(xn,tn)Δt+g(tn)Δwx_{n+1} = x_n + f(x_n,t_n)\Delta_t +g(t_n)\Delta_w

Δt=(tntn1)\Delta_t = (t_n-t_{n-1}) , and ΔwN(0,IΔt)\Delta_w \sim N(0,I\Delta_t) is a small amount of Gaussian random noise. If you let Δt0\Delta_t \to 0, you get the infinitesimal dtdt. Something analogous is true for Brownian motion Δwdw\Delta_w \to dw, but the details are more complicated.

You can roughly think of this whole scheme as the continuous version of adding Gaussian noise over time. We know that it means to go from xtx_t to xt+ϵΔtx_t + \epsilon_{\Delta_t}, where ϵΔtN(0,IΔt)\epsilon_{\Delta_t} \sim N(0,I\Delta_t). What happens when we let Δt0\Delta_t \to 0? You get an SDE, as above.

Here’s a picture of some sample paths in one dimension when dx=xdt+t2dwdx = xdt + t^2dw:

Removing Noise

The way we remove noise is through a reverse-time stochastic differential equation. If we run time forwards with dx=f(x,t)dt+g(t)dwdx = f(x,t)dt + g(t)dw, we can run it backwards with:

dx=[f(x,t)g(t)2xlog(pt(x))]dt+g(t)dwˉdx = [{\color{red}f(x,t)}-{\color{red}g(t)^2}{\color{green}\nabla_x}\log({\color{orange}p_t(x)})] dt + {\color{red}g(t)}{\color{purple}d\bar{w}}

There’s a lot going on in this equation, so I’ve highlighted a few parts with different colors\text{{\color{red}c}{\color{red}o}{\color{green}l}{\color{orange}o}{\color{red}r}{\color{purple}s}} so that we can examine them more closely now. We already know f(x,t){\color{red}f(x,t)}, and g(t){\color{red}g(t)} from our forward-time equation; these are just functions that we assume as given. The function pt(x){\color{orange}p_t(x)} is the deformation of our data distribution p0p_0 at time tt, obtained by applying the forward-time SDE. You can obtain samples from ptp_t by sampling x0x_0 from your data set and then using the forward-time SDE to get a sample xtx_t. The expression dwˉ{\color{purple}d\bar{w}} is the differential of reverse-time Brownian motion wˉ\bar{w}; you can interpret wˉ\bar{w} as a Brownian bridge or, simply, view dwˉd\bar{w} as a formal expression to be integrated.

The important thing is that dwdw and dwˉd\bar{w} are sort-of-identical, because on an infinitesimal scale, forward- and reverse-time Brownian motion have the same distribution. (Go watch some pollen particles get jostled around; you can’t tell which direction time is running.) Finally, KaTeX can only parse string typed expression is the data gradient. If x=(x1,...,xn)x = (x_1,..., x_n), then:

xlog(pt(x)=(x1pt(x),x2pt(x),...,xnpt(x))\nabla_x\log(p_t(x) = \left(\frac{\partial}{\partial x_1}p_t(x), \frac{\partial}{\partial x_2}p_t(x), ..., \frac{\partial}{\partial x_n}p_t(x)\right)

This seems easy.

This seems hard.

During training, we train a neural network with params θ\theta as an approximator sθ(x,t)s(x,t)s_\theta(x,t) \approx s(x,t) for the score function. Specifically, our ideal choice θ\theta^* for θ\theta is given by:

θ:=arg minθEt[λ(t)Ex0(Extx0(sθ(xt,t)xlog(pt(xtx0))2))].\theta^* := \argmin_\theta {\color{cyan}\mathbb{E}_t}\left[\lambda(t){\color{brown}\mathbb{E}_{x_0}}({\color{magenta}\mathbb{E}_{x_t|x_0}}(||s_\theta(x_t,t) - \nabla_x\log(p_t(x_t|x_0))||^2))\right].

This expression for θ\theta^* looks more complicated than it is. We want to make sθ(xt,t)xlog(pt(xtx0))2||s_\theta(x_t,t) - \nabla_x\log(p_t(x_t|x_0))||^2 small. Given a starting point x0x_0, we want to do this across all xtx0x_t | x_0, weighted by the probability of going from x0x_0 to xtx_t. That’s the innermost E{\color{magenta}\mathbb{E}}. And of course, we want to do this weighed by the probability of the first x0x_0 in the first place. That’s the middle E{\color{brown}\mathbb{E}}. And we should do this across all tt (the outermost E{\color{cyan}\mathbb{E}}). We want some weighing function λ\lambda because our job is easy for small tt and hard for large tt. The advice from Song et al is to choose:

λ(t)1E(xtlog(pt(xtx0))2)\lambda(t) \propto \frac{1}{\mathbb{E}\left(||\nabla_{x_t}\log (p_t(x_t|x_0))||^2\right)}

Inference

Once we’ve actually trained sθ(x,t)xlog(pt(x))s_\theta(x,t) \approx \nabla_x\log(p_t(x)), we can generate samples by using the reverse-time equation. Here’s how it works. For reasonable choices of ff and gg, pTp_T is going to be Gaussian distributed for large enough TT. In practice, the whole point of diffusion is to make pTp_T as simple as possible, so usually we set things up so that pT=N(0,σT2I)p_T = N(0,\sigma_T^2I) for some scalar σT\sigma_T.

To go from xTx_T to x0x_0, you first generate KaTeX can only parse string typed expression, where ΣT\Sigma_T is the covariance for pTp_T and μT\mu_T is the mean. Then you iteratively use the reverse time equation dx=[f(x,t)g(t)2sθ(x,t)]dt+g(t)dwˉdx = [f(x,t)-g(t)^2s_\theta(x,t)] dt + g(t)d\bar{w} to go from t=Tt=T to t=0t=0. That’s it.

Here’s one way to think about this. For every time tt, we have a vector field Ft(x)=[f(x,t)g(t)2xsθ(x,t)]F_t(x) = [f(x,t)-g(t)^2\nabla_xs_\theta(x,t)]. We sample some xTN(μT,ΣT)x_T \sim N(\mu_T,\Sigma_T), nudge it slightly with the vector field FT(xT)dtF_T(x_T)dt and then add some random noise g(T)dwˉg(T)d\bar{w}. Now do the same thing with some slightly smaller T=TδT' = T - \delta, using the vector field FT(xT)F_{T'}(x_{T'}) and noise g(T)dwˉg(T')d\bar{w}. We do the same thing all the way from t=Tt=T down to t=0t=0. In theory, this should leave us with a sample x0x_0 from our data distribution p0p_0.

All of this will be more geometrically obvious in the next section.

Summing It Up

Now let’s summarize things, refining our high level description from the last section with some lower-level details from this section.

1. You get a big dataset of images.

You imagine that those images x0x_0 are drawn from some unknown data distribution p0p_0:

2. You corrupt those images. Specifically, you iteratively add noise to those images over time. Let’s call x0x_0 your original image and xtx_t your image at time tt. When tt is small, you’ve only added a little noise and xtx_t looks like a fuzzy / static-y version of x0x_0. When tt is large, xtx_t looks like “pure noise”.

The specific way that you add noise to your images is through the forward-time SDE:

dx=f(x,t)dt+g(t)dwdx = f(x,t)dt + g(t)dw

where ff and gg are functions that you choose.

3. During training, you remove the noise. We train a neural network to take the input (xt,t)(x_t,t) and produce the output x0x_0. The network is rewarded for successfully removing the noise from xtx_t reconstruct the original x0x_0.

We’ll actually train the network to predict the score function s(x,t)=xlog(pt(x))s(x,t) = \nabla_x\log(p_t(x)). Once we have this score function, we can go from xtx_t to x0x_0 as desired, using the reverse-time SDE:

dx=[f(x,t)g(t)2sθ(x,t)]dt+g(t)dwˉdx = [f(x,t)-g(t)^2s_\theta(x,t)] dt + g(t)d\bar{w}

where sθs_\theta is our learned approximator sθss_\theta \approx s.

4. During inference, you input pure noise into your trained model. Your model will think this is (xT,T)(x_T,T) for some very large time TT, and it’ll try to output the corresponding image x0x_0. And if you trained your model well, x0x_0 will look like the original images from your data set. Another way to think about this is that the model will “hallucinate, on purpose” a clean image x0x_0 when given pure noise xTx_T.

We sample xTN(0,σT2I)x_T \sim N(0,\sigma_T^2I) for some large TT, and then use the reverse-time equation as above.

Here’s some stylized code, for those who prefer that:

T = 10 # max time
mse_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

## NN to learn the score function
class ScoreNet(nn.Module):
	pass
	
## Return someting from the image data set
def sample_data():
  pass

# See page 5 of https://arxiv.org/pdf/2011.13456
# We can efficiently estimate the score function conditional
# on knowing x_0.
def numeric_conditional_score_estimate(x_0,x_t,t):
	pass
 
# Adding Noise, forward-time
def add_noise(x_0, t, f, g):
	x_last = x_0
	dt = 0.00001
	current_t = 0
	
	while current_t < t:
		dw = np.random.normal(0,np.sqrt(dt))
		drift = f(x_last,current_t)
		noise = g(x_last,current_t)
		dx = drift*dt + noise*dw
		x_next = x_last + dx
		
		current_t += dt
		x_last = x_next

# Removing Noise, reverse-time
def inference(x_T, T, f, g, score_function):
	x_last = x_T
	dt = 0.00001
	current_t = T

	while current_T > 0:
	dwbar = np.random.normal(0,np.sqrt(dt))
	drift = f(x_last, current_t) - (g(x_last, current_t)**2)*score_function(x_T,T)
	noise = g(x_last, current_t)

	# negative sign because dt is negative in the paper's conventions
	dx = -drift*dt + noise*dwbar
	x_next = x_last + dx

	current_t -= dt
	x_last = x_next
	
# Training Loop
def training_loop(score_net):
  x_0 = sample_data()
  t = np.random.uniform(0,T)
  x_t = add_noise(x_0,t)
  
  predicted_score = score_net(x_t,t)
  true_score = numeric_conditional_score_estimate(x_0,x_t,t)
  
  loss = mse_loss(predicted_score, true_score)
	optimizer.zero_grad()
  loss.backward()
  optimizer.step()

Once at a Lower Level, With an Example

Now let’s do things with a simple example in two dimensions. Again, we’ll break things down into data, adding noise, removing noise, and inference.

Our Data

As before, we’re going to let x0x_0 be a sample from some data distribution p0p_0. In this case, we’ll simply let p0p_0 be the uniform distribution on the unit circle. Here’s 50 samples x0x_0:

Notice that the “data manifold” is 1-dimensional, a lower dimension than the 2D embedding space. This happens in higher dimensions too. Typically the data set of 512x512x3 (or whatever dimension) images will lie on a lower dimensional manifold.

Adding Noise

Now let’s add noise to our “images” using the SDE dx=xtdt+tdwdx = -xtdt + \sqrt{t}dw. Here’s what those 50 samples look like at t=0,0.1,1.0t=0, 0.1, 1.0.

And here’s an animation of the process:

We’re transforming not just these points x0x_0 but the entire data distribution p0p_0. ptp_t is what you get when you apply the forward-time SDE to p0p_0 for tt units of time. Here’s a kernel density estimation of what ptp_t looks like at each time step tt.

If we used more data points for the KDE, we’d see that this distribution looks approximately normal by the time that t=2t=2. ptp_t starts out as the uniform distribution on the circle, but it “forgets” this information as time goes on, and ends up looking like an isotropic Gaussian. In some sense, this is the whole point of the diffusion process. We transition smoothly from a “complicated” distribution p0p_0 to a “simple” distribution pTp_T. We can then sample easily from pTp_T and, during the reverse process, go backwards in time to get a sample from p0p_0.

(In the image above, the “toy” distribution p0p_0 is a mixture of two Gaussians. pTp_T is a single Gaussian. The diffusion process takes from a complex distribution p0p_0 to a simple one, pTp_T, which we can easily sample from. We can then use the reverse-time diffusion process to turn this pTp_T sample into a sample from p0p_0.)

Removing Noise

So how exactly are we going to reverse the process? We need to know s(x,t)=xlog(pt(x,t))s(x,t) = \nabla_x\log(p_t(x,t)) in order to use the reverse-time SDE dx=[f(x,t)g(t)2s(x,t)]dt+g(t)dwˉdx = [f(x,t)-g(t)^2s(x,t)] dt + g(t)d\bar{w}.

Our goal is to train a neural network sθ(x,t)s_\theta(x,t) to approximate s(x,t)s(x,t). As we said in the last section, the optimal choice of θ\theta is given by:

θ:=arg minθEt[λ(t)Ex0(Extx0(sθ(xt,t)xlog(pt(xtx0))2))].\theta^* := \argmin_\theta \mathbb{E}_t\left[\lambda(t)\mathbb{E}_{x_0}(\mathbb{E}_{x_t|x_0}(||s_\theta(x_t,t) - \nabla_x\log(p_t(x_t|x_0))||^2))\right].

Let’s ignore λ\lambda (set λ(t)=1\lambda(t)=1 identically) for now. The key observation is that once x0x_0 is known/fixed, the conditional distribution for xtx0x_t |x_0 is just a Gaussian. I’ll spare the reader the calculations, but for our specific SDE dx=xtdt+tdwdx = -xtdt + \sqrt{t}dw, we have that:

xtx0N(μ(x0,t),σ(t)2I)x_t | x_0 \sim N(\mu(x_0,t), \sigma(t)^2I)

where

μ(x0,t)=x0e12t2\mu(x_0,t) = x_0e^{-\frac{1}{2}t^2}

and

σ(t)2=(et20tes2s ds).\sigma(t)^2 = \left(e^{-t^2}\int_0^t e^{s^2}s \space ds\right).

Knowing the equations for xtx0x_t | x_0, we can write down our conditional score function xt(log(pt(xtx0)))\nabla_{x_t}(\log(p_t(x_t|x_0))) explicitly:

xtlog(pt(xtx0))=1σ(t)2(xtμ(x0,t))\nabla_{x_t}\log(p_t(x_t|x_0)) = -\frac{1}{\sigma(t)^2}(x_t-\mu(x_0,t))

You might be reading this and thinking “If we can just explicitly write down the true score function, why are we bothering to learn it with a neural network?” The answer is that the expression above is only the score function conditional on x0x_0, and we want to learn an unconditional score function. Imagine choosing a specific sample xtx_t. There are many possible x0x_0 that it could have come from. That makes calculating pt(xt)p_t(x_t) hard, because we have to integrate over all these possible starting points, weighted by their probabilities. But calculating pt(xtx0)p_t(x_t | x_0) is easy, because we know that xtx_t evolves from x0x_0 through an explicitly-defined diffusion process. Here’s a stylized sketch of the situation:

In the sketch above: “the same” xtx_t, or two xtx_ts that are very close to each other, could have come from two very different x0sx_0s. That’s why xtlog(pt(xtx0))\nabla_{x_t}\log(p_t(x_t|x_0)) and xtlog(pt(xt))\nabla_{x_t}\log(p_t(x_t)) can be very different. In general, the latter is difficult to evaluate because it involves integrating over the unknown/complicated distribution p0p_0. In our toy example, p0p_0 is the uniform distribution on a circle, which actually makes calculating xt(pt(xt))\nabla_{x_t}(p_t(x_t)) tractable. But in general, this isn’t true, so we should try to avoid leaning on this crutch—though I will lean on it once more in a minute.

With this, we know what we’re going to do during our training loop. Sample x0x_0, i.e. choose a point from our data set. Then choose tt uniformly on [0,T][0,T] (we’ll use T=1)T=1). Then put (xt,t)(x_t,t) into our neural net to get sθ(xt,t)s_\theta(x_t,t). Finally, backprop to make this closer to xt(log(pt(xtx0)))\nabla_{x_t}(\log(p_t(x_t|x_0))).

Let’s go ahead and do that for our points on the circle. I used a simple MLP:

class ScoreNetwork(nn.Module):
    def __init__(self):
        super(ScoreNetwork, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )

    def forward(self, x_t, t):
        t = t.unsqueeze(1)
        input = torch.cat([x_t, t], dim=1)
        return self.net(input)

And here’s what the learned sθ(xt,t)s_\theta(x_t,t) looks like after 1000 epochs of batch size 128:

Notice that for small values of tt, sθ(xt,t)s_\theta(x_t,t) has learned to move points towards the circle, which makes sense, because this is where all the data density is. But for larger values of tt, sθs_\theta is just mapping things towards the origin. This also makes sense, because ptp_t becomes Gaussian as tt increases, and the true score function for the Gaussian distribution N(0,σ2I)N(0,\sigma^2I) is just x/σ2-x/\sigma^2.

Inference

Now let’s use this score function to simulate some reverse-time trajectories, using the reverse-time SDE:

dx=[f(x,t)g(t)2sθ(x,t)]dt+g(t)dwˉdx = [f(x,t)-g(t)^2s_\theta(x,t)] dt + g(t)d\bar{w}

We’ll also illustrate the “drift vector field” f(x,t)g(t)2sθ(x,t)=t(x+sθ(x,t))f(x,t)-g(t)^2s_\theta(x,t) = -t(x + s_\theta(x,t)). Here’s what it looks like when we run time backwards, starting with samples from pTp_T and using the reverse-time SDE to get samples from p0p_0. In the gif below, I’ve stopped time at t=0.1t=0.1 rather than going all the way to t=0t=0, so that you can see what the reverse-time SDE is doing right before all the points freeze.

(Confession: I cheated in making this, because instead of using my learned sθ(x,t)s_\theta(x,t), I used the “true” score function s(x,t)s(x,t). Normally, we wouldn’t be able to calculate s(x,t)s(x,t), but, in this case, we can because p0p_0 is so simple. The reason I cheated is that in training runs, I struggled to get sθ(x,t)s_\theta(x,t) to have large-enough magnitude for small tt, though the trained sθ(x,t)s_\theta(x,t) vector fields did point in the right directions. In the gif above, you can see that f(x,t)g(t)2sθ(x,t)f(x,t) - g(t)^2s_\theta(x,t) explodes as t0t \to 0. This is because g(t)2/σ(t)2g(t)^2/\sigma(t)^2 \to \infty. Moral of the story: choose better ff and gg, or else choose a better KaTeX can only parse string typed expression. Remember, I cavalierly set λ=1\lambda = 1)

Part 2: History

In Part 1, we gave an overview of diffusion models. With that knowledge at hand, now let’s do a historical tour of three key papers:

2015 — Deep Unsupervised Learning using Nonequilibrium Thermodynamics

Deep Unsupervised Learning using Nonequilibrium Thermodynamics” — by Jascha Sohl-Dickstein, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli — invented modern diffusion models. The paper did two key things. First, it showed that diffusion models work for image synthesis. In the figure below, (a) contains true samples x0x_0, (b) contains noised samples xtx_t, (c) contains conditional samples — think of these as the model’s predictions for x0x_0 given xtx_t, and (d) contains “pure noise” samples — think of these as predictions for x0x_0 given “pure noise” xTx_T.

These CIFAR-10 images might not look impressive today, but remember that this paper came out in 2015. This was impressive image gen at the time.

The second thing this paper did was to set up the basic forward-time vs reverse-time framing, where the reverse-time process is actually tractable. The heritage for this idea comes from a very old stochastic processes paper by Feller:

Going a little deeper into details, the approach in this 2015 paper was a bit different from the SDE strategy described in the rest of this blog post above. The paper used a discrete-time Markov chain, where you go from time step t1t-1 to time step tt using the “forward diffusion kernel”

q(xtxt1)=N(xt ; xt11βt , Iβt)q(x_t | x_{t-1}) = N(x_t\space;\space x_{t-1}\sqrt{1-\beta_t} \space,\space I\beta_t)

This is equivalent to a forward-time SDE with dx=(x1βt)dt+βtdwdx =(x\sqrt{1-\beta_t})dt + \beta_tdw. What’s the probability of following some specific path x0x1xTx_0 \to x_1 \to \cdots \to x_T? It’s

q(x0:T)=q(x0)t=1Tq(xtxt1)q(x_{0:T}) = q(x_0) \prod_{t=1}^T q(x_t | x_{t-1})

What’s the probability of going from x0x_0 to x2x_2 following any path? It’s

q(x2x0)=q(x2x1)q(x1x0)dx1q(x_2 | x_0) = \int q(x_2 | x_1)q(x_1 | x_0)dx_1

And similarly,

q(xTx0)=q(xTxT1)q(xT1xT2)q(x1x0)dx1dxT1q(x_T|x_0) = \int\int\int\cdots\int q(x_T|x_{T-1})q(x_{T-1}|x_{T-2})\cdots q(x_1|x_0)dx_1 \cdots dx_{T-1}

Obviously, this expression is intractable for all but the simplest possible transition kernels, which is why we use one of the simplest possible transition kernels (independent Gaussians).

That’s the forward-time process. For reverse-time, instead of using a reverse-time SDE, the paper used a “reverse diffusion kernel”

p(xt1xt)=N(xt1 ; fμ(xt,t),fΣ(xt,t))p(x_{t-1}|x_t) = N(x_{t-1} \space ; \space f_\mu(x_t,t), f_\Sigma(x_t,t))

where you train a neural network to learn the mean fμ(xt,t)f_\mu(x_t,t) and covariance fΣ(xt,t)f_\Sigma(x_t,t) functions. This paper made the important observation that if q(xtxt1)q(x_t|x_{t-1}) is Gaussian, than so is p(xt1xt)p(x_{t-1}|x_t). That’s why we don’t have to learn some super complicated distribution to learn for p(xt1xt)p(x_{t-1}|x_t). We just have to learn a mean vector fμf_\mu and a covariance matrix fΣf_\Sigma. This “Gaussian in both directions” observation is sort of equivalent to noticing that our forward- and reverse-time SDEs have the same functional form.

The training objective is to maximize the model log likelihood:

L:=Eq(log(p(x0))=q(x0)log(p(x0))dx0L := \mathbb{E}_q(\log(p(x_0)) =\int q(x_0)\log(p(x_0))dx_0

where qq is the true data distribution, and pp is the model’s fitted data distribution. Some math tells us that LKL \geq K where

K=t=2TDKL(q(xt1xt,x0)p(xt1xt))q(x0,xt)dx0dxt+Hq(XtX0)Hq(X1X0)+Hp(Xt)K =- \sum_{t=2}^T \int D_{KL}(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))\cdot q(x_0, x_t)dx_0dx_t \newline + H_q(X_t|X_0) - H_q(X_1|X_0) + H_p(X_t)

(Here we’re capitalizing XtX_t to emphasize that the entropies treat xtx_t as a stochastic process, not a deterministic input vector as in the integral above.)

The point is that q(xt1xt,x0)q(x_{t-1}|x_t,x_0) and p(xt1xt)p(x_{t-1}|x_t) are both Gaussian, which means that there’s a closed-form analytic expression for their KL divergence. So this whole expression can be easily calculated.

Notice that there’s no score matching in this paper. We’re trying to maximize the log likelihood directly.

2020 — DDPM

GANs dominated image generation from 2015 to 2020. But in 2020, Jonathan Ho, Ajay Jain, and Pieter Abbeel published their “Denoising Diffusion Probabilistic Models”. Historically, the main thing that this paper accomplished was showing that diffusion models could generate high-quality samples:

The approach is very similar to Jascha’s 2015 paper. You have a forward-time kernel

q(xtxt1)=N(xt ; xt11βt , Iβt)q(x_t | x_{t-1}) = N(x_t\space;\space x_{t-1}\sqrt{1-\beta_t} \space,\space I\beta_t)

and a reverse time kernel

p(xt1xt)=N(xt1 ; μθ(xt,t),Σθ(xt,t))p(x_{t-1}|x_t) = N(x_{t-1} \space ; \space \mu_\theta(x_t,t), \Sigma_\theta(x_t,t))

with learned mean μθ\mu_\theta and covariance Σθ\Sigma_\theta functions. The training objective is the model log-likelihood:

Eq(log(pθ(x0))Eq[DKL(q(xTx0)p(xT))+t>1DKL(q(xt1xt,x0)pθ(xt1xt))logp(x0x1)]\mathbb{E}_q(\log(p_\theta(x_0)) \geq \mathbb{E}_q \left[ D_{KL}(q(x_T|x_0) \| p(x_T)) + \sum_{t>1} D_{KL}(q(x_{t-1}|x_t,x_0) \| p_\theta(x_{t-1}|x_t)) - \log p(x_0|x_1) \right]

So what’s different from the 2015 paper? DDPM introduced a simplified training objective which worked well empirically:

Lsimple(θ):=Et,x0,ϵ[ϵϵθ(xt,t)2]L_{\text{simple}}(\theta) := \mathbb{E}_{t,x_0,\epsilon}\left[||\epsilon - \epsilon_\theta(x_t,t)||^2\right]

where xtαˉtx0+1αˉtϵx_t \sim \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon for some params αˉt\bar{\alpha}_t that are deterministic in tt. (The exact definition isn’t important; the point is that you get from x0x_0 to xtx_t by summing independent Gaussians, so the “straight shot” from x0x_0 to xtx_t is also Gaussian. People call this the “reparameterization trick”, but it’s really just noticing that the sum of independent Gaussians in Gaussian.) The actual training process works by sampling t[0,T]t \in [0,T] uniformly, sampling ϵN(0,I)\epsilon \sim N(0,I), sampling a data point x0x_0, and then plugging these into the equation xt=αˉtx0+1αˉtϵx_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon to get xtx_t. We then backprop to make the predicted error ϵθ(xt,t)\epsilon_\theta(x_t,t) closer to the true error ϵ\epsilon.

The notation for the loss function is a little confusing. Reading it, you might wonder: “Why not just set ϵθ=0\epsilon_\theta =0 identically, since we can’t do better than that to minimize ϵϵθ(xt,t)2||\epsilon - \epsilon_\theta(x_t,t)||^2 if ϵN(0,I)\epsilon \sim N(0,I)?” But keep in mind this picture from earlier, which has a DDPM equivalent:

There are lots of different x0x_0s and ϵ\epsilons that could have produced xtx_t through the equation xt=αˉtx0+1αˉtϵx_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon. Because ϵ\epsilon is truly just noise, the job of ϵθ\epsilon_\theta is implicitly to predict the probability-weighted mean across all the x0x_0s that could have produced xtx_t. Or rather, the job of ϵθ\epsilon_\theta is to predict ϵ=(xtαˉtx0)/1αˉt\epsilon=(x_t - \sqrt{\bar{\alpha}_t}x_0)/\sqrt{1-\bar{\alpha}_t}, which amounts to the same thing if xtx_t and tt are given.

Notice the similarity between the DDPM loss and the Score Based Generative Modeling loss we explained earlier:

DDPMEt,x0,ϵ[ϵϵθ(xt,t)2]\text{DDPM} \quad \quad\mathbb{E}_{t,x_0,\epsilon}\left[||\epsilon - \epsilon_\theta(x_t,t)||^2\right]
SBGMEt[λ(t)Ex0(Extx0(sθ(xt,t)xlog(pt(xtx0))2))]\text{SBGM} \quad \quad\mathbb{E}_t\left[\lambda(t)\mathbb{E}_{x_0}(\mathbb{E}_{x_t|x_0}(||s_\theta(x_t,t) - \nabla_x\log(p_t(x_t|x_0))||^2))\right]

You can think of this loss function as a discretized version of score matching. As we’ll see in a minute, SGBM achieves DDPM as a special case.

2020 — Score Based Generative Modeling through SDEs

Later in 2020, Yang Song, Jascha Sohl-Dickstein (author of the 2015 paper above), Durk Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole published "Score-Based Generative Modeling through Stochastic Differential Equations”. This was a unification paper. It showed that Song & Ermon’s earlier score matching paper and DDPM could be interpreted as variations of the same idea. This paper should look very familiar, since Part 1 from this blog post was all about SBGM.

This paper introduced the forward time dx=f(x,t)dt+g(t)dwdx = f(x,t)dt + g(t)dw and reverse-time dx=(f(x,t)g(t)s(x,t))dt+g(t)dwˉdx = (f(x,t) -g(t)s(x_,t))dt + g(t)d\bar{w} diffusion equations, and the SDE score matching objective

Et[λ(t)Ex0(Extx0(sθ(xt,t)xlog(pt(xtx0))2))]\mathbb{E}_t\left[\lambda(t)\mathbb{E}_{x_0}(\mathbb{E}_{x_t|x_0}(||s_\theta(x_t,t) - \nabla_x\log(p_t(x_t|x_0))||^2))\right]

This paper realizes DDPM as a special case, with values of ff and gg set to

dx=12β(t)x dt+β(t) dwdx = -\frac{1}{2}\beta(t)x \space dt + \sqrt{\beta(t)} \space dw

where β(t)\beta(t) that ranges linearly from β(1)=0.0001\beta(1) = 0.0001 to β(1000)=0.02\beta(1000) = 0.02.

We didn’t talk about this earlier, but this paper also describes an ordinary (non-stochastic) differential equation approach to diffusion. The equation for the ODE is

dx=[f(x,t)12g(t)2xlog(pt(x))]dtdx =\left[f(x,t) - \frac{1}{2}g(t)^2\nabla_x\log(p_t(x))\right]dt

and the point is that if we take p0p_0 and evolve it forward in time, using this equation, to get ptp_t, that’s the same ptp_t that we would have gotten had we used the SDE instead of the ODE. This might be surprising (it was to me), but it’s actually a fairly general result about SDEs. The Fokker-Planck equation tells us that if dx=f(x,t)dt+g(x,t)dwdx = f(x,t)dt + g(x,t)dw, then

pt(x)t=[f(x,t)pt(x)]+[D(x,t)xp(x,t)]\frac{\partial p_t(x)}{\partial t} = -\nabla \cdot [f(x,t)p_t(x)] + \nabla \cdot[D(x,t)\nabla_x p(x,t)]

where DD is the “diffusion tensor” D(x,t)=12g(x,t)g(x,t)TD(x,t) = \frac{1}{2}g(x,t)g(x,t)^T and hi(x)=ihxi\nabla \cdot h_i(x) = \sum_i\frac{\partial h}{\partial x_i} is divergence. More explicitly,

p(x,t)t=i=1Nxi[fi(x,t)p(x,t)]+i=1Nj=1N2xixj[Dij(x,t)p(x,t)]\frac{\partial p(x,t)}{\partial t} = -\sum_{i=1}^N \frac{\partial}{\partial x_i}[f_i(x,t)p(x,t)] + \sum_{i=1}^N\sum_{j=1}^N \frac{\partial^2}{\partial x_i \partial x_j}[D_{ij}(x,t)p(x,t)]

Some calculation shows us that if we write

f~(x,t)=f(x,t)12[g(x,t)g(x,t)T]12g(x,t)g(x,t)Txlog(pt(x))\tilde{f}(x,t) = f(x,t) - \frac{1}{2}\nabla \cdot[g(x,t)g(x,t)^T] - \frac{1}{2}g(x,t)g(x,t)^T\nabla_x\log(p_t(x))

then

pt(x)t=i=1Nxi[f~i(x,t)pt(x)]\frac{\partial p_t(x)}{\partial t} = -\sum_{i=1}^N \frac{\partial}{\partial x_i}[\tilde{f}_i(x,t)p_t(x)]

Notice that there is no diffusion term. In fact, if we apply Fokker-Plank to the ODE

dx=f~(x,t)dtdx = \tilde{f}(x,t)dt

we get the same expression. The upshot is that the pt(x)p_t(x) that we get from the ODE dx=f~(x,t)dtdx = \tilde{f}(x,t)dt is the same pt(x)p_t(x) that we get from the SDE dx=f(x,t)dt+g(x,t)dwdx = f(x,t)dt + g(x,t)dw. So we can recover ptp_t without resorting to stochastics.

Notice that if g(x,t)=g(t)g(x,t) = g(t) is scalar and doesn’t depend on xx, then the equation for f~\tilde{f} simplifies to

f~(x,t)=f(x,t)12g(t)2xlog(pt(x))\tilde{f}(x,t) = f(x,t) - \frac{1}{2}g(t)^2\nabla_x\log(p_t(x))

This gives us the ODE that we mentioned earlier,

dx=[f(x,t)12g(t)2xlog(pt(x))]dtdx =\left[f(x,t) - \frac{1}{2}g(t)^2\nabla_x\log(p_t(x))\right]dt

Knowing that there’s a deterministic way to go from pTp_T to p0p_0 opens up few- and even single-step reverse diffusion processes, and points the way towards consistency models — a story for another time.

Further Reading

Sander Dieleman’s blog is full of high-quality, deep explanations of diffusion concepts. It’s the single best resource for building intuition about diffusion.

Yang Song has a great post explaining score matching, Langevin sampling, and diffusion SDEs/ODEs. Highly recommended.

The best textbooks on stochastic calculus are J Michael Steele’s Stochastic Calculus and Financial Applications (theory) and Särkkä and Solin’s Applied Stochastic Differential Equations (worked examples). I also have a personal-site post about stochastic integration.

If you’d like to see the math for DDPM worked out explicitly, I recommend Lilian Weng’s post on diffusion. Lilian has great stuff more broadly.

For a cookbook on practical methods to get diffusion models to actually work, look no further than “Elucidating the Design Space of Diffusion-Based Generative Models” or, as Sander Dieleman calls it, “The Diffusion Bible”.

Tiny Diffusion lets you play around with 2D diffusion, following DDPM.