Two Ways to Fine-Tune JAX on NVIDIA GPUs: PEFT and SFT with Tunix and MaxText
Want to get hands-on with fine-tuning Llama 3.1-8B on NVIDIA GPUs using JAX? This video shows how JAX runs seamlessly on GPUs and how that unlocks real workflows for training and adapting modern LLMs on NVIDIA hardware.
The video starts by addressing a common misconception — that JAX is only for TPUs — and quickly demonstrates GPU detection in action. From there, it connects the dots between JAX, XLA, and GPU acceleration to show how large-scale model training can run efficiently outside TPU environments.
Two practical demos are included, each built on official Google tooling. The first uses MaxText for full supervised fine-tuning (SFT), designed for multi-GPU setups where the entire model is trained. You’ll see how to configure distributed training, prepare gated Hugging Face checkpoints, convert them into MaxText format, and launch a short training run with TensorBoard tracking.
The second demo focuses on Tunix for parameter-efficient fine-tuning using LoRA and QLoRA. This path is designed for lower-resource environments, showing how adapter-based methods allow you to fine-tune effectively on a single GPU. It also demonstrates how models are loaded, converted to JAX, and automatically sharded across available devices.
All examples are provided as Jupyter notebooks, making it easy to replicate the workflows and experiment with your own setups. The video also highlights when to choose LoRA or QLoRA for lightweight adaptation versus full SFT for deeper model changes.
By the end, you’ll have a clear understanding of fine-tuning Llama 3.1-8B on NVIDIA GPUs, along with two distinct approaches you can apply depending on your hardware and goals.
SFT with maxtext: https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo_gpu.ipynb
PEFT QLoRA / LoRA with tunix: https://github.com/google/tunix/blob/main/examples/qlora_llama3_gpu.ipynb
00:00 -- Breaking the myth of JAX being TPU-only
00:14 – Fine-tuning on GPU using JAX
00:42 – SFT w
Watch on YouTube ↗
(saves to browser)
Sign in to unlock AI tutor explanation · ⚡30
More on: Fine-tuning LLMs
View skill →Related AI Lessons
Chapters (3)
Breaking the myth of JAX being TPU-only
0:14
Fine-tuning on GPU using JAX
0:42
SFT w
🎓
Tutor Explanation
DeepCamp AI