Keras 3 Distributed Training: Scaling Models with JAX using DataParallel, and ModelParallel

Google for Developers · Beginner ·📐 ML Fundamentals ·3w ago
Training large deep learning models doesn't have to be complex. In this video, Yufeng Guo walks you through the Keras 3 Distribution API, showing you how it leverages JAX for efficient data and model parallelism. Whether you're scaling across multiple GPUs or a cluster of TPUs, Keras 3 has you covered. Resources: Distributed training with Keras 3 → https://goo.gle/4u8nGo9 Multi-device distribution → https://goo.gle/46CFOMX LayoutMap API → https://goo.gle/3NfJXjd Gemma get_layout_map → https://goo.gle/4smwNzM Chapters: 0:00 - Intro 0:17 - The Keras 3 Distribution API 0:51 - The Global Pr…
Watch on YouTube ↗ (saves to browser)

Chapters (10)

Intro
0:17 The Keras 3 Distribution API
0:51 The Global Programming Model (SPMD Expansion)
1:26 Using the JAX Backend for Scalability
1:55 Creating a Device Mesh & Tensor Layout
2:46 Implementing Data Parallelism
3:45 Understanding Model Parallelism
4:27 Sharding with LayoutMap
5:43 Tuning Your Device Mesh for Performance
6:14 Conclusion & Next Steps
This new JavaScript algorithm made CSS useless
Next Up
This new JavaScript algorithm made CSS useless
CoderOne