Unlocking Low-Level Control: Customizing Keras Training Loops with JAX

Google for Developers · Beginner ·📐 ML Fundamentals ·1w ago
Do you want the speed and functional power of JAX without losing the high-level convenience of model.fit? In this video, Google ML Developer Advocate Yufeng Guo (@yufengg) explains how Keras implements the principle of Progressive Disclosure of Complexity. Learn how to take full control of your learning algorithms by overriding the train_step() and test_step() methods while keeping access to built-in callbacks, distribution support, and evaluation tools. What You’ll Learn: - Why override train_step instead of writing a loop from scratch? - Understanding how to handle trainable variables, non-trainable variables, and optimizer states in a functional environment. - Creating a compute_loss_and_updates function to manage forward passes and auxiliary data. - Using jax.value_and_grad to compute gradients and losses simultaneously. - Updating evaluation metrics using stateless_update_state. Chapters: 0:00 - Introduction & The Default model.fit() 0:18 - Customizing Keras Training Loops 0:46 - Overriding train_step() 1:14 - Setting up the JAX Backend 1:26 - The Stateless train_step 2:11 - Stateless Loss Computation 3:04 - Taking Gradients in train_step 4:06 - How to pass around non-trainable variables 4:43 - Updating the Model Weights 5:05 - Handling Metrics 5:21 - Custom Evaluation Loops (overriding test_step) Resources: Complete Code Example →https://goo.gle/4eeSvlD Keras Documentation → https://goo.gle/42Ebpv0 Keras Developer Guides →https://goo.gle/4um97N3 Subscribe to Google for Developers → https://goo.gle/developers Speaker: Yufeng Guo Products Mentioned: Google AI
Watch on YouTube ↗ (saves to browser)
Sign in to unlock AI tutor explanation · ⚡30

Related AI Lessons

My Experience with Network Anomaly Detection Using 5 Different ML Approaches
Learn from a developer's experience with network anomaly detection using 5 different ML approaches to improve your skills in machine learning and network security
Medium · Machine Learning
My Experience with Network Anomaly Detection Using 5 Different ML Approaches
Learn from a developer's experience with 5 different ML approaches for network anomaly detection and improve your own detection skills
Medium · Cybersecurity
Sujar Henry on Why Access Still Isn’t Enough in Tech
ML expert Sujar Henry emphasizes that access to tech isn't enough, beginners need a clear path to follow
Medium · Machine Learning
The Day I Realized Most Developers Are Learning Python the Wrong Way
Learn how to apply Python skills by building real systems, rather than just finishing tutorials
Medium · Python

Chapters (11)

Introduction & The Default model.fit()
0:18 Customizing Keras Training Loops
0:46 Overriding train_step()
1:14 Setting up the JAX Backend
1:26 The Stateless train_step
2:11 Stateless Loss Computation
3:04 Taking Gradients in train_step
4:06 How to pass around non-trainable variables
4:43 Updating the Model Weights
5:05 Handling Metrics
5:21 Custom Evaluation Loops (overriding test_step)
Up next
Generative Artificial Intelligence Full Course 2026 | Gen AI Tutorial For Beginners | Simplilearn
Simplilearn
Watch →