FlexAttention: PyTorch Compiler Series
Skills:
LLM Engineering80%
Key Takeaways
Demonstrates FlexAttention, a novel compiler-driven programming model for implementing attention variants in PyTorch
Full Transcript
All right. Hi everybody. Uh so you might have heard of flex attention which is a pretty popular thing that came out of the pyarch team uh recently. Um and today who's who's one of the authors of flex attention is going to introduce us uh to this. Take it away. Okay. Cool. Yeah. Uh thanks for introduction. Uh I'm boy. Uh I'm working on some performance in pet compiler uh especially flex attention. Uh so uh let's start now. Okay. So so we all know attention is very important right and uh so so basically attention we takes the query k value we compute it a m between q and k then we have the softmax and finally multiply with v. Uh this is the attention. Uh attention is very important. It accounts for a significant portion of latency for transformer. That's why we have the flesh attention. Uh flesh attention is the most important optimization for for attention. Uh so uh basically you know it fuses everything together into a single kernel and we can see on the right hand side uh comparing the latency of petarch implementation versus flesh attention we can see flesh attention is much faster. Okay. Uh so uh so people want to have different attention varants for many reasons. uh let's say people start with causal mask which means oh a cury token can only attends to the past uh k token so you know it's like oh when we see a sentence we can we rely on the past word to infer the next word after that people also have the studium window mask which means oh I don't need to uh refer to so long context some recent context might be sufficient and we also have document mask which is we have multiple document at the same time and we only attend to the token within the same same document. So in this way we can say there are many different attention varants. Uh although people may want to say oh I want to encode the relative position of two token into the attention. Mhm. Or the alibi bias may you know uh multiply this relative distance with some other trenable or some other bias. Right. Uh so basically we have lots of different attention but it's very hard to hand tune all attention varants. Uh many times when machine learning researcher propose something new they may not have the expertise to manual tune the kernel. Uh so they probably okay find some GPU expert and ask them to spend one or two weeks to write a good kernel. Then they use this kernel to try and sometimes the this new variant doesn't work. So the J expert like with two weeks of work, right? So it's very hard to do that. That's why we want flex attention. So uh basically for flex attention we have an extra score mode to allow the uh modifying the softmax result to say oh use this attention in a very different way and what is a score modification? Basically it takes the score which is the original uh like the score matrix. Then we say for the batch dimension, he head dimension, query index, qu index, how do we want to modify this score? So it's actually very uh flexible. Cool. So what can we implement with score mode? Uh let's see the relative position. uh we can just say score plus the Q index minus the Q index which is the relative distance right that we add it with the score mode and see oh now we have the relative position information the good thing is we don't need to materialize the relative position bias as a four tensor right so we can avoid the quadratic memory usage and achieve a better performance uh you say uh comparing with the scale dot product attention SDPA in pyarch the flex attention can actually have a much better performance because of this reason another also for let's say for the alib bias we can say oh I have a al bias with the shape uh which is a tens of n height then I have the score I just pass this bias then multiply by the relative distance right so in this way you say with only two lines of code we can implement the alib bias very easily so yeah let let me ask a question so um so it seems to me that okay so there are there are two two things that are interesting here so one is you started with a handwritten flash attention kernel right and then you talked about many variants of it so uh one one of the things you're trying to address is that people should not be needing to write their hand tuned um kernels for all of these variants and and then more variants are going to come in the future and so on. But the other arguably like equally interesting observation that you're making is that by modifying one part of of the attention like mechanism, it's it's sort of like sufficiently general uh like general not only to model many different variants but even more interesting is that you actually I mean got some performance improvements over over like some bad ways of of doing doing the same equivalent thing. Is that is that is that fair? Yeah. Yeah. Yeah. That's true. Cool. Yeah. Because we uh I mean in many cases attention variance is just element wise modification. So with flex attention we just you know modify it in the register. We don't need to materialize the huge tensor of the modification. That's why it is faster. Yeah. Yeah. Uh okay cool. So uh another example is soft capping which is a technique introduced in Java 2 and GR one that prevents log from growing too large right uh so we can also implement in let's say four lines of code uh nice so now uh come back to our cultural mask example uh how to implement it uh we can just say oh the Q index greater equal to Q index in this case we use the score otherwise we mask it off as minus minus infinity. Uh this is uh this is correct for the result but uh I mean it's not so fast. Why? Because okay uh for the upper triangular we know it's like a minus infinity anyway. So uh why should we still compute it? We can just skip it and make it faster. That's why we want another thing called mask mode. What is mask mode? Mask mode just takes the BH query index and the Q index and get a boolean say oh whether I want this position or not. Uh so the the difference here from the score mode is it does not take score the score anymore because we only see a true or false say keep it or not. Once we have the mask mode, we can automatically compute a block mask for exploiting the sparity like okay so okay so how do we implement causal mask say we just return Q index greater or equal to Q index that's it and we can use it uh for we can put it into the create block mask which is a API from flex attention to create the block mask and later when we call flex attention we have the query k value tensor and additionally provide this block mask. So in this case you you do not have you do not implement the the other uh like the score mod function. Yeah. So for this causal mask example we don't need the score mode because uh for every element it's the same value right. Yeah. Yeah. But if you wanted to combine causal masking with some other that you could then then you could Yeah. Okay. Cool. Yeah. It just additionally provides the score mode. Nice. So how about the performance? Uh this is a comparison. We can see if we implement with score mode uh that's a performance that's a latency and if we use M mode it's much faster because many computation can be skipped. Uh a few more example like we can use it to implement the sliding window attention. So we have the let's say this window mask is just take the difference between query index and the Q index compare it with the sliding window. Say oh if it's too far away we don't use it. Also you see we can compose the causal mask with window mask just with some you know element wise and operation. Uh so this is some performance for the sliding window mask. Uh we have the causal FA2 and SDPA plus the mask and finally the flex attention. You see we have a pretty good performance. Another example is a document mask. Uh you say we have the document ID for every uh query and K index. Then we use a single line of code to implement the document mask. We actually have more examples. So uh okay, sounds good. Then how does it work? Uh this is a detailed example of the block mask. Here uh we have the we see we have the score mode which is a relative position. We also have the mask mode which is a sliding window mask. We can implement whatever score mode or mask mode and remember we have the Q and K. Q multiply K gives us some score matrix. Uh here we the we have two types of block. One is a for block another one is a partial block. Four block means all the data within this block will be used and partial block means oh on the boundary uh you know some some elements in this block are used some are not because we are using the sliding window right so basically we have a uh we have a line here and another line here so so beyond this line we don't need and within this line we want it that's why we have the partial block and four block and for the you see for the four block uh we can we can record oh what's the block index we want to use for this row for example in the first row there's nothing so we put nothing here and for the second row we only use the first block that's why we have zero and for the third third row we have zero and one and so on similarly we have the same data for the partial block. In this way, we can encode the sparsity of the of the mask mode and use it during runtime. Okay, cool. A little bit more detail like how did we do the compilation? Uh so basically we have defined the causal mask and the relative position score mode with that we use total compile to trace these two function and do the luring to generate some triton code this mask and score. At the same time, we have built a very efficient attention template and we leave two space to just fit in this generated max mode and the score mode threaten code and modify the score matrix at runtime. In this way, we can generate a very efficient code. So, uh quick question. So, so you showed uh a document mask example, right? where where the the functions were relying on something outside which was the document ID I think right yes so what happens in that scenario uh when you trace through it where where is document ID stored oh okay so that one is a captured tensor it will be captured within the let's say the mask mask graph I see I Okay. Okay. Uh so uh here we have an example for the Triton template. Triton attention template. Uh actually our flex attention is can support different types of back end like there are some CPU back end built by the Intel folks and uh we can also support like a CUDA back end or the Kas template. So as long as oh people built this template we can run the flex attention on different language or or other things. Okay. Uh sounds good. So we have the uh forward and backward and inference. Uh so this covers many different use cases for attention. I want to talk a little bit more about inference because inference is different from forward. Why? Because for the inference there is a unique attention pattern like a very short query sentence will attend to a super super long key value cache. The the context the K cache length could be 128k or even longer. Also why do we have query sequence length as one? Because uh during inference we will inference token one by one right so every time we have a a single query we attends to all the past data we can we we we just predict one that's why it's like one token so uh then what's the issue here uh because there is only a single query token we can no longer do the parallel along the query side right also the KV cache the K lens is super super long so there are some potential some parallelism opportunity yeah uh that's why we have seen the uh flash decoding uh here we have an example it's basically we split along the k dimension and uh yeah so you see we split into five splits and for every space we do the computation in parallel and accumulate it into let's say five different output and finally we have a single reduction step to get the single output right in this way we can use the parallelism on along the K length dimension so how does this compare with uh like I was recently hearing some kind of VLM talks uh and and they they al they they mentioned like batched inference but they also uh mentioned a bunch of like pipelining optimizations that feel feel similar. Is is this is there overlap between these these optimizations? Uh oh. I think there are many different types of the parallelism like the you know the context parallelism, pipeline parallelism and they are working at different level. Uh this one is more like a kernel for the for the you know decode one single decoding step of the attention up. Okay. There could be model level you know different kinds of parallelism. Okay. So here here you're talking about just like when when we're generating the kernel what kind of optimizations. Yeah exactly. So flex attention is at the kernel level. It does not handle the model. Makes sense. Yes. So that's the flesh decoding. uh we have the flex decoding using the same optimization but we still allow the uh apply the score mode for the flexibility. Yeah. So a little bit on the benchmark uh we can see uh we compare with FA KV which is flash decoding we can see flex attention and and flesh decoding they are have similar performance. Uh also we have the scaled out product attention which is much slower compared with the uh like the flex attention. Uh we also integrated to the GPD fast for some model level benchmark. Uh we try it on the llama 3.1 8b and 17b. We can see a significant amount of speed up over various lengths. So okay so uh after that I want to talk about page attention uh because page attention is very important for inference it's done first done by the verm team so during the inference we have the k cache right uh we have multiple request requests a b say uh and we do inference for every request some request has a some long long sentence, some have a shorter sentence etc. Uh but usually uh naively we will allocate the logical cache as a regular tensor. So we will waste lots of memory if the sentence is short. That's why we can have a better KB cache which is a single like a onedimensional tensor in some sense and compact everything. So we don't have the wasted memory. This is good, right? Uh in order to manage the mapping from the logical KB cache to the physical KB cache, we want a page table to say, oh, for this request A and the first block, where should I find it on the physical KB cache? So this is page attention which is very uh useful for uh saving memory and for a larger batch size. Um we want we want flex attention work with page attention. Uh so the first thing we want is the block mask conversion. Remember for the Kiwi cache uh uh remember for flex attendant we have the block to say oh uh what's the what's the block index we should use for every row or batch index right we also want a counter party uh with a physical k cache then we can just apply it as a indirect indexing and uh make it work with flex attention uh that's why we we have the first thing called block mask conversion. Uh I want to give you an example. Uh so basically uh you see for this attention we have the request a uh we only have the value for the block one two four and for zero and three it will be masked off. So we don't need it. Uh that's why we have the KV index which is part of the block mask. Uh for this request A we have one two four uh we will use it and 04 we don't use it. Uh sorry 03 we don't use it because we have the page table for the mapping from logical K cache to physical K cache. We can do a mapping of this KV index to something for the log for the physical KV cache as this converted KV index. After that we use this converted KV index for the flex attention. Right. This is another sort of primitive that that you provide in the library. So that if if you wanted to convert the block marks you could by by implementing this. Okay. Yeah. Exactly. Exactly. people can just call this convert logical block mask to do this transformation. They don't need to care about the details. Uh yeah, so this is the first step for the block mask conversion. Uh after that we also need to use the mask mode conversion. So remember we have the mask mode. We take the uh batch index height and the query KV to say oh whether I want to keep this block or not right. Uh so after we do the conversion we are doing some computation on the converted space it's not the original space. So we also want some conversion for the mass mode. That's why we have this. So basically uh we take the uh we we basically take the original mask mode and return a converted mask mode and for the converted mask mode we will take the uh same batch index height query index but we are taking the physical KV index. After that we do some index computation to get the original logical KV index and finally we call the mask mode on this logical Q index. Right? This automated conversion makes sure oh the result is still correct also user does not need to worry about oh how can I handle this complexity here. Similarly, we also have the scum mode convergent to convert the you know the the SC mode. Uh okay cool. So a little bit on the benchmark side this is the result of flex attention and the flex attention with page attention comparing with flesh attention v2. So we can see flex attention with page attention is similar performance as flex attention and similar as flash attention which means there are not so much latency overhead and we actually only observe like less than 5% latency over this different sequence length. Uh also I have a small sering example showing that oh we can have 76 times higher batch size with this page attention technique. So so which means that like if if you if you're not expecting kind of me a lot of memory overhead then you would use flex attention but if you're sort of like constrained for memory or or or you want to support like larger batch sizes and so on. So then then you would use the page extension variant but but then you would not pay too much. Yes. Yes. So I think in a real world serving example the different requests have super super different sequence right. So that's why we b the page attention and flex attention together. Uh okay so uh we have seen some early adoptions and uh fortunately got some feedback like uh m show in the show using flex attention also and car using flex attention for reproducing the llama 27B overnight and we have the nano JPT training speed run using flex attention. Then the uh some guy uh the Zacher integrated flex attention to hug phase and we have seen it for this B diffusion training like the Nathan flex attention uh and some yeah the flashing for people also are inspired by flex attention to build some super cool thing. So yeah so we fortunately have some early adoptions. So if you are interested definitely try it. Um yeah uh that's all for my talk. Thanks for attention. Uh we have some uh online blog for flex attention flex attention for inference. We have many examples in the attention gym. Uh also recently we have a masses 25 paper to discuss some technical details. Yeah thanks. Cool. Uh very nice talk and uh so uh yeah so what's next for this project like do you do you do you think that you're done with flex attention in some sense the the primitives are general enough that that you can capture future variants or uh or are you are you working on other such kind of abstractions in other areas that are sort of inspired by by the flex attention kind of successful Sorry. Uh yeah. So uh I would say flexation is uh pretty flexible now. And one major thing uh to do is the performance. We can include more like new GPU features into flex tension to make it even faster. Uh that's one major thing. Another thing is to support more use cases. So you know uh like uh we may find some new optimization or opportunities as more and more people use it. Okay, cool. All right, thanks again uh Buan was lovely talk. Yeah, thank you. All right, bye-bye.
Original Description
Flex Attention is a novel compiler-driven programming model that allows implementing the majority of attention variants in a few lines of idiomatic PyTorch code. In this talk, we demonstrate that many existing attention variants can be implemented via FlexAttention, and that we achieve competitive performance compared to handwritten kernels.
Speaker:
Boyuan Feng
Boyuan is a PyTorch core compiler developer, working on inductor, cudagraph, and flex attention.
PyTorch Compiler Series
In this video series, watch the PyTorch Compiler team share tips and tricks that help you get the max out of torch.compile, torch.export, and related technologies, while enjoying a glimpse into all the cool engineering work that goes on behind the scenes.
Watch on YouTube ↗
(saves to browser)
Sign in to unlock AI tutor explanation · ⚡30
Playlist
Uploads from PyTorch · PyTorch · 0 of 60
← Previous
Next →
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
What is PyTorch?
PyTorch
PyTorch Tutorial: A Quick Preview
PyTorch
PyTorch Summer Hackathon 2019
PyTorch
Tips and Tricks on Hacking with PyTorch: A Quick Tutorial by Brad Heintz
PyTorch
PyTorch 1.2 and PyTorch Hub: A Quick Introduction by Soumith Chintala and Ailing Zhang
PyTorch
Torchtext 0.4 with Supervised Learning Datasets: A Quick Introduction by George Zhang
PyTorch
Torchaudio 0.3 with Kaldi Compatibility, New Transforms: A Quick Introduction by Jason Lian
PyTorch
Torchvision 0.4 with Support for Video: A Quick Introduction by Francisco Massa
PyTorch
Introduction to Machine Learning for Developers at F8 2019
PyTorch
Powered by PyTorch at F8 2019
PyTorch
Developing and Scaling AI Experiences at Facebook with PyTorch at F8 2019
PyTorch
New Approaches to Image and Video Reconstruction Using Deep Learning at Facebook at F8 2019
PyTorch
PyTorch Developer Conference 2018: Recap
PyTorch
PyTorch Developer Conference 2018: Keynote & Deep Dive
PyTorch
PyTorch Developer Conference 2018: Production & Research Sessions
PyTorch
PyTorch Developer Conference 2018: Cloud & Academia Sessions
PyTorch
PyTorch Developer Conference 2018: Enterprise, Education, & Future of AI Panel
PyTorch
PyTorch Developer Conference 2019 | Full Livestream
PyTorch
PyTorch Developer Conference 2019: Recap
PyTorch
PyTorch Developer Conference Keynote - Mike Schroepfer
PyTorch
What’s new in PyTorch 1.3 - Lin Qiao
PyTorch
PyTorch Front-End Features: Named Tensors and Type Promotion - Gregory Chanan
PyTorch
Research to Production: PyTorch JIT/TorchScript Updates - Michael Suo
PyTorch
Quantization - Dmytro Dzhulgakov
PyTorch
PyTorch ONNX Export Support - Lara Haidar, Microsoft
PyTorch
Apex - Michael Carilli, NVIDIA
PyTorch
Dataloader Design for PyTorch - Tongzhou Wang, MIT
PyTorch
Linear Algebra in PyTorch - Vishwak Srinivasan, CMU
PyTorch
PyTorch Mobile - David Reiss
PyTorch
Model Interpretability with Captum - Narine Kokhilkyan
PyTorch
Detectron2 - Next Gen Object Detection Library - Yuxin Wu
PyTorch
Speech Extensions to Fairseq - Dmytro Okhonko
PyTorch
PyTorch on Google Cloud TPUs - Google, Salesforce, Facebook
PyTorch
PyTorch Summer Hackathon Winners - Joe Spisak, Sebastien Arnold, Tristan Deleu
PyTorch
PyTorch in Robotics - Yisong Yue, Caltech
PyTorch
StanfordNLP - Yuhao Zhang, Stanford
PyTorch
Sotabench for Reproducible Research - Robert Stojnic, Papers with Code
PyTorch
Collaborative Natural Language Inference - Sasha Rush, Cornell
PyTorch
Privacy Preserving AI - Andrew Trask, OpenMined
PyTorch
CrypTen - Laurens van der Maaten
PyTorch
PyTorch at Uber - Sidney Zhang, Uber
PyTorch
PyTorch at Tesla - Andrej Karpathy, Tesla
PyTorch
PyTorch at Microsoft - Saurabh Tiwary, Microsoft
PyTorch
PyTorch at Dolby Labs - Vivek Kumar, Dolby Labs
PyTorch
PyTorch Developer Conference 2019 - Panel Discussion
PyTorch
Using deep learning and PyTorch to power next gen aircraft at Caltech
PyTorch
Named Tensors, Model Quantization, and the Latest PyTorch Features - Part 1
PyTorch
TorchScript and PyTorch JIT | Deep Dive
PyTorch
Announcing the PyTorch Global Summer Hackathon 2020
PyTorch
Opening Up the Black Box: Model Understanding with Captum and PyTorch
PyTorch
PyTorch Mobile Runtime for Android
PyTorch
Torchvision in 5 minutes
PyTorch
3D Deep Learning with PyTorch3D
PyTorch
What is Torchtext?
PyTorch
TorchAudio: A Quick Intro
PyTorch
PyTorch Mobile Runtime for iOS
PyTorch
PySlowFast: Deep learning with Video
PyTorch
PyTorch Pruning | How it's Made by Michela Paganini
PyTorch
Measuring Fairness in Machine Learning Systems
PyTorch
PyTorch for Hackathons
PyTorch
More on: LLM Engineering
View skill →
🎓
Tutor Explanation
DeepCamp AI