Knowledge Distillation - Keras Code Examples

Connor Shorten · Intermediate ·🧠 Large Language Models ·5y ago

Key Takeaways

This video demonstrates Knowledge Distillation using Keras, where a teacher network is used to label data for a student network, and provides code examples for implementing this technique. The video covers topics such as model compression, self-training, and transformer neural networks.

Full Transcript

welcome to the henry ai labs walkthrough of keras code examples keras has provided 56 code examples implementing popular ideas in deep learning this ranges from the basics such as simple mnist and imdb text classification all the way to cutting edge research ideas such as knowledge distillation supervised contrastive learning and transformers we'll also explore fun generative examples like variational autoencoders and cyclegan my contribution to these code examples is to explain every single line of code in each of them walking through each of the individual keras examples i'm not the author of these code examples please consider starting the github repositories to show support to the original authors i'm really excited to present this code example implementing knowledge distillation in keras knowledge distillation is the idea of having a teacher network label data for the learning of a student network so instead of the student network fitting one hot encoded class label vectors it's fitting this soft distribution that comes from the logits of the teacher network and this has been really useful for model compression in papers like distilled burp the idea is you take a large high high-capacity model like the burt transformer neural network with 300 million parameters and then you train it and you use it to label data for a lower capacity model the distilberry model to learn from those labels and this results in transferring the this uh knowledge that's learned from the higher capacity model into the lower capacity model which is useful for reducing storage size uh inference speed and overall making this more accessible for people who are maybe using just a single gpu or just not the computing resources of what bert is trained with or say gpt-3 or most of these large language models are really you know really massive models that require a lot of computational resources knowledge distillation has also been useful in just achieving state-of-the-art performance in self-training with noisy student improves image net classification they actually scale this up so the student network is actually a larger capacity model than the teacher network and they iterate through rounds of this and it achieves the state of the art in imagenet when the paper was released knowledge distillation has also played an important role in bringing the transformer neural network to computer vision in data efficient image transformers from facebook ai research they use knowledge distillation to transfer the soft target distribution from the convolutional neural network into the transformer so they have this teacher student loop where the teacher is a convolutional neural network and then the student is the vision transformer and this has resulted in state-of-the-art performance in adapting the vision transformer or the transformer neural network to image classification or say the object detection tasks that are frequently benchmarked with deep learning and computer vision so this code example is going to show you how to implement knowledge installation for yourself one of the most exciting ideas in deep learning and plenty of research opportunities to see what can be achieved with knowledge distillation we begin with the standard imports tensorflow keras keras layers and numpy so now we're going to be constructing a distiller class as stated in this documentation describing the class we're going to have a train step test step and compile and we're going to have a teacher model that's already been trained a student model to train the loss function where we implement this custom loss which has the difference between these two predictions with the student teacher predictions and the student predictions and then we're going to have the alpha factor to weight the loss because in knowledge distillation you don't only train on the teacher's labels you're also going to train on the one hot encoded labels as well and you're going to weight these two loss functions with alpha so if you've never implemented a custom loss function or some kind of deep learning training framework where you have two loss functions this is a great place to start because we're going to be doing this implementation where we have loss equals self.alpha times this one loss plus one minus self.alpha times this other loss so a lot of these papers like just to take a random example if you've read gamegan the way that they learn gamegan and the way they put pacman inside of again this research from nvidia is they have several of these loss functions and it's common in things like cyclogan the recent extensions to it where you have these frameworks there's many papers like this where they have multiple loss functions and they're weighted by these hyper parameters like alpha so alpha would be an example of a hyper parameter of this knowledge distillation system so we define our optimizer and then metrics to evaluate the performance our distiller class is going to inherit the properties of the dot model so kiras has these built-in things we've seen many examples in the chaos example playlist of overriding keras layers in this case we're overriding the keras.model itself and we're returning a model and this is useful because we're going to embed all of our training logic within this model and it's a different way of structuring this kind of customized loss function and training curriculum that'll be more clear as we see the whole entire uh example so in the initializer we define our teacher our student the teacher is a pre-trained model and the student we're going to be training from scratch with this new loss function then we pass in the compile argument compiles an argument with these charas.model you have things like model.compile then you pass in the optimizer the metrics and so on so in this case we're overwriting our compile function to have our optimizer something like atom sgd and so on and we have our metrics which is like things like uh categorical cross entropy or whatever we're going to be measuring performance with then we have our loss function and then we have this customized distillation loss function that is the reason that we're doing all this so we have the alpha that's our hyper parameter and temperature is another hyper parameter what temperature does is it smoothens out the distribution that comes from the teacher so the teacher might be really well trained and it might put a massive density on one class and then no probability on the others because it's been trying to do this with the one hot encoded class vectors so we might say put 99 on cat if you're doing c410 and then like you know .001 percent on the other class labels but temperature is going to smoothen that out so we can get more of the prior knowledge about how much density the teacher model has put on each of these other classes so now we're going to we inherit the super.compile that as we're inheriting this character model object in the argument to our distiller class we pass in these optimizer paths on the metrics and now we're adding our custom logic with the student loss function distillation loss function and then the alpha and temperature hyperparameters so now we're overriding the train step so first we take in our data as we're a train step the argument is our data loader whether it's the keras sequence object or if it's just a fitting numpy arrays or a tf.data object whatever it is is going to pass us a batch of data so we're going to unpack the data into x and y variables and we might use this say we had say we had a data loader like in contrastive learning that gives us x x prime all sorts of things that we can overwrite these custom training steps to unpack this data and then do customized things with it as we're implementing these research papers into this code that overwrites kind of the generic frameworks so now we're going to do the forward pass of the teacher so we have this pre-trained teacher network this could be something like the kerastat applications resnet or any kind of pre-trained model and we're going to get these inferences by doing self the teacher and then we do the calling by just passing in the x's that is this batch of data that we got from the data loader and then we set training equals false because there's no gradients that are going to go back through this teacher network it's only being used to label the data and then using these labels as targets for the actual training the training goes through shown here of our student network now we're getting into the more interesting part where we're applying our training so we start off with this loop of with tf.gradient tape as tape this is one way of getting the gradients you see how later on in the code we're going to do gradients equals tape.gradient we can access the gradients in this tape variable and this is one way of accessing the gradients seeing them and then having this way of apply the gradients to our trainable variables this is just another way of overriding the gradients and the way that this works you can maybe print out your gradients log them to a file you have more customizability with respect to whatever it is that you may be doing or whatever kind of thing that you're adapting this to so we loop through this first we do our forward pass with the students with the x's originally from the unpacking the data then we have the loss function this is our these are one hot encoded class label vectors these y's so we're going to first apply say sparse categorical cross entropy to these y's with our student predictions and then we're also going to do the distillation loss with the teacher labels and the student labels so the way that we're doing this is we're taking the the raw logits so what i learned yesterday is the student predictions when you do this loss function if you use sparse categorical cross entropy from logic equals true it's going to apply the soft max soft max is this function where you do e to the activation divided by the sum of e to all the other activations as you're making predictions on each index in a class label vector so what you're doing is you you don't have to actually pass this final output through a softmax if you do sparse categorical cross entropy equals true so i just little thing i hope that didn't distract too much what we're doing now is we we apply the softmax to the teacher prediction and then we're adding in our temperature parameter by doing this instead of say having the sparse categorical cross entropy from lodges equals true between these two just raw predictions because we want to add in our customized temperature so hopefully now that i'm saying that it's tying all this together the reason that we're applying softmax here in the distillation loss is because we want to add in the influence of this hyperparameter self.temperature and again temperature and alpha are the two big hyper parameters that are controlling knowledge distillation alpha with respect to how much we were weighting each loss function and then temperature with respect to how we're smoothing out the distribution from the teacher higher capacity or in the case of self-training with noisy student just just any kind of teacher student loop and the way that we smoothen out the distribution so we can get more of the prior knowledge from the previously trained network now that we've defined these two loss functions we tie them in together with our alpha hyperparameter the alpha weights each of these losses and then it's stored in this loss object then we define the trainable variables equals self.student.trainablevariables our gradients is the tape.gradients tape from with tf.grading tape as tape tape.gradients and then loss and then with respect to these training variables so we pass in this loss with our computational graph and then we try pass in the variables that we want to update with our backpropagation step so now we're going to update the weights by calling apply gradients with the optimizer say this is sgd or atom or something like that and we're going to zip together the gradients and the trainable variables so this is how you do this custom tf.gradient tape and apply these gradients in the training and you can maybe imagine waiting this holding on to these gradients and applying it a second time all sorts of creative decisions you can do with respect to trying to come up with research ideas for deep neural network training and the flexibility offered to you by tf.gradient tape and this kind of functionality so then what we have is um we have this metrics object as we're overwriting a keras.models and it returns this metrics thing say you're doing something like history equals model.fit and you want to index this history and it has an updated dictionary of all these metrics so all this is doing is it's just updating the state with the metrics you pass in the y labels student predictions say this metric is probably something like accuracy accuracy is the most common thing to store in one of these metrics so just updating the accuracy and as you uh as you know how you do model.fit and it's going to progressively tell you what's happening it's going to be reporting the accuracy and this is the code that you need to customly overwrite to see that since we're overwriting this kerastop model object so now we're looking at the test step this is when you do something like model.evaluate and you pass in x test y test so all this is doing is you unpack the data you run a forward pass on the student network you do the student loss function this is not the distillation loss function this is something like sparse categorical cross entropy from logit and then you update the metrics with the accuracy comparing the y and the y prediction and then you return the dictionary with the loss for reporting when you do something like model.evaluate and you see that one line that's something like student loss or training accuracy validation accuracy something like that so overriding this custom distiller class is really the meat of what we're going to be learning in this chaos example but here's the rest of this that ties it together so the next thing to do is to define the teacher network and the student network these can arbitrarily be any kinds of neural networks in this case we have a higher capacity teacher with 256 and 512 that is the hidden dimension of the number of features in the convolutional network then we have 16 and 32 in the student this is an example of model compression we're going to be compressing this teacher network into the student through the use of knowledge distillation and having this student network predict the outputs that come out of this dense layer picture here trying to match these two between the teacher and the student that's kind of the idea of knowledge installation and then we have the student scratch equals keras models like clone model student this is we're going to be cloning the model because later on we want to compare the accuracy when we do knowledge installation compared to just training the student network from scratch and here's a neat trick to do that another way to do it might be to save the weights and then load the weights from a checkpoint but here's another way of doing it if you're curious the next step in training any deep neural network is to define a data pipeline in this case we're using the mnist dataset and where you load it easily by just using keras the datasets.mnist.lowdata we normalize the x's by dividing them by 255. this keeps them between the interval of zero and one and then we reshape them to be compatible with the image organization i think originally it's maybe 784 by one vectors or something like that i'm not sure exactly why i have to reshape the mnist data but that's what this is doing is reshaping the tensors of the original shape of the mnist data that's built into kirasa datasets so the next step is to train the higher capacity teacher network so we do this by first compiling the teacher model with the atom optimizer this uh this is the thing that i was talking about earlier the sparse categorical cross entropy from logic equals true this means we don't have to put the softmax layer at the end of this it's going to automatically apply it to this output of this these uh dents 10 logits and it's going to apply that and adapt it into sparse categorical cross entropy loss and then we see how we pass in this custom metrics relevant to our discussion on overriding the character model with our distiller this is the output now that we've passed in this custom metric using not just accuracy but sparse categorical accuracy by passing in this metric so we do teacher.fit and teacher.evaluate and we step through the data i think we run through it about four epochs we have five epochs and we end up with this loss with about 98 accuracy mnist is is pretty easy to model with these convolutional networks and this is the evaluation of our x test y test where we get the raw loss and then we get the sparse categorical accuracy as the output from this test step as we see when we override our distiller class we have this test step as well this is what you're doing with respect to model.evaluate or as it's iteratively stepping through with a validation set in something like model.fit where you also pass in validation data so now that we've trained our teacher network we have a pre-trained teacher network to pass in as an argument to our distiller kerasta model so we define the distiller with the student network that we defined previously with the 16 and 32 filters in the convolutional layers in our pre-trained teacher network so we compile it passing in the atom optimizer sparse categorical cross entropy for the metrics on the labeled mnist data set with the one hot encoded labels for the handwritten digits of zero one two three four five six seven eight nine and then we have the distillation loss so here's a really interesting thing we're doing the kl divergence between the teacher's probability distribution of the logits and the student's probability distribution so kobach liblar divergence shown in this wikipedia page it's one of these metrics to compare probability distributions it's interesting to note that it's asymmetric you wouldn't get the same value of doing p distance with q compared to q distance with p and the way that you calculate this is by looping through each of the logits so you have 10 logits in the case of mnis so we're comparing the probability density that p put on say zero and q put on zero then p put on a handwritten digit label one q on one and so on until we get to nine and this is and we sum this up and this is the callback libel or divergence for comparing the difference between two probability distributions in this case this teacher network and the student network so the hyper parameters we're exploring are alpha the weighting of the sparse categorical loss with the ground truth one hot encoded labels and then the kl divergence between the teacher and student probability distributions then we have temperature equals 10 to smoothing out the distribution produced by the teacher network so now we call fit passing in our data and then evaluate passing in our data as well and so then after three epochs we achieve 97.5 percent accuracy similar to the test teacher model it's really going to be this is this data set of using mnist is a bad proxy for the actual metric so i wouldn't really look at the performance i'd really just try to understand this code for designing your own custom experiments the all of these models are going to fit mnist basically perfectly so the tutorial concludes by using the student scratch that we got from our initial cloning where we used this syntax of students graduals keras models clone model passing in the student model to have the same initialization and then we also see the performance of training it from scratch without any knowledge installation but again these metrics in my opinion are kind of useless because all these models can fit mnist perfectly but to summarize i think from this tutorial you should have a really good sense of how to implement a custom class that has this custom kind of loss function and implements custom training steps and two loss functions particularly with this weighting of the alpha i think just if there's anything to take away i think it's really these four lines of code or generally this class of distiller so hopefully from this you have all the tools you need to do any kind of experimentation with knowledge distillation so thanks for watching and please subscribe to henry ai labs for more deep learning and ai videos and please check out the rest of the keras code examples playlist [Music]

