{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "Story2Hallucination_GIF.ipynb", "provenance": [], "collapsed_sections": [], "machine_shape": "hm" }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "cells": [ { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qzftMs2bJkqS", "outputId": "8e68a5e2-ac52-445d-940c-3a0a07be0ffd" }, "source": [ "import subprocess\n", "\n", "CUDA_version = [s for s in subprocess.check_output([\"nvcc\", \"--version\"]).decode(\"UTF-8\").split(\", \") if s.startswith(\"release\")][0].split(\" \")[-1]\n", "print(\"CUDA version:\", CUDA_version)\n", "\n", "if CUDA_version == \"10.0\":\n", " torch_version_suffix = \"+cu100\"\n", "elif CUDA_version == \"10.1\":\n", " torch_version_suffix = \"+cu101\"\n", "elif CUDA_version == \"10.2\":\n", " torch_version_suffix = \"\"\n", "else:\n", " torch_version_suffix = \"+cu110\"\n", "\n", "! pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "CUDA version: 10.1\n", "Looking in links: https://download.pytorch.org/whl/torch_stable.html\n", "Requirement already satisfied: torch==1.7.1+cu101 in /usr/local/lib/python3.6/dist-packages (1.7.1+cu101)\n", "Requirement already satisfied: torchvision==0.8.2+cu101 in /usr/local/lib/python3.6/dist-packages (0.8.2+cu101)\n", "Requirement already satisfied: ftfy in /usr/local/lib/python3.6/dist-packages (5.8)\n", "Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (2019.12.20)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (1.19.5)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (3.7.4.3)\n", "Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (0.8)\n", "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.8.2+cu101) (7.0.0)\n", "Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from ftfy) (0.2.5)\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7aYst9jhJo_2", "outputId": "902a2419-3e13-4b6b-d635-616f623b0d86" }, "source": [ "!pip install big-sleep --upgrade" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Requirement already up-to-date: big-sleep in /usr/local/lib/python3.6/dist-packages (0.4.6)\n", "Requirement already satisfied, skipping upgrade: pytorch-pretrained-biggan in /usr/local/lib/python3.6/dist-packages (from big-sleep) (0.1.1)\n", "Requirement already satisfied, skipping upgrade: torch>=1.7.1 in /usr/local/lib/python3.6/dist-packages (from big-sleep) (1.7.1+cu101)\n", "Requirement already satisfied, skipping upgrade: einops>=0.3 in /usr/local/lib/python3.6/dist-packages (from big-sleep) (0.3.0)\n", "Requirement already satisfied, skipping upgrade: tqdm in /usr/local/lib/python3.6/dist-packages (from big-sleep) (4.41.1)\n", "Requirement already satisfied, skipping upgrade: torchvision>=0.8.2 in /usr/local/lib/python3.6/dist-packages (from big-sleep) (0.8.2+cu101)\n", "Requirement already satisfied, skipping upgrade: fire in /usr/local/lib/python3.6/dist-packages (from big-sleep) (0.4.0)\n", "Requirement already satisfied, skipping upgrade: ftfy in /usr/local/lib/python3.6/dist-packages (from big-sleep) (5.8)\n", "Requirement already satisfied, skipping upgrade: regex in /usr/local/lib/python3.6/dist-packages (from big-sleep) (2019.12.20)\n", "Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-biggan->big-sleep) (2.23.0)\n", "Requirement already satisfied, skipping upgrade: numpy in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-biggan->big-sleep) (1.19.5)\n", "Requirement already satisfied, skipping upgrade: boto3 in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-biggan->big-sleep) (1.16.63)\n", "Requirement already satisfied, skipping upgrade: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from torch>=1.7.1->big-sleep) (0.8)\n", "Requirement already satisfied, skipping upgrade: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch>=1.7.1->big-sleep) (3.7.4.3)\n", "Requirement already satisfied, skipping upgrade: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision>=0.8.2->big-sleep) (7.0.0)\n", "Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from fire->big-sleep) (1.15.0)\n", "Requirement already satisfied, skipping upgrade: termcolor in /usr/local/lib/python3.6/dist-packages (from fire->big-sleep) (1.1.0)\n", "Requirement already satisfied, skipping upgrade: wcwidth in /usr/local/lib/python3.6/dist-packages (from ftfy->big-sleep) (0.2.5)\n", "Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-biggan->big-sleep) (2.10)\n", "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-biggan->big-sleep) (1.24.3)\n", "Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-biggan->big-sleep) (3.0.4)\n", "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-biggan->big-sleep) (2020.12.5)\n", "Requirement already satisfied, skipping upgrade: botocore<1.20.0,>=1.19.63 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-pretrained-biggan->big-sleep) (1.19.63)\n", "Requirement already satisfied, skipping upgrade: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-pretrained-biggan->big-sleep) (0.10.0)\n", "Requirement already satisfied, skipping upgrade: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-pretrained-biggan->big-sleep) (0.3.4)\n", "Requirement already satisfied, skipping upgrade: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.20.0,>=1.19.63->boto3->pytorch-pretrained-biggan->big-sleep) (2.8.1)\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "GPhyqxuyL0QC" }, "source": [ "from IPython.display import Image, display\n", "import string\n", "import torch\n", "from torchvision.utils import save_image\n", "from torchvision import transforms\n", "\n", "import numpy as np\n", "\n", "from big_sleep import Imagine\n", "from big_sleep.clip import tokenize\n", "\n", "from nltk.corpus import stopwords\n", "\n", "from skimage.measure import compare_ssim\n", "\n", "import cv2\n", "from pathlib import Path\n", "import ipywidgets\n", "\n", "import PIL\n", "from PIL import ImageFont, ImageDraw\n", "\n", "TEXT = 'story_hallucinator' \n", "SAVE_EVERY = 1\n", "SAVE_PROGRESS = True\n", "LEARNING_RATE = 0.1\n", "ITERATIONS = 1\n", "\n", "REPEATS = 5\n", "SPAN = 6\n", "\n", "def train_step(self, epoch, i, rand=0):\n", " total_loss = 0\n", "\n", " for _ in range(self.gradient_accumulate_every):\n", " losses = self.model(self.encoded_text) \n", " loss = (sum(losses) / self.gradient_accumulate_every) + rand*np.random.randn()\n", " total_loss += loss\n", " loss.backward()\n", "\n", " self.optimizer.step()\n", " self.optimizer.zero_grad()\n", "\n", " mres = self.model.model()\n", " return transforms.ToPILImage()(mres[len(mres)-1].cpu()).convert(\"RGB\")\n", "\n", "filename = TEXT.replace(' ', '_')\n", "\n", "def burnin(words):\n", " # burn in first image\n", " for i in range(10):\n", " phrase = \" \".join(words[:SPAN])\n", " im_model.text = phrase.translate(str.maketrans('', '', string.punctuation))\n", " im_model.encoded_text = tokenize(im_model.text).cuda()\n", " train_step(im_model, 0, i)\n", "\n", "def add_text_to_im(img,msg_orig):\n", " W, H = img.size\n", " draw = ImageDraw.Draw(img)\n", " font = ImageFont.truetype(\"/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf\", 18)\n", " msgs = [msg_orig]\n", " w, h = draw.textsize(msg_orig, font=font)\n", " if w>W:\n", " split = span // 2\n", " msgs = [\" \".join(words[epoch:epoch+split]), \" \".join(words[epoch+split:epoch+span])]\n", " for shift, msg in enumerate(msgs): \n", " w, h = draw.textsize(msg, font=font)\n", " x, y = (W-w)/2, 7*(H-h)/8 + shift*h\n", " adj = 1\n", " #move right\n", " shadowColor = \"black\"\n", " draw.text((x-adj, y), msg, fill=shadowColor, font=font)\n", " #move left\n", " draw.text((x+adj, y), msg, fill=shadowColor, font=font)\n", " #move up\n", " draw.text((x, y+adj), msg, fill=shadowColor, font=font)\n", " #move down\n", " draw.text((x, y-adj), msg, fill=shadowColor, font=font)\n", " #diagnal left up\n", " draw.text((x-adj, y+adj), msg, fill=shadowColor, font=font)\n", " #diagnal right up\n", " draw.text((x+adj, y+adj), msg, fill=shadowColor, font=font)\n", " #diagnal left down\n", " draw.text((x-adj, y-adj), msg, fill=shadowColor, font=font)\n", " #diagnal right down\n", " draw.text((x+adj, y-adj), msg, fill=shadowColor, font=font)\n", " draw.text((x, y), msg, fill=\"white\", font=font)\n", "\n", "def vizualize_words(words,segment_num):\n", " # viz\n", " image_list = []\n", " L = len(words)\n", " step_size = SPAN\n", " iterations = 15\n", " for epoch in range(0,L,step_size):\n", " for i in range(iterations):\n", " phrase = \" \".join(words[epoch+int(i*step_size/iterations):epoch+SPAN+int(i*step_size/iterations)])\n", " im_model.text = phrase.translate(str.maketrans('', '', string.punctuation))\n", " im_model.encoded_text = tokenize(im_model.text).cuda()\n", " image_cur = train_step(im_model, epoch, i)\n", " add_text_to_im(image_cur,phrase)\n", " image_list.append(image_cur)\n", " image_list[0].save(fp=f'aidungeon{segment_num}.gif', format='GIF', append_images=image_list[1:], save_all=True, duration=50, loop=0)\n", " \n", "iter_num=0\n", "result = \"This is the start of the AI Dungeon game\"" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "da5dU-JJI2iY" }, "source": [ "#### \n", "## REFRESH IMAGE MODEL\n", "####\n", "\n", "im_model = Imagine(\n", " text = TEXT,\n", " save_every = SAVE_EVERY,\n", " lr = LEARNING_RATE,\n", " iterations = ITERATIONS,\n", " save_progress = SAVE_PROGRESS\n", ")\n", "\n", "# burn in first image\n", "for i in range(20):\n", " phrase = \" \".join(result.split())\n", " im_model.text = phrase.translate(str.maketrans('', '', string.punctuation))\n", " im_model.encoded_text = tokenize(im_model.text).cuda()\n", " train_step(im_model, 0, i)\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ToCV9g0qxiAW" }, "source": [ "result = \"\"\"\n", "You are Marley, a mutant trying to survive after the deadly plague. You have\n", " scales on your chest and feather on your left arm.\n", "\n", "In the town you were born in, your strange condition was considered a sin, and\n", " you have been banished since you were eleven. After a long journey, you find\n", " a ravaged cottage. You look inside and see that it is not inhabited anymore.\n", " The place has become a graveyard. A few of the bodies are still moving, but\n", " they are all dead now\n", "\"\"\"\n", "vizualize_words(result.split(),iter_num)\n", "iter_num += 1" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "background_save": true, "referenced_widgets": [ "a1473c04f0dd4e449477b8eec1e470f3" ] }, "id": "VjHcJ5vklGGW", "outputId": "c20a5d77-9cb3-4033-8b96-16051fa475f3" }, "source": [ "with open(f'aidungeon1.gif', 'rb') as f_temp:\n", " progress= ipywidgets.Image(\n", " value=f_temp.read(),\n", " format='gif',\n", " width=512,\n", " height=512)\n", " display(progress)" ], "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a1473c04f0dd4e449477b8eec1e470f3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Image(value=b'GIF89a\\x00\\x02\\x00\\x02\\x87\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x…" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "zapl09h1nnp9", "outputId": "2796da06-1c29-464e-ba73-a137d2a15919" }, "source": [ "!ls" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "aidungeon0.gif\taidungeon1.gif\taidungeon2.gif\tsample_data\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "N79_1hzyn464" }, "source": [ "" ], "execution_count": null, "outputs": [] } ] }