Flash Attention derived and coded from first principles with Triton (Python)
In this video, I'll be deriving and coding Flash Attention from scratch.
I'll be deriving every operation we do in Flash Attention using only pen and "paper". Moreover, I'll explain CUDA and Triton from zero, so no prior knowledge of CUDA is required. To code the backwards pass, I'll first explain how the autograd system works in PyTorch and then derive the Jacobian of the matrix multiplication and the Softmax operation and use it to code the backwards pass.
All the code will be written in Python with Triton, but no prior knowledge of Triton is required. I'll also explain the CUDA programmin…
Watch on YouTube ↗
(saves to browser)
Chapters (23)
Introduction
3:10
Multi-Head Attention
9:06
Why Flash Attention
12:50
Safe Softmax
27:03
Online Softmax
39:44
Online Softmax (Proof)
47:26
Block Matrix Multiplication
1:28:38
Flash Attention forward (by hand)
1:44:01
Flash Attention forward (paper)
1:50:53
Intro to CUDA with examples
2:26:28
Tensor Layouts
2:40:48
Intro to Triton with examples
2:54:26
Flash Attention forward (coding)
4:22:11
LogSumExp trick in Flash Attention 2
4:32:53
Derivatives, gradients, Jacobians
4:45:54
Autograd
5:00:00
Jacobian of the MatMul operation
5:16:14
Jacobian through the Softmax
5:47:33
Flash Attention backwards (paper)
6:13:11
Flash Attention backwards (coding)
7:21:10
Triton Autotuning
7:23:29
Triton tricks: software pipelining
7:33:38
Running the code
DeepCamp AI