{ "cells": [ { "cell_type": "markdown", "id": "765216b4", "metadata": {}, "source": [ "# Exploration of stable diffusion spaces\n", "## (prompt embedding space and random latent space)\n", "by <a href=\"https://insana.net\">Giuseppe Insana</a>, December 2022\n" ] }, { "cell_type": "markdown", "id": "c78c82ca", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "markdown", "id": "c94575a9", "metadata": {}, "source": [ "# Contents\n", "- Code\n", " - [Helper sd functions](#SDHelper)\n", " - [Stable Diffusion Explorer functions](#SDX)\n", "- [Start pipeline](#StartPipe)\n", "- [Usage examples](#Examples)\n", " - [Simple txt2img](#Simple)\n", " - [Interpolation between text prompts](#Interpolation)\n", " - [Walking beyond the correct point produced by a prompt](#GoingBeyond)\n", " - [Circular walk through the diffusion noise space with two seeds](#CircleWalk)\n", " - [Spherical spiral walk through the diffusion noise space with three seeds](#SpiralWalk)\n", " - [Multiple samples for same prompt](#Sampling)\n", " - [Multiple variations from the same initial latent](#Variations)\n", " - [Mix of two variation latents](#VariationMix)\n", " - [Interpolation between two variation latents](#VariationWalk)\n", "- [What happens when you let your kids write prompts and explore the space](#KidsExperiments)" ] }, { "cell_type": "code", "execution_count": null, "id": "7b88b1dd", "metadata": {}, "outputs": [], "source": [ "# general\n", "import os\n", "import sys\n", "from math import pi, ceil, sqrt\n", "from tqdm.notebook import trange, tqdm\n", "\n", "# sd\n", "import torch\n", "from torch import Tensor\n", "import safetensors\n", "import transformers\n", "from diffusers import (\n", " StableDiffusionPipeline,\n", " EulerDiscreteScheduler,\n", " EulerAncestralDiscreteScheduler,\n", " DDIMScheduler,\n", " DPMSolverMultistepScheduler,\n", " PNDMScheduler,\n", " LMSDiscreteScheduler,\n", ")\n", "\n", "# image\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "from IPython.display import Image as IImage # for gifs\n", "from mpl_toolkits.axes_grid1 import ImageGrid # for image grid\n", "\n", "\n", "# setup\n", "%matplotlib inline\n", "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"garbage_collection_threshold:0.6, max_split_size_mb:516\"\n", "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "# torch.set_grad_enabled(False)\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "markdown", "id": "97081886", "metadata": {}, "source": [ "## Helper functions" ] }, { "cell_type": "code", "execution_count": null, "id": "4f7c6498", "metadata": {}, "outputs": [], "source": [ "# image display helper functions\n", "def display_images(images, prompt=\"\", subtitles=[]):\n", " \"\"\"\n", " simple image display via matplotlib\n", " prompt can be specified, seed is taken from global variable\n", " subtitles can be a list with as many elements as the images, to specify different labels for the images\n", " \"\"\"\n", " if len(images) > 1:\n", " fig, axs = plt.subplots(1, max(2, len(images)), figsize=(12, 6))\n", " fig.tight_layout()\n", " plt.subplots_adjust(wspace=0, hspace=0)\n", " plt.margins(x=0, y=0)\n", " plt.axis(\"off\")\n", " for c, img in enumerate(images):\n", " axs[c].tick_params(length=0, labelbottom=False, labelleft=False)\n", " axs[c].imshow(img)\n", " axs[c].set_title(subtitles[c] if len(subtitles) else \"\")\n", " fig.suptitle(\"{}\\n{}\".format(prompt, seed))\n", " else:\n", " if prompt:\n", " fig = plt.figure()\n", " fig.suptitle(\"{}\\n{}\".format(prompt, seed))\n", " plt.margins(x=0, y=0)\n", " plt.axis(\"off\")\n", " plt.imshow(images[0])\n", " else:\n", " display(images[0])\n", "\n", "\n", "def display_images_grid(images, prompt=\"\", subtitles=[], grid_size=None, scale=2):\n", " \"\"\"\n", " simple image display in a grid via matplotlib\n", " if grid_size is not specified, the nearest square grid of appropriate size will be used\n", " \"\"\"\n", " if not grid_size:\n", " grid_size = ceil(sqrt(len(images)))\n", "\n", " fig = plt.figure(figsize=(grid_size * scale, grid_size * scale))\n", " grid = ImageGrid(\n", " fig,\n", " nrows_ncols=(grid_size, grid_size), # creates grid of axes\n", " axes_pad=0, # 0.1 # pad between axes in inch.\n", " )\n", " for ax, im in zip(grid, images):\n", " ax.tick_params(length=0, labelbottom=False, labelleft=False)\n", " ax.imshow(im)\n", " if prompt:\n", " fig.suptitle(\"{}\\n{}\".format(prompt, seed))\n", " plt.show()\n", "\n", "\n", "def export_as_gif(filename, images, frames_per_second=10, rubber_band=False):\n", " \"\"\"\n", " export a list of images as a gif, optionally with rubber band repetition\n", " the gif will be both saved to file and displayed in notebook\n", " \"\"\"\n", " my_images = images.copy()\n", " if rubber_band:\n", " my_images += images[2:-1][::-1]\n", " my_images[0].save(\n", " filename,\n", " save_all=True,\n", " append_images=images[1:],\n", " duration=1000 // frames_per_second,\n", " loop=0,\n", " )\n", " display(IImage(filename))" ] }, { "cell_type": "markdown", "id": "2a5d5b18", "metadata": {}, "source": [ "# SDHelper\n", "## Diffusion helper functions\n", "[back to ToC](#Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "1c0c5801", "metadata": {}, "outputs": [], "source": [ "def torch_md_linspace(start: Tensor, stop: Tensor, num: int):\n", " \"\"\"\n", " Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.\n", " Replicates the multi-dimensional behaviour of numpy.linspace for PyTorch tensors.\n", "\n", " e.g.:\n", " start = torch.tensor([[0, 1], [2, 3]])\n", " stop = torch.tensor([[10, 10], [10, 10]])\n", " steps = 3\n", " \"\"\"\n", " # create a tensor of 'num' steps from 0 to 1\n", " steps = torch.arange(num, dtype=torch.float16, device=start.device) / (num - 1)\n", "\n", " for i in range(start.ndim):\n", " steps = steps.unsqueeze(-1)\n", "\n", " # the output starts at 'start' and increments until 'stop' in each dimension\n", " out = start[None] + steps * (stop - start)[None]\n", "\n", " return out\n", "\n", "\n", "# test:\n", "# start = torch.tensor([[0, 1], [2, 3]])\n", "# stop = torch.tensor([[10, 10], [10, 10]])\n", "# steps = 3\n", "# np.isclose(torch_md_linspace(start, stop, num=steps), np.linspace(start, stop, num=steps)).all() # True\n", "\n", "\n", "def eprint(*myargs, **kwargs):\n", " \"\"\"\n", " print to stderr, useful for error messages and to not clobber stdout\n", " \"\"\"\n", " print(*myargs, file=sys.stderr, **kwargs)\n", "\n", "\n", "def text_enc(prompts, maxlen=None, device=\"cuda\"):\n", " \"\"\"\n", " A function to take a textual prompt and convert it into embeddings\n", " example: text_enc([\"A dog wearing a white hat\"])\n", " \"\"\"\n", " if maxlen is None:\n", " maxlen = pipe.tokenizer.model_max_length\n", " inp = pipe.tokenizer(\n", " prompts,\n", " padding=\"max_length\",\n", " max_length=maxlen,\n", " truncation=True,\n", " return_tensors=\"pt\",\n", " ).input_ids.to(device)\n", " return pipe.text_encoder(inp)[0].half()\n", "\n", "\n", "def latents_to_pil(latents):\n", " \"\"\"\n", " Function to convert latents to images\n", " \"\"\"\n", " latents = (1 / 0.18215) * latents\n", " with torch.no_grad():\n", " image = pipe.vae.decode(latents).sample\n", " image = (image / 2 + 0.5).clamp(0, 1)\n", " image = image.detach().cpu().permute(0, 2, 3, 1).numpy()\n", " images = (image * 255).round().astype(\"uint8\")\n", " pil_images = [Image.fromarray(image) for image in images]\n", " return pil_images\n", "\n", "\n", "def sample_space(\n", " latents, emb, g, steps, save_int=False, return_int=False, device=\"cuda\"\n", "):\n", " \"\"\"\n", " return latent representation\n", " optionally save or return intermediate states\n", " \"\"\"\n", " if save_int and not os.path.exists(f\"./steps\"):\n", " os.mkdir(f\"./steps\")\n", "\n", " intermediates = []\n", "\n", " # Setting number of steps in scheduler\n", " scheduler.set_timesteps(steps)\n", "\n", " # Adding noise to the latents\n", " latents = latents.to(device).half() * scheduler.init_noise_sigma\n", "\n", " # Iterating through defined steps\n", " for i, ts in enumerate(tqdm(scheduler.timesteps, desc=\"iterations\", leave=False)):\n", " # We need to scale the i/p latents to match the variance\n", " inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)\n", "\n", " # Predicting noise residual using U-Net\n", " with torch.no_grad():\n", " u, t = pipe.unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)\n", "\n", " # Performing Guidance\n", " pred = u + g * (t - u)\n", "\n", " # Conditioning the latents\n", " latents = scheduler.step(pred, ts, latents).prev_sample\n", "\n", " # Saving intermediate images\n", " if save_int or return_int:\n", " intermediate = latents_to_pil(latents)[0]\n", " if save_int:\n", " intermediate.save(f\"steps/{i:04}.jpeg\")\n", " if return_int:\n", " intermediates.append(intermediate)\n", " if return_int:\n", " return intermediates[0:-1]\n", " else:\n", " return latents_to_pil(latents)\n", "\n", "\n", "def save_images(images, path=\"images\"):\n", " \"\"\"\n", " save a list of images to a path, creating the directory if it does not exist\n", " \"\"\"\n", " if not os.path.exists(f\"./{path}\"):\n", " os.mkdir(f\"./{path}\")\n", " for index, image in enumerate(images):\n", " image.save(f\"./{path}/{index:04}.jpg\")\n", "\n", "\n", "def slerp(val, low, high):\n", " \"\"\"\n", " spherical interpolation\n", " from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3\n", " compatible with image variation used by stable-diffusion-webui\n", " \"\"\"\n", " low_norm = low / torch.norm(low, dim=1, keepdim=True)\n", " high_norm = high / torch.norm(high, dim=1, keepdim=True)\n", " dot = (low_norm * high_norm).sum(1)\n", "\n", " if dot.mean() > 0.9995:\n", " return low * val + high * (1 - val)\n", "\n", " omega = torch.acos(dot)\n", " so = torch.sin(omega)\n", " res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (\n", " torch.sin(val * omega) / so\n", " ).unsqueeze(1) * high\n", " return res" ] }, { "cell_type": "markdown", "id": "ad88743a", "metadata": {}, "source": [ "# SDX\n", "## Space exploration functions\n", "[back to ToC](#Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "f8ec19fe", "metadata": {}, "outputs": [], "source": [ "def _initial_latents(seed=None, width=768, height=768, skip_random=0, verbose=False):\n", " \"\"\"\n", " return initial latents given an optional seed, optionally skipping a series of them\n", " \"\"\"\n", " # Setting the seed\n", " if seed is not None:\n", " if verbose:\n", " eprint(f\"using random seed {seed}\")\n", " torch.manual_seed(seed)\n", " if skip_random:\n", " if verbose:\n", " eprint(f\"skipping {skip_random} random\")\n", " # skip a series of random for latents (we want Nth image in a series generated from an initial seed)\n", " _ = torch.randn(\n", " (skip_random, pipe.unet.config.in_channels, height // 8, width // 8)\n", " )\n", "\n", " initial_latents = torch.randn(\n", " (pipe.unet.config.in_channels, height // 8, width // 8)\n", " )\n", "\n", " if verbose:\n", " eprint(\"initial latent, {}\".format(initial_latents.sum()))\n", "\n", " return initial_latents\n", "\n", "\n", "def _enclat2img(\n", " encodings=[],\n", " initial_latents=[],\n", " multiple_latents=[],\n", " g=7.5,\n", " steps=10,\n", " neg_prompt=None,\n", " device=device,\n", " verbose=False,\n", "):\n", " images = []\n", "\n", " # adding an unconditional prompt helps in the generation process\n", " if neg_prompt is None:\n", " uncond = text_enc([\"\"] * 1, encodings.shape[1], device=device)\n", " elif type(neg_prompt) != str:\n", " eprint(f\"ERROR: neg_prompt must be a string, not '{type(neg_prompt)}'\")\n", " return\n", " else:\n", " if verbose:\n", " eprint(f\"using negative prompt '{neg_prompt}'\")\n", " uncond = text_enc([neg_prompt] * 1, encodings.shape[1], device=device)\n", "\n", " for _, encoding in enumerate(tqdm(encodings, desc=\"prompt\", leave=False)):\n", " emb = torch.cat([uncond, encoding.reshape_as(uncond)])\n", " if len(multiple_latents):\n", " for _, latents in enumerate(\n", " tqdm(multiple_latents, desc=\"latent\", leave=False)\n", " ):\n", " images += sample_space(\n", " torch.unsqueeze(latents, dim=0), emb, g, steps, device=device\n", " )\n", " else:\n", " images += sample_space(\n", " torch.unsqueeze(initial_latents, dim=0), emb, g, steps, device=device\n", " )\n", "\n", " return images\n", "\n", "\n", "def prompt2img(\n", " prompts=[],\n", " neg_prompt=None,\n", " n_samples=1,\n", " g=10,\n", " steps=30,\n", " width=768,\n", " height=768,\n", " seed=None,\n", " skip_random=0,\n", " device=device,\n", " verbose=False,\n", "):\n", " \"\"\"\n", " Return a list of images equal to the number of prompts (optionally multiplied by n_samples)\n", " prompt: list of strings or a single string\n", " n_samples: how many samples to produce for each given prompt\n", " neg_prompt: negative conditioning string\n", " g: classifier free guidance\n", " steps: iteration steps for the diffusion process\n", " width, height: dimensions for the resulting image\n", " seed: initialize random generator; use None to get next available random\n", " skip_random: how many random latents to discard before producing image\n", " \"\"\"\n", "\n", " if type(prompts) == str:\n", " prompts = [prompts]\n", " if n_samples < 1:\n", " eprint(\"ERROR: n_samples must be positive!\")\n", " return\n", "\n", " initial_latents = _initial_latents(\n", " seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose\n", " )\n", " multiple_latents = []\n", "\n", " if n_samples > 1:\n", " sample_latents = [\n", " torch.unsqueeze(initial_latents, dim=0)\n", " ] # the first one generated from the seed\n", " for _ in range(n_samples - 1): # add as many more as requested\n", " sample_latent = torch.randn(\n", " (pipe.unet.config.in_channels, height // 8, width // 8)\n", " )\n", " if verbose:\n", " eprint(\"adding latent, {}\".format(sample_latent.sum()))\n", " sample_latents.append(torch.unsqueeze(sample_latent, dim=0))\n", " multiple_latents = torch.cat(sample_latents)\n", "\n", " encodings = text_enc(prompts, device=device)\n", "\n", " return _enclat2img(\n", " encodings=encodings,\n", " initial_latents=initial_latents,\n", " multiple_latents=multiple_latents,\n", " g=g,\n", " steps=steps,\n", " neg_prompt=neg_prompt,\n", " device=device,\n", " verbose=verbose,\n", " )\n", "\n", "\n", "def beyond_prompt(\n", " prompt=\"\",\n", " neg_prompt=None,\n", " walk_steps=1,\n", " walk_stepsize=0.02,\n", " g=10,\n", " steps=30,\n", " width=768,\n", " height=768,\n", " seed=None,\n", " skip_random=0,\n", " device=device,\n", " verbose=False,\n", "):\n", " \"\"\"\n", " Walk forward in prompt embedding space by walk_stepsize to produce walk_steps + 1 images,\n", " each a step (of optionally specified size) forward from the previous\n", " prompt: string\n", " neg_prompt: negative conditioning string\n", " walk_steps: number of steps to walk forward\n", " walk_stepsize: how much to walk forward in prompt embedding space at each step\n", " g: classifier free guidance\n", " steps: iteration steps for the diffusion process\n", " width, height: dimensions for the resulting image\n", " seed: initialize random generator; use None to get next available random\n", " skip_random: how many random latents to discard before producing image\n", " \"\"\"\n", " if type(prompt) != str:\n", " eprint(f\"ERROR: prompt must be a string, not '{type(prompt)}'\")\n", " return\n", "\n", " initial_latents = _initial_latents(\n", " seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose\n", " )\n", " multiple_latents = []\n", "\n", " encoding = text_enc([prompt], device=device)\n", "\n", " delta = torch.ones_like(encoding) * walk_stepsize\n", " new_encodings = []\n", " for step_index in range(0, walk_steps + 1):\n", " new_encodings.append(encoding)\n", " encoding = encoding + delta # nudge prompt embedding\n", " if verbose:\n", " print(\"nudged prompt by {}\".format(walk_stepsize * step_index))\n", " encodings = torch.cat(new_encodings)\n", "\n", " return _enclat2img(\n", " encodings=encodings,\n", " initial_latents=initial_latents,\n", " multiple_latents=multiple_latents,\n", " g=g,\n", " steps=steps,\n", " neg_prompt=neg_prompt,\n", " device=device,\n", " verbose=verbose,\n", " )\n", "\n", "\n", "def interpolate_prompts(\n", " prompts=[],\n", " neg_prompt=None,\n", " interpolate_steps=1,\n", " g=10,\n", " steps=30,\n", " width=768,\n", " height=768,\n", " seed=None,\n", " skip_random=0,\n", " device=device,\n", " verbose=False,\n", "):\n", " \"\"\"\n", " Given two prompts, interpolate among the two embeddings and produce a number of images equal to interpolate_steps\n", " Return a list of images exploring the embedding space between first and second prompt\n", " prompt: list of strings or a single string\n", " interpolate_steps: number of images to produce between the one from the first and the one from the second prompt\n", " neg_prompt: negative conditioning string\n", " g: classifier free guidance\n", " steps: iteration steps for the diffusion process\n", " width, height: dimensions for the resulting image\n", " seed: initialize random generator; use None to get next available random\n", " skip_random: how many random latents to discard before producing image\n", " Extension: interpolate four prompts and create square grid\n", " \"\"\"\n", "\n", " if type(prompts) != list or len(prompts) != 2:\n", " eprint(\"ERROR: you need to pass a list of 2 prompts!\")\n", " return\n", " if interpolate_steps < 1:\n", " eprint(\"ERROR: interpolate steps must be positive!\")\n", " return\n", "\n", " initial_latents = _initial_latents(\n", " seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose\n", " )\n", " multiple_latents = []\n", "\n", " encodings = text_enc(prompts, device=device)\n", " encodings = torch_md_linspace(encodings[0], encodings[1], interpolate_steps + 2)\n", "\n", " return _enclat2img(\n", " encodings=encodings,\n", " initial_latents=initial_latents,\n", " multiple_latents=multiple_latents,\n", " g=g,\n", " steps=steps,\n", " neg_prompt=neg_prompt,\n", " device=device,\n", " verbose=verbose,\n", " )\n", "\n", "\n", "def revolve_prompt(\n", " prompt=\"\",\n", " neg_prompt=None,\n", " walk_type=\"circle\",\n", " walk_steps=1,\n", " g=10,\n", " steps=30,\n", " width=768,\n", " height=768,\n", " seed=None,\n", " seed2=None,\n", " seed3=None,\n", " skip_random=0,\n", " device=device,\n", " verbose=False,\n", "):\n", " \"\"\"\n", " Walk around a prompt in prompt in latent space to produce walk_steps images\n", " For a \"circle\" walk, it uses two latents which can be determined specifying seed and seed2\n", " In case of \"spiral\", the walk will be along a spherical spiral using three latents, optionally determined by seed, seed2 and seed3\n", " Note: the image that would singularly generated from seed would appear as the first one in the set\n", " and the one from seed2 would be found at around 1/4th of the total image count\n", " in case of spiral walk, the image from seed1 is the first one,\n", " the image from seed2 is found at around 1/4th of the total image count\n", " and the image from seed3 (approximate) would be found at around 1/3rd of the total image count\n", " prompt: string\n", " neg_prompt: negative conditioning string\n", " walk_type: \"circle\" (default) or \"spiral\"\n", " walk_steps: how many steps to take in total (equals number of returned images)\n", " g: classifier free guidance\n", " steps: iteration steps for the diffusion process\n", " width, height: dimensions for the resulting image\n", " seed: initialize random generator; use None to get next available random\n", " seed2: optional seed to determine circular walk\n", " seed3: optional seed to further determine spiral walk\n", " skip_random: how many random latents to discard before producing image\n", " \"\"\"\n", " if type(prompt) != str:\n", " eprint(f\"ERROR: prompt must be a string, not '{type(prompt)}'\")\n", " return\n", "\n", " # initialize alternate latents\n", " if seed2 is not None:\n", " torch.manual_seed(seed2)\n", " variation_latents = torch.randn(\n", " (pipe.unet.config.in_channels, height // 8, width // 8)\n", " )\n", " if walk_type == \"spiral\" and seed3 is not None:\n", " torch.manual_seed(seed3)\n", " alt_variation_latents = torch.randn(\n", " (pipe.unet.config.in_channels, height // 8, width // 8)\n", " )\n", " if walk_type == \"circle\" and seed3 is not None:\n", " eprint(f\"NOTICE: seed3 not used for 'circle' walk_type'\")\n", "\n", " initial_latents = _initial_latents(\n", " seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose\n", " )\n", " multiple_latents = []\n", "\n", " # if no alternate seed were specified, we'll use the next random after the one used for initial latents\n", " if seed2 is None:\n", " variation_latents = torch.randn(\n", " (pipe.unet.config.in_channels, height // 8, width // 8)\n", " )\n", " if walk_type == \"spiral\" and seed3 is None:\n", " alt_variation_latents = torch.randn(\n", " (pipe.unet.config.in_channels, height // 8, width // 8)\n", " )\n", "\n", " encodings = text_enc([prompt], device=device)\n", "\n", " if walk_type == \"circle\": # around a prompt and two random latents in a circle\n", " ##circular roundwalk:\n", " # stepspace = torch.linspace(0, 2, walk_steps + 1)[0:-1] * pi\n", " # walk_scale_x = torch.cos(stepspace)\n", " # walk_scale_y = torch.sin(stepspace)\n", " # (accelerates and decelerates, not very smooth in interpolation)\n", "\n", " # spreadout circular walk:\n", " spread_factor = 30 # the lower this, the more the points will be pushed away from 0, 90, 180, 270 bearings and concentrated towards 45, 135..\n", " stepspace = torch.linspace(0, 2, walk_steps + 1)[0:-1]\n", " stepspace -= torch.sin((stepspace + 0.25) * pi * 4) / spread_factor\n", " stepspace *= pi\n", " walk_scale_x = torch.cos(stepspace)\n", " walk_scale_y = torch.sin(stepspace)\n", "\n", " noise_x = torch.tensordot(walk_scale_x, initial_latents, dims=0)\n", " noise_y = torch.tensordot(walk_scale_y, variation_latents, dims=0)\n", " multiple_latents = torch.add(noise_x, noise_y)\n", " elif walk_type == \"spiral\": # spherical spiral walk with three random latents\n", " c = 2 # use 4 for double the amount of turns in the spherical spiral walk\n", "\n", " # circular spherical spiral walk:\n", " # theta = torch.linspace(1, 2, walk_steps + 1)[0:-1] * pi\n", " # theta2 = torch.linspace(1, 0, walk_steps + 1)[0:-1] * pi\n", " # walk_scale_x1 = torch.sin(theta) * torch.cos(c * theta)\n", " # walk_scale_x2 = torch.sin(theta2) * torch.cos(c * theta2)\n", " # walk_scale_x = torch.cat([walk_scale_x1, walk_scale_x2])\n", " # walk_scale_y1 = torch.sin(theta) * torch.sin(c * theta)\n", " # walk_scale_y2 = torch.sin(theta2) * torch.sin(c * theta2)\n", " # walk_scale_y = torch.cat([walk_scale_y1, walk_scale_y2])\n", " # walk_scale_z = torch.cos(torch.linspace(0, 2 * pi, walk_steps * 2 + 1)[0:-1])\n", "\n", " # spread spherical spiral walk:\n", " spread_factor = 30\n", " theta = torch.linspace(1, 2, walk_steps // 2 + 1)[0:-1]\n", " theta -= torch.sin((theta + 0.25) * pi * 4) / spread_factor\n", " theta *= pi\n", " theta2 = torch.linspace(1, 0, walk_steps // 2 + 1)[0:-1]\n", " theta2 -= torch.sin((theta2 + 0.25) * pi * 4) / spread_factor\n", " theta2 *= pi\n", " walk_scale_x1 = torch.sin(theta) * torch.cos(c * theta)\n", " walk_scale_x2 = torch.sin(theta2) * torch.cos(c * theta2)\n", " walk_scale_x = torch.cat([walk_scale_x1, walk_scale_x2])\n", " walk_scale_y1 = torch.sin(theta) * torch.sin(c * theta)\n", " walk_scale_y2 = torch.sin(theta2) * torch.sin(c * theta2)\n", " walk_scale_y = torch.cat([walk_scale_y1, walk_scale_y2])\n", " stepspace = torch.linspace(0, 2, walk_steps + 1)[0:-1]\n", " stepspace -= torch.sin((stepspace + 0.25) * pi * 4) / spread_factor\n", " stepspace *= pi\n", " walk_scale_z = torch.cos(stepspace)\n", "\n", " noise_z = torch.tensordot(walk_scale_z, initial_latents, dims=0)\n", " noise_x = torch.tensordot(walk_scale_x, variation_latents, dims=0)\n", " noise_y = torch.tensordot(walk_scale_y, alt_variation_latents, dims=0)\n", "\n", " multiple_latents = torch.add(torch.add(noise_x, noise_y), noise_z)\n", " else:\n", " eprint(f\"ERROR: unknown walk_type '{walk_type}'\")\n", " return\n", "\n", " return _enclat2img(\n", " encodings=encodings,\n", " initial_latents=initial_latents,\n", " multiple_latents=multiple_latents,\n", " g=g,\n", " steps=steps,\n", " neg_prompt=neg_prompt,\n", " device=device,\n", " verbose=verbose,\n", " )\n", "\n", "\n", "def prompt_variations(\n", " prompt=\"\",\n", " neg_prompt=None,\n", " variations=1,\n", " variation_strength=0.1,\n", " g=10,\n", " steps=30,\n", " width=768,\n", " height=768,\n", " seed=None,\n", " skip_random=0,\n", " device=device,\n", " verbose=False,\n", "):\n", " \"\"\"\n", " Return an image followed by a list of variations, each one of specified variation_strength from the first\n", " prompt: list of strings or a single string\n", " variations: how many variations to return after the normal image\n", " variation_strength: how much to mix the initial latent and the variant ones (hence how different from initial image)\n", " neg_prompt: negative conditioning string\n", " g: classifier free guidance\n", " steps: iteration steps for the diffusion process\n", " width, height: dimensions for the resulting image\n", " seed: initialize random generator; use None to get next available random\n", " skip_random: how many random latents to discard before producing image\n", " \"\"\"\n", " if type(prompt) != str:\n", " eprint(f\"ERROR: prompt must be a string, not '{type(prompt)}'\")\n", " return\n", " if variations < 1:\n", " eprint(\"ERROR: number of variations must be positive!\")\n", " return\n", "\n", " initial_latents = _initial_latents(\n", " seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose\n", " )\n", " multiple_latents = []\n", "\n", " sample_latents = [\n", " torch.unsqueeze(initial_latents, dim=0)\n", " ] # the first one generated from the seed\n", " for _ in range(variations): # add as many as requested\n", " sample_latent = torch.randn(\n", " (pipe.unet.config.in_channels, height // 8, width // 8)\n", " )\n", " sample_latent = slerp(variation_strength, initial_latents, sample_latent)\n", " if verbose:\n", " eprint(\"adding variation latent: {}\".format(sample_latent.sum()))\n", " sample_latents.append(torch.unsqueeze(sample_latent, dim=0))\n", " multiple_latents = torch.cat(sample_latents)\n", "\n", " encodings = text_enc([prompt], device=device)\n", "\n", " return _enclat2img(\n", " encodings=encodings,\n", " initial_latents=initial_latents,\n", " multiple_latents=multiple_latents,\n", " g=g,\n", " steps=steps,\n", " neg_prompt=neg_prompt,\n", " device=device,\n", " verbose=verbose,\n", " )\n", "\n", "\n", "def variate_prompt(\n", " prompt=\"\",\n", " neg_prompt=None,\n", " variation_strength=0,\n", " var_steps=0,\n", " g=10,\n", " steps=30,\n", " width=768,\n", " height=768,\n", " seed=None,\n", " seed2=None,\n", " skip_random=0,\n", " device=device,\n", " verbose=False,\n", "):\n", " \"\"\"\n", " create an image from the mix (in desired amount) of two random initial latents (optionally specified by seed and seed2)\n", " alternatively, if var_steps is specified, it will interpolate between the two random latents, returning var_steps images\n", " (i.e. like trying a linearly increasing set of values of variation_strength, from 0 to 1)\n", " prompt: string\n", " neg_prompt: negative conditioning string\n", " variation_strength: how much to mix the initial latent and the variant one\n", " var_steps: how many steps to go from an initial latent and a variation latent\n", " g: classifier free guidance\n", " steps: iteration steps for the diffusion process\n", " width, height: dimensions for the resulting image\n", " seed: initialize random generator; use None to get next available random\n", " seed2: optional seed to determine circular walk\n", " skip_random: how many random latents to discard before producing image\n", " \"\"\"\n", " if type(prompt) != str:\n", " eprint(f\"ERROR: prompt must be a string, not '{type(prompt)}'\")\n", " return\n", "\n", " if var_steps > 0 and variation_strength > 0:\n", " eprint(\"NOTICE: variation_strength will be ignored when var_steps specified\")\n", "\n", " if var_steps < 0 or variation_strength < 0:\n", " eprint(\"ERROR: do not use negative values\")\n", " return\n", "\n", " if var_steps <= 0 and variation_strength <= 0:\n", " eprint(\"ERROR: nothing to do. specify either var_steps or variation_strength\")\n", " return\n", "\n", " initial_latents = _initial_latents(\n", " seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose\n", " )\n", "\n", " # initialize variation latent\n", " if seed2 is not None:\n", " torch.manual_seed(seed2)\n", " variation_latents = torch.randn(\n", " (pipe.unet.config.in_channels, height // 8, width // 8)\n", " )\n", "\n", " if var_steps > 0: # gradually interpolate between two random latents\n", " var_latents = []\n", " stepspace = torch.linspace(0, 1, var_steps) # include last point\n", " for stepvalue in stepspace:\n", " if verbose:\n", " eprint(\"generating variation at {}\".format(stepvalue)) # debug\n", " var_latents.append(\n", " torch.unsqueeze(\n", " slerp(stepvalue, initial_latents, variation_latents), dim=0\n", " )\n", " )\n", " multiple_latents = torch.cat(var_latents)\n", " else:\n", " multiple_latents = []\n", "\n", " # if no alternate seed were specified, we'll use the next random after the one used for initial latents\n", " if seed2 is None:\n", " variation_latents = torch.randn(\n", " (pipe.unet.config.in_channels, height // 8, width // 8)\n", " )\n", "\n", " initial_latents = slerp(variation_strength, initial_latents, variation_latents)\n", "\n", " encodings = text_enc([prompt], device=device)\n", "\n", " return _enclat2img(\n", " encodings=encodings,\n", " initial_latents=initial_latents,\n", " multiple_latents=multiple_latents,\n", " g=g,\n", " steps=steps,\n", " neg_prompt=neg_prompt,\n", " device=device,\n", " verbose=verbose,\n", " )" ] }, { "cell_type": "markdown", "id": "d461f9d1", "metadata": {}, "source": [ "# StartPipe\n", "## Choose Model and Scheduler\n", "[back to ToC](#Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "d6dd30f8", "metadata": {}, "outputs": [], "source": [ "# available:\n", "# DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DDPMScheduler\n", "model_id = \"stabilityai/stable-diffusion-2-1-base\"\n", "scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder=\"scheduler\")" ] }, { "cell_type": "markdown", "id": "d5021128", "metadata": {}, "source": [ "## Prepare pipeline" ] }, { "cell_type": "code", "execution_count": null, "id": "fa1bfa2a", "metadata": {}, "outputs": [], "source": [ "pipe = StableDiffusionPipeline.from_pretrained(\n", " model_id, scheduler=scheduler, torch_dtype=torch.float16\n", ")\n", "pipe.safety_checker = None\n", "pipe.requires_safety_checker = False\n", "pipe = pipe.to(device)" ] }, { "cell_type": "markdown", "id": "7d09a4c8", "metadata": {}, "source": [ "# Examples" ] }, { "cell_type": "markdown", "id": "6a0ac6e5", "metadata": {}, "source": [ "# Simple\n", "## Simple txt2img\n", "[back to ToC](#Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "36af43bc", "metadata": {}, "outputs": [], "source": [ "prompt = \"hyper detailed digital painting of scenery, shibuya tokyo, post-apocalypse, ruins, rust, sky, skyscraper, abandoned, blue sky, broken window, building, cloud, crane machine, outdoors, overgrown, pillar, sunset\"\n", "seed = 24509723452345\n", "prompt2img(prompt, seed=seed)[0]" ] }, { "cell_type": "markdown", "id": "8e42d7f3", "metadata": {}, "source": [ "<img src=\"ruins.jpg\" />" ] }, { "cell_type": "code", "execution_count": null, "id": "791b2da4", "metadata": {}, "outputs": [], "source": [ "seed = 64\n", "prompt = \"photo of a tiger demon\"\n", "images = prompt2img(prompt, width=512, height=768, g=7.5, steps=10, seed=seed)\n", "display_images(images)" ] }, { "cell_type": "markdown", "id": "c3d6e47d", "metadata": {}, "source": [ "<img src=\"tigerdemon.jpg\" />" ] }, { "cell_type": "markdown", "id": "6aa39f12", "metadata": {}, "source": [ "# With a negative prompt" ] }, { "cell_type": "code", "execution_count": null, "id": "e2859125", "metadata": {}, "outputs": [], "source": [ "# Testing negative prompts\n", "images = [None, None]\n", "seed = 66\n", "prompt1 = \"photo of a castle in the middle of a forest with trees and bushes, detailed vegetation\"\n", "prompt2 = \"green, leaves, summer, spring\"\n", "images[0] = prompt2img(prompt1, neg_prompt=None, seed=seed)[0]\n", "images[1] = prompt2img(prompt1, neg_prompt=prompt2, seed=seed)[0]\n", "# side by side comparison\n", "display_images(images, prompt=prompt1, subtitles=[\"\", \"- \" + prompt2])" ] }, { "cell_type": "markdown", "id": "56a1ff8c", "metadata": {}, "source": [ "<img src=\"castle_negprompt.jpg\" />" ] }, { "cell_type": "code", "execution_count": null, "id": "22658d5d", "metadata": {}, "outputs": [], "source": [ "# remove the prompt from itself for unexpected results:\n", "prompt = \"watercolour of a tiger\"\n", "prompt2img(prompt, neg_prompt=prompt, seed=1023458422345243)[0]" ] }, { "cell_type": "markdown", "id": "5aa7f2b8", "metadata": {}, "source": [ "<img src=\"tiger_no_tiger.jpg\" />" ] }, { "cell_type": "markdown", "id": "c8168e3d", "metadata": {}, "source": [ "## Multiple prompts can be specified" ] }, { "cell_type": "code", "execution_count": null, "id": "12f895b9", "metadata": {}, "outputs": [], "source": [ "images = prompt2img([\"cute puppy\", \"cute kitten\"], width=512, height=512, seed=123)\n", "display_images(images)" ] }, { "cell_type": "markdown", "id": "fc878e2a", "metadata": {}, "source": [ "<img src=\"puppy_n_kitten.jpg\" />" ] }, { "cell_type": "markdown", "id": "375c3dce", "metadata": {}, "source": [ "# Interpolation\n", "[back to ToC](#Contents)\n", "## Interpolation between text prompts\n", "(and what lies between two different points in prompt encoding space)" ] }, { "cell_type": "code", "execution_count": null, "id": "a2efb5c8", "metadata": {}, "outputs": [], "source": [ "prompt1 = \"a photo of a boy running on the beach\"\n", "prompt2 = \"a photo of a cadillac on a highway\"\n", "seed = 749109862\n", "images = interpolate_prompts(\n", " [prompt1, prompt2], width=512, height=512, seed=seed, interpolate_steps=5\n", ")\n", "display_images(images, prompt1 + \"<->\" + prompt2)" ] }, { "cell_type": "markdown", "id": "b6b2c111", "metadata": {}, "source": [ "<img src=\"boy_to_car.jpg\" />" ] }, { "cell_type": "markdown", "id": "3656ec4f", "metadata": {}, "source": [ "# GoingBeyond\n", "[back to ToC](#Contents)\n", "## Forward walk in embedding latent space from a prompt\n", "(what lies beyond a point in encoding space)" ] }, { "cell_type": "code", "execution_count": null, "id": "bc7ab142", "metadata": {}, "outputs": [], "source": [ "prompt = \"girl with green hair eating rice noodles\"\n", "images = beyond_prompt(\n", " prompt,\n", " neg_prompt=\"malformed\",\n", " width=512,\n", " height=512,\n", " seed=4093245,\n", " walk_steps=8,\n", " walk_stepsize=0.015,\n", ")\n", "display_images_grid(images, prompt)" ] }, { "cell_type": "markdown", "id": "1f4c40cc", "metadata": {}, "source": [ "<img src=\"green_haired_girl.jpg\" />" ] }, { "cell_type": "markdown", "id": "326ebe57", "metadata": {}, "source": [ "# CircleWalk\n", "[back to ToC](#Contents)\n", "## Circular walks around a prompt in noise latent space\n", "(optionally guided using two seeds)" ] }, { "cell_type": "code", "execution_count": null, "id": "fc4df5a7", "metadata": {}, "outputs": [], "source": [ "prompt = \"An oil masterpiece painting of horses in a field next to a farm in Normandy\"\n", "seed = 132432456352\n", "seed2 = 42\n", "walk_steps = 48\n", "images = revolve_prompt(prompt, walk_steps=walk_steps, seed=seed, seed2=seed2)\n", "display_images_grid(images, prompt)\n", "\n", "# save_images(images, path=f\"horses_r{walk_steps}\")\n", "export_as_gif(\n", " f\"horses_r{walk_steps}.gif\", images, frames_per_second=2, rubber_band=False\n", ")" ] }, { "cell_type": "markdown", "id": "b2b6057a", "metadata": {}, "source": [ "<img src=\"horses_r48.jpg\" />" ] }, { "cell_type": "markdown", "id": "b46267ea", "metadata": {}, "source": [ "<video width=\"768\" height=\"768\" src=\"horses_r48.mp4\" controls />" ] }, { "cell_type": "markdown", "id": "cb02cf58", "metadata": {}, "source": [ "# SpiralWalk\n", "[back to ToC](#Contents)\n", "## Spherical spiral walks around a prompt in noise latent space\n", "(optionally guided using three seeds)" ] }, { "cell_type": "code", "execution_count": null, "id": "17bcb583", "metadata": {}, "outputs": [], "source": [ "prompt = \"hires photo of shark, underwater scenery with tropical fishes and coral sea floor, caustics\"\n", "seed = 100\n", "seed = 132432456352\n", "seed2 = 42\n", "seed3 = 897234234\n", "walk_steps = 24\n", "images = revolve_prompt(\n", " prompt,\n", " walk_type=\"spiral\",\n", " walk_steps=walk_steps,\n", " width=512,\n", " height=512,\n", " seed=seed,\n", " seed2=seed2,\n", " seed3=seed3,\n", ")\n", "display_images_grid(images, prompt)" ] }, { "cell_type": "markdown", "id": "2a01890a", "metadata": {}, "source": [ "<img src=\"shark_coral_reef_s24.jpg\" />" ] }, { "cell_type": "markdown", "id": "bf90f11e", "metadata": {}, "source": [ "# Sampling\n", "[back to ToC](#Contents)\n", "## Multiple samples for the same prompt\n", "- sample several different latents for the same prompt\n", "- use skip_random to then directly go to one of the generated images (not working for EulerA)" ] }, { "cell_type": "code", "execution_count": null, "id": "7ac23a10", "metadata": {}, "outputs": [], "source": [ "seed = 10234584620131114\n", "prompt = \"a watercolour painting of Cambridge Jesus Green\"\n", "images = prompt2img(\n", " prompt, width=512, height=512, seed=seed, n_samples=9\n", ")\n", "display_images_grid(images)\n", "\n", "# recreate the 5th variation:\n", "images = prompt2img(\n", " prompt, width=512, height=512, seed=seed, skip_random=4\n", ")\n", "images[0]" ] }, { "cell_type": "markdown", "id": "398cd968", "metadata": {}, "source": [ "<img src=\"cambridge_watercolours.jpg\" />\n", "<img src=\"cambridge_watercolours_var5.jpg\" />" ] }, { "cell_type": "markdown", "id": "d9c4d939", "metadata": {}, "source": [ "# Variations\n", "[back to ToC](#Contents)\n", "## Multiple variations from the same initial latent" ] }, { "cell_type": "code", "execution_count": null, "id": "fc6b8cea", "metadata": {}, "outputs": [], "source": [ "prompt = \"An oil painting of horses in a field next to a farm in Normandy\"\n", "seed = 1022134\n", "images = prompt_variations(prompt, variations=8, variation_strength=0.1, seed=seed)\n", "display_images_grid(images)" ] }, { "cell_type": "markdown", "id": "1964b0ab", "metadata": {}, "source": [ "<img src=\"horses_variations.jpg\" />" ] }, { "cell_type": "markdown", "id": "693251f5", "metadata": {}, "source": [ "# VariationMix\n", "[back to ToC](#Contents)\n", "## Variation using second seed and specified strength" ] }, { "cell_type": "code", "execution_count": null, "id": "10aca36d", "metadata": {}, "outputs": [], "source": [ "seed = 1023458422345243\n", "seed2 = 35634563\n", "prompt = \"movie cover of Schwarzenegger as the Terminator riding a Vespa\"\n", "images = prompt2img(prompts=prompt, neg_prompt=None, width=512, height=512, seed=seed)\n", "images += variate_prompt(\n", " prompt=prompt,\n", " width=512,\n", " height=512,\n", " seed=seed,\n", " seed2=seed2,\n", " variation_strength=0.25,\n", ")\n", "display_images(images, prompt=prompt, subtitles=[\"\", f\"var{seed2} 25%\"])" ] }, { "cell_type": "markdown", "id": "1780b649", "metadata": {}, "source": [ "<img src=\"vespa_terminator.jpg\" />" ] }, { "cell_type": "markdown", "id": "cc2fdfaf", "metadata": {}, "source": [ "# VariationWalk\n", "[back to ToC](#Contents)\n", "## Gradual interpolation between two random latents" ] }, { "cell_type": "code", "execution_count": null, "id": "d33ccc67", "metadata": {}, "outputs": [], "source": [ "seed = 1023458422345243\n", "seed2 = 35634563\n", "prompt = \"movie cover of Schwarzenegger as the Terminator riding a Vespa\"\n", "images = variate_prompt(\n", " prompt=prompt, width=512, height=512, seed=seed, seed2=seed2, var_steps=6\n", ")\n", "display_images(images)" ] }, { "cell_type": "markdown", "id": "3b9ff52e", "metadata": {}, "source": [ "<img src=\"vespa_terminator_varwalk.jpg\" />" ] }, { "cell_type": "markdown", "id": "de43fd36", "metadata": {}, "source": [ "# KidsExperiments\n", "[back to ToC](#Contents)\n", "## And this is what happens when you let your kids write sd prompts 😉:" ] }, { "cell_type": "code", "execution_count": null, "id": "d8399147", "metadata": {}, "outputs": [], "source": [ "seed = 12345\n", "images = interpolate_prompts(\n", " [\"cow cat pawlephant\", \"muleskin beetledog\"],\n", " seed=seed,\n", " interpolate_steps=1,\n", " g=7.5,\n", " steps=10,\n", ")\n", "display_images(images)" ] }, { "cell_type": "markdown", "id": "5cd7a90d", "metadata": {}, "source": [ "<img src=\"cowcat_to_beetledog.jpg\" />" ] }, { "cell_type": "code", "execution_count": null, "id": "2b800f52", "metadata": {}, "outputs": [], "source": [ "prompt2img(\"a cow covered in oreo cookies\", seed=3534534, g=7.5, steps=10)[0]" ] }, { "cell_type": "markdown", "id": "976c82ad", "metadata": {}, "source": [ "<img src=\"cow_covered_in_oreos.jpg\" />" ] }, { "cell_type": "code", "execution_count": null, "id": "9e502969", "metadata": {}, "outputs": [], "source": [ "seed = 3534534\n", "seed2 = 3534533\n", "walk_steps = 360\n", "images = variate_prompt(\n", " \"a cow covered in oreo cookies\", seed=seed, seed2=seed2, var_steps=walk_steps\n", ")\n", "save_images(images, path=\"cow_covered_in_oreos\")" ] }, { "cell_type": "markdown", "id": "4df16b49", "metadata": {}, "source": [ "## ...which when converted to mp4 becomes:" ] }, { "cell_type": "markdown", "id": "629fc5bc", "metadata": {}, "source": [ "<video width=\"768\" height=\"768\" src=\"cow_covered_in_oreos.mp4\" controls />" ] }, { "cell_type": "code", "execution_count": null, "id": "e4d3651a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "@webio": { "lastCommId": null, "lastKernelId": null }, "kernelspec": { "display_name": "tensorflow_env", "language": "python", "name": "tensorflow_env" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }