How to Train a Custom Resnet34 Image Classification Model

Roboflow · Beginner ·👁️ Computer Vision ·5y ago

Key Takeaways

This video demonstrates how to train a custom ResNet34 image classification model using the FastAI and PyTorch frameworks, with data pre-processing and augmentation using RoboFlow, and model training with early stopping and fine-tuning.

Full Transcript

hey guys it's jacob from roboflow have you ever wondered can a computer tell the difference between a daisy and a dandelion well today we put those philosophical musings to the side and we get hands-on with training a custom image classification model to tell the difference between a daisy and a dandelion today in order to do that we'll be training our custom image classification model in fastai with underlying pi torch framework and we'll be using a resnet 34 model um before we get into the fun stuff which will be when we dive in to the modeling here and actually writing the code to train our model the important thing is to nail down our data set so here you can see i have pictures of dandelions and daisies and we have about 1700 of these images and before we port these into the collab notebook which we'll be using for training we will be using roboflow to use data management pre-processing and augmentation so we can make sure that our data set is in its best state before importing it into the model and sending it through the model so here we have our flowers classification data set this has already been loaded into rebel flow you can use this by going to roboflow's public datasets and copying the same one if you want to follow along directly with our tutorial otherwise you can go ahead and upload your own data um this is very quick and easy to do if you have uh your images in a folder structure um where each folder has the images class but there's more details on that in the blog below um so kind of diving into our data set here we have um over 1800 images these are all daisies and dandelions and we've split them automatically into a train valid and test of 70 2010 and you can see here that we have uh some options where we can basically [Music] like add pre-processing options and augmentation object options to our data set um to make it larger so here we have some augmentation options that we can do before we make our data set version all of these might be of use but certainly things like rotation uh cropping and blurring this is all to kind of simulate the ways that uh your images might manifest themselves differently than they are in the training set so they might be slightly rotated they might be slightly zoomed in or out they might have a different exposure all these things that you can simulate with augmentation so in this augmented version i've made a number of augmentations here you can see there's crop rotate brightness all these things and we're generating five images per base training image so this this gives us a larger data set will prevent our model from overfitting and you can go ahead and take your data in here make some steps um and then when you're ready you can go ahead and hit generate um to generate a new version and this will uh basically send your images uh into our back end where all the pre-processing and augmentations will happen and a new version will be will be formulated for you so we'll go ahead and give that a second to run here and when it's done you'll receive a link which you will be able to port into the co-lab notebook which i have linked below and you can just import your data in there and everything will automatically uh match up so you can get rate started with training you don't have to worry too much about uh the exact way you're gonna parse data structures and stuff to get your your data into the model um so okay so here we go we've uh now got our uh dataset version created and you can go ahead and here and you hit show download code and this will give us a curl link once the files are zipped uh and then the curl link is what we can then import into the colab notebook for uh data set load so here we go here's an example curl link uh the keys are blacked out but i've got i've already gone ahead and put that into the notebook and downloaded my data um so now we'll we'll just kind of go ahead and jump into the notebook and jump into the training routine okay so now diving in um here you can see we have this notebook for you rebel flow image classification and in order to train our image classification model we're going to be using fastai fastai is essentially a deep learning framework that sits on top of pi torch and makes things really quick and easy to use and uh if you want to learn more we've got blogs on that too so you can go ahead and check out fastai okay we went ahead and installed that then we're gonna import here from fastai.vision everything um that will give us all kinds of libraries and different codes that we'll need to uh to um train our model here's a place here where you import your links so this is where you'll put your data set link uh to train your own custom model on on your own for your own classes um but i've gone ahead and put mine in there and we can go ahead and just check right here what data.class is um well first we have to set up the data set so this logs it into fast ai and resizes it to 224 that's the input size for uh the resonant model that we're going to be using and it normalizes it according to imagenet stats so uh resnet is pre-trained on imagenet so you want to make sure that you um kind of normalize your your images to the same mean and standard deviation um that's present there uh so now we can go ahead and take a look at data.classes so here we have daisy and dandelion are the two kinds and we can use this line of code dot show batch to uh kind of look and make sure that uh it is indeed the case that our data set has loaded incorrectly into the fast fastai training framework so here we can see we've got dandelion daisy daisy dandelion okay looks good looks like it loaded in fine now we'll go ahead and set up our fastai resonant model so here you use the createcnn function to load in a model from the torchvision library so actually we'll jump over here real quick and take a look at that um so this is torchvision's classification model library um there's all kinds of models here so this tutorial is really um more than just a resonant 34 model you can use any of these so you've got resnet34 here that's what we're using but you've also got bigger resonant models and a little bit smaller ones too um so all these models are pre-trained on imagenet so they already have kind of like a broad view of the way to uh formulate uh features from pixels and you can take these pre-trained weights lift them off the shelf from torch vision bring that through fast ai and then use the training procedure that i'm gonna walk through here in the video um but uh here you go you got a resonant 34 that's the one we're using today but you can also branch off if you want to do a little dense net a little google net little mobile net these are all options there's also object detection models in here too but kind of outside of the scope of today's video we're just doing a classification here okay so we'll go ahead and download that so now we're actually going to torch vision and taking that model down and we can also visualize our network here by looking at learn so this prints out all the different layers that are in our our model so pretty nice right you don't have to build that on your own uh the researchers have already done that for you and posted to torchvision so stand on the shoulders of giants download it and let's let's get going okay so now we're going to go ahead and jump into training uh there's two things that i i'm adding here in the training that i think are are quite useful you know rather than just the the plane off the shelf uh so we've got an early stopping callback this means that uh if your model hasn't improved on validation loss for at least 20 epochs and it's just going to cut off training so you don't have to worry about running it for too long to make sure that you're getting the best model the other one is save best model so based on validation loss best model that comes out is going to be the one that gets saved so you can load that back uh later and uh okay so now moving along we're going to go ahead and make sure that our device is hitting the uh kuda drivers you can make sure that that's the case by going up here hitting change runtime type making sure you're hitting the gpu there um but yeah we're hitting the cuda drivers and we're going to go ahead and uh call this method to start training i have 50 coded into the the notebook but we're just going to do one here for youtube um so okay so now we just kicked off training our data set has been loaded in it's gonna go through passing the model um important to note here that we're training on the frozen resonant model so here we're not actually back propagating through all of the layers we're just back propping through the last layer and that last layer will kind of tune in based on the features that are coming out of the base model and then we'll move forward and train on the whole network okay so after one epoch we've gotten a 7.6 error rate that's pretty good for just one obviously that will continue to go down if you uh train for more epochs and here we just loaded in that best model and we're going to unfreeze so unfreeze is basically now we're opening up the network to be back propagated through all the parameters not just the last layer and we're going to find the appropriate learning rate so the learning rate is uh the way to think about it is basically it's the um amount that you're actually updating the weights after you get the back propagation signal and you don't want to set this too high otherwise you might kind of like branch off way off the loss function into some unknown land but you don't want to have it too slow otherwise your model won't really learn anything with any speed um so i have this function here to to get your optimal learning rate for you um so you don't have to worry about setting that yourself and then again here i'm just going to do one epoch so we're doing one epoch unfrozen model now we're kind of fine tuning through the resnet um and there we go looks like we got our error rate down even further now we're at six and a half percent um so model's getting even better but again like i said that that should get better um with more epochs i've seen this one go as low as like a four percent error rate or so okay so now we've loaded back in the best model from all our training our model has been trained um here we can go ahead and start to do a little introspection on our model uh this plots the confusion matrix so this kind of shows how the model's getting confused looks like it's confusing dandelions for daisies on occasion a little bit more than it confuses daisies or dandelions um that might be for whatever reason that the data is manifesting itself but that that can be useful especially if you have a lot more classes to see where where the confusion is coming from another also nice thing that fast ai gives us is uh it plots the images in our validation set that are the hardest for our model to understand um this can be really useful um to just kind of get an idea in your data set if your data is healthy in quality um and just give you an idea what what what what is the model struggling with um so here you can see that we predicted daisy but this is really a dandelion and this one is really tough right i mean that hardly looks um like a dandelion it really could be kind of either one it's hard to tell so maybe you'd want to actually throw that out of your data set but okay now um finally to kind of wrap up we're going to use our trained model and we're going to do inference on test images and this is something that you could deploy into production in an application so now you you have your train model you're ready to go um that that's all it took just you know a few minutes and a little bit of data management with roboflow to get uh get a trained classification model um with state-of-the-art technology um so okay there we go we're seeing inference here uh looks like it's doing all the dandelions first and getting most of them dandelions and paging through here oh there we go we saw an error there was a daisy but it was supposed to be a dandelion and going a little further down we'll get into the daisy area yep here are the daisies looks like the model is uh kind of doing well on uh inference across all the images in our test set so now you have uh have a model and we've trained the computer to understand the difference between daisies and dandelions if we can do that what what can't we teach it to tell the difference between so now lastly just to kind of wrap up the notebook here we're going to take a look at our saved models so that's in bestresnet34.pth so that's pi torch model and you can go ahead and hear it download that out of google collab import wherever you want you can put in your google drive put it in your amazon s3 keep it on local you can export it into onnx you can export into torch script um you can kind of take this model and deploy it wherever you like the rest of where the inference happens and how the application is built is is up to you but i hope you um learned a lot here in our tutorial on how to train a custom uh image classification model with resnet34 and uh you can take this model uh wherever you like and make classifications with it uh wherever wherever you like and so uh if you enjoyed today's video i would be very appreciative if you would subscribe below and uh go ahead and check out roboflow for your data set management data set pre-processing data set organization and data set augmentation and hopefully these techniques will allow you to form some of the best models in the game and stay tuned for for future tutorials so thanks so much for watching today and we'll see you again soon

