{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Kzqlx7fpXRnJ" }, "source": [ "# Part 3: Train a diffusion model for image generation\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb)\n", "\n", "This tutorial guides you through developing and training a simple diffusion model using the [U-Net architecture](https://en.wikipedia.org/wiki/U-Net) for image generation using JAX, [Flax NNX](http://flax.readthedocs.io) and [Optax](http://optax.readthedocs.io). This builds upon the previous tutorial, [Variational autoencoder (VAE) and debugging in JAX](https://jax-ai-stack.readthedocs.io/en/latest/digits_vae.html), which focus on training a simpler generative model called VAE.\n", "\n", "In this tutorial, you'll learn how to:\n", "\n", "- Load and preprocess the dataset\n", "- Define the diffusion model with Flax\n", "- Create the loss and training functions\n", "- Train the model (with Google Colab’s Cloud TPU v2)\n", "- Visualize and track the model’s progress\n", "\n", "If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with Flax, Optax and JAX." ] }, { "cell_type": "markdown", "metadata": { "id": "gwaaMmjXt7n7" }, "source": [ "## Setup\n", "\n", "JAX for AI (the stack) installation is covered [here](https://docs.jaxstack.ai/en/latest/install.html). And JAX (the library) installation is covered in [this guide](https://jax.readthedocs.io/en/latest/installation.html) on the JAX documentation site.\n", "\n", "Start with importing JAX, JAX NumPy, Flax NNX, Optax, matplotlib and scikit-learn:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dVVACvmDuDCM" }, "outputs": [], "source": [ "import jax\n", "import optax\n", "from flax import nnx\n", "import jax.numpy as jnp\n", "from functools import partial\n", "import matplotlib.pyplot as plt\n", "from sklearn.datasets import load_digits\n", "from typing import Tuple, Callable, List, Optional\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": { "id": "tQ5KGMyrYG2H" }, "source": [ "**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator.\n", "\n", "Check the available JAX devices, or [`jax.Device`s](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ldmtemzPBO5z", "outputId": "d21720a2-65cd-4a5c-ef86-3a0912e36c34" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check the available JAX devices.\n", "jax.devices()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading and preprocessing the data\n", "\n", "We'll use the small, self-contained [scikit-learn `digits` dataset](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) for ease of experimentation to demonstrate diffusion model training. For simplicity, we'll focus on generating only the digit '1' (one).\n", "\n", "This involves several steps, such as:\n", "\n", "1. Loading the dataset\n", "2. Filtering the images of '1' (one)\n", "3. Normalizing pixel values\n", "4. Converting the data into `jax.Array`s\n", "5. Reshaping the data, and splitting it into training and test sets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 303 }, "id": "jNizSH6uuXY4", "outputId": "112723a1-fd36-46b2-946d-6d789f5a33ed" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training set size: 172\n", "Test set size: 10\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Load and preprocess the `digits` dataset.\n", "digits = load_digits()\n", "# Filter for digit '1' (one) images.\n", "images = digits.images[digits.target == 1]\n", "# Normalize pixel values into floating-point arrays in the `[0, 1]` interval.\n", "images = images / 16.0\n", "# Convert to `jax.Array`s.\n", "images = jnp.asarray(images)\n", "# Reshape to `(num_images, height, width, channels)` for convolutional layers.\n", "images = images.reshape(-1, 8, 8, 1)\n", "\n", "# Split the dataset into training and test sets (5% for testing).\n", "images_train, images_test = train_test_split(images, test_size=0.05, random_state=42)\n", "print(f\"Training set size: {images_train.shape[0]}\")\n", "print(f\"Test set size: {images_test.shape[0]}\")\n", "\n", "# Visualize sample images.\n", "fig, axes = plt.subplots(3, 3, figsize=(3, 3))\n", "for i, ax in enumerate(axes.flat):\n", " if i < len(images_train):\n", " ax.imshow(images_train[i, ..., 0], cmap='gray', interpolation='gaussian')\n", " ax.axis('off')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "exKxj9OcG0yk" }, "source": [ "## Defining the diffusion model with Flax\n", "\n", "In this section, we’ll develop various parts of the [diffusion model](https://en.wikipedia.org/wiki/Diffusion_model) and then put them all together.\n", "\n", "### The U-Net architecture\n", "\n", "For this example, we’ll use the [U-Net architecture](https://en.wikipedia.org/wiki/U-Net), a convolutional neural network architecture, as the backbone of the diffusion model. The U-Net consists of the following:\n", "\n", "- An [encoder](https://en.wikipedia.org/wiki/Autoencoder) path that [downsamples](https://en.wikipedia.org/wiki/Downsampling_(signal_processing)) the input image, extracting features.\n", "- A bridge with a (self-)[attention mechanism](https://en.wikipedia.org/wiki/Attention_(machine_learning) that connects the encoder with the decoder.\n", "- A [decoder](https://en.wikipedia.org/wiki/Autoencoder) path that [upsamples](https://en.wikipedia.org/wiki/Upsampling) the feature representations learned by the encoder, reconstructing the output image.\n", "- [Skip connections](https://en.wikipedia.org/wiki/Residual_neural_network#Residual_connection) between the encoder and the decoder.\n", "\n", "Let's define a class called `UNet` by subclassing [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) and using, among other things, [`flax.nnx.Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear) (linear or dense layers for time embedding and time projection layers, as well as the self-attention layers), [`flax.nnx.LayerNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.LayerNorm) (layer normalization), and [`flax.nnx.Conv`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Conv) (convolution layers for the output layer)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F4pxdITOuk79" }, "outputs": [], "source": [ "class UNet(nnx.Module):\n", " def __init__(self,\n", " in_channels: int,\n", " out_channels: int,\n", " features: int,\n", " time_emb_dim: int = 128,\n", " *,\n", " rngs: nnx.Rngs):\n", " \"\"\"\n", " Initialize the U-Net architecture with time embedding.\n", " \"\"\"\n", " self.features = features\n", "\n", " # Time embedding layers for diffusion timestep conditioning.\n", " self.time_mlp_1 = nnx.Linear(in_features=time_emb_dim, out_features=time_emb_dim, rngs=rngs)\n", " self.time_mlp_2 = nnx.Linear(in_features=time_emb_dim, out_features=time_emb_dim, rngs=rngs)\n", "\n", " # Time projection layers for different scales.\n", " self.time_proj1 = nnx.Linear(in_features=time_emb_dim, out_features=features, rngs=rngs)\n", " self.time_proj2 = nnx.Linear(in_features=time_emb_dim, out_features=features * 2, rngs=rngs)\n", " self.time_proj3 = nnx.Linear(in_features=time_emb_dim, out_features=features * 4, rngs=rngs)\n", " self.time_proj4 = nnx.Linear(in_features=time_emb_dim, out_features=features * 8, rngs=rngs)\n", "\n", " # The encoder path.\n", " self.down_conv1 = self._create_residual_block(in_channels, features, rngs)\n", " self.down_conv2 = self._create_residual_block(features, features * 2, rngs)\n", " self.down_conv3 = self._create_residual_block(features * 2, features * 4, rngs)\n", " self.down_conv4 = self._create_residual_block(features * 4, features * 8, rngs)\n", "\n", " # Multi-head self-attention blocks.\n", " self.attention1 = self._create_attention_block(features * 4, rngs)\n", " self.attention2 = self._create_attention_block(features * 8, rngs)\n", "\n", " # The bridge connecting the encoder and the decoder.\n", " self.bridge_down = self._create_residual_block(features * 8, features * 16, rngs)\n", " self.bridge_attention = self._create_attention_block(features * 16, rngs)\n", " self.bridge_up = self._create_residual_block(features * 16, features * 16, rngs)\n", "\n", " # Decoder path with skip connections.\n", " self.up_conv4 = self._create_residual_block(features * 24, features * 8, rngs)\n", " self.up_conv3 = self._create_residual_block(features * 12, features * 4, rngs)\n", " self.up_conv2 = self._create_residual_block(features * 6, features * 2, rngs)\n", " self.up_conv1 = self._create_residual_block(features * 3, features, rngs)\n", "\n", " # Output layers.\n", " self.final_norm = nnx.LayerNorm(features, rngs=rngs)\n", " self.final_conv = nnx.Conv(in_features=features,\n", " out_features=out_channels,\n", " kernel_size=(3, 3),\n", " strides=(1, 1),\n", " padding=((1, 1), (1, 1)),\n", " rngs=rngs)\n", "\n", " def _create_attention_block(self, channels: int, rngs: nnx.Rngs) -> Callable:\n", " \"\"\"Creates a self-attention block with learned query, key, value projections.\n", "\n", " Args:\n", " channels (int): The number of channels in the input feature maps.\n", " rngs (flax.nnx.Rngs): A set of named `flax.nnx.RngStream` objects that generate a stream of JAX pseudo-random number generator (PRNG) keys.\n", "\n", " Returns:\n", " Callable: A function representing a forward pass through the attention block.\n", "\n", " \"\"\"\n", " query_proj = nnx.Linear(in_features=channels, out_features=channels, rngs=rngs)\n", " key_proj = nnx.Linear(in_features=channels, out_features=channels, rngs=rngs)\n", " value_proj = nnx.Linear(in_features=channels, out_features=channels, rngs=rngs)\n", "\n", " def forward(x: jax.Array) -> jax.Array:\n", " \"\"\"Applies self-attention to the input.\n", "\n", " Args:\n", " x (jax.Array): The input tensor with the shape `[batch, height, width, channels]` (or `B, H, W, C`).\n", "\n", " Returns:\n", " jax.Array: The output tensor after applying self-attention.\n", " \"\"\"\n", "\n", " # Shape: batch, height, width, channels.\n", " B, H, W, C = x.shape\n", " scale = jnp.sqrt(C).astype(x.dtype)\n", "\n", " # Project the input into query, key, value projections.\n", " q = query_proj(x)\n", " k = key_proj(x)\n", " v = value_proj(x)\n", "\n", " # Reshape for the attention computation.\n", " q = q.reshape(B, H * W, C)\n", " k = k.reshape(B, H * W, C)\n", " v = v.reshape(B, H * W, C)\n", "\n", " # Compute the scaled dot-product attention.\n", " attention = jnp.einsum('bic,bjc->bij', q, k) / scale # Scaled dot-product.\n", " attention = jax.nn.softmax(attention, axis=-1) # Softmax.\n", "\n", " # The output tensor.\n", " out = jnp.einsum('bij,bjc->bic', attention, v)\n", " out = out.reshape(B, H, W, C)\n", "\n", " return x + out # A ResNet-style residual connection.\n", "\n", " return forward\n", "\n", " def _create_residual_block(self,\n", " in_channels: int,\n", " out_channels: int,\n", " rngs: nnx.Rngs) -> Callable:\n", " \"\"\"Creates a residual block with two convolutions and normalization.\n", "\n", " Args:\n", " in_channels (int): Number of input channels.\n", " out_channels (int): Number of output channels.\n", " rngs (flax.nnx.Rngs): A set of named `flax.nnx.RngStream` objects that generate a stream of JAX PRNG keys.\n", "\n", " Returns:\n", " Callable: A function that represents the forward pass through the residual block.\n", " \"\"\"\n", "\n", " # Convolutional layers with layer normalization.\n", " conv1 = nnx.Conv(in_features=in_channels,\n", " out_features=out_channels,\n", " kernel_size=(3, 3),\n", " strides=(1, 1),\n", " padding=((1, 1), (1, 1)),\n", " rngs=rngs)\n", " norm1 = nnx.LayerNorm(out_channels, rngs=rngs)\n", " conv2 = nnx.Conv(in_features=out_channels,\n", " out_features=out_channels,\n", " kernel_size=(3, 3),\n", " strides=(1, 1),\n", " padding=((1, 1), (1, 1)),\n", " rngs=rngs)\n", " norm2 = nnx.LayerNorm(out_channels, rngs=rngs)\n", "\n", " # Projection shortcut if dimensions change.\n", " shortcut = nnx.Conv(in_features=in_channels,\n", " out_features=out_channels,\n", " kernel_size=(1, 1),\n", " strides=(1, 1),\n", " rngs=rngs)\n", "\n", " # The forward pass through the residual block.\n", " def forward(x: jax.Array) -> jax.Array:\n", " identity = shortcut(x)\n", "\n", " x = conv1(x)\n", " x = norm1(x)\n", " x = nnx.gelu(x)\n", "\n", " x = conv2(x)\n", " x = norm2(x)\n", " x = nnx.gelu(x)\n", "\n", " return x + identity\n", "\n", " return forward\n", "\n", " def _pos_encoding(self, t: jax.Array, dim: int) -> jax.Array:\n", " \"\"\"Applies sinusoidal positional encoding for time embedding.\n", "\n", " Args:\n", " t (jax.Array): The time embedding, representing the timestep.\n", " dim (int): The dimension of the output positional encoding.\n", "\n", " Returns:\n", " jax.Array: The sinusoidal positional embedding per timestep.\n", "\n", " \"\"\"\n", " # Calculate half the embedding dimension.\n", " half_dim = dim // 2\n", " # Compute the logarithmic scaling factor for sinusoidal frequencies.\n", " emb = jnp.log(10000.0) / (half_dim - 1)\n", " # Generate a range of sinusoidal frequencies.\n", " emb = jnp.exp(jnp.arange(half_dim) * -emb)\n", " # Create the positional encoding by multiplying time embeddings with.\n", " emb = t[:, None] * emb[None, :]\n", " # Concatenate sine and cosine components for richer representation.\n", " emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1)\n", " return emb\n", "\n", " def _downsample(self, x: jax.Array) -> jax.Array:\n", " \"\"\"Downsamples the input feature map with max pooling.\"\"\"\n", " return nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')\n", "\n", " def _upsample(self, x: jax.Array, target_size: int) -> jax.Array:\n", " \"\"\"Upsamples the input feature map using nearest neighbor interpolation.\"\"\"\n", " return jax.image.resize(x,\n", " (x.shape[0], target_size, target_size, x.shape[3]),\n", " method='nearest')\n", "\n", " def __call__(self, x: jax.Array, t: jax.Array) -> jax.Array:\n", " \"\"\"Perform the forward pass through the U-Net using time embeddings.\"\"\"\n", "\n", " # Time embedding and projection.\n", " t_emb = self._pos_encoding(t, 128) # Sinusoidal positional encoding for time.\n", " t_emb = self.time_mlp_1(t_emb) # Project and activate the time embedding\n", " t_emb = nnx.gelu(t_emb) # Activation function: `flax.nnx.gelu` (GeLU).\n", " t_emb = self.time_mlp_2(t_emb)\n", "\n", " # Project time embeddings for each scale.\n", " # Project to the correct dimensions for each encoder block.\n", " t_emb1 = self.time_proj1(t_emb)[:, None, None, :]\n", " t_emb2 = self.time_proj2(t_emb)[:, None, None, :]\n", " t_emb3 = self.time_proj3(t_emb)[:, None, None, :]\n", " t_emb4 = self.time_proj4(t_emb)[:, None, None, :]\n", "\n", " # The encoder path with time injection.\n", " d1 = self.down_conv1(x)\n", " t_emb1 = jnp.broadcast_to(t_emb1, d1.shape) # Broadcast the time embedding to match feature map shape.\n", " d1 = d1 + t_emb1 # Add the time embedding to the feature map.\n", "\n", " d2 = self.down_conv2(self._downsample(d1))\n", " t_emb2 = jnp.broadcast_to(t_emb2, d2.shape)\n", " d2 = d2 + t_emb2\n", "\n", " d3 = self.down_conv3(self._downsample(d2))\n", " d3 = self.attention1(d3) # Apply self-attention.\n", " t_emb3 = jnp.broadcast_to(t_emb3, d3.shape)\n", " d3 = d3 + t_emb3\n", "\n", " d4 = self.down_conv4(self._downsample(d3))\n", " d4 = self.attention2(d4)\n", " t_emb4 = jnp.broadcast_to(t_emb4, d4.shape)\n", " d4 = d4 + t_emb4\n", "\n", " # The bridge.\n", " b = self._downsample(d4)\n", " b = self.bridge_down(b)\n", " b = self.bridge_attention(b)\n", " b = self.bridge_up(b)\n", "\n", " # The decoder path with skip connections.\n", " u4 = self.up_conv4(jnp.concatenate([self._upsample(b, d4.shape[1]), d4], axis=-1))\n", " u3 = self.up_conv3(jnp.concatenate([self._upsample(u4, d3.shape[1]), d3], axis=-1))\n", " u2 = self.up_conv2(jnp.concatenate([self._upsample(u3, d2.shape[1]), d2], axis=-1))\n", " u1 = self.up_conv1(jnp.concatenate([self._upsample(u2, d1.shape[1]), d1], axis=-1))\n", "\n", " # Final layers.\n", " x = self.final_norm(u1)\n", " x = nnx.gelu(x)\n", " return self.final_conv(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "XJaqiL07HD9D" }, "source": [ "### Defining the diffusion model\n", "\n", "Here, we will define the diffusion model that encapsulates the previously components, such as the `UNet` class, and include all the layers needed to perform the diffusion operations. The `DiffusionModel` class implements the diffusion process with:\n", "\n", "- Forward diffusion (adding noise)\n", "- Reverse diffusion (denoising)\n", "- Custom noise scheduling" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4ml8cKFUvCdv" }, "outputs": [], "source": [ "class DiffusionModel:\n", " def __init__(self,\n", " model: UNet,\n", " num_steps: int,\n", " beta_start: float,\n", " beta_end: float):\n", " \"\"\"Initialize diffusion process parameters.\n", "\n", " Args:\n", " model (UNet): The U-Net model for image generation.\n", " num_steps (int): The number of diffusion steps in the process.\n", " beta_start: The starting value for beta, controlling the noise level.\n", " beta_end: The end value for beta.\n", " \"\"\"\n", " self.model = model\n", " self.num_steps = num_steps\n", "\n", " # Noise schedule parameters.\n", " self.beta = self._cosine_beta_schedule(num_steps, beta_start, beta_end)\n", " self.alpha = 1 - self.beta\n", " self.alpha_cumulative = jnp.cumprod(self.alpha)\n", "\n", " self.sqrt_alpha_cumulative = jnp.sqrt(self.alpha_cumulative)\n", " self.sqrt_one_minus_alpha_cumulative = jnp.sqrt(1 - self.alpha_cumulative)\n", " self.sqrt_recip_alpha = jnp.sqrt(1 / self.alpha)\n", "\n", " self.posterior_variance = self.beta * (1 - self.alpha_cumulative) / (1 - self.alpha_cumulative + 1e-7)\n", "\n", " def _cosine_beta_schedule(self,\n", " num_steps: int,\n", " beta_start: float,\n", " beta_end: float) -> jax.Array:\n", " \"\"\"Cosine schedule for noise levels.\"\"\"\n", " steps = jnp.linspace(0, num_steps, num_steps + 1)\n", " x = steps / num_steps\n", " alphas = jnp.cos(((x + 0.008) / 1.008) * jnp.pi * 0.5) ** 2\n", " alphas = alphas / alphas[0]\n", " betas = 1 - (alphas[1:] / alphas[:-1])\n", " betas = jnp.clip(betas, beta_start, beta_end)\n", " return jnp.concatenate([betas[0:1], betas])\n", "\n", " def forward(self,\n", " x: jax.Array,\n", " t: jax.Array,\n", " key: jax.Array) -> Tuple[jax.Array, jax.Array]:\n", " \"\"\"Forward diffusion process - adds noise according to schedule.\n", "\n", " Args:\n", " x (jax.Array): The input image.\n", " t (jax.Array): The timestep(s) at which the noise is added.\n", " key (jax.Array): A JAX PRNG key for generating random noise.\n", "\n", " Returns:\n", " Tuple[jax.Array, jax.Array]\n", " \"\"\"\n", " noise = jax.random.normal(key, x.shape)\n", " noisy_x = (\n", " jnp.sqrt(self.alpha_cumulative[t])[:, None, None, None] * x +\n", " jnp.sqrt(1 - self.alpha_cumulative[t])[:, None, None, None] * noise\n", " )\n", " return noisy_x, noise\n", "\n", " def reverse(self, x: jax.Array, key: jax.Array) -> jax.Array:\n", " \"\"\"Performs the reverse diffusion process, denoising the input image gradually.\n", "\n", " Args:\n", " x (jax.Array): The noise image batch per timestep.\n", " key (jax.Array): A JAX PRNG key for the random noise.\n", " \"\"\"\n", " x_t = x\n", " for t in reversed(range(self.num_steps)):\n", " t_batch = jnp.array([t] * x.shape[0])\n", " predicted = self.model(x_t, t_batch) # Predicted noise using the U-Net.\n", "\n", " key, subkey = jax.random.split(key) # Split the JAX PRNG key.\n", " noise = jax.random.normal(subkey, x_t.shape) if t > 0 else 0 # Sample the noise for the current timestep.\n", "\n", " # The denoising step.\n", " x_t = (1 / jnp.sqrt(self.alpha[t])) * (\n", " x_t - ((1 - self.alpha[t]) / jnp.sqrt(1 - self.alpha_cumulative[t])) * predicted\n", " ) + jnp.sqrt(self.beta[t]) * noise\n", "\n", " return x_t # The final denoised image." ] }, { "cell_type": "markdown", "metadata": { "id": "wKnYRqMAI06f" }, "source": [ "## Defining the loss function and training step\n", "\n", "In this section, we’ll define the components for training our diffusion model, including:\n", "\n", "- The loss function (`loss_fn()`), which incorporates [SNR weighting](https://en.wikipedia.org/wiki/Signal-to-noise_ratio) and a gradient penalty; and\n", "- The training step (`train_step()`) with [gradient clipping](https://arxiv.org/pdf/1905.11881) for stability." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rq9Ic8WYCCJI" }, "outputs": [], "source": [ "def loss_fn(model: UNet,\n", " images: jax.Array,\n", " t: jax.Array,\n", " noise: jax.Array,\n", " sqrt_alpha_cumulative: jax.Array,\n", " sqrt_one_minus_alpha_cumulative: jax.Array) -> jax.Array:\n", " \"\"\"Computes the diffusion loss function with SNR weighting and adaptive noise scaling.\n", "\n", " Args:\n", " model(UNet): The U-Net model for image generation.\n", " images (jax.Array): A batch of images used for training.\n", " t (jax.Array): The timestep(s) at which the noise is added to each image.\n", " noise (jax.Array): The noise added to the images.\n", " sqrt_alpha_cumulative (jax.Array): Square root of cumulative alpha values.\n", " sqrt_one_minus_alpha_cumulative (jax.Array): Square root of (1 - cumulative alpha values).\n", "\n", " Returns:\n", " jax.Array: The total loss value.\n", " \"\"\"\n", "\n", " # Generate noisy images.\n", " noisy_images = (\n", " sqrt_alpha_cumulative[t][:, None, None, None] * images +\n", " sqrt_one_minus_alpha_cumulative[t][:, None, None, None] * noise\n", " )\n", "\n", " # Predict the noise using the U-Net.\n", " predicted = model(noisy_images, t)\n", "\n", " # Compute the SNR-weighted loss.\n", " snr = (sqrt_alpha_cumulative[t] / sqrt_one_minus_alpha_cumulative[t])[:, None, None, None]\n", " loss_weights = snr / (1 + snr)\n", "\n", " squared_error = (noise - predicted) ** 2\n", " main_loss = jnp.mean(loss_weights * squared_error)\n", "\n", " # Perform gradient penalty (regularization) with a reduced coefficient.\n", " grad = jax.grad(lambda x: model(x, t).mean())(noisy_images)\n", " grad_penalty = 0.02 * (jnp.square(grad).mean())\n", "\n", " # The total loss.\n", " return main_loss + grad_penalty\n", "\n", "# Flax NNX JIT-compilation for performance (`flax.nnx.jit`).\n", "@nnx.jit\n", "def train_step(model: UNet,\n", " optimizer: nnx.Optimizer,\n", " images: jax.Array,\n", " t: jax.Array,\n", " noise: jax.Array,\n", " sqrt_alpha_cumulative: jax.Array,\n", " sqrt_one_minus_alpha_cumulative: jax.Array) -> jax.Array:\n", " \"\"\"Performs a single training step with gradient clipping.\n", "\n", " Args:\n", " model(UNet): The U-Net model for image generation that is being trained.\n", " optimizer (flax.nnx.Optimizer): The Flax NNX optimizer for parameter updates.\n", " images (jax.Array): A batch of images used for training.\n", " t (jax.Array): The timestep(s) at which the noise is added to each image.\n", " noise (jax.Array): The noise added to the images during training.\n", " sqrt_alpha_cumulative (jax.Array): Square root of cumulative alpha values from the diffusion schedule.\n", " sqrt_one_minus_alpha_cumulative (jax.Array): Square root of (1 - cumulative alpha values) from the diffusion schedule.\n", "\n", " Returns:\n", " jax.Array: The loss value after a single training step.\n", " \"\"\"\n", " # The loss and gradients using `flax.nnx.value_and_grad`.\n", " loss, grads = nnx.value_and_grad(loss_fn)(\n", " model, images, t, noise,\n", " sqrt_alpha_cumulative, sqrt_one_minus_alpha_cumulative\n", " )\n", "\n", " # Apply conservative gradient clipping.\n", " clip_threshold = 0.3\n", " grads = jax.tree_util.tree_map(\n", " lambda g: jnp.clip(g, -clip_threshold, clip_threshold),\n", " grads\n", " )\n", " # Update the parameters using the optimizer.\n", " optimizer.update(grads)\n", " # Return the loss after a single training step.\n", " return loss" ] }, { "cell_type": "markdown", "metadata": { "id": "4slhkQ6vI5tZ" }, "source": [ "### Model training configuration\n", "\n", "Next, we’ll define the model configuration and the training loop implementation.\n", "\n", "We need to set up:\n", "\n", "- Model hyperparameters\n", "- An optimizer with the learning rate schedule" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w4CwR-6ivIjS" }, "outputs": [], "source": [ "# Set the model and training hyperparameters.\n", "key = jax.random.PRNGKey(42) # PRNG seed for reproducibility.\n", "in_channels = 1\n", "out_channels = 1\n", "features = 64 # Number of features in the U-Net.\n", "num_steps = 1000\n", "num_epochs = 5000\n", "batch_size = 64\n", "learning_rate = 1e-4\n", "beta_start = 1e-4 # The starting value for beta (noise level schedule).\n", "beta_end = 0.02 # The end value for beta (noise level schedule).\n", "\n", "# Initialize model components.\n", "key, subkey = jax.random.split(key) # Split the JAX PRNG key for initialization.\n", "model = UNet(in_channels, out_channels, features, rngs=nnx.Rngs(default=subkey)) # Instantiate the U-Net.\n", "\n", "diffusion = DiffusionModel(\n", " model=model,\n", " num_steps=num_steps,\n", " beta_start=beta_start,\n", " beta_end=beta_end\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yLjb_t026uy3", "outputId": "2cda0980-ac98-4fd7-ee3a-02728a64f1f7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input shape: (1, 8, 8, 1)\n", "Output shape: (1, 8, 8, 1)\n", "\n", "Model initialized successfully\n" ] } ], "source": [ "# Learning rate schedule configuration.\n", "# Start with the warmup, then cosine decay.\n", "warmup_steps = 1000 # Number of steps.\n", "total_steps = num_epochs # Total number of training steps.\n", "\n", "# Multiple schedules using `optax.join_schedules`:\n", "# Linear transition (`optax.linear_schedule`) (for the warmup) and\n", "# and cosine learning rate decay (`optax.cosine_decay_schedule`).\n", "schedule_fn = optax.join_schedules(\n", " schedules=[\n", " optax.linear_schedule(\n", " init_value=0.0,\n", " end_value=learning_rate,\n", " transition_steps=warmup_steps\n", " ),\n", " optax.cosine_decay_schedule(\n", " init_value=learning_rate,\n", " decay_steps=total_steps - warmup_steps,\n", " alpha=0.01\n", " )\n", " ],\n", " boundaries=[warmup_steps] # Where the schedule transitions from the warmup to cosine decay.\n", ")\n", "\n", "# Optimizer configuration (AdamW) with gradient clipping.\n", "optimizer = nnx.ModelAndOptimizer(model, optax.chain(\n", " optax.clip_by_global_norm(0.5), # Gradient clipping.\n", " optax.adamw(\n", " learning_rate=schedule_fn,\n", " weight_decay=2e-5,\n", " b1=0.9,\n", " b2=0.999,\n", " eps=1e-8\n", " )\n", "))\n", "\n", "# Model initialization with dummy input.\n", "dummy_input = jnp.ones((1, 8, 8, 1))\n", "dummy_t = jnp.zeros((1,), dtype=jnp.int32)\n", "output = model(dummy_input, dummy_t)\n", "\n", "print(\"Input shape:\", dummy_input.shape)\n", "print(\"Output shape:\", output.shape)\n", "print(\"\\nModel initialized successfully\")" ] }, { "cell_type": "markdown", "metadata": { "id": "LrzTfkDPJm2X" }, "source": [ "### Implementing the training loop\n", "\n", "Finally, we need to implement the main training loop for the diffusion model with:\n", "\n", "- The progressive timestep sampling strategy\n", "- [Exponential moving average (EMA)](https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average) loss tracking\n", "- Adaptive noise generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZnQqHCAoVfi1", "outputId": "a105e2de-ba88-44d0-bad5-3a9a69e54826" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0, Loss: 1.2441\n", "Epoch 100, Loss: 1.1178\n", "Epoch 200, Loss: 0.8737\n", "Epoch 300, Loss: 0.7176\n", "Epoch 400, Loss: 0.6327\n", "Epoch 500, Loss: 0.5682\n", "Epoch 600, Loss: 0.5024\n", "Epoch 700, Loss: 0.4417\n", "Epoch 800, Loss: 0.3805\n", "Epoch 900, Loss: 0.3254\n", "Epoch 1000, Loss: 0.2803\n", "Epoch 1100, Loss: 0.2534\n", "Epoch 1200, Loss: 0.2339\n", "Epoch 1300, Loss: 0.2221\n", "Epoch 1400, Loss: 0.2141\n", "Epoch 1500, Loss: 0.2085\n", "Epoch 1600, Loss: 0.2046\n", "Epoch 1700, Loss: 0.1991\n", "Epoch 1800, Loss: 0.1951\n", "Epoch 1900, Loss: 0.1923\n", "Epoch 2000, Loss: 0.1919\n", "Epoch 2100, Loss: 0.1913\n", "Epoch 2200, Loss: 0.1888\n", "Epoch 2300, Loss: 0.1858\n", "Epoch 2400, Loss: 0.1861\n", "Epoch 2500, Loss: 0.1867\n", "Epoch 2600, Loss: 0.1855\n", "Epoch 2700, Loss: 0.1832\n", "Epoch 2800, Loss: 0.1834\n", "Epoch 2900, Loss: 0.1839\n", "Epoch 3000, Loss: 0.1844\n", "Epoch 3100, Loss: 0.1838\n", "Epoch 3200, Loss: 0.1816\n", "Epoch 3300, Loss: 0.1824\n", "Epoch 3400, Loss: 0.1815\n", "Epoch 3500, Loss: 0.1823\n", "Epoch 3600, Loss: 0.1834\n", "Epoch 3700, Loss: 0.1823\n", "Epoch 3800, Loss: 0.1811\n", "Epoch 3900, Loss: 0.1806\n", "Epoch 4000, Loss: 0.1804\n", "Epoch 4100, Loss: 0.1814\n", "Epoch 4200, Loss: 0.1802\n", "Epoch 4300, Loss: 0.1813\n", "Epoch 4400, Loss: 0.1799\n", "Epoch 4500, Loss: 0.1811\n", "Epoch 4600, Loss: 0.1820\n", "Epoch 4700, Loss: 0.1829\n", "Epoch 4800, Loss: 0.1828\n", "Epoch 4900, Loss: 0.1832\n", "Epoch 5000, Loss: 0.1827\n", "\n", "Training completed.\n" ] } ], "source": [ "# Initialize training metrics.\n", "losses: List[float] = [] # Store the EMA loss history.\n", "moving_avg_loss: Optional[float] = None # The EMA of the loss value.\n", "beta: float = 0.99 # The EMA decay factor for loss smoothing.\n", "\n", "for epoch in range(num_epochs + 1):\n", " # Split the JAX PRNG key for independent random operations.\n", " key, subkey1 = jax.random.split(key)\n", " key, subkey2 = jax.random.split(key)\n", "\n", " # Progressive timestep sampling - weights early steps more heavily as training progresses.\n", " # This helps model focus on fine details in later epochs while maintaining stability.\n", " progress = epoch / num_epochs\n", " t_weights = jnp.linspace(1.0, 0.1 * (1.0 - progress), num_steps)\n", " t = jax.random.choice(\n", " subkey1,\n", " num_steps,\n", " shape=(images_train.shape[0],),\n", " p=t_weights/t_weights.sum()\n", " )\n", "\n", " # Generate the Gaussian noise for the current batch of images.\n", " noise = jax.random.normal(subkey2, images_train.shape)\n", "\n", " # Execute the training step with noise prediction and parameter updates.\n", " loss = train_step(\n", " model, optimizer, images_train, t, noise,\n", " diffusion.sqrt_alpha_cumulative, diffusion.sqrt_one_minus_alpha_cumulative\n", " )\n", "\n", " # Update the exponential moving average (EMA) of the loss for smoother tracking.\n", " if moving_avg_loss is None:\n", " moving_avg_loss = loss\n", " else:\n", " moving_avg_loss = beta * moving_avg_loss + (1 - beta) * loss\n", "\n", " losses.append(moving_avg_loss)\n", "\n", " # Log the training progress at regular intervals.\n", " if epoch % 100 == 0:\n", " print(f\"Epoch {epoch}, Loss: {moving_avg_loss:.4f}\")\n", "\n", "print(\"\\nTraining completed.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "s-iqch4HKlBV" }, "source": [ "### Training loss visualization\n", "\n", "To visualize the training loss, we can use a logarithmic scale to better display the exponential decay of the loss values over time. This representation helps identify both early rapid improvements and later fine-tuning phases of the training process.\n", "\n", "Based on the results, the model appears to perform well, as the training loss falls over time during training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 487 }, "id": "1bjvWNCcbN24", "outputId": "457fd13f-377f-4021-ddc2-e36940b42550" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the training loss history with logarithmic scaling.\n", "plt.figure(figsize=(10, 5)) # Create figure with wide aspect ratio for clarity\n", "plt.plot(losses) # losses: List[float] - historical EMA loss values.\n", "plt.title('Training Loss Over Time')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.yscale('log') # Use the log scale to better visualize exponential decay.\n", "plt.grid(True) # Add a grid for easier value reading.\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "M2ql0KwYJLqn" }, "source": [ "## Visualization functions\n", "\n", "Here, we can create several utilities for:\n", "\n", "- Sample generation;\n", "- Forward/reverse process visualization; and\n", "- Training progress tracking." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 212 }, "id": "thP6DDl56iXM", "outputId": "68b47408-bbc1-40e8-fb90-47d43230984e" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "@partial(nnx.jit, static_argnums=(3,))\n", "def reverse_diffusion_batch(model: UNet,\n", " x: jax.Array,\n", " key: jax.Array,\n", " num_steps: int) -> jax.Array:\n", " \"\"\"Efficiently generates samples from the trained diffusion model using batched reverse diffusion (with `jax.lax.scan`).\n", "\n", " Args:\n", " model (UNet): The trained U-Net model for image generation.\n", " x (jax.Array): Noisy image (or pure noise).\n", " key (jax.Array): A JAX PRNG key for generating random noise.\n", " num_steps (int): Number of denoising steps in the reverse diffusion process.\n", "\n", " Returns:\n", " jax.Array: The denoised image after `num_steps` iterations.\n", " \"\"\"\n", " # Define the schedule for beta (noise level) and alpha (signal strength).\n", " beta = jnp.linspace(1e-4, 0.02, num_steps)\n", " alpha = 1 - beta\n", " alpha_cumulative = jnp.cumprod(alpha)\n", "\n", " def scan_step(carry: Tuple[jax.Array, jax.Array],\n", " step: int) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]:\n", " \"\"\"Applied a single denoising step.\"\"\"\n", " # Carry-over information.\n", " x, key = carry\n", "\n", " # Create a batch of timesteps for the current iteration.\n", " t_batch = jnp.full((x.shape[0],), step)\n", "\n", " # Predict the noise using the U-Net model.\n", " predicted = model(x, t_batch)\n", "\n", " # Generate noise for the current timestep (after the first \"pure noise\" step).\n", " key, subkey = jax.random.split(key) # Split the JAX PRNG key.\n", " noise = jnp.where(step > 0, jax.random.normal(subkey, x.shape), 0)\n", "\n", " # Update the image using denoising.\n", " x_new = 1 / jnp.sqrt(alpha[step]) * (\n", " x - (1 - alpha[step]) / jnp.sqrt(1 - alpha_cumulative[step]) * predicted\n", " ) + jnp.sqrt(beta[step]) * noise\n", "\n", " # Return the updated image and carry-over information.\n", " return (x_new, key), x_new\n", "\n", " steps = jnp.arange(num_steps - 1, -1, -1)\n", " (final_x, _), _ = jax.lax.scan(scan_step, (x, key), steps)\n", " return final_x\n", "\n", "def plot_samples(model: UNet,\n", " diffusion: DiffusionModel,\n", " images: jax.Array,\n", " key: jax.Array,\n", " num_samples: int = 9) -> None:\n", " \"\"\"Visualize original vs reconstructed images.\"\"\"\n", " indices = jax.random.randint(key, (num_samples,), 0, len(images))\n", " samples = images[indices]\n", "\n", " key, subkey = jax.random.split(key) # Split the JAX PRNG key.\n", " noisy = diffusion.forward(samples, jnp.full((num_samples,), diffusion.num_steps-1), subkey)[0]\n", "\n", " key, subkey = jax.random.split(key) # Split the JAX PRNG key.\n", " reconstructed = reverse_diffusion_batch(model, noisy, subkey, diffusion.num_steps)\n", "\n", " fig, axes = plt.subplots(2, num_samples, figsize=(8, 2))\n", "\n", " for i in range(num_samples):\n", " axes[0, i].imshow(samples[i, ..., 0], cmap='gray')\n", " axes[0, i].axis('off')\n", " axes[1, i].imshow(reconstructed[i, ..., 0], cmap='gray')\n", " axes[1, i].axis('off')\n", "\n", " axes[0, 0].set_title('Original')\n", " axes[1, 0].set_title('Reconstructed')\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "# Create a plot of original vs reconstructed images.\n", "key, subkey = jax.random.split(key) # Split the JAX PRNG key.\n", "plot_samples(model, diffusion, images_test, subkey)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 277 }, "id": "iqfjpn8havnI", "outputId": "756595c5-5380-46fd-f625-e21b4581381e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Full Forward and Reverse Process:\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "@partial(nnx.jit, static_argnums=(3,))\n", "def compute_forward_sequence(model: UNet,\n", " image: jax.Array,\n", " key: jax.Array,\n", " num_vis_steps: int) -> jax.Array:\n", " \"\"\"Computes the forward diffusion sequence efficiently.\"\"\"\n", " # Prepare image sequence and noise parameters.\n", " image_repeated = jnp.repeat(image[None], num_vis_steps, axis=0)\n", " timesteps = jnp.linspace(0, 999, num_vis_steps).astype(jnp.int32) # Assuming 1000 steps\n", " beta = jnp.linspace(1e-4, 0.02, 1000)\n", " alpha = 1 - beta\n", " alpha_cumulative = jnp.cumprod(alpha)\n", "\n", " # Generate and apply noise progressively.\n", " noise = jax.random.normal(key, image_repeated.shape)\n", " noisy_images = (\n", " jnp.sqrt(alpha_cumulative[timesteps])[:, None, None, None] * image_repeated +\n", " jnp.sqrt(1 - alpha_cumulative[timesteps])[:, None, None, None] * noise\n", " )\n", " return noisy_images\n", "\n", "@partial(nnx.jit, static_argnums=(3,))\n", "def compute_reverse_sequence(model: UNet,\n", " noisy_image: jax.Array,\n", " key: jax.Array,\n", " num_vis_steps: int) -> jax.Array:\n", " \"\"\"Compute reverse diffusion sequence efficiently.\"\"\"\n", " # Denoise completely and create interpolation sequence.\n", " final_image = reverse_diffusion_batch(model, noisy_image[None], key, 1000)[0]\n", " alphas = jnp.linspace(0, 1, num_vis_steps)\n", " reverse_sequence = (\n", " (1 - alphas)[:, None, None, None] * noisy_image +\n", " alphas[:, None, None, None] * final_image\n", " )\n", " return reverse_sequence\n", "\n", "def plot_forward_and_reverse(model: UNet,\n", " diffusion: DiffusionModel,\n", " image: jax.Array,\n", " key: jax.Array,\n", " num_steps: int = 9) -> None:\n", " \"\"\"Plot both forward and reverse diffusion processes with optimized computation.\"\"\"\n", " # Compute the forward/reverse transformations\n", " key1, key2 = jax.random.split(key)\n", " forward_sequence = compute_forward_sequence(model, image, key1, num_steps)\n", " reverse_sequence = compute_reverse_sequence(model, forward_sequence[-1], key2, num_steps)\n", "\n", " # Plot the grid.\n", " fig, (ax1, ax2) = plt.subplots(2, num_steps, figsize=(8, 2))\n", " fig.suptitle('Forward and reverse diffusion process', y=1.1)\n", "\n", " timesteps = jnp.linspace(0, diffusion.num_steps-1, num_steps).astype(jnp.int32)\n", "\n", " # Visualize forward diffusion.\n", " for i in range(num_steps):\n", " ax1[i].imshow(forward_sequence[i, ..., 0], cmap='binary', interpolation='gaussian')\n", " ax1[i].axis('off')\n", " ax1[i].set_title(f't={timesteps[i]}')\n", " ax1[0].set_ylabel('Forward', rotation=90, labelpad=10)\n", "\n", " # Visualize reverse diffusion.\n", " for i in range(num_steps):\n", " ax2[i].imshow(reverse_sequence[i, ..., 0], cmap='binary', interpolation='gaussian')\n", " ax2[i].axis('off')\n", " ax2[i].set_title(f't={timesteps[num_steps-1-i]}')\n", " ax2[0].set_ylabel('Reverse', rotation=90, labelpad=10)\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "# Create a plot.\n", "key, subkey = jax.random.split(key)\n", "print(\"\\nFull forward and reverse diffusion processes:\")\n", "plot_forward_and_reverse(model, diffusion, images_test[0], subkey)" ] }, { "cell_type": "markdown", "metadata": { "id": "o43bRWpiM6Mt" }, "source": [ "## Summary\n", "\n", "In this tutorial, we implemented a simple diffusion model using JAX and Flax NNX, and trained it with Optax and Flax NNX. The model consisted of the U-Net model architecture with attention mechanisms, the training used Flax’s NNX JIT compilation (`flax.nnx.jit`), and we also learned how to visualize the diffusion process." ] } ], "metadata": { "accelerator": "TPU", "colab": { "gpuType": "V28", "machine_shape": "hm", "provenance": [] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 0 }