CuTe DSL for JAX Developers: Writing Custom GPU Kernels in Python

NVIDIA Developer · Beginner ·🛠️ AI Tools & Apps ·1mo ago

Key Takeaways

Writes custom GPU kernels in Python using CuTe DSL for JAX developers

Original Description

CuTe DSL for JAX is a practical way to write custom high-performance GPU kernels in Python while keeping your workflow inside the JAX ecosystem. In this video, we’ll explore how CuTe DSL for JAX lets you build custom NVIDIA GPU kernels with CUTLASS CuTe DSL, then call them from JAX as if they were native JAX operations. JAX already delivers excellent performance on NVIDIA GPUs, but there are times when you need more control than XLA can provide automatically. You might need a fused operation, a custom memory layout, a non-standard primitive, or a kernel designed for a very specific performance bottleneck. This tutorial walks through a hands-on workflow for writing those kernels in Python and integrating them cleanly into compiled JAX programs. We’ll start with the mental model behind CUTLASS and CuTe, including how tensors, layouts, threads, blocks, and launch shapes fit together. Then we’ll move through practical examples, beginning with vector add and SAXPY, before building ReLU and a fused bias plus ReLU kernel to show why fusion can reduce kernel launches and memory traffic. We’ll also look at a tiled GEMM example to demonstrate tiling, launch configuration, and the limits of a simple educational implementation compared with highly optimized library calls. The video also covers how custom CuTe kernels can compose with JAX features such as @jax.jit and multi-GPU sharding. You’ll see how cutlass.jax.cutlass_call bridges a CuTe launcher with JAX’s compiled execution model, allowing your custom kernel to run as part of a larger JAX program. Finally, we’ll introduce Ahead-of-Time compilation with jax.export. You’ll learn how to export, serialize, deserialize, and run a JAX function that includes a CUTLASS custom call, including how symbolic shapes can make exported artifacts more reusable. By the end, you’ll understand where CuTe DSL for JAX fits, when it is useful, and how to start building custom GPU kernels for fusion, special layouts, custom data movement, an
Watch on YouTube ↗ (saves to browser)
Sign in to unlock AI tutor explanation · ⚡30

Related Reads

📰
How I Built a Free Online Image & PDF Processing Platform with Vue 3 + FastAPI
Learn how to build a free online image and PDF processing platform using Vue 3 and FastAPI, and discover the benefits of combining these technologies for efficient file processing
Dev.to · IAMUU
📰
I Built a Free AI-Powered YouTube SEO Toolkit With Zero Budget. Here’s What Actually Happened.
Learn how a solo dev built a free AI-powered YouTube SEO toolkit with zero budget and the lessons they learned from the experience
Medium · Startup
📰
How to Create a Second Version of Yourself Inside Obsidian Using AI (Step-by-Step Guide)
Learn to create a second version of yourself inside Obsidian using AI with a step-by-step guide
Medium · ChatGPT
📰
How to prepare for Spain civil service TIC exam using AI in 2026
Learn how to prepare for the Spain civil service TIC exam using AI in 2026, boosting your chances of success with technology-driven study techniques
Dev.to · David García
Up next
I Asked Gemini to Build a Dashboard... I Didn't Expect This
Patech
Watch →