Original Description

In this video, we train a custom classification model using Resnet34 implemented in the fastai and PyTorch Frameworks. You can use this tutorial with any of the classification models in the torch vision model library, so really this is a general tutorial for: * How to train a custom Resnet18 image classification model * How to train a custom Resnet50 image classification model * How to train a custom Resnet101 image classification model * How to train a custom Resnet152 image classification model * How to train a custom Squeezenet image classification model * How to train a custom VGG image classification model Corresponding Blog * https://blog.roboflow.com/how-to-train-a-custom-resnet34-model Corresponding Dataset * https://public.roboflow.com/classification/flowers_classification Corresponding Colab Notebook * https://colab.research.google.com/drive/1mVISoBYTDk3Q9D5VmviTkYD28WlFW-Jk#scrollTo=miSJjgU6PZ5Z Happy training!
Watch on YouTube ↗ (saves to browser)
Sign in to unlock AI tutor explanation · ⚡30

Playlist

Uploads from Roboflow · Roboflow · 21 of 60

1 YOLOv3 PyTorch Notebook Tutorial
YOLOv3 PyTorch Notebook Tutorial
Roboflow
2 How to Train YOLOv4 on a Custom Dataset (PyTorch)
How to Train YOLOv4 on a Custom Dataset (PyTorch)
Roboflow
3 How to Train YOLOv5 on a Custom Dataset
How to Train YOLOv5 on a Custom Dataset
Roboflow
4 How to Use the Roboflow Dataset Health Check
How to Use the Roboflow Dataset Health Check
Roboflow
5 What is Mean Average Precision (mAP)?
What is Mean Average Precision (mAP)?
Roboflow
6 How to Use the Roboflow Model Library
How to Use the Roboflow Model Library
Roboflow
7 How to Train EfficientDet in TensorFlow 2 Object Detection
How to Train EfficientDet in TensorFlow 2 Object Detection
Roboflow
8 How to Train YOLO v4 Tiny (Darknet) on a Custom Dataset
How to Train YOLO v4 Tiny (Darknet) on a Custom Dataset
Roboflow
9 Ask the Roboflow Team Anything - Episode 1
Ask the Roboflow Team Anything - Episode 1
Roboflow
10 Exploring The COCO Dataset
Exploring The COCO Dataset
Roboflow
11 Community Spotlight: Improving Uno with Computer Vision
Community Spotlight: Improving Uno with Computer Vision
Roboflow
12 Mosaic Data Augmentation - Deep Dive
Mosaic Data Augmentation - Deep Dive
Roboflow
13 Hands on with the OAK-1
Hands on with the OAK-1
Roboflow
14 Glenn Jocher: What is New in YOLO v5?
Glenn Jocher: What is New in YOLO v5?
Roboflow
15 How to Use Amazon Rekognition Custom Labels and Roboflow to Build an Object Detection Model
How to Use Amazon Rekognition Custom Labels and Roboflow to Build an Object Detection Model
Roboflow
16 An Interview with Brandon Gilles, Luxonis Founder and OAK Chief Architect
An Interview with Brandon Gilles, Luxonis Founder and OAK Chief Architect
Roboflow
17 How to Train a Custom Mobile Object Detection Model (with YOLOv4 Tiny and TensorFlow Lite)
How to Train a Custom Mobile Object Detection Model (with YOLOv4 Tiny and TensorFlow Lite)
Roboflow
18 Tackling the Small Object Problem in Object Detection
Tackling the Small Object Problem in Object Detection
Roboflow
19 Fast.ai v2 Released - What's New?
Fast.ai v2 Released - What's New?
Roboflow
20 Teaser: Roboflow Train (1-Click Computer Vision AutoML)
Teaser: Roboflow Train (1-Click Computer Vision AutoML)
Roboflow
How to Train a Custom Resnet34 Image Classification Model
How to Train a Custom Resnet34 Image Classification Model
Roboflow
22 How to Label Images for Object Detection with CVAT
How to Label Images for Object Detection with CVAT
Roboflow
23 Deploy YOLOv5 to Jetson Xavier NX at 30 FPS
Deploy YOLOv5 to Jetson Xavier NX at 30 FPS
Roboflow
24 Elisha Odemakinde Hosts Roboflow ML Engineer, Jacob Solawetz
Elisha Odemakinde Hosts Roboflow ML Engineer, Jacob Solawetz
Roboflow
25 Getting Started with VoTT - Computer Vision Annotation
Getting Started with VoTT - Computer Vision Annotation
Roboflow
26 How to Manage Classes in Object Detection (Rename, Combine, Balance)
How to Manage Classes in Object Detection (Rename, Combine, Balance)
Roboflow
27 How to Train YOLOv4 on a Custom Dataset in Darknet
How to Train YOLOv4 on a Custom Dataset in Darknet
Roboflow
28 Is Grayscale a Preprocessing or Augmentation Step in Computer Vision?
Is Grayscale a Preprocessing or Augmentation Step in Computer Vision?
Roboflow
29 Getting Started with Image Data Augmentation
Getting Started with Image Data Augmentation
Roboflow
30 Glenn Jocher: Image Augmentation in YOLO v5 and Beyond
Glenn Jocher: Image Augmentation in YOLO v5 and Beyond
Roboflow
31 GA Hosts Roboflow - Healthcare and AI
GA Hosts Roboflow - Healthcare and AI
Roboflow
32 How do self driving cars know when to stop?
How do self driving cars know when to stop?
Roboflow
33 What is PASCAL VOC XML?
What is PASCAL VOC XML?
Roboflow
34 AutoML Showdown: Google vs Amazon vs Microsoft
AutoML Showdown: Google vs Amazon vs Microsoft
Roboflow
35 How is computer vision changing manufacturing?
How is computer vision changing manufacturing?
Roboflow
36 The Alphabet in American Sign Language
The Alphabet in American Sign Language
Roboflow
37 Luxonis OAK-D: Computer Vision on Device
Luxonis OAK-D: Computer Vision on Device
Roboflow
38 How to Train a Custom Faster R-CNN Model with Facebook AI's Detectron2 | Use Your Own Dataset
How to Train a Custom Faster R-CNN Model with Facebook AI's Detectron2 | Use Your Own Dataset
Roboflow
39 TensorFlow vs PyTorch: Fireside
TensorFlow vs PyTorch: Fireside
Roboflow
40 Occlusion Techniques in Computer Vision
Occlusion Techniques in Computer Vision
Roboflow
41 A Customizable Web Application for Your Computer Vision Model
A Customizable Web Application for Your Computer Vision Model
Roboflow
42 Model Tradeoffs and the Future of Computer Vision
Model Tradeoffs and the Future of Computer Vision
Roboflow
43 Designing an Augmented Reality Board Game App
Designing an Augmented Reality Board Game App
Roboflow
44 YOLOv4 - Advanced Tactics
YOLOv4 - Advanced Tactics
Roboflow
45 How to Use CreateML and Build a Computer Vision iPhone App | AR Object Detection
How to Use CreateML and Build a Computer Vision iPhone App | AR Object Detection
Roboflow
46 Fireside Chat: Computer Vision in Agriculture
Fireside Chat: Computer Vision in Agriculture
Roboflow
47 Scaled-YOLOv4 Tops EfficientDet: Research Rundown
Scaled-YOLOv4 Tops EfficientDet: Research Rundown
Roboflow
48 What is Image Preprocessing?
What is Image Preprocessing?
Roboflow
49 Building a Community of Creators with BlkArthouse and Von Deon
Building a Community of Creators with BlkArthouse and Von Deon
Roboflow
50 How to Train Scaled-YOLOv4 to Detect Custom Objects
How to Train Scaled-YOLOv4 to Detect Custom Objects
Roboflow
51 Intro to Computer Vision: Fireside
Intro to Computer Vision: Fireside
Roboflow
52 The Best Way to Annotate Images for Object Detection
The Best Way to Annotate Images for Object Detection
Roboflow
53 The Computer Vision Process: Fireside
The Computer Vision Process: Fireside
Roboflow
54 How to Annotate Images with Your Team Using Roboflow
How to Annotate Images with Your Team Using Roboflow
Roboflow
55 Introducing the Roboflow Object Count Histogram
Introducing the Roboflow Object Count Histogram
Roboflow
56 How Fast is the M1 at Machine Learning? Benchmarking Apple's M1 and Intel's Chips
How Fast is the M1 at Machine Learning? Benchmarking Apple's M1 and Intel's Chips
Roboflow
57 CLIP: OpenAI's amazing new zero-shot image classifier
CLIP: OpenAI's amazing new zero-shot image classifier
Roboflow
58 How I hacked my Nest camera to run custom models
How I hacked my Nest camera to run custom models
Roboflow
59 Getting Started with the Roboflow Inference API
Getting Started with the Roboflow Inference API
Roboflow
60 Transfer Learning in Computer Vision | What, How, Why
Transfer Learning in Computer Vision | What, How, Why
Roboflow

