DALL-E: Text-to-Image generation - Explained!
Skills:
Image Generation Basics80%
Key Takeaways
Explains DALL-E for text-to-image generation and its applications
Full Transcript
Greetings fellow learners. In this video, we are going to talk about DALL-E, what it is, why it exists, and exactly how it works internally. So, I hope you're ready. So, what is DALL-E? It's a foundation model for text-to-image generation. And when I say foundation model, I mean it is a transformer-style model that is trained on a lot of data that can further be tuned and hence acts as a foundation. Now, why do we have DALL-E? And in order to understand this, I'm going to go back way to 2017 with the creation of this transformer architecture as a sequence-to-sequence architecture for language translation. And when I say sequence-to-sequence, the input to this transformer is going to be a sequence of text tokens, and the output is also a sequence of text tokens. Now, 2018 onwards, we have large language models or LLMs that were used for other NLP tasks. And some of the earlier versions of these language models were GPT and BERT. And as far as performance is concerned, they performed better than the state-of-the-art long short-term memory cells, which were the state-of-the-art for sequence-to-sequence models. And the reason for transformers succeeding here was because of their attention mechanism. It made them much better for preserving long-term dependencies. [clears throat] Now, around the same time, images were also passed as a sequence of tokens to transformer-style architectures. So, this here is, for example, a vision transformer, which would take an image, split it up into a bunch of patches, and we would embed this into vectors and process it through this transformer. And in this case, it was passed as a sequence of tokens, where the sequence of tokens were like image patches here. And this type of network was actually quite successful compared to the state of the art for vision models, which was convolution networks, provided that we had massive amounts of data. For more modest amounts of data, the convolution architecture was still dominant, but with lots of data, with these larger, colorful circles, it shows that the transformer architecture became more dominant. And this gray area shows the performance of convolution architectures. So, in 2021, researchers at OpenAI basically asked, "Can we formulate the text-to-image generation as a sequence-to-sequence problem?" Now, it turns out we can, and the eventual solution that came up was DALL-E. And so, I hope the why of DALL-E and how it came about makes more sense. Now that we know the what and why of DALL-E, we are going to talk about exactly how DALL-E works in three passes. So, I thought in the first pass, I'm just going to walk through a high-level training and inference of how the entire thing works In the second and third passes, we're going to dive into more and more detail about each individual component. So, let's get to it. So, for pass one, high-level training, Dolly has two main components. It has a discrete variational autoencoder and then a generative pre-trained transformer. And so, the the training itself is done in two stages. We first train the discrete variational autoencoder to generate image tokens from images, and then we will train GPT to take text and auto-regressively, that is, one after the other, we're going to generate image tokens. So, let's take a look at these two stages. So, in stage one, to begin it, let's say that we have an image. This could be like 256 cross 256 cross three. This image is passed into a DVAE encoder. So, uh some characteristic about these variational autoencoders is that they have an encoder and then a decoder. So, we'll take an image, pass it through an encoder, which is a convolutional neural network, and this will give us a grid of logits. Logits that correspond to numbers between negative infinity and positive infinity. Now, this tensor is then passed into a tokenizer. So, essentially, this 32 cross 32 cross I have 8,192 over here. The reason it's this number is because that is going to be the size of our image embedding codebook vocabulary. So, what that means is we have a bunch of these learnable vectors over here. Each of them are like D dimensions and there's 8,192 of them. Now, the idea here is that these learnable vectors should learn about general properties of these training images and they'll learn during the backpropagation phase. So, what's going to happen here is that we're going to have a bunch of logits and the depth is going to constitute the the vocabulary here. The tokenizer is going to be responsible for essentially in each of these positions just choosing one of those positions, which corresponds to one of these vectors and then we are just going to use that as a token. I know these numbers are a little bit small, but you can see this first number says 1,249. That's basically saying, "Oh, choose the 1,249th vector from this codebook over here." So, and these correspond to discrete tokens and we have 32 cross 32 of them, which is 1,024. Next, what we're going to do is, you know, from this index, we're actually going to grab the corresponding D dimensional vectors in every single position. So, this will be a grid of discrete vectors. Then, we're going to pass this into the decoder of our variational autoencoder, which is a convolution neural network in order to generate an image or reconstruct an image here. And then, we're going to compute our loss, which has two components. It is a reconstruction loss as well as a regularization loss. The reconstruction loss is going to ensure that the image is reconstructed correctly and the regularization loss is going to ensure that the image embeddings will not memorize image details. And then, we have a loss, which allows us to now learn this entire network via backpropagation, including these codebook vectors. So, I hope the high-level DVAE training makes sense. Now that we have a trained DVAE, we are going to, let's say, just take those components and I've made them green just to indicate that they're already trained and we don't need to learn anymore. So, gradients won't be propagating here. But, we take these, you know, the DVAE encoder, the tokenizer that we use for images. We're going to basically take, you know, an image, create the grid of logits, tokenize them. So, image to tokens. We have now got an image tokens. And now what we're going to do is train our GPT. So, we're going to take our text. We're going to pass it through our text tokenizer. So, this is going to have, you know, basically, you know, these can be tokenized and when they're tokenized, it'll be a part of a 16,384-sized vocabulary. And each of these text codebook vectors are eventually also going to be learned during the backpropagation phase. So, what's going to happen is that we'll take this and we're going to basically map it to, um, you know, one of these vectors effectively. So, there's going to be 256 tokens. You know, while this might be just 15 tokens, we're going to add a bunch of padding to make it 256 tokens. And then what we're going to do is concatenate the 256 tokens over here along with the 1,024 tokens from the DVAE in order to get a 1,280 tokens over here, so that they're just like indices effectively. And then what we're going to do is get their corresponding embeddings, so their D-dimensional matrices. And we can use that by referencing the codebooks. And then we will pass these 1,800 vectors into GPT over here. And what's very interesting at this point is that we have treated text. This consists of both text vectors as well as image vectors, and all of them are just considered as tokens, just tokens to GPT. They're not really considered distinct. Now, once passed to GPT, it's going to just predict what the next token would be. Now, in during the training phase, all of this is happening in parallel and not auto-regressively. So, we'll have parallel predictions of probability distributions of all image and text tokens. And you can see that it'll be 1,280 probability distributions, where each distribution will have the number of text possible vocabulary tokens and the number of image tokens. That's 8,000 plus like 16,000. That's why you get this number. And then what we can do is also compute this ground truth, which is just by shifting the tokens down by one so that we know the ground truth in order to compute the cross-entropy loss. We can do it because we have a prediction and we have the ground truth. Hence, we can compute a loss, and then this loss can just back-propagate in order to learn what this GPT is. And also our our tokenizer, which would effectively be our like learnable codebook embeddings, too. They all learn during this phase. So, that's stage two and hence the total training of the model. Now that we have an entire trained pipeline, we can use it for inference. So, during inference time, we only have text. So, let's say we put in a text like a brown and white Charles Spaniel looking sad because they did something wrong. That's the text, and we want to generate an image. So, now we have our text tokenizer. It's green because it's learned. We will tokenize this into whatever. This might be like 20 um 20 tokens, but then we will pad it with another extra tokens like padding tokens to make it 256 tokens. And then, we are going to index our learned codebook vector over here in order to get 256 vectors of D dimensions. Next, we will pass this all into GPT, and then we are going to now auto-regressively get image tokens. So, we'll make 1,024 predictions into GPT. So, we make this first prediction, we're just going to get the first image token over here. And then, we're going to take this, append it to our token list that we have, get that embedding over here, and then we are going to now pass it into GPT to generate the second token, and repeat this until eventually we have our 1,024 image tokens. And then, when you have your 1,024 image tokens, that's this part right over here, this blue part, we're going to pass all of these into now our DVAE decoder in order to generate our image. And hopefully, this image should correspond to what this text is. So, that's all for pass one training and inference, and I hope all of it now made sense. All right. So, for part two, let's now dive into these individual components for this DVAE encoder over here. Starting now with just the encoder. So, what is this exactly? Let's blow it up. So, this is going to be a ResNet style architecture. So, it's going to have a bunch of convolution activation blocks and normalization blocks, along with a bunch of skip connections that allows this network to become very deep. And the specific architecture kind of looks like this over here. So, we would have an image input, which could be 256 x 256 with three channels. We'd perform an initial convolution that makes it 128 channels deep. And then we're going to have a sequence of residual blocks and strided convolutions. So, these residual blocks are going to be normalization, activation, and a convolution operation, which will preserve its its shape over here. But then we will downsample because of strided convolutions. Strided convolutions essentially mean that this 4 x 4 convolution is going to skip one space every other space while performing the convolution operation. And so, this will be downsampling. And we perform this one, two, three times over here. And we'll get 32 x 32 x 512. But then we will perform a pointwise convolution or a one cross one convolution. And this is used to very cheaply just expand or contract the dimensions. And this is now going to be expanded to 8,192 channels because 8,192 is the number of image tokens that are possible in our vocabulary. So, we consider it. And that's kind of the DVAE encoder. Now, let's talk about this tokenizer part over here. What is that and how is it implemented? Well, one way to do a tokenization is like we take our grid of logits. We can then perform a softmax across each of these, right? And then from logits will now be then converted into a probability distributions. So, the sum of all of these 8,192 entries will be one. These 8,192 entries will be one. We create probability distributions. And then what we can do is just sample from those distributions. And we can get then token indices or just token values. So, that's one way to tokenize. But the problem here is that this is a part of a neural network which needs to learn via back propagation. And so, the sampling operation unfortunately is not differentiable. So, gradients cannot flow beyond this point. And so, in order to resolve this, we use a technique called Gumbel softmax relaxation. And this will allow gradients to flow through the network to earlier layers. Now, we're going to talk about this in pass three, but I hope you know now at least at a high level why it's even here. And so, we have Gumbel softmax relaxation. Now, once we have our tokens and we have our like say discrete vectors over here, we want to go into the DVAE decoder. So, let's talk about that. Essentially, it's going to be almost the exact opposite of our encoder, where instead of downsampling, we are upsampling from our codebook mapping. We have 32 cross 32 cross 512, right? We then have pointwise convolution followed by, you know, step up blocks of residual and transpose convolution blocks in order to get our final image back. Now, let's talk about the fourth component, which is our loss. There's two components to the loss, which is a reconstruction loss and a regularization loss. So, the reconstruction loss is going to ensure that the original image and the reconstruction are going to be as close to each other as possible, and the loss will just quantify that difference. Now, regularization loss is an interesting one. So, the codebook vectors over here, these 8,192 vectors of D dimensions, they should learn general features that are shared by all images in the training set. And so, we don't want any like single codebook vector to pay too much attention to just a single image. Otherwise, it will learn very specific information about an image, and that's not what we want. So, mathematically, this is equivalent to basically saying for each of these numbers in the 8,192-dimensional vector, they should be smooth and not peaky. And so, regularization loss is the KL divergence between a uniform categorical distribution and the grid of logits. So, a little bit more idea of what that means is we have this uniform distribution, so this is like an 8,192-dimensional vector. For example, if we had logit values that looked like this, then you can see that this is going to be like super peaky because you have some values that are very high. Whereas, and hence it'll constitute a high regularization loss. Whereas, these ones are more nominal and much lower, and hence it's going to be a low regularization loss. And so, overall, the resulting objective is going to push the discrete variational autoencoder to reconstruct the input image while ensuring each codebook image vector learns general information about images. And this makes the DVAE useful for image generation. And so, we have it. If you wanted to understand more about the exact architecture, I'll link to this code for DALL-E's encoder and their decoder also in the description below. So, you can see it all spelled out. And so, now that we've talked about the pass one and pass two, let's now get into this Gumbel softmax relaxation, pass three. So, overall, this is a method to tokenize logits while allowing gradients to flow during back propagation. So, we have a grid of logits, pass it through here. This will give us tokens, and gradients can flow through it. So, in order to talk about exactly this, we need to first talk about this idea of a Gumbel distribution. The Gumbel distribution is a distribution of extreme values, maxima or minima. And its probability density function looks like this. So, the best way that I can explain this is with a very simple example. Let's say that you record temperatures every week for many weeks. So, in week one, we have daily temperatures recorded in centigrade. Week two, we have the same thing. Week three, and so on for many, many hundreds or thousands of weeks. Now, we're going to take the max temperature of every week. And then we'll pre-plot a histogram of this. So, that's going to be this purple diagram over here. That would be a histogram. And in this case, I've just normalized that histogram so that its center becomes zero. And interestingly enough, the shape of this diagram, you can actually fit a Gumbel distribution to it. And by this, what we mean is that we can find a location mu and a scale beta parameters that best actually fit this temperature data. And when location is zero and scale is one, it's going to reduce to this form, which is the standard Gumbel distribution. And so, I hope this kind of makes sense. And the reason why we're even thinking about it into this context is because we want to fit the 8,192 class distribution to select a high logit token which simulates a sampling operation. So, that's why we're even dealing with this extreme value theory here. Now that we understand the Gumbel distribution, let's take a look at the Gumbel max trick. So, we mentioned that sampling is not a differentiable operation and so we can't just use soft max plus sampling. But the soft max and sampling can actually be replaced exactly with the following setup. And that setup is, let's say you have an image, you pass it into your discrete variational autoencoder, we're going to get logits, the grid of logits over here. Now what we're going to do is from the Gumbel distribution, the standard Gumbel distribution mu zero beta one, sample from there, you're going to get a grid of Gumbel noise. So, each of these values is going to be stamper sampled from a standard Gumbel distribution. And then we can add these two together, add this to the logits to get a grid of perturbed logits with Gumbel noise. And then what we can do is for each of these situations, each of these 8,192 dimensional vectors, we're going to take an arg max. So, max is just taking the maximum value, arg max is taking the index corresponding to that maximum value. And this can help us give the token itself. Now this reparameterization is known as the Gumbel max trick. And I said that this reparameterization is exact because you can actually have a mathematical proof where we are proving exactly that the transformation that I mentioned is true. And I'm just going to link to all of this this math for you to go over down in the description below. But if you do go through it, you will see that there is a mathematical equivalence between the two. But in order to show you this a little bit more clearly, I came up with a simulation. So, let's take a look at that. So, here is a simulation where we have logits instead of like the 8,192, we have just three logits. And the logit values are 1, 2, and 0.5. Now, from this we can create a softmax distribution. I just pass it to a softmax, and this is our distribution over here. And so, you can see the sum is going to be our probabilities, you know, they're all probabilities, so sum is one. Now, if I sampled from this softmax distribution over here like 10,000 times or so, I'm going to and then I'll just like calculate the frequencies of each of the the tokens I sampled from 0, 1, 2, whatever it is. Um you're going to actually get a frequency that corresponds to this distribution. And that's the gray plot over here. So, this is our softmax distribution. But sampling using the Gumbel max trick, we are actually going to get the pink distribution over here. And just to show you that how it's implemented over here, you can see way at the top, we have our tensor of logits. We perform a softmax over here. We then do sampling 10,000 times. We'll do like 10,000 simulations of just sampling, and that's how the gray plot is plotted. And then we just do our counts. But for that second method, for Gumbel distribution, we basically take our standard Gumbel distribution of 0 and 1. So, it's a standard Gumbel distribution. We perturb our logits by adding it to the logits, take an argmax, and just keep doing this 1,000 times and then you can just, you know, get the frequency distribution of them. You'll get the pink graph. And so you can see the equivalence between the two computationally. So a softmax plus sampling can be replaced with the Gumbel perturbed logits and argmax. That is the Gumbel max trick. But the problem here is that this argmax operation is still not differentiable. So that means that gradient still cannot pass through it. So what we can do though is approximate it and we approximate it using a temperature scaled softmax that looks like this. Now on its own, this function is differentiable. But I replace the solid arrow with a dashed arrow over here because even here we have to sample from softmax in order to get our tokens. And when you have to sample, we cannot take derivatives. It's still not differentiable. But do we still have to actually do the sampling in this case? Well, the answer is actually that's not really true in practice. We don't have to do it. Now one way that we can kind of get around this problem is through a straight-through estimation technique. So that means that during the forward pass, we actually just use an argmax operation over here to actually get the grid of tokens and hence generate the loss. But in the backward direction, we're just going to pretend like we use the temperature scaled softmax all along. So we'll put it back in here, but it's going to be used for computing gradients. Now, this is a straight-through estimation technique, and I believe one caveat here is that DALL-E internally may use a slightly different technique, but just to prevent all confusion, I think this is also a very valid thing that many practitioners do in practice. And so, I hope all of this now makes sense and enriches your understanding of the entire DVAE plus GPT and hence DALL-E. Quiz time. Have you been paying attention? Let's quiz you to find out. What was the key benefit of using DVAE in DALL-E? A, it increases image resolution. B, it removed the reconstruction loss. C, it created a learned image vocabulary. Or D, it makes training fully supervised. I'll give you a few seconds to answer this question. The correct option is C. Did you get it right? Please leave your reasoning down in the comments below and let's have a discussion. And at this point, if you think I deserve it, please do consider giving this video a like because it will help me out a lot. And that's going to do it for quiz time, but before we go, let's generate a summary. In this video, we took a look at DALL-E, which is a foundation model for text-to-image generation. We also saw how OpenAI researchers formulated this text-to-image generation as a sequence-to-sequence problem. We took a look at how training is done in two passes, where first we train the discrete variational autoencoder. And then we are going to train the components for GPT. And once both are trained, we can then use this GPT for auto-regressive generation of the image token 1024 times to eventually get the image we need. And then what we did was we expanded on each of what these components are. So this DVAE encoder is a convolution network. This tokenizer is a uses Gumbel softmax relaxation to to ensure that the gradients produced in the backward direction are differentiable, but it could still create some grid of tokens in the forward direction. And this decoder is another convolution network similar to the encoder. We also saw the nature of the loss which pushes the VAE to reconstruct the image input while ensuring the codebook vectors learn general information about images. And that's kind of all that we have today. So I hope everything here makes a lot of sense. I know this is a lot of information, but I'm going to be linking to all the code, the papers, and other resources down in the description below. And I hope this video can be used as kind of a supplement to help you tie a lot of these concepts together as you are navigating this like very vast topic. So thank you all so much for watching, and I will see you in another one. Take care.
Original Description
In this video, we take a look at a DALL-E for text-to-image generation. What is it? Why do we have it? How does it look?
ABOUT ME
⭕ Subscribe: https://www.youtube.com/c/CodeEmporium?sub_confirmation=1
📚 Medium Blog: https://medium.com/@dataemporium
💻 Github: https://github.com/ajhalthor
👔 LinkedIn: https://www.linkedin.com/in/ajay-halthor-477974bb/
RESOURCES
[1 📚] Slides: https://link.excalidraw.com/p/readonly/NXtiUh19HjH4BuC2IQ6V
[2 📚] DALL-E main paper: https://arxiv.org/pdf/2102.12092
[3 📚] DALL-E blog page: https://openai.com/index/dall-e/
[4 📚] Evolution of auto encoders: https://youtu.be/XyWNmHZi1oA?si=0X5iE2FKfToDaRNM
[5 📚] Colab notebook I put together to understand the gumbel distribution, gumbel max trick and Gumbel Softmax Relaxation: https://colab.research.google.com/drive/1KSKB3AIUzyMnpym8HeSVZCxOtzS-DI9u#scrollTo=1af4a395
[6 📚] Nice mathematical proof to show gumbel max trick: [https://github.com/priyammaz/PyTorch-Adventures/blob/main/PyTorch for Generation/AutoEncoders/Intro to AutoEncoders/gumbel_softmax_quantizer.ipynb](https://github.com/priyammaz/PyTorch-Adventures/blob/main/PyTorch%20for%20Generation/AutoEncoders/Intro%20to%20AutoEncoders/gumbel_softmax_quantizer.ipynb)
[7 📚] Attention is all you need paper: https://arxiv.org/pdf/1706.03762
[8 📚] Image is worth 16 x 16 words paper: https://arxiv.org/pdf/2010.11929
[9 📚] Improving generative language understanding paper: https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf
[10 📚] Learning Bounded Context-Free-Grammar via LSTM and the Transformer:
Difference and Explanations paper: https://arxiv.org/pdf/2112.09174
[11 📚] DALL-E architecture code: https://github.com/openai/DALL-E/blob/master/dall_e/encoder.py
PLAYLISTS FROM MY CHANNEL
⭕ Reinforcement Learning: https://youtube.com/playlist?list=PLTl9hO2Oobd9kS--NgVz0EPNyEmygV1Ha&si=AuThDZJwG19cgTA8
Natural Language Processing: https://youtube.com/playlist?list=PLTl9hO2Oobd_bzXUpzKMKA
Watch on YouTube ↗
(saves to browser)
Sign in to unlock AI tutor explanation · ⚡30
More on: Image Generation Basics
View skill →Related AI Lessons
⚡
⚡
⚡
⚡
I Spent Weeks Looking for a Research Gap Before I Realized I Was Searching the Wrong Way
Medium · AI
ICMI 2026 Reviews [D]
Reddit r/MachineLearning
Workshop submission for main conference paper under review [D]
Reddit r/MachineLearning
Kept context-switching between arxiv, OpenReview, GitHub, and HuggingFace for every paper, so I built this. Chrome extension + website with everything inline, plus citation graph + SPECTER2 neighbors. 3M papers, free, feedback welcome [P]
Reddit r/MachineLearning
🎓
Tutor Explanation
DeepCamp AI