{ "cells": [ { "cell_type": "markdown", "id": "396f7d27-5d90-4bf3-a77f-f9b1c9f36bef", "metadata": {}, "source": [ "# Diffusers integration with fastai\n", "By Tanishq Abraham\n", "\n", "This notebook demonstration a simple-to-use fastai integration with the [HuggingFace Diffusers](https://github.com/huggingface/diffusers/) library.\n" ] }, { "cell_type": "markdown", "id": "92135e76-e9d3-4561-a2c8-b689278c938e", "metadata": {}, "source": [ "## Imports\n", "\n", "Here are all of our imports. Mostly the fastai library and the Diffusers library." ] }, { "cell_type": "code", "execution_count": 1, "id": "df56594f-4607-4208-ba8b-c15c8fed54e5", "metadata": {}, "outputs": [], "source": [ "import diffusers\n", "from fastai.vision.all import *\n", "from fastai.vision.gan import *\n", "from copy import deepcopy" ] }, { "cell_type": "markdown", "id": "6d5974cb-13e1-4e3b-b356-cb8f15c82e2f", "metadata": {}, "source": [ "## Data loading\n", "\n", "Let's load our data. We'll work with the famous MNIST dataset." ] }, { "cell_type": "code", "execution_count": 2, "id": "3e44f2e4-f3c7-46c2-b3a1-cb36c6f6c7c4", "metadata": { "id": "okr8kAqSvIUD" }, "outputs": [], "source": [ "bs = 128 # batch size\n", "size = 32 # image size" ] }, { "cell_type": "code", "execution_count": 3, "id": "f6de6207-d249-4974-bd42-7b598cdfebbb", "metadata": { "id": "B8nsEdKXvKim" }, "outputs": [], "source": [ "path = untar_data(URLs.CIFAR)" ] }, { "cell_type": "markdown", "id": "bb1fb7d4-3fe5-4f23-b249-6df4955f5364", "metadata": {}, "source": [ "We use the highly flexible DataBlock API in fastai to create our DataLoaders.\n", "\n", "Note that we start with pure noise, generated with the obviously named `generate_noise` function." ] }, { "cell_type": "code", "execution_count": 8, "id": "266399d7-9490-4ee6-bb1d-591a3f9f8aef", "metadata": { "id": "DQF8UdVvvP4R" }, "outputs": [], "source": [ "dblock = DataBlock(blocks = (TransformBlock, ImageBlock),\n", " get_x = partial(generate_noise, size=(3, size, size)),\n", " get_items = get_image_files,\n", " splitter = IndexSplitter(list(range(len(get_image_files(path))))[-bs:]),\n", " item_tfms=Resize(size), \n", " batch_tfms = Normalize.from_stats(torch.tensor([0.5]), torch.tensor([0.5])))" ] }, { "cell_type": "code", "execution_count": 11, "id": "b7fcacd4-6dbb-45cf-8ece-c01531cf8b70", "metadata": { "id": "L6iHHHFRvRPx" }, "outputs": [], "source": [ "dls = dblock.dataloaders(path, path=path, bs=bs)" ] }, { "cell_type": "code", "execution_count": 12, "id": "564918af-745e-4ae4-93fd-a8ada58f3bd3", "metadata": { "id": "ANw0OdjzvRvY" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "dls.show_batch(max_n=8)" ] }, { "cell_type": "markdown", "id": "e07e5a05-3340-4187-81f3-4e4ebadb5ca0", "metadata": {}, "source": [ "A key aspect of the diffusion models is that our model has the same size input and output:" ] }, { "cell_type": "code", "execution_count": 13, "id": "1ed870d9-b80a-4727-a8c0-93a4de6223d8", "metadata": { "id": "XlII4jxmwUnS" }, "outputs": [], "source": [ "xb, yb = next(iter(dls.train))\n", "assert xb.shape == yb.shape" ] }, { "cell_type": "markdown", "id": "fd4d06e7-c04c-44bf-aaf5-86dd922de5b4", "metadata": {}, "source": [ "## Diffusers Callback\n", "\n", "This callback is based on my previous DDPM callback with some additional modifications.\n", "\n", "\n", "The basic idea is we set a sampler/scheduler which we use to add noise to the image to train our noise-conditioned denoising network. Then during sampling, we create a `Pipeline` which couples our model and scheduler to allow us to sample. We can exchange out different samplers and sampler parameters with the `set_sampler` and `set_sampling_params` functions in the callback." ] }, { "cell_type": "code", "execution_count": 14, "id": "1dc7794f-49c1-449d-835f-cd7be4a2d8c6", "metadata": {}, "outputs": [], "source": [ "available_samplers = {\n", " 'DDPM': diffusers.DDPMScheduler,\n", " 'DDIM': diffusers.DDIMScheduler,\n", " 'Karras': diffusers.KarrasVeScheduler,\n", " 'LMS': diffusers.LMSDiscreteScheduler,\n", " 'PNDM': diffusers.PNDMScheduler,\n", " 'VESDE': diffusers.ScoreSdeVeScheduler,\n", "# 'VPSDE': diffusers.ScoreSdeVpScheduler\n", "}\n", "\n", "corresponding_pipelines = {\n", " diffusers.DDPMScheduler: diffusers.DDPMPipeline,\n", " diffusers.DDIMScheduler: diffusers.DDIMPipeline,\n", " diffusers.KarrasVeScheduler: diffusers.KarrasVePipeline,\n", " diffusers.PNDMScheduler: diffusers.PNDMPipeline,\n", " diffusers.ScoreSdeVeScheduler: diffusers.ScoreSdeVePipeline,\n", "\n", "}" ] }, { "cell_type": "code", "execution_count": 15, "id": "49b60a37-e0f8-4381-9602-3ea08d4f1623", "metadata": {}, "outputs": [], "source": [ "class Diffusers(Callback):\n", " def __init__(self, sampler='DDPM', tensor_type=TensorImage, **kwargs):\n", " self.tensor_type=tensor_type \n", " self.set_sampler(sampler, **kwargs)\n", " \n", " def before_batch_training(self):\n", " eps = self.tensor_type(self.xb[0]) # noise, x_T\n", " x0 = self.yb[0] # original images, x_0\n", " batch_size = x0.shape[0]\n", " t = torch.randint(0, self.sampler.config.num_train_timesteps, (batch_size,), device=x0.device, dtype=torch.long) # select random timesteps\n", " xt = self.tensor_type(self.sampler.add_noise(x0, eps, t)) # noisify the images\n", " self.learn.xb = (xt, t) # input to our model is noisy image and timestep\n", " self.learn.yb = (eps,) # ground truth is the noise \n", " \n", " def before_batch_sampling(self):\n", " self.pipeline = self.create_pipeline()\n", " if not hasattr(self, 'sampling_params'): self.set_sampling_params()\n", " images = self.pipeline(batch_size=self.xb[0].shape[0], output_type=\"numpy\", **self.sampling_params).images\n", " xt = self.tensor_type(images)\n", " self.learn.pred = (xt,)\n", " raise CancelBatchException\n", " \n", " def before_batch(self):\n", " if not hasattr(self, 'gather_preds'): self.before_batch_training()\n", " else: self.before_batch_sampling()\n", " \n", " def set_sampler(self, sampler_str, **kwargs):\n", " self.sampler = available_samplers[sampler_str](**kwargs)\n", " \n", " \n", " def create_pipeline(self):\n", " assert type(self.model) == DiffusersModel, \"Need to use DiffusersModel for Pipeline to work\"\n", " return corresponding_pipelines[type(self.sampler)](self.model.m, self.sampler)\n", " \n", " \n", " def set_sampling_params(self, **kwargs):\n", " self.sampling_params = kwargs" ] }, { "cell_type": "markdown", "id": "d55f993f-ce2f-4bb3-b5e2-9334d39a7fd7", "metadata": {}, "source": [ "Since Diffusers Pipelines expect models of their own type, we need to use it. But it returns a special dataclass output so we need to get it out and return it directly so fastai knows what to do with it. Hence this `DiffusersModel` class:" ] }, { "cell_type": "code", "execution_count": 16, "id": "7a203b0b-e0b2-4f18-ad76-bae21ae10531", "metadata": {}, "outputs": [], "source": [ "class DiffusersModel(nn.Module):\n", " def __init__(self, **kwargs):\n", " super().__init__()\n", " self.m = diffusers.UNet2DModel(**kwargs)\n", " \n", " def forward(self, x, t):\n", " return self.m(x,t).sample" ] }, { "cell_type": "code", "execution_count": 17, "id": "2dfa525d-82fc-4849-a622-f1be578aa1ce", "metadata": {}, "outputs": [], "source": [ "model = DiffusersModel(sample_size=32)" ] }, { "cell_type": "markdown", "id": "ea218e6d-a37b-473f-a890-b97199d45f3e", "metadata": {}, "source": [ "Now we can create a Learner as such:" ] }, { "cell_type": "code", "execution_count": 18, "id": "ab1f3da5-0649-4730-b5b9-8069d084f400", "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, model, cbs=[Diffusers('DDIM', num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02)], loss_func=nn.MSELoss())" ] }, { "cell_type": "markdown", "id": "2c6cd9a6-d2dd-48dc-aa8d-0f6b34f57309", "metadata": {}, "source": [ "And use awesome fastai features like LR finder:" ] }, { "cell_type": "code", "execution_count": 19, "id": "b6a09457-ec77-4955-8281-8aad615bbf13", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "SuggestedLRs(valley=3.0199516913853586e-05)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.lr_find()" ] }, { "cell_type": "markdown", "id": "ebaf03da-78f7-4e69-bff3-52a2537588aa", "metadata": {}, "source": [ "And train with one-cycle LR schedule:" ] }, { "cell_type": "code", "execution_count": 20, "id": "9cbf4d53-d5a0-4a16-8f9e-88971957ac6d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.0579530.04504003:22
10.0431190.03325203:22
20.0413410.03697503:22
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(3, 3e-5)" ] }, { "cell_type": "code", "execution_count": 21, "id": "8773118d-9a07-4623-906c-6f87d89f5caa", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_loss() " ] }, { "cell_type": "markdown", "id": "f1529252-472e-417a-8425-884eed136e9a", "metadata": {}, "source": [ "Let's save our model:" ] }, { "cell_type": "code", "execution_count": 16, "id": "84f13379-d033-4c80-84ea-7747aa6e41ed", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Path('/home/tmabraham/.fastai/data/mnist_png/models/diffusers-mnist.pth')" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.save('diffusers-mnist')" ] }, { "cell_type": "code", "execution_count": 17, "id": "797edaff-3ada-42db-8e5c-4dc1cdb523f4", "metadata": {}, "outputs": [], "source": [ "learn = learn.load('diffusers-mnist')" ] }, { "cell_type": "markdown", "id": "5e6fcd74-bf41-44f2-ad4e-d354442da9be", "metadata": {}, "source": [ "## Sample generation\n", "\n", "Thanks to the fastai API, sample generation is as simple as this:" ] }, { "cell_type": "code", "execution_count": 22, "id": "2e452066-fd98-4b0f-b01c-46af0f66d36f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "070af088aa8049b680164ff0e11e75f8", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/50 [00:00" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "preds[0][6].show()" ] }, { "cell_type": "markdown", "id": "3377d9af-9103-4b56-9f8b-bbea12a6af2e", "metadata": {}, "source": [ "Awesome, we got a simple MNIST digit!\n", "\n", "We can also try out different samplers with different parameters:" ] }, { "cell_type": "code", "execution_count": 26, "id": "ab0c6a74-4969-4edf-8b20-71df2e814108", "metadata": {}, "outputs": [], "source": [ "learn.diffusers.set_sampler('DDPM')" ] }, { "cell_type": "code", "execution_count": 27, "id": "70243503-ffa7-4bcb-84f2-2fdf26411f17", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "41f5e2b5c22f4146a56d4c48458f06f8", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1000 [00:00" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "preds[0][0].show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" } }, "nbformat": 4, "nbformat_minor": 5 }