Day 1 Talks: JAX, Flax & Transformers ๐ค
Key Takeaways
This video covers the basics of JAX, Flax, and Transformers, with a focus on scalable research, compiler-based systems, and high-speed interconnects. The speakers discuss the use of JAX with Cloud TPU VMs, XLA compiler, and Flax for RL with the Dopamine library.
Full Transcript
thanks to zenna and thanks everyone for coming um like susanna said i'm a software developer on jax and i'm super excited to kick off these talks for the community week and get you started with jax on cloud tpus so this talk is mostly going to consist of a demo go going over some jax basics running on a cloud tpu vm and this might be a terrible idea but i'm gonna try to set up this demo completely from scratch to hopefully show you uh just how easy it is to get started with jaxx on tbu so wish me luck um so with that i'm actually going to dive right in and create a demo vm that we can use so we can have that starting while we go over some slides so i'm going to switch over to my gcp dashboard and go into the compute engine tpus tab and then i already have some tpus running but i'm going to create a new one for this demo by hitting create tpu node um that filled load great this is how you know it's a live demo great so i am going to call this demo vm pick my favorite zone us central 1a we are going to opt into the new tpa tpu vm architecture i'll talk more about this in a minute um we are going to stick to a modest eight uh tpu cores v3 tpu cores and pick v2 alpha this is always what you choose when creating a jaxx tpu vm so let's get that creating okay sometimes this takes a minute or two so i am gonna in the meanwhile switch over to some slides and give you some basics on tpus um and how you can use them with jax so to start um just in case you're not familiar what are cloud tpus so pictured here is a single tpu v3 board and the board has four individual chips on it and each tpu chip has two cores each for a total of eight tpu cores and then this board is just plugged into a regular cpu host very similarly to how you might plug a gpu into a machine um so what makes the cloud tpu system special and more than just a box with some chips attached to it are the two key features you can see on this slide so the first key feature is a high speed interconnects which are basically high band with communication links between the cores so the cores can talk directly to each other without going through the host um and the second is basically the power of compilers so tpus were designed to be targeted by a compiler specifically by the xla compiler so if you compare this to other platforms that rely more on uh pre-compiled kernels that are invoked in something more like an interpreter loop um when you run on a tpu everything is compiled and this lets us drive the tpu cores more effectively okay so what makes cloud tpus or tpus in general really cool is that they're also designed to be scaled out into a tpu pod supercomputer so now instead of just four chips on a single board a cloud tpu pod is 1024 chips for a total of 2048 cores attached to hundreds of cpu hosts um and this is really scalable because we have those same two key features so the same high-speed interconnects that connected those four chips now connect all the chips in the entire pod uh and similarly the xla compiler can also scale to cover this entire pod so you can use xla to compile for a single core a single board or the entire pod or anything in between okay so what are these two key features actually get you um so with the high speed interconnects you basically get easy communication scaling uh you don't need to worry about the data center network topology everything kind of just works as you scale it out um and with the compiler magic you get a ton of cool compiler optimizations uh a few of which are listed here but there are many more i'm sure um uh but another cool thing about using a compiler-based system for research in particular is that often when you're doing research you're writing new code or new algorithms that no one has ever tried before that's what makes it research so there's probably not an efficient handwritten kernel for all of the things you're trying to do so instead of writing your own kernel or maybe waiting for someone else to catch on and write a kernel you can let the compiler do that work for you and all the code you write can be optimized and super efficient okay so we think that jax is a particularly good fit for tpus because jax itself is compiler oriented um and this means that jax was designed from the ground up around the xla compiler and then you know xla was designed around tpus so jackson tpus are a good fit um the implications of being compiler oriented is that a few things listed here the first is that all operations uh run through the compiler so you always get that compilation benefit everything you do in jax um so besides the nice performance benefit of this another nice feature which is the second bullet point is that um all jack's code can be run on cpu or tpu with no code changes just by targeting different platforms using xla uh so there are a few exceptions to this edge compiler features that just haven't been implemented on all platforms yet but so this is more of a implementation issue than a fundamental design or api issue and then finally being compiler oriented means that you have full control over the compiler so we'll see more of this in the demo but basically jax gives you a direct interface to what's compiled all wrapped up in a nice easy to use python package okay so just to give a little quantitative evidence for jackson tpus being a good match uh here are some benchmark results from last year so these are from the mlperf benchmark which is an industry standard benchmark competition and it's designed to fairly compare the performance of different ml hardware and software systems uh with everyone on equal footing because everyone's running the same models uh on the same data sets to the same accuracy so you can see in these charts that jax on tpus is very competitive i believe we set a few world records which was really exciting um and i don't want to get into too much detail here but one thing you might note is that we're not comparing equal number of chips in every column so for example down here in the in the bert section uh jax is running on about four thousand chips and pi torch is running on about two thousand a one hundreds so what's going on here is that we're comparing the largest and fastest submissions from each platform so these results actually demonstrate the two advantages of tpus that we mentioned earlier uh basically the specialized high speed interconnect allows the tpu submissions to scale usefully to more chips and then the xla compiler let us squeeze as much performance as possible out of each chip um even though the tpu v3 is actually released i think like two years before the a100 so that's how we were able to get running so fast on so many chips um oh something i forgot to mention so these numbers were produced on internal google tpus not on cloud tpus so this is like the same chips but with different software on top um and until recently there wasn't a way to access cloud tpus in a way that you could reasonably run a high performance workload like this but luckily now you can um enter cloud tpu vms so this is basically a new way to access cloud tpus where you have a vm running directly on the tpu host so this makes things pretty simple because you just have ssh access to the tpu host vm and you can run anything you want like jack's code you can also run pytorch or tensorflow uh random python code c plus plus whatever it's really just a regular vm that happens to have some gpus attached to it um and then here's just a picture showing that of course the skills up to pods um in this case when you're running on a pod or a pod slice which is a subset of a pod you just have more hosts and so that's more vms that you connect to but again you can run whatever you want on these pod vms um okay so this is the final slide before we uh hopefully get the demo working and i just wanted to give a very quick overview of what jax looks like before diving into the demo so here's an example of some jax code and if you've used numpy before this might look familiar so here in the middle we're defining a kind of fully little fully connected neural net and we have a matching loss function and so the difference is that instead of using numpy we're using jax dot numpy uh which we're going to call jnp and so what this gets you is that um like i said jax was built directly on top of the xla compiler and so all of this numpy code or jack's numpy code runs directly on an accelerator so a gpu a tpu or on cpu if you don't have access to an accelerator or prototyping on your laptop or something like that um so out of the box we get accelerated numpy but at the bottom here you can also see we have some function transformations uh so i'm going to go over these in detail in the demo but just to give you a taste we're going to do some things like taking gradients doing some just-in-time compilation uh doing a vmap whatever that is uh okay so with that let's jump back and see how our demo vm is doing it looks like it's up okay now i have to do a complicated operation where i switch to a terminal window so let me stop sharing this let me create a new terminal bear with me okay um someone yell if you cannot see this terminal so uh i'm gonna i have some of these rewritten but hopefully these commands will make sense um hopefully everyone can read this so first we're going to use the gcloud command to ssh into the new demo vm we created and the interesting thing here is i'm going to add this is a standard ssh flag to set up port forwarding so that um i'm going to start a jupiter notebook on the tpu vm and i want to be able to load that notebook on my laptop so i found the easiest way to do this is with port forwarding so gcloud is nice because it automatically sets up all your keys for you okay so now we are running on the tpu vm next i'm going to install jaxx so if you've used jax before we've actually recently updated our install command so this is pulling in all the dependencies you need including jack slip so once this finishes we should be all set up to use jacks next i'm going to set up the demo notebook so this uh is still in my jack's fork on github i promise i will check this in um but for now we're going to be looking at my fork and then let's install a few dependencies to actually run a jupiter notebook and we'll do some plotting [Music] okay it's complaining that these new notebook binaries are not on our path so i'm going to put that on my path and then finally hopefully this works let's start our jupiter notebook great i'm gonna open that okay let me switch back to that window one second i think it's this one okay hopefully everyone can see i'm in my jupiter notebook now um and so i am going to open the notebook i prepared awesome okay so hopefully this works i think i'm gonna have to move pretty fast but to start we're just gonna go over um jacks as accelerated numpy so i'm gonna import jacks.numpy as jmp just like the slide and i'm gonna create a 5000 by 5000 matrix x so this thing should act just like regular numpy i can run operations i can slice i can print values but if we take a closer look we see this is not in fact a regular numpy array but a device array so what this means is that this value is stored in device memory so in this case uh in tpu core memory on my tpu vm and all the operations we run run on tpu so the dot the slice everything happens on the tpu and we only bring the value back to the host if we need to print it or write it to disk or something like that but other than that definitely thank you thank you how's that looks good thank you okay so other uh we have a device array but we can treat it just like a regular numpy array so we can plot it uh matplotlib is a little bit slow uh we can run more operations we can do fancy numpy indexing uh so to demonstrate that this truly is running on my tpu vm i'm going to do a quick micro benchmark where i'm going to create the same array as a regular numpy array uh and we can time a dot product and we see it runs about 153 milliseconds per dot operation not bad for numpy but if we run the same benchmark using jmp.dot with our device array we see that it runs much faster just under six milliseconds per dot product so turns out tpus are pretty good at this um okay so this gives us numpy running on our cloud tpu vm so that's already pretty cool but let's dive in to those jacks uh function transformations that we saw in the slide so the first transformation i'm gonna go over is automatic differentiation uh jax supports a lot of different types of ad but i'm just gonna focus on your kind of bread and butter reverse mode ad which in jax is called grad so here is a little toy function f that we can play with and differentiate so we have some control flow doing some operations and this is how we take the gradient of f in jax so all function transformations in jacks have a similar api where the transform takes a function as input and then returns a new function as output so in this case we pass in f as input and then this returns the gradient function of f as output so if we apply the gradient function to some value we get the gradient value out and what's cool about this kind of uh function transform api is that it's composable meaning that you can take the output of one function transform and pass it right back in to another function transform as input so this is how we're going to achieve higher order uh gradients in jacks we just take the output of that first grad call which is the first order gradient and pass it right back into grad to get the second order gradient which we can then apply to a value and so on and so forth for higher order gradients uh okay so this is all i'm going to go over for grad but we're really just scratching the surface here jax supports a number of other types of ad and you can learn all about this in the jack's auto diff cookbook in our documentation so the next transform i'd like to go over is called jit for just in time compilation so remember that everything so far has been running on the tpu and running with xla but we've only been uh running a single operation at a time that's kind of how jax works by default uh we call this up by op mode also called eager mode um jet lets us take full advantage of xla which is like this really awesome full program optimizing compiler by staging out entire functions to compile with xla instead of just one tiny operation at a time so let's see how this looks um so here's another toy function that we can try to compile uh so we're gonna call this f of course we can run f uh just as is we get some values out uh but this is running in up by op mode so now let's make a new function g um that's gonna be the jitted version of f and so uh jit has a similar api to grad where it takes a function in and returns a new function out but this time instead of changing what the function computes it's going to change how the function is executed so you can see that g of x if you kind of match up these numbers produces the same value as f x but what's different about it is that g of x is going to stage out this whole function and compile the entire function using xla so in order to see this let's do another micro benchmark so first we're going to time f x which is the op by op non-compiled mode get about 21 milliseconds per iteration so if we look at g of x i think it's running it like a million times okay yeah so order uh over an order of magnitude faster when running the compiled version so that's a really nice speed up in this case um just to give you an idea of the kinds of things that xla can do in order to get this speed up um one thing is it can fuse operations together so instead of each of these being a separate op that we dispatch it's all one big operator and then we're only returning a slice of the output in this example and so xla is smart enough to not even do the work it doesn't need to do in order to produce this output plus a bunch of other optimizations um so uh jit is also composable uh so here we have jit and grad composed many times it works just fine um this is kind of a silly example in real life you wouldn't actually call jit twice like this on the same function um like you don't get twice the speed up by calling jit twice but this kind of composability does come in handy uh for instance when you're writing library code you can functions in your library and then users can just call those functions however they want they can call grad they can call jit uh on top of it it's all fine and it also becomes more interesting when you have more bits of code in between these different uh composed transforms okay hopefully so far so good um i'm going to move on to the next transform which is called vmap so this one might be less familiar um but it's actually pretty simple hopefully um we map is very similar to a python map so you can see we pass a function as input and we're going to get a function as output as usual and how this works is that the output function so say that this lambda here takes a scalar we'll call this a scalar function the output function is going to take a whole vector of inputs and then run this square function on each element of the input so it's basically mapping over the input and applying this function uh that's what the v mapped function does um okay so we could have done this you know with like a python map or probably a for loop why use vmap and what's interesting is the way that vmap achieves this mapped result and it does it via vectorization thus vmap meaning that it's going to take all the operations in the input function and turn them into vectorized versions of the same operations another way to think about this is that vmap adds a batch dimension to your function and so it turns this single example function into a batched function um so in order to try to see this i am going to use another transform called make jack spur and make jack spur basically gives us a view of the input function as jax sees it so the details aren't super important here but the idea is that if we call make jack spur on dot and then pass it to vectors so this is going to be a vector vector dot you can see we get this one dot general operation out and the details aren't super important but like basically this tells you that it's a vector vector dot but this kind of operation can do lots of different dot like things um okay so now let's look at the jack spur for v map of dot and since we're adding that new mapped dimension we're going to pass into matrices instead of two vectors uh but we can see that the v vmwrapped operation is still just a single dot general so this is basically going to be a matrix matrix dot instead of a vector vector dot and we efficiently do this in a single operation and so on and so forth we can keep calling vmap each time it's going to add a new dimension to our input but we still just get this single dot general operation um so that's all i'm going to say about vmap but uh this is a very toy example you can imagine um you know when you plug this into a bigger function uh it can save you a lot of time and simplify your code by uh automatically adding all these vectorized operations okay um the next transform i'm going to go over is called pmap um and so like i said everything so far has been running on the tpu but we've only been running on a single tpu core so by default all jax code runs on a single device in this case a single core um but like we saw in the slides a tpu vm actually comes with an entire board with h cores total so if we look at jacks.devices we can actually see we have eight tpu devices available to us on our vm so how do we take full advantage of all eight tpu cores here um that's where we can use pmap so pmap as the name implies is very similar to vmap in fact it has exactly the same api so we pass the function in and we get a mapped or batched version of the function out the difference is how pmap computes that mapped result so if we take a look at the result um what's going on here is that pmap instead of vectorizing is going to parallelize so we take this function that we want to map and it's going to split up the input and send each element to a different tpu core and each tpu core will compute its result over uh its slice of the input and then so our output is this thing called a sharded device array so what this means is that each of these output elements is stored on a different tpu core um but this thing otherwise acts just like a regular numpy array or just like a device array so we can run some operations on it it works just fine um one thing i want to go over that trips people up a lot is that we have this started device very y that came out of our pmap um and we can just run normal numpy operations on it but as soon as we start running operations on it outside of a pmap we're gonna go back to that single core execution mode so you can see the output of this operation is a device array again because basically what happened is it collected all these values back onto a single core in order to um do the division operation so basically the rule is that any operations by default um are happening on a single core unless they're happening in a pmap and then that's how you know if it's running parallelized or single core okay um but going back to the sharded device array this kind of is similar to the device array and that because we're storing these values uh sharded across multiple cores we can avoid communication and only do the minimum amount of communication and work needed and we do this by stringing together many p-mapped operations or alternatively you can just p-map a larger function but just for the points of demonstration let's take a look at this example so what we're doing here is using pmap to create a 5000 by 5000 matrix on each tpu core so we're going to get eight matrices total stored as a sharded device array we can then um perform a parallelized dot to dot each each matrix with itself so this is each core is going to individually do a dot uh we can then take a mean of the result on each core again using pmap and then print the total result of this mean so the only communication that happens here is sending the final mean back to the host for printing otherwise everything happens independently on each core and we can try to see this via a micro benchmark so here i'm just going to do the single device dot this is similar to what we did earlier so yeah just under 6 milliseconds per dot and so now if we do eight dots in parallel using pmap um you can see that there's only a slight overhead to doing the eight dots but certainly much much better than eight times the single device version okay so that tells us how to do work on each tpu core but like i mentioned we have these special high-speed interconnects that allow the cores to talk directly to each other so how do we take advantage of that in jax um that's where these collective communication operations come in so we model this in jacks with collectives and what collectives basically do is let you break the map abstraction and do an operation across the mapped axis that you're adding with pmap so the simplest example of this is p sum um as the name implies this is going to do a sum across all of the cores across all that mapped axis that pmap is adding so to give an example of this here's a function where um remember so this x is the input to each uh each core each slice of the map and we are going to compute a normalized version of x by um summing together all the x's across all the mapped axis and then dividing each individual input by this total and i'm also returning the total so that we can see what this looks like um so you can see here that we get all our normalized values out and then the total itself is going to be um each core gets its own copy of the total so that's why we get uh eight copies of the total here um okay uh so you might notice this axis name um this is basically when we like pmap is adding a new batch access is one way to think about it we can give a name to that axis and then when we do a collective operation we specify that name to say which axis we want to do that operation over and so this is not very interesting when there's only one pmap and one axis name but it gets more interesting when you have multiple pmaps with multiple axis names so i want to make sure we have time for questions now so i'm not going to go over this in detail but basically you can specify different axis names in order to do different um individual sums like we can sum across rows or some across columns or sum everything all right i have one more transform that i wanted to go over before ending this talk so i really don't have time to go over this detail i think but i'm gonna try to go over it really fast and still leave time for questions so this is a new experimental transform meaning that we're still working on it it might change as we develop it but i thought it might be interesting to take a look and see what it can do so this is called pgit and pj is another way to take advantage of multiple devices and parallelize across devices um but instead of it basically takes another view of the computation than pmap does so pmap kind of takes a per device view of the computation like we specify the function that we want to run on each device and then manually include collective operations when we want there to be communication between devices pget you give it a global view of the computation like the computation you want to run across all devices and it automatically splits it up across the devices and adds communication as necessary so for an example we can do a convolution operator um so this might be useful say you're doing a convolutional neural net and you want to feed it image examples that are so large that a single example doesn't even fit on a single device you can use pj to automatically split up that convolutional net uh and run across multiple devices so here we're just going to focus on the conf operator here i give an example that is very small and does fit on one device so we can see that that works just fine with pgit it works somewhat similarly to jit so like instead of doing a map it's more like jet where we pass in our function that we want a pj and we get a parallel version of the same function out so um this should ideally get the same response or sorry the same result just like jit but um it's actually running a parallel version that runs across multiple devices so you can see in our micro benchmark this is running the single core conf in about 45 milliseconds and then when we use the pgitted parallel version of khan we get about a 5x speed up so that's pretty good for running across eight devices and uh you get 5x instead of 8x because there is some communication overhead so you can't get perfect scaling um okay yeah i'm not going to go over the details of how to actually invoke pget but if you are interested to learn more about that um yeah maybe ask a question in the slack channel or hit me up later i'd love to talk more about it um so with that that's all i have prepared and i would love to stop now for questions thanks very much for listening thanks so much sky super interesting we had a lot of good questions and discussion in the chat if you want to you can keep your slides on maybe we'll go back and forth we have just a few minutes um i just want to let everybody know in case and we're probably not going to go through all of the questions um due to time constraints but feel free to also send slack uh send questions um to slack um okay let's start with maybe from the beginning while you're working on gcp can the same be run can the same project be run on kaggle tpu or google tpus or does it require some different uh installation of libraries okay that's a good question sorry i'm like way behind in a chat so i can't actually know myself because i can't um uh so like i mentioned tpu vms are like a new tpu architecture that was recently released um some platforms in particular collab and i think also kaggle which i think uses a similar notebook setup to collab um are still on the old architecture so you can still use jacks with these um but you're not gonna get um as like there's some performance overheads and usability issues uh i didn't go over the details between the two architectures but so just to be aware that unless you do use a gcp tpu vm you are probably on the old architecture and it won't be exactly the same experience but it's it's still jackson tpus right and first for some of the questions avital also answered i'll just anyway uh maybe read them out to you uh what is the difference between grad and jit is that a question uh the difference between grad and jit i mean they do different things so grad will return the gradient function of its input and jit doesn't take the gradient at all it just um returns a compiled version of the function and when would you not use jit and avita answered that you will probably always want to use jit other than if you're debugging okay yeah that's such a good question something i skipped in this talk but if you check out our documentation um i think in the jax 101 tutorial there is a section on jit that goes over this um there are some constraints to using um like you can't get literally all of your python code and have it run faster so yeah i think the answer is you would always use jet when you can but if you look at that tutorial you'll see that sometimes you have to break your logic up into smaller jit blocks um yeah one constraint is around control flow um some control flow can be jitted like if and while loops but sometimes it can't so you can learn more about that in the docs that's a good question cool maybe related to that if you if you would only apply it to uh primitive functions or end to end so matrix multiplication versus food policy gradient updates and um yeah probably to the entire training step um oh yeah so that's a good point too i guess another situation you want to jet is if you're just running a single operation anyway so like in that conv benchmark i just showed at the end i didn't use jit because i'm just running one operation so yeah usually you want to get as big a block as you can to give xla as much to work with as it can so yeah it's very common for people to jet their entire training step and then to like drive that from a python while or for loop um occasionally people even jet the entire training loop but that can sometimes get more complicated uh due to the control flow constraints i mentioned earlier okay we have a lot of questions i i moved down i moved down a bit but yeah let me just say that again you feel free to to also post your questions on slack um does jax have a name tensor feature pmap seems to be something similar but it seems exclusive exclusive to to map and those jacks have a tensor annotation feature yeah i think lv avital mentioned this in the chat we are working on the new x-map transform which i didn't have time to demo this time uh that's even more under development than pj but yeah the big idea there is kind of introducing named axes in a named axis programming model um yeah i think we have an xmap tutorial online if you want to learn more about that but that is very likely to change moving forward is there any tool akin to nvidia smi to see what the tpu cores are doing and if i'm making good use of the available memory while training that's a great question um so there are profiling tools for tpu uh they're built as a tensorboard plug-in um again if you look at the jack's documentation there's a page on profiling that shows you how to get set up with this um so it's kind of a little heavier weight than nvidia smi like it's not just the command line tool it's more like a full profiling tool um but that gives you like a view of all the operations that are happening and how long they're taking i think that you can get a view of the memory usage um we definitely have that internally and i think it works in the cloud dpu version as well so i would definitely check out the tensorboard profiling plugin and you can look at that on the jax awesome i think we also linked to to jack's documentation and i'll make sure i'll share the relevant links as well later on um i think that's all for for questions um for now at least and we can move to slack thank you so much sky this was super interesting um and yeah thank you thanks everyone and yeah please i think i'll probably easiest if i didn't answer your question to cc me in slack because i'm not sure i'm gonna make it through this whole chat window that'll try it's a lot thank you so much skye um okay then i think we can maybe stop sharing and uh move on to the next talk and um whenever mark is ready we can maybe start sharing yes just let me share my screen all right can you see this yes looks good uh though your video we can't see your video but it's up to you if you want oh okay um i think it stops okay can you see it now looks good thanks okay okay so let's get started um yeah so my name is mark fonzay and i'm a software engineer at flex team and today i'm going to give an introduction to flags and so the audience for this talk is a general machine learning interested audience it's a beginner's talk but it does require some background knowledge on basic machine learning terminology like feed forward neural networks patch normalization or automatic differentiation the content of this talk consists of three parts um first i'm going to give a background about flex and jacks then um talk a bit about our ecosystem philosophy and then i'm we're going to go into the flex code and look at our main abstraction which are called the flex modules unfortunately i'm not going give uh demos all the cool demos that sky gave so uh this for me is gonna be just slides unfortunately okay backgrounds um yeah so flex is a neural networking library an ecosystem built on jacks and we specifically call ourselves an ecosystem because we care about the ecosystem in two ways first we want to make sure that our library has good open source support so we have a lot of examples we make sure our examples run on cloud tpu we have a lot of how to's and also we develop with an open source first approach and so we develop based on github issues and our discussions are very active but on the other hand um we also curate recommendations for other libraries built on top of jacks so it's our view that we don't want flex to be a monolithic library built on top of jags but rather they should be in a collection of libraries that work well together so that's why we serve our users and so we want to guarantee that users can use other libraries to do their projects and i'll talk about this a bit more later there's just some flex numbers um within alphabet um we're now the most commonly used neural networking library for jags um and we have 187 dependent projects and 97 collaborators and hugging phase import really a lot of flex models more than four thousand so um it's also nice so um sky just gave a really great introduction to jax um so please take a look at that as well but maybe people that just tuned in and they haven't seen it i'll just have one slide explaining jax again um so jax people say checked is autograd autograd and xla those are immediately two new terms um so autograd is like basically saying that checks can automatically differentiate differentiate python code so your python functions so you can write a python function and there's this this transformation called grad which is a jax transformation and you can pass this python function to grad and um jax will work great and it will differentiate this function and so differentiate through loops and all kinds of things um and then jax is also xla so um you can you um you can actually jit compile um your own python functions into xla optimized kernels um so you do this and then because you generally get large computations like um sky also explained in her talk you give xla all the information it needs to find the optimal way to execute it and then xla will optimize it to get the best performance depending on the accelerator yeah so another way of looking at this is that jax is an extensible system for composable function transformations so these functions jit pmap and grad they are all function transformations and that means you input a function and this function transformation will return a new function and an important part is that jax uses a functional api so it means that it only guarantees correct behavior when you are using or you are inputting functions functions without side effects um so typically side effects are the result of mutating an object that lives outside the function and this functional api is a reason why xla can do such a good job optimizing it so here are just two example compositions you can adjut grad which is what you usually do when you do standard gradient descent or you can v-map and grad and that is a way to get per example gradients um and you get it very efficiently because if you want to do it naively you have to use a batch size of one which is pretty slow okay now the world already has tensorflow and pi torch and and there's little need to build a clone of either of them um but we believe that the funk the composable function transformation approach that jax takes opens up new frontiers for making neural net code more maintainable more scalable and also more performance than existing libraries so i just listed some examples here but there are many more and so using jax you can write models as single example code and you can introduce batching automatically with vmap so you just input your entire uh all your examples in one array and it will go back to automatically you can also automatically handle record batches using masking and you can create efficient compile-time and runtime models or remove memory headaches by uh using easy re-materialization or reversibility for instance the reformer was also based on jax and lastly but not least it's it's really fast um if you there's a there's a so hugging face recently created a number of really amazing scripts but they um they train some flex models and then they do some comparisons and they are really fast for instance you can pre-train bird in under 18 hours and and using uh the cloud tpus that's less than 150 dollars which is which is quite quite cheap okay so now let's just take a preview of flags um flex builds on top of checks and it contains everything you need to do your deep learning research here on the left you see an example um but i will talk about the code in more detail later so it doesn't matter if you don't get all the details and our main abstraction is the module abstraction and it's pretty similar to the abstraction that you might know from tensorflow and pytorch if you've used that so we strive to offer an api familiar to those experienced with those kind of libraries but flex is fundamentally a functional system for defining neural networks so what i mean is that you write your modules in a stateful or object-oriented way as you can see in this example and this is this is a class which is within with the state um in order to operate with jax transformations we must expose pure pure function so like i said before jax uses pre-order functions and we cannot just construct this mlp and pass it to a jax transformation that wouldn't work so instead we should create a stateless or pure function from our modules that we can use with checks transformations and this also allows you to compose them like you can in jax and therefore modules come with two functions that you that return a state rather than maintaining a state in it and apply and here you see in the example we have an mlp a multi-layered pressure prone that's created and then um first you construct your model and then you call model.init and that is actually how you initialize your parameters and there you see the parameters are returned and they are not stored in the model and also when you call apply the output is also intense so um this is this is the way we build flex um but alternatively you can of course also expose stateful functions um but then you cannot use the jax transformations directly so this is what other libraries for instance objects they do that and it also works but it's not the philosophy you follow okay so that's about the module abstraction um also our rng handling is an important feature of flex so jax handles rngs differently from numpy because numpy's rng design and so the generating pseudo-random numbers uh it makes it hard to to guarantee a number of desirable properties for doing machine learning research um so i won't go into the details here but um basically a numpy the the rngs are based on a global state and with jax you specifically provide the state to your random number generation so in in flex the user also explicitly provides the rngs when initializing or applying a module however you only have to do this at the top level so as you see here as an argument to the init function we have jacks.random.prng key0 which is a which is used for creating um for for initializing the parameters of your model but it only has to be done at the top level and then flex make sure that that this sheet is split deterministically for all the sub modules like in this uh this nlp you have a few dances that are automatically also then initialized um so now we also have a number of utility functions which are thin and decoupled um in this sense flex comes with what we say batteries included and we have a number of those to simplify your workflow um but at the same time we believe it's also a good feature and the dvds can bend can be decoupled and split off into separate libraries um for instance we recently switched our optimizers to use another library optics and i'll talk about that also a bit more later we also maintain a number of examples and from different domains and also talk about a bit more later and then we have some how-to guides and patterns explaining some of our design choices okay so let me now talk a bit about the flex ecosystem philosophy and as anyone who's ever worked with jax is probably well aware there have been many libraries built on top of jacks and they usually end with x flags objects tracks and there's just really a lot of them and while it can be a bit confusing we believe that this is essentially not a bug but it's a feature and we believe it's actually very healthy to have a number of libraries built on jacks that interoperate and actually we we believe that this really plays well with the compositionality of jax transformations that we talked about before so what we envision as flags is the czech's ecosystem as a set of decoupled libraries that are each individually well maintained and they follow these points that we um that we suggest and so the first one is minimize indirection so we want to make it easy for users to navigate to the low level operations in their machine learning code so we want to keep the control flow relatively simple and you don't have to pass a function that calls a function that's called a function or very deep subclassing those are the things we try to avoid secondly we have a bias against inversion of control we prefer duplicating code over abstractions that require many options so simply put we prefer to avoid abstractions like a trainer and that capture common use cases and hyper parameters in the construction arguments so we we actually believe that it's it's virtually impossible to consider all use cases and and keep if you keep expanding extending such abstractions with new options that that's a road that we don't want to go down to and i also think that in the past that has has shown that this can lead to quite confusing apis um of course we don't want to discourage anyone who would like to implement something like that on top of flex um but we as the flex team we refrain from doing so and the third lossly loosely corporate libraries um so this is comparable to the unix philosophy so just like in unix processes operate through streams and functions can be piped um that injects libraries interoperate by passing pure functions to each other so pure functions are the lingua franca and four uh yeah we care about error messages we use slugs with uh which are well documented um five codes should read out the math in the paper it follows somewhat from the other points but we just want to stress that the math is what matters there should be a few distractions from actually this math and then we prefer separate libraries for separate needs but it's more important that they are actively maintained and supported okay so this is an example of what of a recent change in our in our ecosystem so um as flags we use these flips which are flex which is a flex improvement process so if someone wants to propose a large change then they can create a flip and then we discuss it [Music] so we think this is a good example of healthy development so we as flex we had our own optimizer module which is flex dot optim um but we always felt that it did not work right it wasn't it didn't really have a functional api um but then we found that deepmind actually built a dedicated library for optimizers called optics which is based on composable gradient transformations and it is very functional this it uses a functional api so we actually believe that it works nicer than our own optimizer so we in during this flip um we we decided to switch to optics and now we recommend using flex with optics and most of our examples already are using optics now and also talk a bit more about how optics work works later in this talk so here's um a visualization of the jack's ecosystem that we command we recommend as the flex team and of course libraries we don't mention here so it is opinionated and we also only mention libraries that fit our our philosophy like and the most important thing is that they have to work well together with each other and they are well maintained so at the bottom you see um jacks and then all the verticals that are built on top of that so we have flex for no networks uh op text for optimizers and losses and then if you want to work on special specialized domains we have a number of libraries um ott for optimal transport rlax for reinforcement learning and jref for graph neural networks then there's also checks to tensorflow which is a tool that is actually built by the checks team and and it allows you to to take a jax graph and to directly convert it to a tensorflow graph and that basically allows you to export it to a saved model and then you can use everything that the tensorflow ecosystem has to offer like for serving or for running your models on the web or in the browser um there's also checks for testing and um in flexting we also have the checkpoints that we use and we also recommend using glue which is called common loop details and that is to simplify simplify your training loop so this is a list of some of our examples and um yeah if you just want to get to know this it's nice to look through it we started with mnist it's a very simple example and it's just one file i think um then we also have seek to seek which is which uses a recurrent neural network and lstm with imagenet of course confnets it's very popular and then we have actually a few transformer implementations um wnt is i think our most advanced one is also i think based on the mlperf results so it uses some of those tricks and then we have pixel cnn for generating images and we have ppo for reinforcement learning and we have a bunch more examples okay so that was um about the flex ecosystem let's now talk about flex code and specifically flex linear modules so in the past one one and a half year we worked on the new what we call the linen api so we had an api before that and this is kind of our our second iteration on our api sorry and what we did we studied all the existing neural networking libraries built on top of jacks at the time so for instance tracks objects haiku and then we try to see what would be the best of of all these libraries and what we what do we think is best um then we did a number of focus groups with different groups of users so for instance some users that really like flags or people that really didn't like it and just to and then we iterated from that further so what we ended up is the linen api and i mean we believe it's quite ergonomic and we try to give a lot of flexibility while avoiding food guns and especially what we call the silent foot guns so that is for instance that things happen and you find out really late so um some of your some of the parameters in your model were actually changed why you didn't want them to change they weren't mutable or you get a nan when after training really long and that's that can also be apis can also be a cause of that and we try to avoid that as much as possible okay so let's look at an overview of how these models are used modules are used so in the first step you instantiate the module so this module is called model uh and we're going to look at plenty of examples later um so now you instantiate it and but this will only construct the module and not initialize any variables because remember that we don't want to store the state of the parameters in the module then to create an issue of variables you call model.init and as a first argument you pass the prng key for initializing all the parameters um in this presentation i probably call parameters often params as well but it's the same and you pass them you pass them the inputs of your model but when you initialize these inputs are only used for the shapes then the parameters are created using initial initialization function that is provided when you define them i'll show it later okay so now ther
Original Description
Day 1 Talks: JAX, Flax & Transformers ๐ค
0:00:00 Skye Wanderman-Milne (Google Brain): Intro to JAX on Cloud TPUs
0:42:49 Marc van Zee (Google Brain): Introduction to Flax
1:28:26 Pablo Castro (Google Brain): Using Jax & Flax for RL with the Dopamine library
Find more information about the speakers and the talks here https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md#wednesday-june-30th
Watch on YouTube โ
(saves to browser)
Sign in to unlock AI tutor explanation ยท โก30
Playlist
Uploads from HuggingFace ยท HuggingFace ยท 43 of 60
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
โถ
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
The Future of Natural Language Processing
HuggingFace
Trends in Model Size & Computational Efficiency in NLP
HuggingFace
Increasing Data Usage in Natural Language Processing
HuggingFace
In Domain & Out of Domain Generalization in the Future of NLP
HuggingFace
The Limits of NLU & the Rise of NLG in the Future of NLP
HuggingFace
The Lack of Robustness in the Future of NLP
HuggingFace
Inductive Bias, Common Sense, Continual Learning in The Future of NLP
HuggingFace
Train a Hugging Face Transformers Model with Amazon SageMaker
HuggingFace
What is Transfer Learning?
HuggingFace
The pipeline function
HuggingFace
Navigating the Model Hub
HuggingFace
Transformer models: Decoders
HuggingFace
The Transformer architecture
HuggingFace
Transformer models: Encoder-Decoders
HuggingFace
Transformer models: Encoders
HuggingFace
Keras introduction
HuggingFace
The push to hub API
HuggingFace
Fine-tuning with TensorFlow
HuggingFace
Learning rate scheduling with TensorFlow
HuggingFace
TensorFlow Predictions and metrics
HuggingFace
Welcome to the Hugging Face course
HuggingFace
The tokenization pipeline
HuggingFace
Supercharge your PyTorch training loop with Accelerate
HuggingFace
The Trainer API
HuggingFace
Batching inputs together (PyTorch)
HuggingFace
Batching inputs together (TensorFlow)
HuggingFace
Hugging Face Datasets overview (Pytorch)
HuggingFace
Hugging Face Datasets overview (Tensorflow)
HuggingFace
What is dynamic padding?
HuggingFace
What happens inside the pipeline function? (PyTorch)
HuggingFace
What happens inside the pipeline function? (TensorFlow)
HuggingFace
Instantiate a Transformers model (PyTorch)
HuggingFace
Instantiate a Transformers model (TensorFlow)
HuggingFace
Preprocessing sentence pairs (PyTorch)
HuggingFace
Preprocessing sentence pairs (TensorFlow)
HuggingFace
Write your training loop in PyTorch
HuggingFace
Managing a repo on the Model Hub
HuggingFace
Chapter 1 Live Session with Sylvain
HuggingFace
Chapter 2 Live Session with Lewis
HuggingFace
The push to hub API
HuggingFace
Chapter 2 Live Session with Sylvain
HuggingFace
Chapter 3 live sessions with Lewis (PyTorch)
HuggingFace
Day 1 Talks: JAX, Flax & Transformers ๐ค
HuggingFace
Day 2 Talks: JAX, Flax & Transformers ๐ค
HuggingFace
Day 3 Talks JAX, Flax, Transformers ๐ค
HuggingFace
Chapter 4 live sessions with Omar
HuggingFace
Deploy a Hugging Face Transformers Model from S3 to Amazon SageMaker
HuggingFace
Deploy a Hugging Face Transformers Model from the Model Hub to Amazon SageMaker
HuggingFace
Run a Batch Transform Job using Hugging Face Transformers and Amazon SageMaker
HuggingFace
[Webinar] How to add machine learning capabilities with just a few lines of code
HuggingFace
Hugging Face + Zapier Demo Video
HuggingFace
Hugging Face + Google Sheets Demo
HuggingFace
Hugging Face Infinity Launch - 09/28
HuggingFace
Build and Deploy a Machine Learning App in 2 Minutes
HuggingFace
Hugging Face Infinity - GPU Walkthrough
HuggingFace
Otto - ๐ค Infinity Case Study
HuggingFace
Workshop: Getting started with Amazon Sagemaker Train a Hugging Face Transformers and deploy it
HuggingFace
Workshop: Going Production: Deploying, Scaling & Monitoring Hugging Face Transformer models
HuggingFace
๐ค Tasks: Causal Language Modeling
HuggingFace
๐ค Tasks: Masked Language Modeling
HuggingFace
More on: LLM Engineering
View skill โRelated AI Lessons
โก
โก
โก
โก
Building LSTMs with PyTorch and Lightning AI Part 7: Resuming Training with Checkpoints
Dev.to ยท Rijul Rajesh
How AI Learns with Less Labeled Data
Medium ยท AI
Comparing Sarvam-30B and Qwen2.5โ14B on Spider Text-to-SQL: An Active-Parameter Perspective
Medium ยท LLM
Claude Sonnet 5 closes the gap to Opus without the Opus bill
Medium ยท LLM
Chapters (3)
Skye Wanderman-Milne (Google Brain): Intro to JAX on Cloud TPUs
42:49
Marc van Zee (Google Brain): Introduction to Flax
1:28:26
Pablo Castro (Google Brain): Using Jax & Flax for RL with the Dopamine library
๐
Tutor Explanation
DeepCamp AI