This video teaches how to train a custom ResNet34 image classification model using FastAI and PyTorch, with data pre-processing and augmentation using RoboFlow. The model is trained with early stopping and fine-tuning, and the results are visualized using a confusion matrix and hardest images plot.

Key Takeaways
  1. Load images into RoboFlow
  2. Apply pre-processing and augmentation options
  3. Generate augmented dataset
  4. Split data into training, validation, and testing sets
  5. Import data into Colab notebook
  6. Set up dataset in FastAI
  7. Resize and normalize dataset
  8. Load ResNet34 model from torchvision library
  9. Visualize model by looking at learn
  10. Add early stopping callback to training
💡 Using pre-trained models like ResNet34 and fine-tuning them for custom datasets can achieve high accuracy in image classification tasks, and data augmentation techniques can significantly improve model performance.

Related AI Lessons

Cloud-Optimized OpenCV + A Special Surprise Announcement on OpenCV Live
Learn about Cloud-Optimized OpenCV for faster computer vision computations and a special announcement on OpenCV Live
OpenCV Blog
When the Camera Becomes an Exam Proctor: Building an AI-Powered Exam Monitoring System with…
Learn how to build an AI-powered exam monitoring system using Computer Vision and DeepFace to assist professional certification exams
Medium · Python
When the Camera Becomes an Exam Proctor: Building an AI-Powered Exam Monitoring System with…
Build an AI-powered exam monitoring system using Computer Vision and Deep Learning to enhance professional certification exams
Medium · Deep Learning
When the Camera Becomes an Exam Proctor: Building an AI-Powered Exam Monitoring System with…
Build an AI-powered exam monitoring system using Computer Vision and Deep Learning to enhance exam security and integrity
Medium · Cybersecurity
Up next
Marketing management for ugc net| Important topics of marketing management ugc net commerce dec 2023
Bhoomi Learning Centre~Dr. Muskan
Watch →