{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "765216b4",
   "metadata": {},
   "source": [
    "# Exploration of stable diffusion spaces\n",
    "## (prompt embedding space and random latent space)\n",
    "by <a href=\"https://insana.net\">Giuseppe Insana</a>, December 2022\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c78c82ca",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c94575a9",
   "metadata": {},
   "source": [
    "# Contents\n",
    "- Code\n",
    "    - [Helper sd functions](#SDHelper)\n",
    "    - [Stable Diffusion Explorer functions](#SDX)\n",
    "- [Start pipeline](#StartPipe)\n",
    "- [Usage examples](#Examples)\n",
    "    - [Simple txt2img](#Simple)\n",
    "    - [Interpolation between text prompts](#Interpolation)\n",
    "    - [Walking beyond the correct point produced by a prompt](#GoingBeyond)\n",
    "    - [Circular walk through the diffusion noise space with two seeds](#CircleWalk)\n",
    "    - [Spherical spiral walk through the diffusion noise space with three seeds](#SpiralWalk)\n",
    "    - [Multiple samples for same prompt](#Sampling)\n",
    "    - [Multiple variations from the same initial latent](#Variations)\n",
    "    - [Mix of two variation latents](#VariationMix)\n",
    "    - [Interpolation between two variation latents](#VariationWalk)\n",
    "- [What happens when you let your kids write prompts and explore the space](#KidsExperiments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b88b1dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# general\n",
    "import os\n",
    "import sys\n",
    "from math import pi, ceil, sqrt\n",
    "from tqdm.notebook import trange, tqdm\n",
    "\n",
    "# sd\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import safetensors\n",
    "import transformers\n",
    "from diffusers import (\n",
    "    StableDiffusionPipeline,\n",
    "    EulerDiscreteScheduler,\n",
    "    EulerAncestralDiscreteScheduler,\n",
    "    DDIMScheduler,\n",
    "    DPMSolverMultistepScheduler,\n",
    "    PNDMScheduler,\n",
    "    LMSDiscreteScheduler,\n",
    ")\n",
    "\n",
    "# image\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import Image as IImage  # for gifs\n",
    "from mpl_toolkits.axes_grid1 import ImageGrid  # for image grid\n",
    "\n",
    "\n",
    "# setup\n",
    "%matplotlib inline\n",
    "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"garbage_collection_threshold:0.6, max_split_size_mb:516\"\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "# torch.set_grad_enabled(False)\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97081886",
   "metadata": {},
   "source": [
    "## Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f7c6498",
   "metadata": {},
   "outputs": [],
   "source": [
    "# image display helper functions\n",
    "def display_images(images, prompt=\"\", subtitles=[]):\n",
    "    \"\"\"\n",
    "    simple image display via matplotlib\n",
    "    prompt can be specified, seed is taken from global variable\n",
    "    subtitles can be a list with as many elements as the images, to specify different labels for the images\n",
    "    \"\"\"\n",
    "    if len(images) > 1:\n",
    "        fig, axs = plt.subplots(1, max(2, len(images)), figsize=(12, 6))\n",
    "        fig.tight_layout()\n",
    "        plt.subplots_adjust(wspace=0, hspace=0)\n",
    "        plt.margins(x=0, y=0)\n",
    "        plt.axis(\"off\")\n",
    "        for c, img in enumerate(images):\n",
    "            axs[c].tick_params(length=0, labelbottom=False, labelleft=False)\n",
    "            axs[c].imshow(img)\n",
    "            axs[c].set_title(subtitles[c] if len(subtitles) else \"\")\n",
    "        fig.suptitle(\"{}\\n{}\".format(prompt, seed))\n",
    "    else:\n",
    "        if prompt:\n",
    "            fig = plt.figure()\n",
    "            fig.suptitle(\"{}\\n{}\".format(prompt, seed))\n",
    "            plt.margins(x=0, y=0)\n",
    "            plt.axis(\"off\")\n",
    "            plt.imshow(images[0])\n",
    "        else:\n",
    "            display(images[0])\n",
    "\n",
    "\n",
    "def display_images_grid(images, prompt=\"\", subtitles=[], grid_size=None, scale=2):\n",
    "    \"\"\"\n",
    "    simple image display in a grid via matplotlib\n",
    "    if grid_size is not specified, the nearest square grid of appropriate size will be used\n",
    "    \"\"\"\n",
    "    if not grid_size:\n",
    "        grid_size = ceil(sqrt(len(images)))\n",
    "\n",
    "    fig = plt.figure(figsize=(grid_size * scale, grid_size * scale))\n",
    "    grid = ImageGrid(\n",
    "        fig,\n",
    "        nrows_ncols=(grid_size, grid_size),  # creates grid of axes\n",
    "        axes_pad=0,  # 0.1  # pad between axes in inch.\n",
    "    )\n",
    "    for ax, im in zip(grid, images):\n",
    "        ax.tick_params(length=0, labelbottom=False, labelleft=False)\n",
    "        ax.imshow(im)\n",
    "    if prompt:\n",
    "        fig.suptitle(\"{}\\n{}\".format(prompt, seed))\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def export_as_gif(filename, images, frames_per_second=10, rubber_band=False):\n",
    "    \"\"\"\n",
    "    export a list of images as a gif, optionally with rubber band repetition\n",
    "    the gif will be both saved to file and displayed in notebook\n",
    "    \"\"\"\n",
    "    my_images = images.copy()\n",
    "    if rubber_band:\n",
    "        my_images += images[2:-1][::-1]\n",
    "    my_images[0].save(\n",
    "        filename,\n",
    "        save_all=True,\n",
    "        append_images=images[1:],\n",
    "        duration=1000 // frames_per_second,\n",
    "        loop=0,\n",
    "    )\n",
    "    display(IImage(filename))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a5d5b18",
   "metadata": {},
   "source": [
    "# SDHelper\n",
    "## Diffusion helper functions\n",
    "[back to ToC](#Contents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c0c5801",
   "metadata": {},
   "outputs": [],
   "source": [
    "def torch_md_linspace(start: Tensor, stop: Tensor, num: int):\n",
    "    \"\"\"\n",
    "    Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.\n",
    "    Replicates the multi-dimensional behaviour of numpy.linspace for PyTorch tensors.\n",
    "\n",
    "    e.g.:\n",
    "    start = torch.tensor([[0, 1], [2, 3]])\n",
    "    stop = torch.tensor([[10, 10], [10, 10]])\n",
    "    steps = 3\n",
    "    \"\"\"\n",
    "    # create a tensor of 'num' steps from 0 to 1\n",
    "    steps = torch.arange(num, dtype=torch.float16, device=start.device) / (num - 1)\n",
    "\n",
    "    for i in range(start.ndim):\n",
    "        steps = steps.unsqueeze(-1)\n",
    "\n",
    "    # the output starts at 'start' and increments until 'stop' in each dimension\n",
    "    out = start[None] + steps * (stop - start)[None]\n",
    "\n",
    "    return out\n",
    "\n",
    "\n",
    "# test:\n",
    "# start = torch.tensor([[0, 1], [2, 3]])\n",
    "# stop = torch.tensor([[10, 10], [10, 10]])\n",
    "# steps = 3\n",
    "# np.isclose(torch_md_linspace(start, stop, num=steps), np.linspace(start, stop, num=steps)).all() # True\n",
    "\n",
    "\n",
    "def eprint(*myargs, **kwargs):\n",
    "    \"\"\"\n",
    "    print to stderr, useful for error messages and to not clobber stdout\n",
    "    \"\"\"\n",
    "    print(*myargs, file=sys.stderr, **kwargs)\n",
    "\n",
    "\n",
    "def text_enc(prompts, maxlen=None, device=\"cuda\"):\n",
    "    \"\"\"\n",
    "    A function to take a textual prompt and convert it into embeddings\n",
    "    example: text_enc([\"A dog wearing a white hat\"])\n",
    "    \"\"\"\n",
    "    if maxlen is None:\n",
    "        maxlen = pipe.tokenizer.model_max_length\n",
    "    inp = pipe.tokenizer(\n",
    "        prompts,\n",
    "        padding=\"max_length\",\n",
    "        max_length=maxlen,\n",
    "        truncation=True,\n",
    "        return_tensors=\"pt\",\n",
    "    ).input_ids.to(device)\n",
    "    return pipe.text_encoder(inp)[0].half()\n",
    "\n",
    "\n",
    "def latents_to_pil(latents):\n",
    "    \"\"\"\n",
    "    Function to convert latents to images\n",
    "    \"\"\"\n",
    "    latents = (1 / 0.18215) * latents\n",
    "    with torch.no_grad():\n",
    "        image = pipe.vae.decode(latents).sample\n",
    "    image = (image / 2 + 0.5).clamp(0, 1)\n",
    "    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()\n",
    "    images = (image * 255).round().astype(\"uint8\")\n",
    "    pil_images = [Image.fromarray(image) for image in images]\n",
    "    return pil_images\n",
    "\n",
    "\n",
    "def sample_space(\n",
    "    latents, emb, g, steps, save_int=False, return_int=False, device=\"cuda\"\n",
    "):\n",
    "    \"\"\"\n",
    "    return latent representation\n",
    "    optionally save or return intermediate states\n",
    "    \"\"\"\n",
    "    if save_int and not os.path.exists(f\"./steps\"):\n",
    "        os.mkdir(f\"./steps\")\n",
    "\n",
    "    intermediates = []\n",
    "\n",
    "    # Setting number of steps in scheduler\n",
    "    scheduler.set_timesteps(steps)\n",
    "\n",
    "    # Adding noise to the latents\n",
    "    latents = latents.to(device).half() * scheduler.init_noise_sigma\n",
    "\n",
    "    # Iterating through defined steps\n",
    "    for i, ts in enumerate(tqdm(scheduler.timesteps, desc=\"iterations\", leave=False)):\n",
    "        # We need to scale the i/p latents to match the variance\n",
    "        inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)\n",
    "\n",
    "        # Predicting noise residual using U-Net\n",
    "        with torch.no_grad():\n",
    "            u, t = pipe.unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)\n",
    "\n",
    "        # Performing Guidance\n",
    "        pred = u + g * (t - u)\n",
    "\n",
    "        # Conditioning the latents\n",
    "        latents = scheduler.step(pred, ts, latents).prev_sample\n",
    "\n",
    "        # Saving intermediate images\n",
    "        if save_int or return_int:\n",
    "            intermediate = latents_to_pil(latents)[0]\n",
    "            if save_int:\n",
    "                intermediate.save(f\"steps/{i:04}.jpeg\")\n",
    "            if return_int:\n",
    "                intermediates.append(intermediate)\n",
    "    if return_int:\n",
    "        return intermediates[0:-1]\n",
    "    else:\n",
    "        return latents_to_pil(latents)\n",
    "\n",
    "\n",
    "def save_images(images, path=\"images\"):\n",
    "    \"\"\"\n",
    "    save a list of images to a path, creating the directory if it does not exist\n",
    "    \"\"\"\n",
    "    if not os.path.exists(f\"./{path}\"):\n",
    "        os.mkdir(f\"./{path}\")\n",
    "    for index, image in enumerate(images):\n",
    "        image.save(f\"./{path}/{index:04}.jpg\")\n",
    "\n",
    "\n",
    "def slerp(val, low, high):\n",
    "    \"\"\"\n",
    "    spherical interpolation\n",
    "    from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3\n",
    "    compatible with image variation used by stable-diffusion-webui\n",
    "    \"\"\"\n",
    "    low_norm = low / torch.norm(low, dim=1, keepdim=True)\n",
    "    high_norm = high / torch.norm(high, dim=1, keepdim=True)\n",
    "    dot = (low_norm * high_norm).sum(1)\n",
    "\n",
    "    if dot.mean() > 0.9995:\n",
    "        return low * val + high * (1 - val)\n",
    "\n",
    "    omega = torch.acos(dot)\n",
    "    so = torch.sin(omega)\n",
    "    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (\n",
    "        torch.sin(val * omega) / so\n",
    "    ).unsqueeze(1) * high\n",
    "    return res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad88743a",
   "metadata": {},
   "source": [
    "# SDX\n",
    "## Space exploration functions\n",
    "[back to ToC](#Contents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ec19fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _initial_latents(seed=None, width=768, height=768, skip_random=0, verbose=False):\n",
    "    \"\"\"\n",
    "    return initial latents given an optional seed, optionally skipping a series of them\n",
    "    \"\"\"\n",
    "    # Setting the seed\n",
    "    if seed is not None:\n",
    "        if verbose:\n",
    "            eprint(f\"using random seed {seed}\")\n",
    "        torch.manual_seed(seed)\n",
    "        if skip_random:\n",
    "            if verbose:\n",
    "                eprint(f\"skipping {skip_random} random\")\n",
    "            # skip a series of random for latents (we want Nth image in a series generated from an initial seed)\n",
    "            _ = torch.randn(\n",
    "                (skip_random, pipe.unet.config.in_channels, height // 8, width // 8)\n",
    "            )\n",
    "\n",
    "    initial_latents = torch.randn(\n",
    "        (pipe.unet.config.in_channels, height // 8, width // 8)\n",
    "    )\n",
    "\n",
    "    if verbose:\n",
    "        eprint(\"initial latent, {}\".format(initial_latents.sum()))\n",
    "\n",
    "    return initial_latents\n",
    "\n",
    "\n",
    "def _enclat2img(\n",
    "    encodings=[],\n",
    "    initial_latents=[],\n",
    "    multiple_latents=[],\n",
    "    g=7.5,\n",
    "    steps=10,\n",
    "    neg_prompt=None,\n",
    "    device=device,\n",
    "    verbose=False,\n",
    "):\n",
    "    images = []\n",
    "\n",
    "    # adding an unconditional prompt helps in the generation process\n",
    "    if neg_prompt is None:\n",
    "        uncond = text_enc([\"\"] * 1, encodings.shape[1], device=device)\n",
    "    elif type(neg_prompt) != str:\n",
    "        eprint(f\"ERROR: neg_prompt must be a string, not '{type(neg_prompt)}'\")\n",
    "        return\n",
    "    else:\n",
    "        if verbose:\n",
    "            eprint(f\"using negative prompt '{neg_prompt}'\")\n",
    "        uncond = text_enc([neg_prompt] * 1, encodings.shape[1], device=device)\n",
    "\n",
    "    for _, encoding in enumerate(tqdm(encodings, desc=\"prompt\", leave=False)):\n",
    "        emb = torch.cat([uncond, encoding.reshape_as(uncond)])\n",
    "        if len(multiple_latents):\n",
    "            for _, latents in enumerate(\n",
    "                tqdm(multiple_latents, desc=\"latent\", leave=False)\n",
    "            ):\n",
    "                images += sample_space(\n",
    "                    torch.unsqueeze(latents, dim=0), emb, g, steps, device=device\n",
    "                )\n",
    "        else:\n",
    "            images += sample_space(\n",
    "                torch.unsqueeze(initial_latents, dim=0), emb, g, steps, device=device\n",
    "            )\n",
    "\n",
    "    return images\n",
    "\n",
    "\n",
    "def prompt2img(\n",
    "    prompts=[],\n",
    "    neg_prompt=None,\n",
    "    n_samples=1,\n",
    "    g=10,\n",
    "    steps=30,\n",
    "    width=768,\n",
    "    height=768,\n",
    "    seed=None,\n",
    "    skip_random=0,\n",
    "    device=device,\n",
    "    verbose=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Return a list of images equal to the number of prompts (optionally multiplied by n_samples)\n",
    "        prompt: list of strings or a single string\n",
    "        n_samples: how many samples to produce for each given prompt\n",
    "        neg_prompt: negative conditioning string\n",
    "        g: classifier free guidance\n",
    "        steps: iteration steps for the diffusion process\n",
    "        width, height: dimensions for the resulting image\n",
    "        seed: initialize random generator; use None to get next available random\n",
    "        skip_random: how many random latents to discard before producing image\n",
    "    \"\"\"\n",
    "\n",
    "    if type(prompts) == str:\n",
    "        prompts = [prompts]\n",
    "    if n_samples < 1:\n",
    "        eprint(\"ERROR: n_samples must be positive!\")\n",
    "        return\n",
    "\n",
    "    initial_latents = _initial_latents(\n",
    "        seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose\n",
    "    )\n",
    "    multiple_latents = []\n",
    "\n",
    "    if n_samples > 1:\n",
    "        sample_latents = [\n",
    "            torch.unsqueeze(initial_latents, dim=0)\n",
    "        ]  # the first one generated from the seed\n",
    "        for _ in range(n_samples - 1):  # add as many more as requested\n",
    "            sample_latent = torch.randn(\n",
    "                (pipe.unet.config.in_channels, height // 8, width // 8)\n",
    "            )\n",
    "            if verbose:\n",
    "                eprint(\"adding latent, {}\".format(sample_latent.sum()))\n",
    "            sample_latents.append(torch.unsqueeze(sample_latent, dim=0))\n",
    "        multiple_latents = torch.cat(sample_latents)\n",
    "\n",
    "    encodings = text_enc(prompts, device=device)\n",
    "\n",
    "    return _enclat2img(\n",
    "        encodings=encodings,\n",
    "        initial_latents=initial_latents,\n",
    "        multiple_latents=multiple_latents,\n",
    "        g=g,\n",
    "        steps=steps,\n",
    "        neg_prompt=neg_prompt,\n",
    "        device=device,\n",
    "        verbose=verbose,\n",
    "    )\n",
    "\n",
    "\n",
    "def beyond_prompt(\n",
    "    prompt=\"\",\n",
    "    neg_prompt=None,\n",
    "    walk_steps=1,\n",
    "    walk_stepsize=0.02,\n",
    "    g=10,\n",
    "    steps=30,\n",
    "    width=768,\n",
    "    height=768,\n",
    "    seed=None,\n",
    "    skip_random=0,\n",
    "    device=device,\n",
    "    verbose=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Walk forward in prompt embedding space by walk_stepsize to produce walk_steps + 1 images,\n",
    "    each a step (of optionally specified size) forward from the previous\n",
    "        prompt: string\n",
    "        neg_prompt: negative conditioning string\n",
    "        walk_steps: number of steps to walk forward\n",
    "        walk_stepsize: how much to walk forward in prompt embedding space at each step\n",
    "        g: classifier free guidance\n",
    "        steps: iteration steps for the diffusion process\n",
    "        width, height: dimensions for the resulting image\n",
    "        seed: initialize random generator; use None to get next available random\n",
    "        skip_random: how many random latents to discard before producing image\n",
    "    \"\"\"\n",
    "    if type(prompt) != str:\n",
    "        eprint(f\"ERROR: prompt must be a string, not '{type(prompt)}'\")\n",
    "        return\n",
    "\n",
    "    initial_latents = _initial_latents(\n",
    "        seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose\n",
    "    )\n",
    "    multiple_latents = []\n",
    "\n",
    "    encoding = text_enc([prompt], device=device)\n",
    "\n",
    "    delta = torch.ones_like(encoding) * walk_stepsize\n",
    "    new_encodings = []\n",
    "    for step_index in range(0, walk_steps + 1):\n",
    "        new_encodings.append(encoding)\n",
    "        encoding = encoding + delta  # nudge prompt embedding\n",
    "        if verbose:\n",
    "            print(\"nudged prompt by {}\".format(walk_stepsize * step_index))\n",
    "    encodings = torch.cat(new_encodings)\n",
    "\n",
    "    return _enclat2img(\n",
    "        encodings=encodings,\n",
    "        initial_latents=initial_latents,\n",
    "        multiple_latents=multiple_latents,\n",
    "        g=g,\n",
    "        steps=steps,\n",
    "        neg_prompt=neg_prompt,\n",
    "        device=device,\n",
    "        verbose=verbose,\n",
    "    )\n",
    "\n",
    "\n",
    "def interpolate_prompts(\n",
    "    prompts=[],\n",
    "    neg_prompt=None,\n",
    "    interpolate_steps=1,\n",
    "    g=10,\n",
    "    steps=30,\n",
    "    width=768,\n",
    "    height=768,\n",
    "    seed=None,\n",
    "    skip_random=0,\n",
    "    device=device,\n",
    "    verbose=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Given two prompts, interpolate among the two embeddings and produce a number of images equal to interpolate_steps\n",
    "    Return a list of images exploring the embedding space between first and second prompt\n",
    "        prompt: list of strings or a single string\n",
    "        interpolate_steps: number of images to produce between the one from the first and the one from the second prompt\n",
    "        neg_prompt: negative conditioning string\n",
    "        g: classifier free guidance\n",
    "        steps: iteration steps for the diffusion process\n",
    "        width, height: dimensions for the resulting image\n",
    "        seed: initialize random generator; use None to get next available random\n",
    "        skip_random: how many random latents to discard before producing image\n",
    "    Extension: interpolate four prompts and create square grid\n",
    "    \"\"\"\n",
    "\n",
    "    if type(prompts) != list or len(prompts) != 2:\n",
    "        eprint(\"ERROR: you need to pass a list of 2 prompts!\")\n",
    "        return\n",
    "    if interpolate_steps < 1:\n",
    "        eprint(\"ERROR: interpolate steps must be positive!\")\n",
    "        return\n",
    "\n",
    "    initial_latents = _initial_latents(\n",
    "        seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose\n",
    "    )\n",
    "    multiple_latents = []\n",
    "\n",
    "    encodings = text_enc(prompts, device=device)\n",
    "    encodings = torch_md_linspace(encodings[0], encodings[1], interpolate_steps + 2)\n",
    "\n",
    "    return _enclat2img(\n",
    "        encodings=encodings,\n",
    "        initial_latents=initial_latents,\n",
    "        multiple_latents=multiple_latents,\n",
    "        g=g,\n",
    "        steps=steps,\n",
    "        neg_prompt=neg_prompt,\n",
    "        device=device,\n",
    "        verbose=verbose,\n",
    "    )\n",
    "\n",
    "\n",
    "def revolve_prompt(\n",
    "    prompt=\"\",\n",
    "    neg_prompt=None,\n",
    "    walk_type=\"circle\",\n",
    "    walk_steps=1,\n",
    "    g=10,\n",
    "    steps=30,\n",
    "    width=768,\n",
    "    height=768,\n",
    "    seed=None,\n",
    "    seed2=None,\n",
    "    seed3=None,\n",
    "    skip_random=0,\n",
    "    device=device,\n",
    "    verbose=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Walk around a prompt in prompt in latent space to produce walk_steps images\n",
    "    For a \"circle\" walk, it uses two latents which can be determined specifying seed and seed2\n",
    "    In case of \"spiral\", the walk will be along a spherical spiral using three latents, optionally determined by seed, seed2 and seed3\n",
    "    Note: the image that would singularly generated from seed would appear as the first one in the set\n",
    "          and the one from seed2 would be found at around 1/4th of the total image count\n",
    "          in case of spiral walk, the image from seed1 is the first one,\n",
    "          the image from seed2 is found at around 1/4th of the total image count\n",
    "          and the image from seed3 (approximate) would be found at around 1/3rd of the total image count\n",
    "        prompt: string\n",
    "        neg_prompt: negative conditioning string\n",
    "        walk_type: \"circle\" (default) or \"spiral\"\n",
    "        walk_steps: how many steps to take in total (equals number of returned images)\n",
    "        g: classifier free guidance\n",
    "        steps: iteration steps for the diffusion process\n",
    "        width, height: dimensions for the resulting image\n",
    "        seed: initialize random generator; use None to get next available random\n",
    "        seed2: optional seed to determine circular walk\n",
    "        seed3: optional seed to further determine spiral walk\n",
    "        skip_random: how many random latents to discard before producing image\n",
    "    \"\"\"\n",
    "    if type(prompt) != str:\n",
    "        eprint(f\"ERROR: prompt must be a string, not '{type(prompt)}'\")\n",
    "        return\n",
    "\n",
    "    # initialize alternate latents\n",
    "    if seed2 is not None:\n",
    "        torch.manual_seed(seed2)\n",
    "        variation_latents = torch.randn(\n",
    "            (pipe.unet.config.in_channels, height // 8, width // 8)\n",
    "        )\n",
    "    if walk_type == \"spiral\" and seed3 is not None:\n",
    "        torch.manual_seed(seed3)\n",
    "        alt_variation_latents = torch.randn(\n",
    "            (pipe.unet.config.in_channels, height // 8, width // 8)\n",
    "        )\n",
    "    if walk_type == \"circle\" and seed3 is not None:\n",
    "        eprint(f\"NOTICE: seed3 not used for 'circle' walk_type'\")\n",
    "\n",
    "    initial_latents = _initial_latents(\n",
    "        seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose\n",
    "    )\n",
    "    multiple_latents = []\n",
    "\n",
    "    # if no alternate seed were specified, we'll use the next random after the one used for initial latents\n",
    "    if seed2 is None:\n",
    "        variation_latents = torch.randn(\n",
    "            (pipe.unet.config.in_channels, height // 8, width // 8)\n",
    "        )\n",
    "    if walk_type == \"spiral\" and seed3 is None:\n",
    "        alt_variation_latents = torch.randn(\n",
    "            (pipe.unet.config.in_channels, height // 8, width // 8)\n",
    "        )\n",
    "\n",
    "    encodings = text_enc([prompt], device=device)\n",
    "\n",
    "    if walk_type == \"circle\":  # around a prompt and two random latents in a circle\n",
    "        ##circular roundwalk:\n",
    "        # stepspace = torch.linspace(0, 2, walk_steps + 1)[0:-1] * pi\n",
    "        # walk_scale_x = torch.cos(stepspace)\n",
    "        # walk_scale_y = torch.sin(stepspace)\n",
    "        # (accelerates and decelerates, not very smooth in interpolation)\n",
    "\n",
    "        # spreadout circular walk:\n",
    "        spread_factor = 30  # the lower this, the more the points will be pushed away from 0, 90, 180, 270 bearings and concentrated towards 45, 135..\n",
    "        stepspace = torch.linspace(0, 2, walk_steps + 1)[0:-1]\n",
    "        stepspace -= torch.sin((stepspace + 0.25) * pi * 4) / spread_factor\n",
    "        stepspace *= pi\n",
    "        walk_scale_x = torch.cos(stepspace)\n",
    "        walk_scale_y = torch.sin(stepspace)\n",
    "\n",
    "        noise_x = torch.tensordot(walk_scale_x, initial_latents, dims=0)\n",
    "        noise_y = torch.tensordot(walk_scale_y, variation_latents, dims=0)\n",
    "        multiple_latents = torch.add(noise_x, noise_y)\n",
    "    elif walk_type == \"spiral\":  # spherical spiral walk with three random latents\n",
    "        c = 2  # use 4 for double the amount of turns in the spherical spiral walk\n",
    "\n",
    "        # circular spherical spiral walk:\n",
    "        # theta = torch.linspace(1, 2, walk_steps + 1)[0:-1] * pi\n",
    "        # theta2 = torch.linspace(1, 0, walk_steps + 1)[0:-1] * pi\n",
    "        # walk_scale_x1 = torch.sin(theta) * torch.cos(c * theta)\n",
    "        # walk_scale_x2 = torch.sin(theta2) * torch.cos(c * theta2)\n",
    "        # walk_scale_x = torch.cat([walk_scale_x1, walk_scale_x2])\n",
    "        # walk_scale_y1 = torch.sin(theta) * torch.sin(c * theta)\n",
    "        # walk_scale_y2 = torch.sin(theta2) * torch.sin(c * theta2)\n",
    "        # walk_scale_y = torch.cat([walk_scale_y1, walk_scale_y2])\n",
    "        # walk_scale_z = torch.cos(torch.linspace(0, 2 * pi, walk_steps * 2 + 1)[0:-1])\n",
    "\n",
    "        # spread spherical spiral walk:\n",
    "        spread_factor = 30\n",
    "        theta = torch.linspace(1, 2, walk_steps // 2 + 1)[0:-1]\n",
    "        theta -= torch.sin((theta + 0.25) * pi * 4) / spread_factor\n",
    "        theta *= pi\n",
    "        theta2 = torch.linspace(1, 0, walk_steps // 2 + 1)[0:-1]\n",
    "        theta2 -= torch.sin((theta2 + 0.25) * pi * 4) / spread_factor\n",
    "        theta2 *= pi\n",
    "        walk_scale_x1 = torch.sin(theta) * torch.cos(c * theta)\n",
    "        walk_scale_x2 = torch.sin(theta2) * torch.cos(c * theta2)\n",
    "        walk_scale_x = torch.cat([walk_scale_x1, walk_scale_x2])\n",
    "        walk_scale_y1 = torch.sin(theta) * torch.sin(c * theta)\n",
    "        walk_scale_y2 = torch.sin(theta2) * torch.sin(c * theta2)\n",
    "        walk_scale_y = torch.cat([walk_scale_y1, walk_scale_y2])\n",
    "        stepspace = torch.linspace(0, 2, walk_steps + 1)[0:-1]\n",
    "        stepspace -= torch.sin((stepspace + 0.25) * pi * 4) / spread_factor\n",
    "        stepspace *= pi\n",
    "        walk_scale_z = torch.cos(stepspace)\n",
    "\n",
    "        noise_z = torch.tensordot(walk_scale_z, initial_latents, dims=0)\n",
    "        noise_x = torch.tensordot(walk_scale_x, variation_latents, dims=0)\n",
    "        noise_y = torch.tensordot(walk_scale_y, alt_variation_latents, dims=0)\n",
    "\n",
    "        multiple_latents = torch.add(torch.add(noise_x, noise_y), noise_z)\n",
    "    else:\n",
    "        eprint(f\"ERROR: unknown walk_type '{walk_type}'\")\n",
    "        return\n",
    "\n",
    "    return _enclat2img(\n",
    "        encodings=encodings,\n",
    "        initial_latents=initial_latents,\n",
    "        multiple_latents=multiple_latents,\n",
    "        g=g,\n",
    "        steps=steps,\n",
    "        neg_prompt=neg_prompt,\n",
    "        device=device,\n",
    "        verbose=verbose,\n",
    "    )\n",
    "\n",
    "\n",
    "def prompt_variations(\n",
    "    prompt=\"\",\n",
    "    neg_prompt=None,\n",
    "    variations=1,\n",
    "    variation_strength=0.1,\n",
    "    g=10,\n",
    "    steps=30,\n",
    "    width=768,\n",
    "    height=768,\n",
    "    seed=None,\n",
    "    skip_random=0,\n",
    "    device=device,\n",
    "    verbose=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Return an image followed by a list of variations, each one of specified variation_strength from the first\n",
    "        prompt: list of strings or a single string\n",
    "        variations: how many variations to return after the normal image\n",
    "        variation_strength: how much to mix the initial latent and the variant ones (hence how different from initial image)\n",
    "        neg_prompt: negative conditioning string\n",
    "        g: classifier free guidance\n",
    "        steps: iteration steps for the diffusion process\n",
    "        width, height: dimensions for the resulting image\n",
    "        seed: initialize random generator; use None to get next available random\n",
    "        skip_random: how many random latents to discard before producing image\n",
    "    \"\"\"\n",
    "    if type(prompt) != str:\n",
    "        eprint(f\"ERROR: prompt must be a string, not '{type(prompt)}'\")\n",
    "        return\n",
    "    if variations < 1:\n",
    "        eprint(\"ERROR: number of variations must be positive!\")\n",
    "        return\n",
    "\n",
    "    initial_latents = _initial_latents(\n",
    "        seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose\n",
    "    )\n",
    "    multiple_latents = []\n",
    "\n",
    "    sample_latents = [\n",
    "        torch.unsqueeze(initial_latents, dim=0)\n",
    "    ]  # the first one generated from the seed\n",
    "    for _ in range(variations):  # add as many as requested\n",
    "        sample_latent = torch.randn(\n",
    "            (pipe.unet.config.in_channels, height // 8, width // 8)\n",
    "        )\n",
    "        sample_latent = slerp(variation_strength, initial_latents, sample_latent)\n",
    "        if verbose:\n",
    "            eprint(\"adding variation latent: {}\".format(sample_latent.sum()))\n",
    "        sample_latents.append(torch.unsqueeze(sample_latent, dim=0))\n",
    "    multiple_latents = torch.cat(sample_latents)\n",
    "\n",
    "    encodings = text_enc([prompt], device=device)\n",
    "\n",
    "    return _enclat2img(\n",
    "        encodings=encodings,\n",
    "        initial_latents=initial_latents,\n",
    "        multiple_latents=multiple_latents,\n",
    "        g=g,\n",
    "        steps=steps,\n",
    "        neg_prompt=neg_prompt,\n",
    "        device=device,\n",
    "        verbose=verbose,\n",
    "    )\n",
    "\n",
    "\n",
    "def variate_prompt(\n",
    "    prompt=\"\",\n",
    "    neg_prompt=None,\n",
    "    variation_strength=0,\n",
    "    var_steps=0,\n",
    "    g=10,\n",
    "    steps=30,\n",
    "    width=768,\n",
    "    height=768,\n",
    "    seed=None,\n",
    "    seed2=None,\n",
    "    skip_random=0,\n",
    "    device=device,\n",
    "    verbose=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    create an image from the mix (in desired amount) of two random initial latents (optionally specified by seed and seed2)\n",
    "    alternatively, if var_steps is specified, it will interpolate between the two random latents, returning var_steps images\n",
    "      (i.e. like trying a linearly increasing set of values of variation_strength, from 0 to 1)\n",
    "        prompt: string\n",
    "        neg_prompt: negative conditioning string\n",
    "        variation_strength: how much to mix the initial latent and the variant one\n",
    "        var_steps: how many steps to go from an initial latent and a variation latent\n",
    "        g: classifier free guidance\n",
    "        steps: iteration steps for the diffusion process\n",
    "        width, height: dimensions for the resulting image\n",
    "        seed: initialize random generator; use None to get next available random\n",
    "        seed2: optional seed to determine circular walk\n",
    "        skip_random: how many random latents to discard before producing image\n",
    "    \"\"\"\n",
    "    if type(prompt) != str:\n",
    "        eprint(f\"ERROR: prompt must be a string, not '{type(prompt)}'\")\n",
    "        return\n",
    "\n",
    "    if var_steps > 0 and variation_strength > 0:\n",
    "        eprint(\"NOTICE: variation_strength will be ignored when var_steps specified\")\n",
    "\n",
    "    if var_steps < 0 or variation_strength < 0:\n",
    "        eprint(\"ERROR: do not use negative values\")\n",
    "        return\n",
    "\n",
    "    if var_steps <= 0 and variation_strength <= 0:\n",
    "        eprint(\"ERROR: nothing to do. specify either var_steps or variation_strength\")\n",
    "        return\n",
    "\n",
    "    initial_latents = _initial_latents(\n",
    "        seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose\n",
    "    )\n",
    "\n",
    "    # initialize variation latent\n",
    "    if seed2 is not None:\n",
    "        torch.manual_seed(seed2)\n",
    "        variation_latents = torch.randn(\n",
    "            (pipe.unet.config.in_channels, height // 8, width // 8)\n",
    "        )\n",
    "\n",
    "    if var_steps > 0:  # gradually interpolate between two random latents\n",
    "        var_latents = []\n",
    "        stepspace = torch.linspace(0, 1, var_steps)  # include last point\n",
    "        for stepvalue in stepspace:\n",
    "            if verbose:\n",
    "                eprint(\"generating variation at {}\".format(stepvalue))  # debug\n",
    "            var_latents.append(\n",
    "                torch.unsqueeze(\n",
    "                    slerp(stepvalue, initial_latents, variation_latents), dim=0\n",
    "                )\n",
    "            )\n",
    "        multiple_latents = torch.cat(var_latents)\n",
    "    else:\n",
    "        multiple_latents = []\n",
    "\n",
    "    # if no alternate seed were specified, we'll use the next random after the one used for initial latents\n",
    "    if seed2 is None:\n",
    "        variation_latents = torch.randn(\n",
    "            (pipe.unet.config.in_channels, height // 8, width // 8)\n",
    "        )\n",
    "\n",
    "    initial_latents = slerp(variation_strength, initial_latents, variation_latents)\n",
    "\n",
    "    encodings = text_enc([prompt], device=device)\n",
    "\n",
    "    return _enclat2img(\n",
    "        encodings=encodings,\n",
    "        initial_latents=initial_latents,\n",
    "        multiple_latents=multiple_latents,\n",
    "        g=g,\n",
    "        steps=steps,\n",
    "        neg_prompt=neg_prompt,\n",
    "        device=device,\n",
    "        verbose=verbose,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d461f9d1",
   "metadata": {},
   "source": [
    "# StartPipe\n",
    "## Choose Model and Scheduler\n",
    "[back to ToC](#Contents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6dd30f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# available:\n",
    "# DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DDPMScheduler\n",
    "model_id = \"stabilityai/stable-diffusion-2-1-base\"\n",
    "scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder=\"scheduler\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5021128",
   "metadata": {},
   "source": [
    "## Prepare pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa1bfa2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = StableDiffusionPipeline.from_pretrained(\n",
    "    model_id, scheduler=scheduler, torch_dtype=torch.float16\n",
    ")\n",
    "pipe.safety_checker = None\n",
    "pipe.requires_safety_checker = False\n",
    "pipe = pipe.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d09a4c8",
   "metadata": {},
   "source": [
    "# Examples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a0ac6e5",
   "metadata": {},
   "source": [
    "# Simple\n",
    "## Simple txt2img\n",
    "[back to ToC](#Contents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36af43bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"hyper detailed digital painting of scenery, shibuya tokyo, post-apocalypse, ruins, rust, sky, skyscraper, abandoned, blue sky, broken window, building, cloud, crane machine, outdoors, overgrown, pillar, sunset\"\n",
    "seed = 24509723452345\n",
    "prompt2img(prompt, seed=seed)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e42d7f3",
   "metadata": {},
   "source": [
    "<img src=\"ruins.jpg\" />"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "791b2da4",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 64\n",
    "prompt = \"photo of a tiger demon\"\n",
    "images = prompt2img(prompt, width=512, height=768, g=7.5, steps=10, seed=seed)\n",
    "display_images(images)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3d6e47d",
   "metadata": {},
   "source": [
    "<img src=\"tigerdemon.jpg\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6aa39f12",
   "metadata": {},
   "source": [
    "# With a negative prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2859125",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Testing negative prompts\n",
    "images = [None, None]\n",
    "seed = 66\n",
    "prompt1 = \"photo of a castle in the middle of a forest with trees and bushes, detailed vegetation\"\n",
    "prompt2 = \"green, leaves, summer, spring\"\n",
    "images[0] = prompt2img(prompt1, neg_prompt=None, seed=seed)[0]\n",
    "images[1] = prompt2img(prompt1, neg_prompt=prompt2, seed=seed)[0]\n",
    "# side by side comparison\n",
    "display_images(images, prompt=prompt1, subtitles=[\"\", \"- \" + prompt2])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56a1ff8c",
   "metadata": {},
   "source": [
    "<img src=\"castle_negprompt.jpg\" />"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22658d5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# remove the prompt from itself for unexpected results:\n",
    "prompt = \"watercolour of a tiger\"\n",
    "prompt2img(prompt, neg_prompt=prompt, seed=1023458422345243)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5aa7f2b8",
   "metadata": {},
   "source": [
    "<img src=\"tiger_no_tiger.jpg\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8168e3d",
   "metadata": {},
   "source": [
    "## Multiple prompts can be specified"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12f895b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "images = prompt2img([\"cute puppy\", \"cute kitten\"], width=512, height=512, seed=123)\n",
    "display_images(images)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc878e2a",
   "metadata": {},
   "source": [
    "<img src=\"puppy_n_kitten.jpg\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "375c3dce",
   "metadata": {},
   "source": [
    "# Interpolation\n",
    "[back to ToC](#Contents)\n",
    "## Interpolation between text prompts\n",
    "(and what lies between two different points in prompt encoding space)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2efb5c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt1 = \"a photo of a boy running on the beach\"\n",
    "prompt2 = \"a photo of a cadillac on a highway\"\n",
    "seed = 749109862\n",
    "images = interpolate_prompts(\n",
    "    [prompt1, prompt2], width=512, height=512, seed=seed, interpolate_steps=5\n",
    ")\n",
    "display_images(images, prompt1 + \"<->\" + prompt2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6b2c111",
   "metadata": {},
   "source": [
    "<img src=\"boy_to_car.jpg\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3656ec4f",
   "metadata": {},
   "source": [
    "# GoingBeyond\n",
    "[back to ToC](#Contents)\n",
    "## Forward walk in embedding latent space from a prompt\n",
    "(what lies beyond a point in encoding space)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc7ab142",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"girl with green hair eating rice noodles\"\n",
    "images = beyond_prompt(\n",
    "    prompt,\n",
    "    neg_prompt=\"malformed\",\n",
    "    width=512,\n",
    "    height=512,\n",
    "    seed=4093245,\n",
    "    walk_steps=8,\n",
    "    walk_stepsize=0.015,\n",
    ")\n",
    "display_images_grid(images, prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f4c40cc",
   "metadata": {},
   "source": [
    "<img src=\"green_haired_girl.jpg\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "326ebe57",
   "metadata": {},
   "source": [
    "# CircleWalk\n",
    "[back to ToC](#Contents)\n",
    "## Circular walks around a prompt in noise latent space\n",
    "(optionally guided using two seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc4df5a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"An oil masterpiece painting of horses in a field next to a farm in Normandy\"\n",
    "seed = 132432456352\n",
    "seed2 = 42\n",
    "walk_steps = 48\n",
    "images = revolve_prompt(prompt, walk_steps=walk_steps, seed=seed, seed2=seed2)\n",
    "display_images_grid(images, prompt)\n",
    "\n",
    "# save_images(images, path=f\"horses_r{walk_steps}\")\n",
    "export_as_gif(\n",
    "    f\"horses_r{walk_steps}.gif\", images, frames_per_second=2, rubber_band=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2b6057a",
   "metadata": {},
   "source": [
    "<img src=\"horses_r48.jpg\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b46267ea",
   "metadata": {},
   "source": [
    "<video width=\"768\" height=\"768\"  src=\"horses_r48.mp4\" controls />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb02cf58",
   "metadata": {},
   "source": [
    "# SpiralWalk\n",
    "[back to ToC](#Contents)\n",
    "## Spherical spiral walks around a prompt in noise latent space\n",
    "(optionally guided using three seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17bcb583",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"hires photo of shark, underwater scenery with tropical fishes and coral sea floor, caustics\"\n",
    "seed = 100\n",
    "seed = 132432456352\n",
    "seed2 = 42\n",
    "seed3 = 897234234\n",
    "walk_steps = 24\n",
    "images = revolve_prompt(\n",
    "    prompt,\n",
    "    walk_type=\"spiral\",\n",
    "    walk_steps=walk_steps,\n",
    "    width=512,\n",
    "    height=512,\n",
    "    seed=seed,\n",
    "    seed2=seed2,\n",
    "    seed3=seed3,\n",
    ")\n",
    "display_images_grid(images, prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a01890a",
   "metadata": {},
   "source": [
    "<img src=\"shark_coral_reef_s24.jpg\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf90f11e",
   "metadata": {},
   "source": [
    "# Sampling\n",
    "[back to ToC](#Contents)\n",
    "## Multiple samples for the same prompt\n",
    "- sample several different latents for the same prompt\n",
    "- use skip_random to then directly go to one of the generated images (not working for EulerA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ac23a10",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 10234584620131114\n",
    "prompt = \"a watercolour painting of Cambridge Jesus Green\"\n",
    "images = prompt2img(\n",
    "    prompt, width=512, height=512, seed=seed, n_samples=9\n",
    ")\n",
    "display_images_grid(images)\n",
    "\n",
    "# recreate the 5th variation:\n",
    "images = prompt2img(\n",
    "    prompt, width=512, height=512, seed=seed, skip_random=4\n",
    ")\n",
    "images[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "398cd968",
   "metadata": {},
   "source": [
    "<img src=\"cambridge_watercolours.jpg\" />\n",
    "<img src=\"cambridge_watercolours_var5.jpg\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9c4d939",
   "metadata": {},
   "source": [
    "# Variations\n",
    "[back to ToC](#Contents)\n",
    "## Multiple variations from the same initial latent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc6b8cea",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"An oil painting of horses in a field next to a farm in Normandy\"\n",
    "seed = 1022134\n",
    "images = prompt_variations(prompt, variations=8, variation_strength=0.1, seed=seed)\n",
    "display_images_grid(images)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1964b0ab",
   "metadata": {},
   "source": [
    "<img src=\"horses_variations.jpg\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "693251f5",
   "metadata": {},
   "source": [
    "# VariationMix\n",
    "[back to ToC](#Contents)\n",
    "## Variation using second seed and specified strength"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10aca36d",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 1023458422345243\n",
    "seed2 = 35634563\n",
    "prompt = \"movie cover of Schwarzenegger as the Terminator riding a Vespa\"\n",
    "images = prompt2img(prompts=prompt, neg_prompt=None, width=512, height=512, seed=seed)\n",
    "images += variate_prompt(\n",
    "    prompt=prompt,\n",
    "    width=512,\n",
    "    height=512,\n",
    "    seed=seed,\n",
    "    seed2=seed2,\n",
    "    variation_strength=0.25,\n",
    ")\n",
    "display_images(images, prompt=prompt, subtitles=[\"\", f\"var{seed2} 25%\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1780b649",
   "metadata": {},
   "source": [
    "<img src=\"vespa_terminator.jpg\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc2fdfaf",
   "metadata": {},
   "source": [
    "# VariationWalk\n",
    "[back to ToC](#Contents)\n",
    "## Gradual interpolation between two random latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d33ccc67",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 1023458422345243\n",
    "seed2 = 35634563\n",
    "prompt = \"movie cover of Schwarzenegger as the Terminator riding a Vespa\"\n",
    "images = variate_prompt(\n",
    "    prompt=prompt, width=512, height=512, seed=seed, seed2=seed2, var_steps=6\n",
    ")\n",
    "display_images(images)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b9ff52e",
   "metadata": {},
   "source": [
    "<img src=\"vespa_terminator_varwalk.jpg\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de43fd36",
   "metadata": {},
   "source": [
    "# KidsExperiments\n",
    "[back to ToC](#Contents)\n",
    "## And this is what happens when you let your kids write sd prompts &#128521;:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8399147",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 12345\n",
    "images = interpolate_prompts(\n",
    "    [\"cow cat pawlephant\", \"muleskin beetledog\"],\n",
    "    seed=seed,\n",
    "    interpolate_steps=1,\n",
    "    g=7.5,\n",
    "    steps=10,\n",
    ")\n",
    "display_images(images)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cd7a90d",
   "metadata": {},
   "source": [
    "<img src=\"cowcat_to_beetledog.jpg\" />"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b800f52",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt2img(\"a cow covered in oreo cookies\", seed=3534534, g=7.5, steps=10)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "976c82ad",
   "metadata": {},
   "source": [
    "<img src=\"cow_covered_in_oreos.jpg\" />"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e502969",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 3534534\n",
    "seed2 = 3534533\n",
    "walk_steps = 360\n",
    "images = variate_prompt(\n",
    "    \"a cow covered in oreo cookies\", seed=seed, seed2=seed2, var_steps=walk_steps\n",
    ")\n",
    "save_images(images, path=\"cow_covered_in_oreos\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4df16b49",
   "metadata": {},
   "source": [
    "## ...which when converted to mp4 becomes:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "629fc5bc",
   "metadata": {},
   "source": [
    "<video width=\"768\" height=\"768\"  src=\"cow_covered_in_oreos.mp4\" controls />"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4d3651a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "tensorflow_env",
   "language": "python",
   "name": "tensorflow_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}