Diffusion Without Tears
Background context and a primer for our “History of Diffusion” interviews with Jascha Sohl-Dickstein, Yang Song and Sander Dieleman.
We’re excited to release our “History of Diffusion” interviews with Jascha Sohl-Dickstein, Yang Song, and Sander Dieleman. Those interviews have some surprising revelations about the research hunches, near misses, and other behind-the-scenes ups-and-downs that led to landmark diffusion papers.
The revelations are much more interesting with a little background knowledge about diffusion. When Yang Song says, for example, that the simplicity of the closed-form reverse-time SDE was a surprise to him and his coauthors, it helps to know what that equation is and why it was surprising.
That’s the purpose of this article: to give the background context that will help you get the most out of the interviews with Jascha, Yang Song, and Sander. The post has two parts: Part 1 is a self-contained explanation of diffusion models. It covers the SDE interpretation, and works out a 2D example “by hand” to give some geometric intuition. Part 2 is a brief history, focusing on Jascha’s original paper, DDPM, and SDEs.
To read this, it helps to remember a little vector calculus. Here’s a (open book) quiz. Let . What is What is the work integral of along the line segment from to Using Stoke’s/Green’s Theorem, what is , where is the unit disk? If you can answer these questions, you’re good to go.
Part 1: Diffusion Explained
In Part 1, we’ll explain diffusion three times:
- Once at a high level,
- Once at a lower level, quickly
- Once at that same lower level, slowly, with a worked example in 2D.
The advantage of doing an example in 2D is that it builds some geometric intuition for what’s going on. Throughout all of this, we’ll talk mainly about image generation models, since that’s most intuitive and that’s what many of the original papers were about.
We’re not going to cover diffusion guidance at all. This a major omission, since classifier-free guidance is what most people have in mind when they think about Midjourney, Ideogram, or other image-generation tools (inputting “ultra realistic iPhone photo of a panda on a water slide” is the classifier-free text guidance you use with Midjourney, Ideogram, Dalle, etc.) The reason I’m not covering guidance is that you don’t need to understand it to appreciate the papers we cover in the interviews. However, guidance is an interesting and important subject, and I highly recommend Sander’s excellent post as a starting point; it should be very accessible after reading this post.
Once at a High Level
At a high level, image gen diffusion models work like this:
- You get a big dataset of images.
- You corrupt those images. Specifically, you iteratively add noise to those images over time. Let’s call your original image and your image at time When is small, you’ve only added a little noise and looks like a fuzzy / static-y version of . When is large, looks like “pure noise”.
- During training, you remove the noise. We train a neural network to take the input and produce the output . The network is rewarded for successfully removing the noise from to reconstruct the original . (This step is actually a little more complicated, and I’m fudging the explanation, as you’ll see when you keep reading; but this paragraph is still approximately correct.)
- During inference, you input pure noise into your trained model. Your model will think this is for some very large time , and it’ll try to output the corresponding image . And if you trained your model well, will look like the original images from your data set. Another way to think about this is that the model will “hallucinate, on purpose” a clean image when given pure noise .
That’s it in a nutshell, and it’s useful to keep the high-level picture in mind. However, this explanation isn’t quite correct (some details have been over-simplified), and it leaves a lot unspecified. Let’s go deeper.
Once at a Lower Level, Briefly
There are several different ways to interpret diffusion models. I’m going to explain the stochastic differential equation interpretation, developed by Song et al. This method works by adding noise through a stochastic differential equation (SDE), and removing noise through a different SDE.
There are a few moving parts in this diffusion interpretation, and I think the best way to understand them is to see them twice:
- Once briefly, so that you can see all the parts at once and keep the big picture in mind; and
- Once slowly, so that you can build intuition for each part individually.
This section is the fast overview. I recommend reading it once, then reading the next section (fully-worked 2D example), and then coming back to this section again. For structure, we’ll mirror the high-level overview (the section above) by going deeper into data, adding noise, removing noise, and inference. We’ll do the same thing in the next section.
Our Data
To start, we imagine that our images are drawn from some unknown “data distribution” with mass function . We’ll never actually calculate these probabilities, but we’re imaging that assigns a probability density to every point in , or whatever the dimension of our image data is.
Adding Noise
The way we add noise to our images is through a stochastic differential equation:
where w is Brownian motion. If you’ve never seen an SDE before, you can think of this as roughly equivalent to a sequence where:
, and is a small amount of Gaussian random noise. If you let , you get the infinitesimal . Something analogous is true for Brownian motion , but the details are more complicated.
You can roughly think of this whole scheme as the continuous version of adding Gaussian noise over time. We know that it means to go from to , where . What happens when we let ? You get an SDE, as above.
Here’s a picture of some sample paths in one dimension when :
Removing Noise
The way we remove noise is through a reverse-time stochastic differential equation. If we run time forwards with , we can run it backwards with:
There’s a lot going on in this equation, so I’ve highlighted a few parts with different so that we can examine them more closely now. We already know , and from our forward-time equation; these are just functions that we assume as given. The function is the deformation of our data distribution at time , obtained by applying the forward-time SDE. You can obtain samples from by sampling from your data set and then using the forward-time SDE to get a sample . The expression is the differential of reverse-time Brownian motion ; you can interpret as a Brownian bridge or, simply, view as a formal expression to be integrated.
The important thing is that and are sort-of-identical, because on an infinitesimal scale, forward- and reverse-time Brownian motion have the same distribution. (Go watch some pollen particles get jostled around; you can’t tell which direction time is running.) Finally, KaTeX can only parse string typed expression is the data gradient. If , then:
During training, we train a neural network with params as an approximator for the score function. Specifically, our ideal choice for is given by:
This expression for looks more complicated than it is. We want to make small. Given a starting point , we want to do this across all , weighted by the probability of going from to . That’s the innermost . And of course, we want to do this weighed by the probability of the first in the first place. That’s the middle . And we should do this across all (the outermost ). We want some weighing function because our job is easy for small and hard for large . The advice from Song et al is to choose:
Inference
Once we’ve actually trained , we can generate samples by using the reverse-time equation. Here’s how it works. For reasonable choices of and , is going to be Gaussian distributed for large enough . In practice, the whole point of diffusion is to make as simple as possible, so usually we set things up so that for some scalar .
To go from to , you first generate KaTeX can only parse string typed expression, where is the covariance for and is the mean. Then you iteratively use the reverse time equation to go from to . That’s it.
Here’s one way to think about this. For every time , we have a vector field . We sample some , nudge it slightly with the vector field and then add some random noise . Now do the same thing with some slightly smaller , using the vector field and noise . We do the same thing all the way from down to . In theory, this should leave us with a sample from our data distribution .
All of this will be more geometrically obvious in the next section.
Summing It Up
Now let’s summarize things, refining our high level description from the last section with some lower-level details from this section.
1. You get a big dataset of images.
You imagine that those images are drawn from some unknown data distribution :
2. You corrupt those images. Specifically, you iteratively add noise to those images over time. Let’s call your original image and your image at time . When is small, you’ve only added a little noise and looks like a fuzzy / static-y version of . When is large, looks like “pure noise”.
The specific way that you add noise to your images is through the forward-time SDE:
where and are functions that you choose.
3. During training, you remove the noise. We train a neural network to take the input and produce the output . The network is rewarded for successfully removing the noise from reconstruct the original .
We’ll actually train the network to predict the score function . Once we have this score function, we can go from to as desired, using the reverse-time SDE:
where is our learned approximator .
4. During inference, you input pure noise into your trained model. Your model will think this is for some very large time , and it’ll try to output the corresponding image . And if you trained your model well, will look like the original images from your data set. Another way to think about this is that the model will “hallucinate, on purpose” a clean image when given pure noise .
We sample for some large , and then use the reverse-time equation as above.
Here’s some stylized code, for those who prefer that:
T = 10 # max time
mse_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
## NN to learn the score function
class ScoreNet(nn.Module):
pass
## Return someting from the image data set
def sample_data():
pass
# See page 5 of https://arxiv.org/pdf/2011.13456
# We can efficiently estimate the score function conditional
# on knowing x_0.
def numeric_conditional_score_estimate(x_0,x_t,t):
pass
# Adding Noise, forward-time
def add_noise(x_0, t, f, g):
x_last = x_0
dt = 0.00001
current_t = 0
while current_t < t:
dw = np.random.normal(0,np.sqrt(dt))
drift = f(x_last,current_t)
noise = g(x_last,current_t)
dx = drift*dt + noise*dw
x_next = x_last + dx
current_t += dt
x_last = x_next
# Removing Noise, reverse-time
def inference(x_T, T, f, g, score_function):
x_last = x_T
dt = 0.00001
current_t = T
while current_T > 0:
dwbar = np.random.normal(0,np.sqrt(dt))
drift = f(x_last, current_t) - (g(x_last, current_t)**2)*score_function(x_T,T)
noise = g(x_last, current_t)
# negative sign because dt is negative in the paper's conventions
dx = -drift*dt + noise*dwbar
x_next = x_last + dx
current_t -= dt
x_last = x_next
# Training Loop
def training_loop(score_net):
x_0 = sample_data()
t = np.random.uniform(0,T)
x_t = add_noise(x_0,t)
predicted_score = score_net(x_t,t)
true_score = numeric_conditional_score_estimate(x_0,x_t,t)
loss = mse_loss(predicted_score, true_score)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Once at a Lower Level, With an Example
Now let’s do things with a simple example in two dimensions. Again, we’ll break things down into data, adding noise, removing noise, and inference.
Our Data
As before, we’re going to let be a sample from some data distribution . In this case, we’ll simply let be the uniform distribution on the unit circle. Here’s 50 samples :
Notice that the “data manifold” is 1-dimensional, a lower dimension than the 2D embedding space. This happens in higher dimensions too. Typically the data set of 512x512x3 (or whatever dimension) images will lie on a lower dimensional manifold.
Adding Noise
Now let’s add noise to our “images” using the SDE . Here’s what those 50 samples look like at .
And here’s an animation of the process:
We’re transforming not just these points but the entire data distribution . is what you get when you apply the forward-time SDE to for units of time. Here’s a kernel density estimation of what looks like at each time step .
If we used more data points for the KDE, we’d see that this distribution looks approximately normal by the time that . starts out as the uniform distribution on the circle, but it “forgets” this information as time goes on, and ends up looking like an isotropic Gaussian. In some sense, this is the whole point of the diffusion process. We transition smoothly from a “complicated” distribution to a “simple” distribution . We can then sample easily from and, during the reverse process, go backwards in time to get a sample from .
(In the image above, the “toy” distribution is a mixture of two Gaussians. is a single Gaussian. The diffusion process takes from a complex distribution to a simple one, , which we can easily sample from. We can then use the reverse-time diffusion process to turn this sample into a sample from .)
Removing Noise
So how exactly are we going to reverse the process? We need to know in order to use the reverse-time SDE .
Our goal is to train a neural network to approximate . As we said in the last section, the optimal choice of is given by:
Let’s ignore (set identically) for now. The key observation is that once is known/fixed, the conditional distribution for is just a Gaussian. I’ll spare the reader the calculations, but for our specific SDE , we have that:
where
and
Knowing the equations for , we can write down our conditional score function explicitly:
You might be reading this and thinking “If we can just explicitly write down the true score function, why are we bothering to learn it with a neural network?” The answer is that the expression above is only the score function conditional on , and we want to learn an unconditional score function. Imagine choosing a specific sample . There are many possible that it could have come from. That makes calculating hard, because we have to integrate over all these possible starting points, weighted by their probabilities. But calculating is easy, because we know that evolves from through an explicitly-defined diffusion process. Here’s a stylized sketch of the situation:
In the sketch above: “the same” , or two s that are very close to each other, could have come from two very different . That’s why and can be very different. In general, the latter is difficult to evaluate because it involves integrating over the unknown/complicated distribution . In our toy example, is the uniform distribution on a circle, which actually makes calculating tractable. But in general, this isn’t true, so we should try to avoid leaning on this crutch—though I will lean on it once more in a minute.
With this, we know what we’re going to do during our training loop. Sample , i.e. choose a point from our data set. Then choose uniformly on (we’ll use . Then put into our neural net to get . Finally, backprop to make this closer to .
Let’s go ahead and do that for our points on the circle. I used a simple MLP:
class ScoreNetwork(nn.Module):
def __init__(self):
super(ScoreNetwork, self).__init__()
self.net = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 2)
)
def forward(self, x_t, t):
t = t.unsqueeze(1)
input = torch.cat([x_t, t], dim=1)
return self.net(input)
And here’s what the learned looks like after 1000 epochs of batch size 128:
Notice that for small values of , has learned to move points towards the circle, which makes sense, because this is where all the data density is. But for larger values of , is just mapping things towards the origin. This also makes sense, because becomes Gaussian as increases, and the true score function for the Gaussian distribution is just .
Inference
Now let’s use this score function to simulate some reverse-time trajectories, using the reverse-time SDE:
We’ll also illustrate the “drift vector field” . Here’s what it looks like when we run time backwards, starting with samples from and using the reverse-time SDE to get samples from . In the gif below, I’ve stopped time at rather than going all the way to , so that you can see what the reverse-time SDE is doing right before all the points freeze.
(Confession: I cheated in making this, because instead of using my learned , I used the “true” score function . Normally, we wouldn’t be able to calculate , but, in this case, we can because is so simple. The reason I cheated is that in training runs, I struggled to get to have large-enough magnitude for small , though the trained vector fields did point in the right directions. In the gif above, you can see that explodes as . This is because . Moral of the story: choose better and , or else choose a better KaTeX can only parse string typed expression. Remember, I cavalierly set )
Part 2: History
In Part 1, we gave an overview of diffusion models. With that knowledge at hand, now let’s do a historical tour of three key papers:
- “Deep Unsupervised Learning using Nonequilibrium Thermodynamics”,
- “Denoising Diffusion Probabilistic Models”, and
- "Score-Based Generative Modeling through Stochastic Differential Equations”.
2015 — Deep Unsupervised Learning using Nonequilibrium Thermodynamics
“Deep Unsupervised Learning using Nonequilibrium Thermodynamics” — by Jascha Sohl-Dickstein, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli — invented modern diffusion models. The paper did two key things. First, it showed that diffusion models work for image synthesis. In the figure below, (a) contains true samples , (b) contains noised samples , (c) contains conditional samples — think of these as the model’s predictions for given , and (d) contains “pure noise” samples — think of these as predictions for given “pure noise” .
These CIFAR-10 images might not look impressive today, but remember that this paper came out in 2015. This was impressive image gen at the time.
The second thing this paper did was to set up the basic forward-time vs reverse-time framing, where the reverse-time process is actually tractable. The heritage for this idea comes from a very old stochastic processes paper by Feller:
Going a little deeper into details, the approach in this 2015 paper was a bit different from the SDE strategy described in the rest of this blog post above. The paper used a discrete-time Markov chain, where you go from time step to time step using the “forward diffusion kernel”
This is equivalent to a forward-time SDE with . What’s the probability of following some specific path ? It’s
What’s the probability of going from to following any path? It’s
And similarly,
Obviously, this expression is intractable for all but the simplest possible transition kernels, which is why we use one of the simplest possible transition kernels (independent Gaussians).
That’s the forward-time process. For reverse-time, instead of using a reverse-time SDE, the paper used a “reverse diffusion kernel”
where you train a neural network to learn the mean and covariance functions. This paper made the important observation that if is Gaussian, than so is . That’s why we don’t have to learn some super complicated distribution to learn for . We just have to learn a mean vector and a covariance matrix . This “Gaussian in both directions” observation is sort of equivalent to noticing that our forward- and reverse-time SDEs have the same functional form.
The training objective is to maximize the model log likelihood:
where is the true data distribution, and is the model’s fitted data distribution. Some math tells us that where
(Here we’re capitalizing to emphasize that the entropies treat as a stochastic process, not a deterministic input vector as in the integral above.)
The point is that and are both Gaussian, which means that there’s a closed-form analytic expression for their KL divergence. So this whole expression can be easily calculated.
Notice that there’s no score matching in this paper. We’re trying to maximize the log likelihood directly.
2020 — DDPM
GANs dominated image generation from 2015 to 2020. But in 2020, Jonathan Ho, Ajay Jain, and Pieter Abbeel published their “Denoising Diffusion Probabilistic Models”. Historically, the main thing that this paper accomplished was showing that diffusion models could generate high-quality samples:
The approach is very similar to Jascha’s 2015 paper. You have a forward-time kernel
and a reverse time kernel
with learned mean and covariance functions. The training objective is the model log-likelihood:
So what’s different from the 2015 paper? DDPM introduced a simplified training objective which worked well empirically:
where for some params that are deterministic in . (The exact definition isn’t important; the point is that you get from to by summing independent Gaussians, so the “straight shot” from to is also Gaussian. People call this the “reparameterization trick”, but it’s really just noticing that the sum of independent Gaussians in Gaussian.) The actual training process works by sampling uniformly, sampling , sampling a data point , and then plugging these into the equation to get . We then backprop to make the predicted error closer to the true error .
The notation for the loss function is a little confusing. Reading it, you might wonder: “Why not just set identically, since we can’t do better than that to minimize if ?” But keep in mind this picture from earlier, which has a DDPM equivalent:
There are lots of different s and s that could have produced through the equation . Because is truly just noise, the job of is implicitly to predict the probability-weighted mean across all the s that could have produced . Or rather, the job of is to predict , which amounts to the same thing if and are given.
Notice the similarity between the DDPM loss and the Score Based Generative Modeling loss we explained earlier:
You can think of this loss function as a discretized version of score matching. As we’ll see in a minute, SGBM achieves DDPM as a special case.
2020 — Score Based Generative Modeling through SDEs
Later in 2020, Yang Song, Jascha Sohl-Dickstein (author of the 2015 paper above), Durk Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole published "Score-Based Generative Modeling through Stochastic Differential Equations”. This was a unification paper. It showed that Song & Ermon’s earlier score matching paper and DDPM could be interpreted as variations of the same idea. This paper should look very familiar, since Part 1 from this blog post was all about SBGM.
This paper introduced the forward time and reverse-time diffusion equations, and the SDE score matching objective
This paper realizes DDPM as a special case, with values of and set to
where that ranges linearly from to .
We didn’t talk about this earlier, but this paper also describes an ordinary (non-stochastic) differential equation approach to diffusion. The equation for the ODE is
and the point is that if we take and evolve it forward in time, using this equation, to get , that’s the same that we would have gotten had we used the SDE instead of the ODE. This might be surprising (it was to me), but it’s actually a fairly general result about SDEs. The Fokker-Planck equation tells us that if , then
where is the “diffusion tensor” and is divergence. More explicitly,
Some calculation shows us that if we write
then
Notice that there is no diffusion term. In fact, if we apply Fokker-Plank to the ODE
we get the same expression. The upshot is that the that we get from the ODE is the same that we get from the SDE . So we can recover without resorting to stochastics.
Notice that if is scalar and doesn’t depend on , then the equation for simplifies to
This gives us the ODE that we mentioned earlier,
Knowing that there’s a deterministic way to go from to opens up few- and even single-step reverse diffusion processes, and points the way towards consistency models — a story for another time.
Further Reading
Sander Dieleman’s blog is full of high-quality, deep explanations of diffusion concepts. It’s the single best resource for building intuition about diffusion.
Yang Song has a great post explaining score matching, Langevin sampling, and diffusion SDEs/ODEs. Highly recommended.
The best textbooks on stochastic calculus are J Michael Steele’s Stochastic Calculus and Financial Applications (theory) and Särkkä and Solin’s Applied Stochastic Differential Equations (worked examples). I also have a personal-site post about stochastic integration.
If you’d like to see the math for DDPM worked out explicitly, I recommend Lilian Weng’s post on diffusion. Lilian has great stuff more broadly.
For a cookbook on practical methods to get diffusion models to actually work, look no further than “Elucidating the Design Space of Diffusion-Based Generative Models” or, as Sander Dieleman calls it, “The Diffusion Bible”.
Tiny Diffusion lets you play around with 2D diffusion, following DDPM.