Pytorch Conditional GAN Tutorial

Aladdin Persson · Beginner ·🧬 Deep Learning ·5y ago

Key Takeaways

This video teaches how to implement a Conditional GAN in PyTorch, conditioning output on labels

Full Transcript

what is going on guys welcome back for another video on cans so in this video i want to take a look at how we can build a model uh where we can also decide what the output should be so for example if we're training on just normal mnist we can say that you know i want digits of 0 and then the generator will will generate that digit for us all right so uh for this video i'm gonna base our implementation from vegan and specifically weegan gp so if you haven't watched that video in this series of gans i recommend you watch that first in that in in that way i don't have to sort of repeat everything i did in that video but so essentially i'm going to modify the model first so here we have the discriminator and that we implemented in that video let me just make this full screen uh what we want to do here is add something that will um i guess that we sending the label uh to the discriminant and the generator uh one way to do this and there are sort of many ways to to make the gan uh conditional so this is you know what we're implementing here is a conditional gan where we generate something conditioned on a label so you know the condition can be that this digit should be a five or a zero whatever one way to do this which i find sort of the easiest is that we create an embedding so we do n uh we do n dot embedding and then we take some number of classes to some image size time some image size all right so let me explain why this is going to make sense so first we're just going to take some number of classes we're also going to take some image size to the discriminator so what what is going to happen here is that we are going to create this embedding in our in our forward right here so we're going to do you know uh self.embed of labels so of course then we need to send in the labels for in the forward so we send we do the embedding on the labels and then we call this you know just embedding and so what you can view this as is since it will be image size times image size it's going to be you know sort of an additional channel in a way right so we have rgb where the height and the width is the image size so what we can do is we can just do reshape or i guess dot view and we can do labels shape 0 and then 1 and then image size oh and then we got to do self.image size comma self.image size all right something like that and then i guess we got to add that to the top here so self.image size is image size all right so what we're doing here is we're reshaping it to be just with this one additional right here and we're doing it so that the height is image size and the width is image size and so what i want to do then is change the input x to be torch.cat of x and embedding and then dimension equals one all right so the you know for the first dimension dimension zero we have the number of examples in our batch then we have channels and then we have sort of the image size or the height and then image size or the width so all we got to do then is we just got to do uh plus one in the model and that's it so you know how you can view this is that we have you know our original image and now we've just sort of added additional channel just one additional channel where this is sort of a stamp that we're sending into the discriminatory machine you know here's the here's the image and here's the additional stamp for what that image is and so we're giving the discriminator the information what the digit also is the label of it okay so then uh what we're gonna do is we're gonna do the same thing for the generator so i guess so i guess for the generator we're gonna need to add some stuff in our init method as well all right so we got to add the number of classes uh the image size and then we also got to add some embed size so i guess we can do self.image size equals image size and here we gotta do sort of one you know the same thing and then we add an embedding but remember here is in that uh the embedding now has to be added to this noise that we're sending in because here we're sending in just some some latent vector zed that is converted or generated into an image so what we do here is we take some number of classes to some embed size so the embed size doesn't have to be just you know we're not going to add this as an additional channel we're just going to output it in some dimension some embed size and here for the sort of the forward we're going to add some labels and we're going to do uh sort of the same thing and then we create an embedding self.embed of labels but what we got to do here is we're going to sort of add dimensions so we gotta unsqueeze and then unsqueeze again so remember the the sort of the input to this uh so the latent vector z is n times uh you know i guess noise dimension times one times one because it's inputted as just noise dimension channels which is just normally distributed random values and then to add these together the embedding and the latent vector z you know the the x that's coming in we need to unsqueeze so that we also have times one times one at the end so then we do x is torch.cat uh of x and the embedding and we do that for dimension equals one and then we just send that through our our generator i guess one thing here is that you know the first for the generator is the channel's noise but now it's also we gotta add the embed size so that's all we got to do so you know how this works is that now the generator knows so sort of the information of which label it should produce and the discriminator has the information what the image actually should be so in this way for the generator to be able to actually fool the discriminator it got it has to learn also to up with the right digit uh when we you know from the label that we give it all right hopefully i didn't forget anything about this but otherwise i guess we're going to get an error later on but i think that should be it then for the gradient penalty uh all we're going to add is just uh labels because now the critic or you know the discriminator um that we now in the forward propagation we're just sending the labels as well uh even if they you know we just want to take the gradient of retrospective interpolated images uh we have to send in the labels uh for for the credit to do to do anything so here i just send in the labels and that's the gradient penalty that's all we gotta change on that one so let me go to the training file and let's see here what we gotta add so you know we have some image size right here we have some channels image we got to add the number of classes so that's going to be 10 and we got to add the generator embedding and this is a high parameter we can just set it to 100 and that's yeah i'm just going to add those two so that's it for that then we got to go into sort of the generator and discriminator this is going to now take additional inputs so we have let's see we have the channels image we have feature gen what feature gen there we go feature gen and then we gotta have uh the num classes uh the the image size and then the generator embedding all right so that's just sort of what we just added to the uh to the to the model we want to send in a number of classes the image size and the generator embedding similar thing for the discriminator i want to send in just the number of classes and the image size all right so now we have those sent in and and then of course we know we have to before you know target labels were not needed and hard unsupervised unfortunately we got to remove that because for the conditional again we need the labels so let's take the labels and let's uh labels dot to device let's send it to the gpu if we can and then we got you know the generator here we're gonna send in the labels we're gonna send in the labels going to send in the labels everything needs information about the labels the grid penalty we're going to send that in as well uh same thing here labels i think that's it oh yeah we need to change it here too for the generator right here and labels and i guess instead of sending in some fixed noise let's just send in the noise um to make sure the labels actually match so some noise and then the labels and you know what i think that's actually it hopefully there are no more stuff uh well let's see we can i mean we can just run it right all right gotta activate my environment uh like that let's rerun it okay image size that should be image size or you know what it actually makes more sense to have it just image size so let me change it up here instead image size and then let's check also if we wrote that somewhere else yeah it did and that's it so hopefully it works just another sunny day in southern california it's where the people all right so it's been training now for almost three epoch what it looks like is this and so i guess they're not that good yet needs a couple of more epochs and look a lot better i think but at least we can identify some digits and also what's interesting here so what's different than what we've been done before is that if you look at just let's say the top right digit six here is that they're exactly identical or they're this same digit so you know they're supposed to represent the same digit so for the first one it's seven five so here's the real ones seven five eight five seven seven five eight five seven so you know that's because we're sending in the labels to uh sort of the generator so it's learning to generate the digits that we're asking it to do and that's what's interesting about conditional gans all right so you know that's it for conditional gans hopefully we're able to follow along and implement it yourself and understand the steps uh let me know if you got any questions in the comment section below but thank you for watching the video and i hope to see you in the next one

