How to save and load models in Pytorch

Aladdin Persson · Intermediate ·🧬 Deep Learning ·6y ago

Key Takeaways

This video demonstrates how to save and load models in PyTorch, specifically using the state dictionary and optimizer to create checkpoints.

Full Transcript

[Music] let's say we're at the point where we get a model to work and it's training and we want to be able to save the model and continuing the training at another point so for this we need to be to figure out how to save and load a check point of the model in this case I've got a very simple CNN network that's training on the M&S dataset what we want to do is we want to be able to let's say check point we're gonna store it as a dictionary it's called State dictionary and model dot state dictionary and also not only the model parameters we're also gonna store the optimizer you can store more things as well like the epoch the best accuracy and whatever information you want let's just keep it simple and we won't store the model parameters and the optimizer in this case then we're gonna do optimizer the state dictionary and then we want to do is let's say if epoch is equal to two you can do whatever here if maybe the fifth epoch or something like that then we want to [Music] then we want to call to some function save checkpoint with the checkpoint with the dictionary we created so let's create that function define nine checkpoint and the checkpoint is gonna take some state so the dictionary we created here and it's gonna also outputs to some file let's put it call it my checkpoint and use the standard convention or using dot PTH tatar so let's do some print saving checkpoint and then we're going to use torch dot save state at filing okay right so see checkpoint to checkpoint yeah so let's try run this undefined serger yeah okay it's gonna be a save checkpoint and this might take a while so I'll just continue when it's done yeah so now we're trying to epochs and we see saving checkpoint and let's actually see yeah so if I take up that folder it's going to show my checkpoint and that's the file now next point let's say we want to actually load it yeah we wanted create another function on the final load checkpoint from a checkpoint and we're gonna do print loading checkpoint and then we're just gonna do model dot load State dictionary quick point from state ticked and then pretty much the same thing but for the optimizer and again if you save more things in the checkpoint like accuracy or epoch or whatever you're gonna have to take this from the dictionary as well so for example if we would have checkpoint of best current accuracy or something you would call it like this but we only have the state dictionary and the optimizer then let's say we're gonna have another high parameter let's say we have load model is true so we're gonna do after we've initialized the model in optimizer we're gonna do load model of of torch dot load my chick points of the file that we created whatever you called that one and that's all so we can do if load model then load checkpoint that's what we call that function so now it should load model and if the epoch is two is gonna also save the a checkpoint but let's see let's say that we want to I don't know store it every third epoch so we can do epoch modulus three equals zero then it's going to create another checkpoint and save the checkpoint for example you could also do you could check the accuracy see if the accuracy is better than some best accuracy then you could save the model there are multiple ways of doing it let's say we just want to do it in this simple way I'm gonna let it rain for a while just so we can see all of it alright so it trained for 10 epochs and we can see that I in the beginning it loaded the check point and then it also saved the Check Point because epoch zero module three is zero so then we ran for three epochs or train for three up ox and we saved checkpoint another three save another three save so it seems to be working let's see so now it's 0.4 T three the mean loss for that epoch let's say that we now rerun it and we can see so it restarted with about the same lost value as the previous one right so this means that it's continuing from this point rather than restarting if for example we set load model to false then we see that it restarts right the loss is much higher now one thing to be careful of is now when we set load model to false it now when it's shaved this check point it actually over it over writes the previous file so you have to be cautious of that not to train for a long time and then rewrite overwrite your checkpoint file yeah if you have any questions about this leave them in the comment section thank you for watching the video and hope T in the next one

Original Description

Let's say you have a model that is working but now you want to be able to save a checkpoint and load it to continue training at a later point. In this video I walkthrough an example of how to do it! ❤️ Support the channel ❤️ https://www.youtube.com/channel/UCkzW5JSFwvKRjXABI-UTAkQ/join Paid Courses I recommend for learning (affiliate links, no extra cost for you): ⭐ Machine Learning Specialization https://bit.ly/3hjTBBt ⭐ Deep Learning Specialization https://bit.ly/3YcUkoI 📘 MLOps Specialization http://bit.ly/3wibaWy 📘 GAN Specialization https://bit.ly/3FmnZDl 📘 NLP Specialization http://bit.ly/3GXoQuP ✨ Free Resources that are great: NLP: https://web.stanford.edu/class/cs224n/ CV: http://cs231n.stanford.edu/ Deployment: https://fullstackdeeplearning.com/ FastAI: https://www.fast.ai/ 💻 My Deep Learning Setup and Recording Setup: https://www.amazon.com/shop/aladdinpersson GitHub Repository: https://github.com/aladdinpersson/Machine-Learning-Collection ✅ One-Time Donations: Paypal: https://bit.ly/3buoRYH ▶️ You Can Connect with me on: Twitter - https://twitter.com/aladdinpersson LinkedIn - https://www.linkedin.com/in/aladdin-persson-a95384153/ Github - https://github.com/aladdinpersson
Watch on YouTube ↗ (saves to browser)
Sign in to unlock AI tutor explanation · ⚡30

