๐Ÿงจ Stable Diffusion in JAX / Flax !

๐Ÿ“ฐ Hugging Face Blog

Run Stable Diffusion in JAX/Flax for fast inference on Google TPUs

intermediate Published 13 Oct 2022
Action Steps
  1. Install the Hugging Face Diffusers library
  2. Import the necessary libraries and load the Stable Diffusion model
  3. Run inference using JAX/Flax on a TPU backend
  4. Experiment with different prompts and parallelization techniques
Who Needs to Know This

AI engineers and researchers can benefit from this tutorial to improve their model inference speed, while data scientists and ML engineers can apply this knowledge to their projects

Key Insight

๐Ÿ’ก Hugging Face Diffusers supports Flax for fast inference on Google TPUs

Share This
๐Ÿš€ Run Stable Diffusion in JAX/Flax for fast inference on Google TPUs! ๐Ÿค—

Key Takeaways

Run Stable Diffusion in JAX/Flax for fast inference on Google TPUs

Full Article

Published Time: 2022-10-13T00:00:00.136Z

# ๐Ÿงจ Stable Diffusion in JAX / Flax !

[![Image 1: Hugging Face's logo](https://huggingface.co/front/assets/huggingface_logo-noborder.svg)Hugging Face](https://huggingface.co/)

* [Models](https://huggingface.co/models)
* [Datasets](https://huggingface.co/datasets)
* [Spaces](https://huggingface.co/spaces)
* [Buckets new](https://huggingface.co/storage)
* [Docs](https://huggingface.co/docs)
* [Enterprise](https://huggingface.co/enterprise)
* [Pricing](https://huggingface.co/pricing)
*
*
* * *

* [Log In](https://huggingface.co/login)
* [Sign Up](https://huggingface.co/join)

[Back to Articles](https://huggingface.co/blog)

# [](https://huggingface.co/blog/stable_diffusion_jax#%F0%9F%A7%A8-stable-diffusion-in-jax--flax-) ๐Ÿงจ Stable Diffusion in JAX / Flax !

Published October 13, 2022

[Update on GitHub](https://github.com/huggingface/blog/blob/main/stable_diffusion_jax.md)

[- [x] Upvote 4](https://huggingface.co/login?next=%2Fblog%2Fstable_diffusion_jax)
* [![Image 2](https://cdn-avatars.huggingface.co/v1/production/uploads/64547ca006728ff79a38e7a5/xfTVQXMSfRv-IbljFpswQ.jpeg)](https://huggingface.co/guandao "guandao")
* [![Image 3](https://cdn-avatars.huggingface.co/v1/production/uploads/65ae3e7fb189e85e6038b4fb/tnsZ8FC7UsgPgQGrqCAZV.jpeg)](https://huggingface.co/tensorkelechi "tensorkelechi")
* [![Image 4](https://huggingface.co/avatars/129d1e86bbaf764b507501f4feb177db.svg)](https://huggingface.co/Aanuoluwapo65 "Aanuoluwapo65")
* [![Image 5](https://huggingface.co/avatars/ea14f8807b22ac82d822d97ca5ff2239.svg)](https://huggingface.co/kamapulia77 "kamapulia77")

[![Image 6: Pedro Cuenca's avatar](https://cdn-avatars.huggingface.co/v1/production/uploads/1617264212503-603d25b75f9d390ab190b777.jpeg)](https://huggingface.co/pcuenq)

[Pedro Cuenca pcuenq Follow](https://huggingface.co/pcuenq)

[![Image 7: Patrick von Platen's avatar](https://cdn-avatars.huggingface.co/v1/production/uploads/1584435275418-5dfcb1aada6d0311fd3d5448.jpeg)](https://huggingface.co/patrickvonplaten)

[Patrick von Platen patrickvonplaten Follow](https://huggingface.co/patrickvonplaten)

[![Image 8: Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion_jax_how_to.ipynb)

* [Setup](https://huggingface.co/blog/stable_diffusion_jax#setup "Setup")

* [Model Loading](https://huggingface.co/blog/stable_diffusion_jax#model-loading "Model Loading")

* [Inference](https://huggingface.co/blog/stable_diffusion_jax#inference "Inference")
* [Replication and parallelization](https://huggingface.co/blog/stable_diffusion_jax#replication-and-parallelization "Replication and parallelization")

* [Visualization](https://huggingface.co/blog/stable_diffusion_jax#visualization "Visualization")

* [Using different prompts](https://huggingface.co/blog/stable_diffusion_jax#using-different-prompts "Using different prompts")

* [How does parallelization work?](https://huggingface.co/blog/stable_diffusion_jax#how-does-parallelization-work "How does parallelization work?")

๐Ÿค— Hugging Face [Diffusers](https://github.com/huggingface/diffusers) supports Flax since version `0.5.1`! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform.
This post shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to [this Colab notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb).

If you want to follow along, click the button above to open this post as a Colab notebook.

First, make sure you are using a TPU backend. If you are running this notebook in Colab, select `Runtime` in the menu above, then select the option "Change runtime type" and then select `TPU` under the `Hardware
Read full article โ†’ โ† Back to Reads