๐งจ Stable Diffusion in JAX / Flax !
๐ฐ Hugging Face Blog
Run Stable Diffusion in JAX/Flax for fast inference on Google TPUs
Action Steps
- Install the Hugging Face Diffusers library
- Import the necessary libraries and load the Stable Diffusion model
- Run inference using JAX/Flax on a TPU backend
- Experiment with different prompts and parallelization techniques
Who Needs to Know This
AI engineers and researchers can benefit from this tutorial to improve their model inference speed, while data scientists and ML engineers can apply this knowledge to their projects
Key Insight
๐ก Hugging Face Diffusers supports Flax for fast inference on Google TPUs
Share This
๐ Run Stable Diffusion in JAX/Flax for fast inference on Google TPUs! ๐ค
Key Takeaways
Run Stable Diffusion in JAX/Flax for fast inference on Google TPUs
Full Article
Published Time: 2022-10-13T00:00:00.136Z
# ๐งจ Stable Diffusion in JAX / Flax !
[Hugging Face](https://huggingface.co/)
* [Models](https://huggingface.co/models)
* [Datasets](https://huggingface.co/datasets)
* [Spaces](https://huggingface.co/spaces)
* [Buckets new](https://huggingface.co/storage)
* [Docs](https://huggingface.co/docs)
* [Enterprise](https://huggingface.co/enterprise)
* [Pricing](https://huggingface.co/pricing)
*
*
* * *
* [Log In](https://huggingface.co/login)
* [Sign Up](https://huggingface.co/join)
[Back to Articles](https://huggingface.co/blog)
# [](https://huggingface.co/blog/stable_diffusion_jax#%F0%9F%A7%A8-stable-diffusion-in-jax--flax-) ๐งจ Stable Diffusion in JAX / Flax !
Published October 13, 2022
[Update on GitHub](https://github.com/huggingface/blog/blob/main/stable_diffusion_jax.md)
[- [x] Upvote 4](https://huggingface.co/login?next=%2Fblog%2Fstable_diffusion_jax)
* [](https://huggingface.co/guandao "guandao")
* [](https://huggingface.co/tensorkelechi "tensorkelechi")
* [](https://huggingface.co/Aanuoluwapo65 "Aanuoluwapo65")
* [](https://huggingface.co/kamapulia77 "kamapulia77")
[](https://huggingface.co/pcuenq)
[Pedro Cuenca pcuenq Follow](https://huggingface.co/pcuenq)
[](https://huggingface.co/patrickvonplaten)
[Patrick von Platen patrickvonplaten Follow](https://huggingface.co/patrickvonplaten)
[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion_jax_how_to.ipynb)
* [Setup](https://huggingface.co/blog/stable_diffusion_jax#setup "Setup")
* [Model Loading](https://huggingface.co/blog/stable_diffusion_jax#model-loading "Model Loading")
* [Inference](https://huggingface.co/blog/stable_diffusion_jax#inference "Inference")
* [Replication and parallelization](https://huggingface.co/blog/stable_diffusion_jax#replication-and-parallelization "Replication and parallelization")
* [Visualization](https://huggingface.co/blog/stable_diffusion_jax#visualization "Visualization")
* [Using different prompts](https://huggingface.co/blog/stable_diffusion_jax#using-different-prompts "Using different prompts")
* [How does parallelization work?](https://huggingface.co/blog/stable_diffusion_jax#how-does-parallelization-work "How does parallelization work?")
๐ค Hugging Face [Diffusers](https://github.com/huggingface/diffusers) supports Flax since version `0.5.1`! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform.
This post shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to [this Colab notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb).
If you want to follow along, click the button above to open this post as a Colab notebook.
First, make sure you are using a TPU backend. If you are running this notebook in Colab, select `Runtime` in the menu above, then select the option "Change runtime type" and then select `TPU` under the `Hardware
# ๐งจ Stable Diffusion in JAX / Flax !
[Hugging Face](https://huggingface.co/)
* [Models](https://huggingface.co/models)
* [Datasets](https://huggingface.co/datasets)
* [Spaces](https://huggingface.co/spaces)
* [Buckets new](https://huggingface.co/storage)
* [Docs](https://huggingface.co/docs)
* [Enterprise](https://huggingface.co/enterprise)
* [Pricing](https://huggingface.co/pricing)
*
*
* * *
* [Log In](https://huggingface.co/login)
* [Sign Up](https://huggingface.co/join)
[Back to Articles](https://huggingface.co/blog)
# [](https://huggingface.co/blog/stable_diffusion_jax#%F0%9F%A7%A8-stable-diffusion-in-jax--flax-) ๐งจ Stable Diffusion in JAX / Flax !
Published October 13, 2022
[Update on GitHub](https://github.com/huggingface/blog/blob/main/stable_diffusion_jax.md)
[- [x] Upvote 4](https://huggingface.co/login?next=%2Fblog%2Fstable_diffusion_jax)
* [](https://huggingface.co/guandao "guandao")
* [](https://huggingface.co/tensorkelechi "tensorkelechi")
* [](https://huggingface.co/Aanuoluwapo65 "Aanuoluwapo65")
* [](https://huggingface.co/kamapulia77 "kamapulia77")
[](https://huggingface.co/pcuenq)
[Pedro Cuenca pcuenq Follow](https://huggingface.co/pcuenq)
[](https://huggingface.co/patrickvonplaten)
[Patrick von Platen patrickvonplaten Follow](https://huggingface.co/patrickvonplaten)
[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion_jax_how_to.ipynb)
* [Setup](https://huggingface.co/blog/stable_diffusion_jax#setup "Setup")
* [Model Loading](https://huggingface.co/blog/stable_diffusion_jax#model-loading "Model Loading")
* [Inference](https://huggingface.co/blog/stable_diffusion_jax#inference "Inference")
* [Replication and parallelization](https://huggingface.co/blog/stable_diffusion_jax#replication-and-parallelization "Replication and parallelization")
* [Visualization](https://huggingface.co/blog/stable_diffusion_jax#visualization "Visualization")
* [Using different prompts](https://huggingface.co/blog/stable_diffusion_jax#using-different-prompts "Using different prompts")
* [How does parallelization work?](https://huggingface.co/blog/stable_diffusion_jax#how-does-parallelization-work "How does parallelization work?")
๐ค Hugging Face [Diffusers](https://github.com/huggingface/diffusers) supports Flax since version `0.5.1`! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform.
This post shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to [this Colab notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb).
If you want to follow along, click the button above to open this post as a Colab notebook.
First, make sure you are using a TPU backend. If you are running this notebook in Colab, select `Runtime` in the menu above, then select the option "Change runtime type" and then select `TPU` under the `Hardware
DeepCamp AI