Unlocking Low-Level Control: Customizing Keras Training Loops with JAX
Do you want the speed and functional power of JAX without losing the high-level convenience of model.fit? In this video, Google ML Developer Advocate Yufeng Guo (@yufengg) explains how Keras implements the principle of Progressive Disclosure of Complexity.
Learn how to take full control of your learning algorithms by overriding the train_step() and test_step() methods while keeping access to built-in callbacks, distribution support, and evaluation tools.
What You’ll Learn:
- Why override train_step instead of writing a loop from scratch?
- Understanding how to handle trainable variables, non-trainable variables, and optimizer states in a functional environment.
- Creating a compute_loss_and_updates function to manage forward passes and auxiliary data.
- Using jax.value_and_grad to compute gradients and losses simultaneously.
- Updating evaluation metrics using stateless_update_state.
Chapters:
0:00 - Introduction & The Default model.fit()
0:18 - Customizing Keras Training Loops
0:46 - Overriding train_step()
1:14 - Setting up the JAX Backend
1:26 - The Stateless train_step
2:11 - Stateless Loss Computation
3:04 - Taking Gradients in train_step
4:06 - How to pass around non-trainable variables
4:43 - Updating the Model Weights
5:05 - Handling Metrics
5:21 - Custom Evaluation Loops (overriding test_step)
Resources:
Complete Code Example →https://goo.gle/4eeSvlD
Keras Documentation → https://goo.gle/42Ebpv0
Keras Developer Guides →https://goo.gle/4um97N3
Subscribe to Google for Developers → https://goo.gle/developers
Speaker: Yufeng Guo
Products Mentioned: Google AI
Watch on YouTube ↗
(saves to browser)
Sign in to unlock AI tutor explanation · ⚡30
More on: ML Pipelines
View skill →Related AI Lessons
⚡
⚡
⚡
⚡
My Experience with Network Anomaly Detection Using 5 Different ML Approaches
Medium · Machine Learning
My Experience with Network Anomaly Detection Using 5 Different ML Approaches
Medium · Cybersecurity
Sujar Henry on Why Access Still Isn’t Enough in Tech
Medium · Machine Learning
The Day I Realized Most Developers Are Learning Python the Wrong Way
Medium · Python
Chapters (11)
Introduction & The Default model.fit()
0:18
Customizing Keras Training Loops
0:46
Overriding train_step()
1:14
Setting up the JAX Backend
1:26
The Stateless train_step
2:11
Stateless Loss Computation
3:04
Taking Gradients in train_step
4:06
How to pass around non-trainable variables
4:43
Updating the Model Weights
5:05
Handling Metrics
5:21
Custom Evaluation Loops (overriding test_step)
🎓
Tutor Explanation
DeepCamp AI