Two Ways to Fine-Tune JAX on NVIDIA GPUs: PEFT and SFT with Tunix and MaxText

NVIDIA Developer · Beginner ·🧠 Large Language Models ·2w ago
Want to get hands-on with fine-tuning Llama 3.1-8B on NVIDIA GPUs using JAX? This video shows how JAX runs seamlessly on GPUs and how that unlocks real workflows for training and adapting modern LLMs on NVIDIA hardware. The video starts by addressing a common misconception — that JAX is only for TPUs — and quickly demonstrates GPU detection in action. From there, it connects the dots between JAX, XLA, and GPU acceleration to show how large-scale model training can run efficiently outside TPU environments. Two practical demos are included, each built on official Google tooling. The first uses MaxText for full supervised fine-tuning (SFT), designed for multi-GPU setups where the entire model is trained. You’ll see how to configure distributed training, prepare gated Hugging Face checkpoints, convert them into MaxText format, and launch a short training run with TensorBoard tracking. The second demo focuses on Tunix for parameter-efficient fine-tuning using LoRA and QLoRA. This path is designed for lower-resource environments, showing how adapter-based methods allow you to fine-tune effectively on a single GPU. It also demonstrates how models are loaded, converted to JAX, and automatically sharded across available devices. All examples are provided as Jupyter notebooks, making it easy to replicate the workflows and experiment with your own setups. The video also highlights when to choose LoRA or QLoRA for lightweight adaptation versus full SFT for deeper model changes. By the end, you’ll have a clear understanding of fine-tuning Llama 3.1-8B on NVIDIA GPUs, along with two distinct approaches you can apply depending on your hardware and goals. SFT with maxtext: https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo_gpu.ipynb PEFT QLoRA / LoRA with tunix: https://github.com/google/tunix/blob/main/examples/qlora_llama3_gpu.ipynb 00:00 -- Breaking the myth of JAX being TPU-only 00:14 – Fine-tuning on GPU using JAX 00:42 – SFT w
Watch on YouTube ↗ (saves to browser)
Sign in to unlock AI tutor explanation · ⚡30

Related AI Lessons

Chapters (3)

Breaking the myth of JAX being TPU-only
0:14 Fine-tuning on GPU using JAX
0:42 SFT w
Up next
5 Levels of AI Agents - From Simple LLM Calls to Multi-Agent Systems
Dave Ebbelaar (LLM Eng)
Watch →