Original Description

This Keras Code Examples show you how to implement Knowledge Distillation! Knowledge Distillation has lead to new advances in compression, training state of the art models, and stabilizing Transformers for Computer Vision. All you need to do to build on this is swap out the Teacher and Student architectures. I think the example of how to overwrite keras.Model and integrate two loss functions controlled with an alpha hyperparameter weighting is very useful as well. Content Links Knowledge Distillation (Keras Code Examples): https://keras.io/examples/vision/knowledge_distillation/ DistilBERT: https://arxiv.org/pdf/1910.01108.pdf Self-Training with Noisy Student: https://arxiv.org/pdf/1911.04252.pdf Data-efficient Image Transformers: https://ai.facebook.com/blog/data-efficient-image-transformers-a-promising-new-technique-for-image-classification/ KL Divergence: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence 0:00 Beginning 0:44 Motivation, Success Stories 2:47 Custom keras.Model 11:18 Teacher and Student models 12:17 Data Loading, Train the Teacher 14:05 Distill Teacher to Student
Watch on YouTube ↗ (saves to browser)
Sign in to unlock AI tutor explanation · ⚡30

Playlist

Uploads from Connor Shorten · Connor Shorten · 0 of 60

← Previous Next →
1 DenseNets
DenseNets
Connor Shorten
2 DeepWalk Explained
DeepWalk Explained
Connor Shorten
3 Inception Network Explained
Inception Network Explained
Connor Shorten
4 StackGAN
StackGAN
Connor Shorten
5 StyleGAN
StyleGAN
Connor Shorten
6 Progressive Growing of GANs Explained
Progressive Growing of GANs Explained
Connor Shorten
7 Improved Techniques for Training GANs
Improved Techniques for Training GANs
Connor Shorten
8 Word2Vec Explained
Word2Vec Explained
Connor Shorten
9 Must Read Papers on GANs
Must Read Papers on GANs
Connor Shorten
10 Unsupervised Feature Learning
Unsupervised Feature Learning
Connor Shorten
11 Self-Supervised GANs
Self-Supervised GANs
Connor Shorten
12 Embedding Graphs with Deep Learning
Embedding Graphs with Deep Learning
Connor Shorten
13 Transfer Learning in GANs
Transfer Learning in GANs
Connor Shorten
14 ReLU Activation Function
ReLU Activation Function
Connor Shorten
15 AC-GAN Explained
AC-GAN Explained
Connor Shorten
16 SimGAN Explained
SimGAN Explained
Connor Shorten
17 DC-GAN Explained!
DC-GAN Explained!
Connor Shorten
18 ResNet Explained!
ResNet Explained!
Connor Shorten
19 Graph Convolutional Networks
Graph Convolutional Networks
Connor Shorten
20 Neural Architecture Search
Neural Architecture Search
Connor Shorten
21 Henry AI Labs
Henry AI Labs
Connor Shorten
22 Video Classification with Deep Learning
Video Classification with Deep Learning
Connor Shorten
23 BigGANs in Data Augmentation
BigGANs in Data Augmentation
Connor Shorten
24 Introduction to Deep Learning
Introduction to Deep Learning
Connor Shorten
25 EfficientNet Explained!
EfficientNet Explained!
Connor Shorten
26 Self-Attention GAN
Self-Attention GAN
Connor Shorten
27 Curriculum Learning in Deep Neural Networks
Curriculum Learning in Deep Neural Networks
Connor Shorten
28 Deep Learning Podcast #1 | Edward Dixon | Stochastic Weight Averaging
Deep Learning Podcast #1 | Edward Dixon | Stochastic Weight Averaging
Connor Shorten
29 Deep Compression
Deep Compression
Connor Shorten
30 Skin Cancer Classification with Deep Learning
Skin Cancer Classification with Deep Learning
Connor Shorten
31 Deep Learning Podcast #2 | Edward Peake | Deep Learning in Medical Imaging
Deep Learning Podcast #2 | Edward Peake | Deep Learning in Medical Imaging
Connor Shorten
32 The Lottery Ticket Hypothesis Explained!
The Lottery Ticket Hypothesis Explained!
Connor Shorten
33 SqueezeNet
SqueezeNet
Connor Shorten
34 GauGAN Explained!
GauGAN Explained!
Connor Shorten
35 AutoML with Hyperband
AutoML with Hyperband
Connor Shorten
36 DL Podcast #3 | Yannic Kilcher | Population-Based Search
DL Podcast #3 | Yannic Kilcher | Population-Based Search
Connor Shorten
37 Weakly Supervised Pretraining
Weakly Supervised Pretraining
Connor Shorten
38 Image Data Augmentation for Deep Learning
Image Data Augmentation for Deep Learning
Connor Shorten
39 Unsupervised Data Augmentation
Unsupervised Data Augmentation
Connor Shorten
40 Wide ResNet Explained!
Wide ResNet Explained!
Connor Shorten
41 RevNet: Backpropagation without Storing Activations
RevNet: Backpropagation without Storing Activations
Connor Shorten
42 GANs with Fewer Labels
GANs with Fewer Labels
Connor Shorten
43 BigBiGAN Unsupervised Learning!
BigBiGAN Unsupervised Learning!
Connor Shorten
44 Self-Supervised Learning
Self-Supervised Learning
Connor Shorten
45 Multi-Task Self-Supervised Learning
Multi-Task Self-Supervised Learning
Connor Shorten
46 Self-Supervised GANs
Self-Supervised GANs
Connor Shorten
47 Population Based Training
Population Based Training
Connor Shorten
48 Show, Attend and Tell
Show, Attend and Tell
Connor Shorten
49 Siamese Neural Networks
Siamese Neural Networks
Connor Shorten
50 WaveGAN Explained!
WaveGAN Explained!
Connor Shorten
51 VAE-GAN Explained!
VAE-GAN Explained!
Connor Shorten
52 Evolution in Neural Architecture Search!
Evolution in Neural Architecture Search!
Connor Shorten
53 AI Research Weekly Update August 18th, 2019
AI Research Weekly Update August 18th, 2019
Connor Shorten
54 Weight Agnostic Neural Networks Explained!
Weight Agnostic Neural Networks Explained!
Connor Shorten
55 AI Research Weekly Update August 25th, 2019
AI Research Weekly Update August 25th, 2019
Connor Shorten
56 Neuroevolution of Augmenting Topologies (NEAT)
Neuroevolution of Augmenting Topologies (NEAT)
Connor Shorten
57 CoDeepNEAT
CoDeepNEAT
Connor Shorten
58 AI Research Weekly Update September 1st, 2019
AI Research Weekly Update September 1st, 2019
Connor Shorten
59 Randomly Wired Neural Networks
Randomly Wired Neural Networks
Connor Shorten
60 Genetic CNN
Genetic CNN
Connor Shorten

This video teaches how to implement Knowledge Distillation using Keras, which is a technique used for model compression and self-training. The video provides code examples and demonstrates how to use a teacher network to label data for a student network.

Key Takeaways
  1. Define a custom loss function for Knowledge Distillation
  2. Compile the model with a custom optimizer and metrics
  3. Train the model with the custom loss function
  4. Use temperature to smooth out the teacher's distribution
  5. Override the train step to customize data unpacking and processing
  6. Use a pre-trained teacher network for inference and labeling
  7. Apply gradients to trainable variables using tf.GradientTape
💡 Knowledge Distillation can be used to compress models and improve their performance by using a teacher network to label data for a student network.

Related Reads

Chapters (6)

Beginning
0:44 Motivation, Success Stories
2:47 Custom keras.Model
11:18 Teacher and Student models
12:17 Data Loading, Train the Teacher
14:05 Distill Teacher to Student
Up next
5 Levels of AI Agents - From Simple LLM Calls to Multi-Agent Systems
Dave Ebbelaar (LLM Eng)
Watch →