Original Description

In this video we take a look at a way of also deciding what the output from the GAN should be. Specifically the output is conditioned on the labels that we send in and as an example we take a look at training on MNIST (of course) ;) But these ideas extend to any dataset you're working with really! ❤️ Support the channel ❤️ https://www.youtube.com/channel/UCkzW5JSFwvKRjXABI-UTAkQ/join Paid Courses I recommend for learning (affiliate links, no extra cost for you): ⭐ Machine Learning Specialization https://bit.ly/3hjTBBt ⭐ Deep Learning Specialization https://bit.ly/3YcUkoI 📘 MLOps Specialization http://bit.ly/3wibaWy 📘 GAN Specialization https://bit.ly/3FmnZDl 📘 NLP Specialization http://bit.ly/3GXoQuP ✨ Free Resources that are great: NLP: https://web.stanford.edu/class/cs224n/ CV: http://cs231n.stanford.edu/ Deployment: https://fullstackdeeplearning.com/ FastAI: https://www.fast.ai/ 💻 My Deep Learning Setup and Recording Setup: https://www.amazon.com/shop/aladdinpersson GitHub Repository: https://github.com/aladdinpersson/Machine-Learning-Collection ✅ One-Time Donations: Paypal: https://bit.ly/3buoRYH ▶️ You Can Connect with me on: Twitter - https://twitter.com/aladdinpersson LinkedIn - https://www.linkedin.com/in/aladdin-persson-a95384153/ Github - https://github.com/aladdinpersson OUTLINE: 0:00 - Introduction 0:56 - Modifying Generator and Discriminator 6:58 - Modifying Gradient Penalty 7:35 - Modifying Training 10:43 - Evaluation & Ending
Watch on YouTube ↗ (saves to browser)
Sign in to unlock AI tutor explanation · ⚡30

Playlist

Uploads from Aladdin Persson · Aladdin Persson · 0 of 60