Playlist

Uploads from Aladdin Persson · Aladdin Persson · 35 of 60

1 computeCost.m Linear Regression Cost Function - Machine Learning
computeCost.m Linear Regression Cost Function - Machine Learning
Aladdin Persson
2 gradientDescent.m Gradient Descent Implementation -  Machine Learning
gradientDescent.m Gradient Descent Implementation - Machine Learning
Aladdin Persson
3 Neural Network from scratch - Part 1 (Standard Notation)
Neural Network from scratch - Part 1 (Standard Notation)
Aladdin Persson
4 Neural Network from scratch - Part 2 (Forward Propagation)
Neural Network from scratch - Part 2 (Forward Propagation)
Aladdin Persson
5 Neural Network from scratch - Part 3 (Backward Propagation)
Neural Network from scratch - Part 3 (Backward Propagation)
Aladdin Persson
6 Neural Network from scratch - Part 4 (With Python)
Neural Network from scratch - Part 4 (With Python)
Aladdin Persson
7 sigmoid.m - Programming Assignment 2 Machine Learning
sigmoid.m - Programming Assignment 2 Machine Learning
Aladdin Persson
8 costFunction.m - Programming Assignment 2 Machine Learning
costFunction.m - Programming Assignment 2 Machine Learning
Aladdin Persson
9 predict.m - Programming Assignment 2 Machine Learning
predict.m - Programming Assignment 2 Machine Learning
Aladdin Persson
10 costFunctionReg.m - Programming Assignment 2 Machine Learning
costFunctionReg.m - Programming Assignment 2 Machine Learning
Aladdin Persson
11 lrCostFunction.m - Programming Assignment 3 Machine Learning
lrCostFunction.m - Programming Assignment 3 Machine Learning
Aladdin Persson
12 oneVsAll.m - Programming Assignment 3 Machine Learning
oneVsAll.m - Programming Assignment 3 Machine Learning
Aladdin Persson
13 predictOneVsAll.m - Programming Assignment 3 Machine Learning
predictOneVsAll.m - Programming Assignment 3 Machine Learning
Aladdin Persson
14 predict.m - Programming Assignment 3 Machine Learning
predict.m - Programming Assignment 3 Machine Learning
Aladdin Persson
15 Caesar Cipher Encryption and Decryption with example
Caesar Cipher Encryption and Decryption with example
Aladdin Persson
16 Cryptography: Caesar Cipher Python
Cryptography: Caesar Cipher Python
Aladdin Persson
17 Vigenere Cipher Explained (with Example)
Vigenere Cipher Explained (with Example)
Aladdin Persson
18 Cryptography: Vigenere Cipher Python
Cryptography: Vigenere Cipher Python
Aladdin Persson
19 Hill Cipher Explained (with Example)
Hill Cipher Explained (with Example)
Aladdin Persson
20 Cryptography: Hill Cipher Python
Cryptography: Hill Cipher Python
Aladdin Persson
21 Interval Scheduling Greedy Algorithm: Python
Interval Scheduling Greedy Algorithm: Python
Aladdin Persson
22 Weighted Interval Scheduling Algorithm Explained
Weighted Interval Scheduling Algorithm Explained
Aladdin Persson
23 Weighted Interval Scheduling Python Code
Weighted Interval Scheduling Python Code
Aladdin Persson
24 Sequence Alignment | Needleman Wunsch Algorithm
Sequence Alignment | Needleman Wunsch Algorithm
Aladdin Persson
25 Sequence Alignment | Needleman Wunsch in Python
Sequence Alignment | Needleman Wunsch in Python
Aladdin Persson
26 Codility BinaryGap Python
Codility BinaryGap Python
Aladdin Persson
27 Codility CyclicRotation Python
Codility CyclicRotation Python
Aladdin Persson
28 Derivation Linear Regression with Gradient Descent
Derivation Linear Regression with Gradient Descent
Aladdin Persson
29 Linear Regression Gradient Descent From Scratch in Python
Linear Regression Gradient Descent From Scratch in Python
Aladdin Persson
30 Pytorch Neural Network example
Pytorch Neural Network example
Aladdin Persson
31 Pytorch CNN example (Convolutional Neural Network)
Pytorch CNN example (Convolutional Neural Network)
Aladdin Persson
32 Pytorch LeNet implementation from scratch
Pytorch LeNet implementation from scratch
Aladdin Persson
33 Pytorch VGG implementation from scratch
Pytorch VGG implementation from scratch
Aladdin Persson
34 Pytorch GoogLeNet / InceptionNet implementation from scratch
Pytorch GoogLeNet / InceptionNet implementation from scratch
Aladdin Persson
How to save and load models in Pytorch
How to save and load models in Pytorch
Aladdin Persson
36 How to build custom Datasets for Images in Pytorch
How to build custom Datasets for Images in Pytorch
Aladdin Persson
37 Pytorch Transfer Learning and Fine Tuning Tutorial
Pytorch Transfer Learning and Fine Tuning Tutorial
Aladdin Persson
38 Pytorch Data Augmentation using Torchvision
Pytorch Data Augmentation using Torchvision
Aladdin Persson
39 Pytorch Quick Tip: Weight Initialization
Pytorch Quick Tip: Weight Initialization
Aladdin Persson
40 Pytorch Quick Tip: Using a Learning Rate Scheduler
Pytorch Quick Tip: Using a Learning Rate Scheduler
Aladdin Persson
41 Pytorch ResNet implementation from Scratch
Pytorch ResNet implementation from Scratch
Aladdin Persson
42 Pytorch TensorBoard Tutorial
Pytorch TensorBoard Tutorial
Aladdin Persson
43 Pytorch DCGAN Tutorial (See description for updated video)
Pytorch DCGAN Tutorial (See description for updated video)
Aladdin Persson
44 Naive Bayes from Scratch - Machine Learning Python
Naive Bayes from Scratch - Machine Learning Python
Aladdin Persson
45 Spam Classifier using Naive Bayes in Python
Spam Classifier using Naive Bayes in Python
Aladdin Persson
46 K-Nearest Neighbor from scratch - Machine Learning Python
K-Nearest Neighbor from scratch - Machine Learning Python
Aladdin Persson
47 Linear Regression Normal Equation Python
Linear Regression Normal Equation Python
Aladdin Persson
48 SVM from Scratch - Machine Learning Python (Support Vector Machine)
SVM from Scratch - Machine Learning Python (Support Vector Machine)
Aladdin Persson
49 Neural Network from Scratch - Machine Learning Python
Neural Network from Scratch - Machine Learning Python
Aladdin Persson
50 Pytorch RNN example (Recurrent Neural Network)
Pytorch RNN example (Recurrent Neural Network)
Aladdin Persson
51 Pytorch Bidirectional LSTM example
Pytorch Bidirectional LSTM example
Aladdin Persson
52 Pytorch Text Generator with character level LSTM
Pytorch Text Generator with character level LSTM
Aladdin Persson
53 Logistic Regression from Scratch - Machine Learning Python
Logistic Regression from Scratch - Machine Learning Python
Aladdin Persson
54 K-Means Clustering from Scratch - Machine Learning Python
K-Means Clustering from Scratch - Machine Learning Python
Aladdin Persson
55 Pytorch Torchtext Tutorial 1: Custom Datasets and loading JSON/CSV/TSV files
Pytorch Torchtext Tutorial 1: Custom Datasets and loading JSON/CSV/TSV files
Aladdin Persson
56 Pytorch Torchtext Tutorial 2: Built in Datasets with Example
Pytorch Torchtext Tutorial 2: Built in Datasets with Example
Aladdin Persson
57 Pytorch Torchtext Tutorial 3: From Textfiles to Dataset
Pytorch Torchtext Tutorial 3: From Textfiles to Dataset
Aladdin Persson
58 Paper Review: Sequence to Sequence Learning with Neural Networks
Paper Review: Sequence to Sequence Learning with Neural Networks
Aladdin Persson
59 Pytorch Seq2Seq Tutorial for Machine Translation
Pytorch Seq2Seq Tutorial for Machine Translation
Aladdin Persson
60 Pytorch Seq2Seq with Attention for Machine Translation
Pytorch Seq2Seq with Attention for Machine Translation
Aladdin Persson

