In [None]:
import subprocess

CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
else:
    torch_version_suffix = "+cu110"

! pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex

CUDA version: 10.1
Looking in links: https://download.pytorch.org/whl/torch_stable.html


In [None]:
!pip install big-sleep --upgrade

Requirement already up-to-date: big-sleep in /usr/local/lib/python3.6/dist-packages (0.4.6)


In [None]:
from IPython.display import Image, display
import string
import torch
from torchvision.utils import save_image
from torchvision import transforms

import numpy as np

from big_sleep import Imagine
from big_sleep.clip import tokenize

from nltk.corpus import stopwords

from skimage.measure import compare_ssim

import cv2
from pathlib import Path
import ipywidgets

import PIL
from PIL import ImageFont, ImageDraw

TEXT = 'story_hallucinator' 
SAVE_EVERY = 1
SAVE_PROGRESS = True
LEARNING_RATE = 0.1
ITERATIONS =  1

REPEATS = 5
SPAN = 6

def train_step(self, epoch, i, rand=0):
  total_loss = 0

  for _ in range(self.gradient_accumulate_every):
      losses = self.model(self.encoded_text) 
      loss = (sum(losses) / self.gradient_accumulate_every) + rand*np.random.randn()
      total_loss += loss
      loss.backward()

  self.optimizer.step()
  self.optimizer.zero_grad()

  mres = self.model.model()
  return transforms.ToPILImage()(mres[len(mres)-1].cpu()).convert("RGB")

filename = TEXT.replace(' ', '_')

def burnin(words):
    # burn in first image
    for i in range(10):
        phrase = " ".join(words[:SPAN])
        im_model.text = phrase.translate(str.maketrans('', '', string.punctuation))
        im_model.encoded_text = tokenize(im_model.text).cuda()
        train_step(im_model, 0, i)

def add_text_to_im(img,msg_orig):
    W, H = img.size
    draw = ImageDraw.Draw(img)
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 18)
    msgs = [msg_orig]
    w, h = draw.textsize(msg_orig, font=font)
    if w>W:
        split = span // 2
        msgs = [" ".join(words[epoch:epoch+split]), " ".join(words[epoch+split:epoch+span])]
    for shift, msg in enumerate(msgs): 
        w, h = draw.textsize(msg, font=font)
        x, y = (W-w)/2, 7*(H-h)/8 + shift*h
        adj = 1
        #move right
        shadowColor = "black"
        draw.text((x-adj, y), msg, fill=shadowColor, font=font)
        #move left
        draw.text((x+adj, y), msg, fill=shadowColor, font=font)
        #move up
        draw.text((x, y+adj), msg, fill=shadowColor, font=font)
        #move down
        draw.text((x, y-adj), msg, fill=shadowColor, font=font)
        #diagnal left up
        draw.text((x-adj, y+adj), msg, fill=shadowColor, font=font)
        #diagnal right up
        draw.text((x+adj, y+adj), msg, fill=shadowColor, font=font)
        #diagnal left down
        draw.text((x-adj, y-adj), msg, fill=shadowColor, font=font)
        #diagnal right down
        draw.text((x+adj, y-adj), msg, fill=shadowColor, font=font)
        draw.text((x, y), msg, fill="white", font=font)

def vizualize_words(words,segment_num):
    # viz
    image_list = []
    L = len(words)
    step_size = SPAN
    iterations = 15
    for epoch in range(0,L,step_size):
        for i in range(iterations):
            phrase = " ".join(words[epoch+int(i*step_size/iterations):epoch+SPAN+int(i*step_size/iterations)])
            im_model.text = phrase.translate(str.maketrans('', '', string.punctuation))
            im_model.encoded_text = tokenize(im_model.text).cuda()
            image_cur = train_step(im_model, epoch, i)
            add_text_to_im(image_cur,phrase)
            image_list.append(image_cur)
    image_list[0].save(fp=f'aidungeon{segment_num}.gif', format='GIF', append_images=image_list[1:], save_all=True, duration=50, loop=0)
        
iter_num=0
result = "This is the start of the AI Dungeon game"

In [None]:
#### 
##   REFRESH IMAGE MODEL
####

im_model = Imagine(
    text = TEXT,
    save_every = SAVE_EVERY,
    lr = LEARNING_RATE,
    iterations = ITERATIONS,
    save_progress = SAVE_PROGRESS
)

# burn in first image
for i in range(20):
    phrase = " ".join(result.split())
    im_model.text = phrase.translate(str.maketrans('', '', string.punctuation))
    im_model.encoded_text = tokenize(im_model.text).cuda()
    train_step(im_model, 0, i)


In [None]:
result = """
You are Marley, a mutant trying to survive after the deadly plague. You have
 scales on your chest and feather on your left arm.

In the town you were born in, your strange condition was considered a sin, and
 you have been banished since you were eleven. After a long journey, you find
 a ravaged cottage. You look inside and see that it is not inhabited anymore.
 The place has become a graveyard. A few of the bodies are still moving, but
 they are all dead now
"""
vizualize_words(result.split(),iter_num)
iter_num += 1

In [None]:
with open(f'aidungeon1.gif', 'rb') as f_temp:
    progress= ipywidgets.Image(
        value=f_temp.read(),
        format='gif',
        width=512,
        height=512)
    display(progress)

Image(value=b'GIF89a\x00\x02\x00\x02\x87\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xâ€¦

In [None]:
!ls

aidungeon0.gif	aidungeon1.gif	aidungeon2.gif	sample_data
