Pytorch Conditional GAN Tutorial
Skills:
Generative Models90%
Key Takeaways
This video teaches how to implement a Conditional GAN in PyTorch, conditioning output on labels
Full Transcript
what is going on guys welcome back for another video on cans so in this video i want to take a look at how we can build a model uh where we can also decide what the output should be so for example if we're training on just normal mnist we can say that you know i want digits of 0 and then the generator will will generate that digit for us all right so uh for this video i'm gonna base our implementation from vegan and specifically weegan gp so if you haven't watched that video in this series of gans i recommend you watch that first in that in in that way i don't have to sort of repeat everything i did in that video but so essentially i'm going to modify the model first so here we have the discriminator and that we implemented in that video let me just make this full screen uh what we want to do here is add something that will um i guess that we sending the label uh to the discriminant and the generator uh one way to do this and there are sort of many ways to to make the gan uh conditional so this is you know what we're implementing here is a conditional gan where we generate something conditioned on a label so you know the condition can be that this digit should be a five or a zero whatever one way to do this which i find sort of the easiest is that we create an embedding so we do n uh we do n dot embedding and then we take some number of classes to some image size time some image size all right so let me explain why this is going to make sense so first we're just going to take some number of classes we're also going to take some image size to the discriminator so what what is going to happen here is that we are going to create this embedding in our in our forward right here so we're going to do you know uh self.embed of labels so of course then we need to send in the labels for in the forward so we send we do the embedding on the labels and then we call this you know just embedding and so what you can view this as is since it will be image size times image size it's going to be you know sort of an additional channel in a way right so we have rgb where the height and the width is the image size so what we can do is we can just do reshape or i guess dot view and we can do labels shape 0 and then 1 and then image size oh and then we got to do self.image size comma self.image size all right something like that and then i guess we got to add that to the top here so self.image size is image size all right so what we're doing here is we're reshaping it to be just with this one additional right here and we're doing it so that the height is image size and the width is image size and so what i want to do then is change the input x to be torch.cat of x and embedding and then dimension equals one all right so the you know for the first dimension dimension zero we have the number of examples in our batch then we have channels and then we have sort of the image size or the height and then image size or the width so all we got to do then is we just got to do uh plus one in the model and that's it so you know how you can view this is that we have you know our original image and now we've just sort of added additional channel just one additional channel where this is sort of a stamp that we're sending into the discriminatory machine you know here's the here's the image and here's the additional stamp for what that image is and so we're giving the discriminator the information what the digit also is the label of it okay so then uh what we're gonna do is we're gonna do the same thing for the generator so i guess so i guess for the generator we're gonna need to add some stuff in our init method as well all right so we got to add the number of classes uh the image size and then we also got to add some embed size so i guess we can do self.image size equals image size and here we gotta do sort of one you know the same thing and then we add an embedding but remember here is in that uh the embedding now has to be added to this noise that we're sending in because here we're sending in just some some latent vector zed that is converted or generated into an image so what we do here is we take some number of classes to some embed size so the embed size doesn't have to be just you know we're not going to add this as an additional channel we're just going to output it in some dimension some embed size and here for the sort of the forward we're going to add some labels and we're going to do uh sort of the same thing and then we create an embedding self.embed of labels but what we got to do here is we're going to sort of add dimensions so we gotta unsqueeze and then unsqueeze again so remember the the sort of the input to this uh so the latent vector z is n times uh you know i guess noise dimension times one times one because it's inputted as just noise dimension channels which is just normally distributed random values and then to add these together the embedding and the latent vector z you know the the x that's coming in we need to unsqueeze so that we also have times one times one at the end so then we do x is torch.cat uh of x and the embedding and we do that for dimension equals one and then we just send that through our our generator i guess one thing here is that you know the first for the generator is the channel's noise but now it's also we gotta add the embed size so that's all we got to do so you know how this works is that now the generator knows so sort of the information of which label it should produce and the discriminator has the information what the image actually should be so in this way for the generator to be able to actually fool the discriminator it got it has to learn also to up with the right digit uh when we you know from the label that we give it all right hopefully i didn't forget anything about this but otherwise i guess we're going to get an error later on but i think that should be it then for the gradient penalty uh all we're going to add is just uh labels because now the critic or you know the discriminator um that we now in the forward propagation we're just sending the labels as well uh even if they you know we just want to take the gradient of retrospective interpolated images uh we have to send in the labels uh for for the credit to do to do anything so here i just send in the labels and that's the gradient penalty that's all we gotta change on that one so let me go to the training file and let's see here what we gotta add so you know we have some image size right here we have some channels image we got to add the number of classes so that's going to be 10 and we got to add the generator embedding and this is a high parameter we can just set it to 100 and that's yeah i'm just going to add those two so that's it for that then we got to go into sort of the generator and discriminator this is going to now take additional inputs so we have let's see we have the channels image we have feature gen what feature gen there we go feature gen and then we gotta have uh the num classes uh the the image size and then the generator embedding all right so that's just sort of what we just added to the uh to the to the model we want to send in a number of classes the image size and the generator embedding similar thing for the discriminator i want to send in just the number of classes and the image size all right so now we have those sent in and and then of course we know we have to before you know target labels were not needed and hard unsupervised unfortunately we got to remove that because for the conditional again we need the labels so let's take the labels and let's uh labels dot to device let's send it to the gpu if we can and then we got you know the generator here we're gonna send in the labels we're gonna send in the labels going to send in the labels everything needs information about the labels the grid penalty we're going to send that in as well uh same thing here labels i think that's it oh yeah we need to change it here too for the generator right here and labels and i guess instead of sending in some fixed noise let's just send in the noise um to make sure the labels actually match so some noise and then the labels and you know what i think that's actually it hopefully there are no more stuff uh well let's see we can i mean we can just run it right all right gotta activate my environment uh like that let's rerun it okay image size that should be image size or you know what it actually makes more sense to have it just image size so let me change it up here instead image size and then let's check also if we wrote that somewhere else yeah it did and that's it so hopefully it works just another sunny day in southern california it's where the people all right so it's been training now for almost three epoch what it looks like is this and so i guess they're not that good yet needs a couple of more epochs and look a lot better i think but at least we can identify some digits and also what's interesting here so what's different than what we've been done before is that if you look at just let's say the top right digit six here is that they're exactly identical or they're this same digit so you know they're supposed to represent the same digit so for the first one it's seven five so here's the real ones seven five eight five seven seven five eight five seven so you know that's because we're sending in the labels to uh sort of the generator so it's learning to generate the digits that we're asking it to do and that's what's interesting about conditional gans all right so you know that's it for conditional gans hopefully we're able to follow along and implement it yourself and understand the steps uh let me know if you got any questions in the comment section below but thank you for watching the video and i hope to see you in the next one
Original Description
In this video we take a look at a way of also deciding what the output from the GAN should be. Specifically the output is conditioned on the labels that we send in and as an example we take a look at training on MNIST (of course) ;) But these ideas extend to any dataset you're working with really!
❤️ Support the channel ❤️
https://www.youtube.com/channel/UCkzW5JSFwvKRjXABI-UTAkQ/join
Paid Courses I recommend for learning (affiliate links, no extra cost for you):
⭐ Machine Learning Specialization https://bit.ly/3hjTBBt
⭐ Deep Learning Specialization https://bit.ly/3YcUkoI
📘 MLOps Specialization http://bit.ly/3wibaWy
📘 GAN Specialization https://bit.ly/3FmnZDl
📘 NLP Specialization http://bit.ly/3GXoQuP
✨ Free Resources that are great:
NLP: https://web.stanford.edu/class/cs224n/
CV: http://cs231n.stanford.edu/
Deployment: https://fullstackdeeplearning.com/
FastAI: https://www.fast.ai/
💻 My Deep Learning Setup and Recording Setup:
https://www.amazon.com/shop/aladdinpersson
GitHub Repository:
https://github.com/aladdinpersson/Machine-Learning-Collection
✅ One-Time Donations:
Paypal: https://bit.ly/3buoRYH
▶️ You Can Connect with me on:
Twitter - https://twitter.com/aladdinpersson
LinkedIn - https://www.linkedin.com/in/aladdin-persson-a95384153/
Github - https://github.com/aladdinpersson
OUTLINE:
0:00 - Introduction
0:56 - Modifying Generator and Discriminator
6:58 - Modifying Gradient Penalty
7:35 - Modifying Training
10:43 - Evaluation & Ending
Watch on YouTube ↗
(saves to browser)
Sign in to unlock AI tutor explanation · ⚡30
Playlist
Uploads from Aladdin Persson · Aladdin Persson · 0 of 60
← Previous
Next →
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
computeCost.m Linear Regression Cost Function - Machine Learning
Aladdin Persson
gradientDescent.m Gradient Descent Implementation - Machine Learning
Aladdin Persson
Neural Network from scratch - Part 1 (Standard Notation)
Aladdin Persson
Neural Network from scratch - Part 2 (Forward Propagation)
Aladdin Persson
Neural Network from scratch - Part 3 (Backward Propagation)
Aladdin Persson
Neural Network from scratch - Part 4 (With Python)
Aladdin Persson
sigmoid.m - Programming Assignment 2 Machine Learning
Aladdin Persson
costFunction.m - Programming Assignment 2 Machine Learning
Aladdin Persson
predict.m - Programming Assignment 2 Machine Learning
Aladdin Persson
costFunctionReg.m - Programming Assignment 2 Machine Learning
Aladdin Persson
lrCostFunction.m - Programming Assignment 3 Machine Learning
Aladdin Persson
oneVsAll.m - Programming Assignment 3 Machine Learning
Aladdin Persson
predictOneVsAll.m - Programming Assignment 3 Machine Learning
Aladdin Persson
predict.m - Programming Assignment 3 Machine Learning
Aladdin Persson
Caesar Cipher Encryption and Decryption with example
Aladdin Persson
Cryptography: Caesar Cipher Python
Aladdin Persson
Vigenere Cipher Explained (with Example)
Aladdin Persson
Cryptography: Vigenere Cipher Python
Aladdin Persson
Hill Cipher Explained (with Example)
Aladdin Persson
Cryptography: Hill Cipher Python
Aladdin Persson
Interval Scheduling Greedy Algorithm: Python
Aladdin Persson
Weighted Interval Scheduling Algorithm Explained
Aladdin Persson
Weighted Interval Scheduling Python Code
Aladdin Persson
Sequence Alignment | Needleman Wunsch Algorithm
Aladdin Persson
Sequence Alignment | Needleman Wunsch in Python
Aladdin Persson
Codility BinaryGap Python
Aladdin Persson
Codility CyclicRotation Python
Aladdin Persson
Derivation Linear Regression with Gradient Descent
Aladdin Persson
Linear Regression Gradient Descent From Scratch in Python
Aladdin Persson
Pytorch Neural Network example
Aladdin Persson
Pytorch CNN example (Convolutional Neural Network)
Aladdin Persson
Pytorch LeNet implementation from scratch
Aladdin Persson
Pytorch VGG implementation from scratch
Aladdin Persson
Pytorch GoogLeNet / InceptionNet implementation from scratch
Aladdin Persson
How to save and load models in Pytorch
Aladdin Persson
How to build custom Datasets for Images in Pytorch
Aladdin Persson
Pytorch Transfer Learning and Fine Tuning Tutorial
Aladdin Persson
Pytorch Data Augmentation using Torchvision
Aladdin Persson
Pytorch Quick Tip: Weight Initialization
Aladdin Persson
Pytorch Quick Tip: Using a Learning Rate Scheduler
Aladdin Persson
Pytorch ResNet implementation from Scratch
Aladdin Persson
Pytorch TensorBoard Tutorial
Aladdin Persson
Pytorch DCGAN Tutorial (See description for updated video)
Aladdin Persson
Naive Bayes from Scratch - Machine Learning Python
Aladdin Persson
Spam Classifier using Naive Bayes in Python
Aladdin Persson
K-Nearest Neighbor from scratch - Machine Learning Python
Aladdin Persson
Linear Regression Normal Equation Python
Aladdin Persson
SVM from Scratch - Machine Learning Python (Support Vector Machine)
Aladdin Persson
Neural Network from Scratch - Machine Learning Python
Aladdin Persson
Pytorch RNN example (Recurrent Neural Network)
Aladdin Persson
Pytorch Bidirectional LSTM example
Aladdin Persson
Pytorch Text Generator with character level LSTM
Aladdin Persson
Logistic Regression from Scratch - Machine Learning Python
Aladdin Persson
K-Means Clustering from Scratch - Machine Learning Python
Aladdin Persson
Pytorch Torchtext Tutorial 1: Custom Datasets and loading JSON/CSV/TSV files
Aladdin Persson
Pytorch Torchtext Tutorial 2: Built in Datasets with Example
Aladdin Persson
Pytorch Torchtext Tutorial 3: From Textfiles to Dataset
Aladdin Persson
Paper Review: Sequence to Sequence Learning with Neural Networks
Aladdin Persson
Pytorch Seq2Seq Tutorial for Machine Translation
Aladdin Persson
Pytorch Seq2Seq with Attention for Machine Translation
Aladdin Persson
More on: Generative Models
View skill →Related Reads
📰
📰
📰
📰
Understanding Deep Learning Through Four Interactive Experiments
Medium · Data Science
Understanding Deep Learning Through Four Interactive Experiments
Medium · Deep Learning
Optimizers in Deep Learning: From Gradient Descent to Adam
Medium · Deep Learning
The Meta-Architecture of Interface Fracture: High-Dimensional Logical Stress and Systemic Collapse…
Medium · Deep Learning
Chapters (5)
Introduction
0:56
Modifying Generator and Discriminator
6:58
Modifying Gradient Penalty
7:35
Modifying Training
10:43
Evaluation & Ending
🎓
Tutor Explanation
DeepCamp AI