This video teaches how to save and load models in PyTorch, allowing for continuation of training at a later point. It covers creating checkpoints using the state dictionary and optimizer, and loading these checkpoints to resume training.

Key Takeaways
  1. Create a state dictionary with model parameters and optimizer state
  2. Define a function to save checkpoints
  3. Call the save function at desired intervals (e.g. every few epochs)
  4. Define a function to load checkpoints
  5. Load the checkpoint before training
  6. Resume training from the loaded checkpoint
💡 Saving and loading models in PyTorch allows for efficient continuation of training, but be cautious of overwriting previous checkpoint files.

Related Reads

📰
Understanding Deep Learning Through Four Interactive Experiments
Explore deep learning concepts through interactive experiments to gain hands-on understanding
Medium · Data Science
📰
Understanding Deep Learning Through Four Interactive Experiments
Explore deep learning through interactive experiments to gain hands-on understanding
Medium · Deep Learning
📰
Optimizers in Deep Learning: From Gradient Descent to Adam
Learn how optimizers in deep learning work, from basic Gradient Descent to advanced Adam optimizer, to improve model training
Medium · Deep Learning
📰
The Meta-Architecture of Interface Fracture: High-Dimensional Logical Stress and Systemic Collapse…
Learn about the meta-architecture of interface fracture and its relation to high-dimensional logical stress and systemic collapse in deep learning systems
Medium · Deep Learning
Up next
Image Classification with ml5.js
The Coding Train
Watch →