← Previous Next →
1 computeCost.m Linear Regression Cost Function - Machine Learning
computeCost.m Linear Regression Cost Function - Machine Learning
Aladdin Persson
2 gradientDescent.m Gradient Descent Implementation -  Machine Learning
gradientDescent.m Gradient Descent Implementation - Machine Learning
Aladdin Persson
3 Neural Network from scratch - Part 1 (Standard Notation)
Neural Network from scratch - Part 1 (Standard Notation)
Aladdin Persson
4 Neural Network from scratch - Part 2 (Forward Propagation)
Neural Network from scratch - Part 2 (Forward Propagation)
Aladdin Persson
5 Neural Network from scratch - Part 3 (Backward Propagation)
Neural Network from scratch - Part 3 (Backward Propagation)
Aladdin Persson
6 Neural Network from scratch - Part 4 (With Python)
Neural Network from scratch - Part 4 (With Python)
Aladdin Persson
7 sigmoid.m - Programming Assignment 2 Machine Learning
sigmoid.m - Programming Assignment 2 Machine Learning
Aladdin Persson
8 costFunction.m - Programming Assignment 2 Machine Learning
costFunction.m - Programming Assignment 2 Machine Learning
Aladdin Persson
9 predict.m - Programming Assignment 2 Machine Learning
predict.m - Programming Assignment 2 Machine Learning
Aladdin Persson
10 costFunctionReg.m - Programming Assignment 2 Machine Learning
costFunctionReg.m - Programming Assignment 2 Machine Learning
Aladdin Persson
11 lrCostFunction.m - Programming Assignment 3 Machine Learning
lrCostFunction.m - Programming Assignment 3 Machine Learning
Aladdin Persson
12 oneVsAll.m - Programming Assignment 3 Machine Learning
oneVsAll.m - Programming Assignment 3 Machine Learning
Aladdin Persson
13 predictOneVsAll.m - Programming Assignment 3 Machine Learning
predictOneVsAll.m - Programming Assignment 3 Machine Learning
Aladdin Persson
14 predict.m - Programming Assignment 3 Machine Learning
predict.m - Programming Assignment 3 Machine Learning
Aladdin Persson
15 Caesar Cipher Encryption and Decryption with example
Caesar Cipher Encryption and Decryption with example
Aladdin Persson
16 Cryptography: Caesar Cipher Python
Cryptography: Caesar Cipher Python
Aladdin Persson
17 Vigenere Cipher Explained (with Example)
Vigenere Cipher Explained (with Example)
Aladdin Persson
18 Cryptography: Vigenere Cipher Python
Cryptography: Vigenere Cipher Python
Aladdin Persson
19 Hill Cipher Explained (with Example)
Hill Cipher Explained (with Example)
Aladdin Persson
20 Cryptography: Hill Cipher Python
Cryptography: Hill Cipher Python
Aladdin Persson
21 Interval Scheduling Greedy Algorithm: Python
Interval Scheduling Greedy Algorithm: Python
Aladdin Persson
22 Weighted Interval Scheduling Algorithm Explained
Weighted Interval Scheduling Algorithm Explained
Aladdin Persson
23 Weighted Interval Scheduling Python Code
Weighted Interval Scheduling Python Code
Aladdin Persson
24 Sequence Alignment | Needleman Wunsch Algorithm
Sequence Alignment | Needleman Wunsch Algorithm
Aladdin Persson
25 Sequence Alignment | Needleman Wunsch in Python
Sequence Alignment | Needleman Wunsch in Python
Aladdin Persson
26 Codility BinaryGap Python
Codility BinaryGap Python
Aladdin Persson
27 Codility CyclicRotation Python
Codility CyclicRotation Python
Aladdin Persson
28 Derivation Linear Regression with Gradient Descent
Derivation Linear Regression with Gradient Descent
Aladdin Persson
29 Linear Regression Gradient Descent From Scratch in Python
Linear Regression Gradient Descent From Scratch in Python
Aladdin Persson
30 Pytorch Neural Network example
Pytorch Neural Network example
Aladdin Persson
31 Pytorch CNN example (Convolutional Neural Network)
Pytorch CNN example (Convolutional Neural Network)
Aladdin Persson
32 Pytorch LeNet implementation from scratch
Pytorch LeNet implementation from scratch
Aladdin Persson
33 Pytorch VGG implementation from scratch
Pytorch VGG implementation from scratch
Aladdin Persson
34 Pytorch GoogLeNet / InceptionNet implementation from scratch
Pytorch GoogLeNet / InceptionNet implementation from scratch
Aladdin Persson
35 How to save and load models in Pytorch
How to save and load models in Pytorch
Aladdin Persson
36 How to build custom Datasets for Images in Pytorch
How to build custom Datasets for Images in Pytorch
Aladdin Persson
37 Pytorch Transfer Learning and Fine Tuning Tutorial
Pytorch Transfer Learning and Fine Tuning Tutorial
Aladdin Persson
38 Pytorch Data Augmentation using Torchvision
Pytorch Data Augmentation using Torchvision
Aladdin Persson
39 Pytorch Quick Tip: Weight Initialization
Pytorch Quick Tip: Weight Initialization
Aladdin Persson
40 Pytorch Quick Tip: Using a Learning Rate Scheduler
Pytorch Quick Tip: Using a Learning Rate Scheduler
Aladdin Persson
41 Pytorch ResNet implementation from Scratch
Pytorch ResNet implementation from Scratch
Aladdin Persson
42 Pytorch TensorBoard Tutorial
Pytorch TensorBoard Tutorial
Aladdin Persson
43 Pytorch DCGAN Tutorial (See description for updated video)
Pytorch DCGAN Tutorial (See description for updated video)
Aladdin Persson
44 Naive Bayes from Scratch - Machine Learning Python
Naive Bayes from Scratch - Machine Learning Python
Aladdin Persson
45 Spam Classifier using Naive Bayes in Python
Spam Classifier using Naive Bayes in Python
Aladdin Persson
46 K-Nearest Neighbor from scratch - Machine Learning Python
K-Nearest Neighbor from scratch - Machine Learning Python
Aladdin Persson
47 Linear Regression Normal Equation Python
Linear Regression Normal Equation Python
Aladdin Persson
48 SVM from Scratch - Machine Learning Python (Support Vector Machine)
SVM from Scratch - Machine Learning Python (Support Vector Machine)
Aladdin Persson
49 Neural Network from Scratch - Machine Learning Python
Neural Network from Scratch - Machine Learning Python
Aladdin Persson
50 Pytorch RNN example (Recurrent Neural Network)
Pytorch RNN example (Recurrent Neural Network)
Aladdin Persson
51 Pytorch Bidirectional LSTM example
Pytorch Bidirectional LSTM example
Aladdin Persson
52 Pytorch Text Generator with character level LSTM
Pytorch Text Generator with character level LSTM
Aladdin Persson
53 Logistic Regression from Scratch - Machine Learning Python
Logistic Regression from Scratch - Machine Learning Python
Aladdin Persson
54 K-Means Clustering from Scratch - Machine Learning Python
K-Means Clustering from Scratch - Machine Learning Python
Aladdin Persson
55 Pytorch Torchtext Tutorial 1: Custom Datasets and loading JSON/CSV/TSV files
Pytorch Torchtext Tutorial 1: Custom Datasets and loading JSON/CSV/TSV files
Aladdin Persson
56 Pytorch Torchtext Tutorial 2: Built in Datasets with Example
Pytorch Torchtext Tutorial 2: Built in Datasets with Example
Aladdin Persson
57 Pytorch Torchtext Tutorial 3: From Textfiles to Dataset
Pytorch Torchtext Tutorial 3: From Textfiles to Dataset
Aladdin Persson
58 Paper Review: Sequence to Sequence Learning with Neural Networks
Paper Review: Sequence to Sequence Learning with Neural Networks
Aladdin Persson
59 Pytorch Seq2Seq Tutorial for Machine Translation
Pytorch Seq2Seq Tutorial for Machine Translation
Aladdin Persson
60 Pytorch Seq2Seq with Attention for Machine Translation
Pytorch Seq2Seq with Attention for Machine Translation
Aladdin Persson

Related Reads

📰
Understanding Deep Learning Through Four Interactive Experiments
Explore deep learning concepts through interactive experiments to gain hands-on understanding
Medium · Data Science
📰
Understanding Deep Learning Through Four Interactive Experiments
Explore deep learning through interactive experiments to gain hands-on understanding
Medium · Deep Learning
📰
Optimizers in Deep Learning: From Gradient Descent to Adam
Learn how optimizers in deep learning work, from basic Gradient Descent to advanced Adam optimizer, to improve model training
Medium · Deep Learning
📰
The Meta-Architecture of Interface Fracture: High-Dimensional Logical Stress and Systemic Collapse…
Learn about the meta-architecture of interface fracture and its relation to high-dimensional logical stress and systemic collapse in deep learning systems
Medium · Deep Learning

Chapters (5)

Introduction
0:56 Modifying Generator and Discriminator
6:58 Modifying Gradient Penalty
7:35 Modifying Training
10:43 Evaluation & Ending
Up next
Image Classification with ml5.js
The Coding Train
Watch →