{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Infinite patterns by Alexander Mordvintsev", "version": "0.3.2", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "OaBXW2_cWj4y", "colab_type": "text" }, "source": [ "# Infinite patterns by Alexander Mordvintsev\n", "\n", "\n", "with   \"drawing\"\n", "

 

\n", "\n", "![alt text](https://storage.googleapis.com/cilex-common/temp/infinite-patterns-notebook-imgs/deepdream.png)\n", "\n", "Google’s **Alexander Mordvintsev** is the creator of DeepDream, a computer vision program that uses a neural network to find and create patterns in images. The end result often leads to dream-like, hallucinogenic, hyper-processed images.\n", "\n", "## Available to everyone\n", "\n", "The idea for Infinite Patterns came from a trip to the [Google Arts & Culture Lab](https://experiments.withgoogle.com/collection/arts-culture) in Paris, where [Pinar&Viola](https://www.pinar-viola.com/) saw images created by [DeepDream](https://en.wikipedia.org/wiki/DeepDream) first hand – and noticed similarities between the neural network’s creations and their own body of work. Noticing an opportunity for collaboration, “it felt as if all we have ever done in our career was done to bring us here,” said the artists.\n", "\n", "In December 2018, Pinar&Viola joined the lab as artists-in-residence, where they were able to work with Alex in the creation of new work. **Alex created a tool for the artists**, allowing them to create infinite patterns with DeepDream. First, the artists curated a selection of pictures, which were fed into the neural nets tool to extract inspiring patterns. Next, they combined these ML patterns with the original images to create the final design.\n", "\n", "**Today, the tool made by Alexander is available for you to use in the creation of your own art works!**\n", "\n", "This tools can create a pattern from any picture. More info about how the algorithms works [here](https://distill.pub/2018/differentiable-parameterizations/). " ] }, { "cell_type": "markdown", "metadata": { "id": "BwMfcuH3aWS-", "colab_type": "text" }, "source": [ "# Instructions ([video](https://www.youtube.com/watch?v=pFXaE7MLiTU))" ] }, { "cell_type": "markdown", "metadata": { "id": "Rw_stzrEWhTR", "colab_type": "text" }, "source": [ "1. Click **CONNECT** in the top right corner\n", "\n", "![Connect](https://storage.googleapis.com/cilex-common/temp/infinite-patterns-notebook-imgs/inst2.png)\n", "\n", "2. Run the cells by clicking the play button on the top left corner\n", "\n", "![Run cell](https://storage.googleapis.com/cilex-common/temp/infinite-patterns-notebook-imgs/run-cell-2.png)\n", "\n", "3. Click \"download\" link in the upper-left corner of the resulting image to save it." ] }, { "cell_type": "markdown", "metadata": { "id": "gBDgjeQvaaL8", "colab_type": "text" }, "source": [ "# Upload an image" ] }, { "cell_type": "code", "metadata": { "id": "xy0Fs7UZ21ak", "colab_type": "code", "cellView": "form", "outputId": "b09e7e58-c135-4ec7-bb4f-d5842060e473", "colab": { "base_uri": "https://localhost:8080/", "height": 221 } }, "source": [ "#@title ← Click this button to upload a different image\n", "\n", "try:\n", " import lucid\n", "except ImportError:\n", " !pip install -q lucid\n", "\n", "from __future__ import print_function\n", "import os\n", "import io\n", "import string\n", "import numpy as np\n", "import PIL\n", "import base64\n", "from glob import glob\n", "import itertools\n", "\n", "import cv2\n", "\n", "import matplotlib.pylab as pl\n", "\n", "import tensorflow as tf\n", "from tensorflow.contrib import slim\n", "\n", "import IPython\n", "from IPython.display import clear_output, Image, display, HTML\n", "\n", "from google.colab import files\n", "from google.colab import output\n", "from google.colab import drive\n", "\n", "from lucid.modelzoo import vision_models\n", "import lucid.misc.io.showing as show\n", "from lucid.optvis import objectives\n", "from lucid.optvis import render\n", "from lucid.misc.tfutil import create_session\n", "from lucid.optvis import style\n", "from lucid.modelzoo.util import forget_xy\n", "\n", "\n", "def imwrite(fn, img):\n", " if len(img.shape) == 4:\n", " img = img[0]\n", " img = np.uint8(img.clip(0, 1)*255)\n", " im = PIL.Image.fromarray(img)\n", " im.save(fn, quality=95)\n", "\n", "def show_tiled_image(img):\n", " url = show._image_url(img, fmt='jpeg')\n", " h, w = img.shape[:2]\n", "\n", " display(HTML('''\n", " \n", "
download
\n", " '''.format(url=url, h=h, w=w)))\n", "\n", "\n", "def anorm(a, axis=None, keepdims=False):\n", " return (a*a).sum(axis=axis, keepdims=keepdims)**0.5\n", "\n", "\n", "def composite_activation(x):\n", " x = tf.atan(x)\n", " # Coefficients computed by:\n", " # def rms(x):\n", " # return np.sqrt((x*x).mean())\n", " # a = np.arctan(np.random.normal(0.0, 1.0, 10**6))\n", " # print(rms(a), rms(a*a))\n", " return tf.concat([x/0.67, (x*x)/0.6], -1)\n", "\n", "\n", "def composite_activation_unbiased(x):\n", " x = tf.atan(x)\n", " # Coefficients computed by:\n", " # a = np.arctan(np.random.normal(0.0, 1.0, 10**6))\n", " # aa = a*a\n", " # print(a.std(), aa.mean(), aa.std())\n", " return tf.concat([x/0.67, (x*x-0.45)/0.396], -1)\n", "\n", "\n", "def relu_normalized(x):\n", " x = tf.nn.relu(x)\n", " # Coefficients computed by:\n", " # a = np.random.normal(0.0, 1.0, 10**6)\n", " # a = np.maximum(a, 0.0)\n", " # print(a.mean(), a.std())\n", " return (x-0.40)/0.58\n", "\n", "\n", "def image_cppn(\n", " size,\n", " offset=0.0,\n", " num_output_channels=3,\n", " num_hidden_channels=24,\n", " num_layers=8,\n", " activation_fn=composite_activation,\n", " normalize=False):\n", " coord_range = tf.to_float(tf.range(size))/tf.to_float(size)*2.0*np.pi\n", " #coord_range = tf.linspace(-np.pi, np.pi, size)\n", " y, x = tf.meshgrid(coord_range, coord_range, indexing='ij')\n", " net = tf.expand_dims(tf.stack([x, y], -1), 0) # add batch dimension\n", " net += offset\n", " net = tf.concat([tf.sin(net), tf.cos(net)], -1)\n", " \n", "\n", " with slim.arg_scope([slim.conv2d], kernel_size=1, activation_fn=None):\n", " for i in range(num_layers):\n", " in_n = int(net.shape[-1])\n", " net = slim.conv2d(\n", " net, num_hidden_channels,\n", " # this is untruncated version of tf.variance_scaling_initializer\n", " weights_initializer=tf.random_normal_initializer(0.0, np.sqrt(1.0/in_n)),\n", " )\n", " if normalize:\n", " net = slim.instance_norm(net)\n", " net = activation_fn(net)\n", " \n", " rgb = slim.conv2d(net, num_output_channels, activation_fn=tf.nn.sigmoid,\n", " weights_initializer=tf.zeros_initializer())\n", " return rgb\n", "\n", "\n", "\n", "def render_graph(fn, size=224, stripe_width=256):\n", " graph_def = tf.GraphDef.FromString(open(fn, 'rb').read())\n", " g = tf.Graph()\n", " with g.as_default():\n", " tf.import_graph_def(graph_def, name='')\n", " with tf.Session(graph=g) as sess:\n", " ty, tx = 'meshgrid/mul:0', 'meshgrid/mul_1:0'\n", " y, x = sess.run([ty, tx], {'size:0': size})\n", " stripes = []\n", " for s in range(0, len(x), stripe_width):\n", " stripe = sess.run('image:0', \n", " {tx: x[s:s+stripe_width], ty: y[s:s+stripe_width]})\n", " stripes.append(stripe[0])\n", " return np.vstack(stripes)\n", "\n", "\n", "model = vision_models.InceptionV1_caffe()\n", "model.load_graphdef()\n", "\n", "print('\\n↓ Now click the \"Choose Files\" button and select an image\\n')\n", "\n", "from google.colab import files\n", "uploaded = files.upload()\n", "image_name, _ = uploaded.popitem()\n", "\n", "clear_output()\n", "\n", "im = PIL.Image.open(image_name)\n", "g_image = np.float32(im)[...,:3]/255.0\n", "show.image(g_image)\n", "\n", "print('\\rimage uploaded, size:', im.size, end='')" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\rimage uploaded, size: (275, 183)" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "bS_4kf1Mhfuh", "colab_type": "text" }, "source": [ "# Render patterns" ] }, { "cell_type": "code", "metadata": { "id": "jbPlSLt_FWwG", "colab_type": "code", "cellView": "form", "outputId": "0067d7f3-a402-4240-81f0-575f187d8ba5", "colab": { "base_uri": "https://localhost:8080/", "height": 1041 } }, "source": [ "#@title ← Click this button to render\n", "#@markdown # CPPN pattern tool\n", "#@markdown This tool uses simple neural networks that map pixel coordinates to colors to represent image, that gets optimized. This approach is also known as [Compositional pattern-producing network](https://en.wikipedia.org/wiki/Compositional_pattern-producing_network).\n", "#@markdown Optimization tries to generate an image, that produces activation pattern, similar to the target image, in a particular layer (defined by **layer_index**) of ImageNet-trained classification network. **v1** objective tries to match the average pattern, while **v2** tries to match the whole distrubution.\n", "#@markdown **style_weight** parameter defines the contribution of [Gatys et al](https://arxiv.org/abs/1508.06576) style loss, and ignored in case of **v2** objective.\n", "#@markdown **activation** function used in the image-generating CPPN influence the resulting image style.\n", "\n", "objective = 'v2' #@param ['v1', 'v2']\n", "activation = 'composite' #@param ['composite', 'relu']\n", "style_weight = 1.34 #@param {type: \"slider\", min: 0.0, max: 2.0, step:0.01}\n", "layer_index = 7 #@param {type: \"slider\", min: 2, max: 8}\n", "\n", "\n", "sess = create_session()\n", "\n", "t_size = tf.placeholder_with_default(224, [], name='size')\n", "t_offset = tf.placeholder_with_default([0.0, 0.0], [2], name='offset')\n", "if activation == 'relu':\n", " t_image = image_cppn(t_size, normalize=True, activation_fn=tf.nn.relu)\n", "elif activation == 'composite':\n", " t_image = image_cppn(t_size, t_offset, normalize=False, activation_fn=composite_activation_unbiased)\n", "t_image = forget_xy(t_image)\n", "t_image = tf.identity(t_image, 'image')\n", "\n", "model.import_graph(t_image)\n", "\n", "tensor_name = 'import/%s:0'%model.layers[layer_index].name\n", "tensor = sess.graph.get_tensor_by_name(tensor_name)\n", "act = sess.run(tensor_name, {t_image:g_image[None, :,:,:3]})\n", "\n", "if objective=='v2':\n", " act = act.reshape(-1, act.shape[-1])\n", " act /= anorm(act, -1, True)\n", "\n", " flat_tensor = tf.reshape(tensor, [-1, act.shape[-1]])\n", " flat_tensor /= tf.norm(flat_tensor, axis=-1, keepdims=True)\n", "\n", " cross = tf.matmul(flat_tensor, act, transpose_b=True)\n", " t_loss0 = -tf.reduce_mean(tf.reduce_max(cross, 0))\n", " t_loss1 = -tf.reduce_mean(tf.reduce_max(cross, 1))\n", " t_loss = t_loss0 + t_loss1 - tf.reduce_mean(tensor)*0.05\n", "else:\n", " target_act = act.mean((0, 1, 2))\n", " t_dd_loss = -tf.reduce_mean(tensor*target_act)\n", " sl = style.StyleLoss([tensor], loss_func=style.mean_l1_loss)\n", " t_loss = t_dd_loss + sl.style_loss*style_weight\n", "\n", "t_lr = tf.constant(0.003)\n", "trainer = tf.train.AdamOptimizer(t_lr)\n", "train_op = trainer.minimize(t_loss)\n", "\n", "init_op = tf.global_variables_initializer()\n", "init_op.run()\n", "if objective=='v1':\n", " sl.set_style({t_image:g_image[None,...]})\n", "\n", "init_op.run()\n", "try:\n", " for i in range(500+1):\n", " dx, dy = np.random.rand(2)*np.pi*2\n", " _, loss = sess.run([train_op, t_loss], {t_offset:[dx, dy]})\n", " if i%50 == 0:\n", " clear_output()\n", " show.image(sess.run(t_image), format='jpeg')\n", " print(i, loss)\n", "except KeyboardInterrupt:\n", " pass\n", "\n", "clear_output()\n", "\n", "img = sess.run(t_image, {t_size:1024})[0]\n", "show_tiled_image(img)\n" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", "
download
\n", " " ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "AZ9jgnMCRBB-", "colab_type": "code", "cellView": "form", "outputId": "93f9a36a-090d-46b2-b3ad-d70d800b498f", "colab": { "base_uri": "https://localhost:8080/", "height": 989 } }, "source": [ "#@title ← Click this button to render\n", "#@markdown # DeepDream pattern tool\n", "#@markdown This tool uses pixel based image representation and a [DeepDream-inspired](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/tutorials/deepdream) multiscale pattern generation approach.\n", "\n", "preview_octave_n = 6\n", "preview_octave_steps = 50\n", "\n", "layer_index = 8 #@param {type: \"slider\", min: 3, max: 8}\n", "selectivity = 2 #@param {type: \"slider\", min: 0.2, max: 4.0, step: 0.1}\n", "colorful = 0.08 #@param {type: \"slider\", min: 0.0, max: 0.2, step: 0.01}\n", "\n", "\n", "k = np.float32([1,4,6,4,1])\n", "k = np.outer(k, k)\n", "k5x5 = k[:,:,None,None]/k.sum()*np.eye(3, dtype=np.float32)\n", "\n", "\n", "def lap_split(img):\n", " '''Split the image into lo and hi frequency components'''\n", " with tf.name_scope('split'):\n", " lo = tf.nn.conv2d(img, k5x5, [1,2,2,1], 'SAME')\n", " lo2 = tf.nn.conv2d_transpose(lo, k5x5*4, tf.shape(img), [1,2,2,1])\n", " hi = img-lo2\n", " return lo, hi\n", "\n", "def lap_split_n(img, n):\n", " '''Build Laplacian pyramid with n splits'''\n", " levels = []\n", " for i in range(n):\n", " img, hi = lap_split(img)\n", " levels.append(hi)\n", " levels.append(img)\n", " return levels[::-1]\n", "\n", "def lap_merge(levels):\n", " '''Merge Laplacian pyramid'''\n", " img = levels[0]\n", " for hi in levels[1:]:\n", " with tf.name_scope('merge'):\n", " img = tf.nn.conv2d_transpose(img, k5x5*4, tf.shape(hi), [1,2,2,1]) + hi\n", " return img\n", "\n", "def normalize_std(img, eps=1e-10):\n", " '''Normalize image by making its standard deviation = 1.0'''\n", " with tf.name_scope('normalize'):\n", " std = tf.sqrt(tf.reduce_mean(tf.square(img)))\n", " return img/tf.maximum(std, eps)\n", "\n", "def lap_normalize(img, scale_n=4):\n", " '''Perform the Laplacian pyramid normalization.'''\n", " img = tf.expand_dims(img,0)\n", " tlevels = lap_split_n(img, scale_n)\n", " tlevels = list(map(normalize_std, tlevels))\n", " out = lap_merge(tlevels)\n", " return out[0,:,:,:]\n", "\n", "def split1d(n, approx_tile_size):\n", " tile_n = (n + approx_tile_size//2) // approx_tile_size\n", " tile_n = max(tile_n, 1)\n", " tile_size = n // tile_n\n", " splits = np.arange(tile_n+1)*tile_size\n", " splits[-1] = n\n", " return np.c_[splits[:-1], splits[1:]]\n", " #return list(map(slice, splits[:-1], splits[1:]))\n", "\n", "def split2d(shape, approx_tile_size):\n", " h, w = shape\n", " splity = split1d(h, approx_tile_size)\n", " splitx = split1d(w, approx_tile_size)\n", " splits = itertools.product(splity, splitx)\n", " tiles = list(splits)\n", " return tiles\n", "\n", "\n", "def show_crop(dream, sz=512):\n", " y, x = np.int32(dream.rgb.shape[:2]) // 2 - sz//2\n", " rgb = dream.rgb[y:y+sz, x:x+sz]\n", " show.image(rgb, format='jpeg')\n", " \n", "class DeepDream:\n", " \n", " def __init__(self, g_image, params, shape0=(128, 128)):\n", " self.params = params\n", " self.octave_scale = 1.5\n", " self.base_step = 0.02\n", " h, w = np.int32(shape0)\n", " self.rgb = np.float16(np.random.rand(h, w, 3)*0.01+0.5)\n", " \n", " layer_index = params['layer_index']\n", " graph = tf.Graph()\n", " sess = self.sess = tf.Session(graph=graph)\n", " with sess.as_default(), graph.as_default():\n", " self.t_rgb = tf.placeholder(tf.float32, [None, None, 3])\n", " model.import_graph(self.t_rgb)\n", " tensor = sess.graph.get_tensor_by_name(\"import/%s:0\"%model.layers[layer_index].name)\n", " \n", " act = sess.run(tensor, {self.t_rgb:g_image})\n", " act **= params['selectivity']\n", " target_act = act.mean((0, 1, 2))\n", " target_act /= target_act.max()\n", " t_dd_loss = -tf.reduce_mean(tensor*target_act)\n", " t_loss = t_dd_loss\n", " t_score = -t_loss\n", " \n", " [t_rgb_grad] = tf.gradients(t_score, self.t_rgb)\n", " self.t_rgb_grad = lap_normalize(t_rgb_grad, 4)\n", " \n", " \n", " def run_octave(self, iter_n, step, iter_cb=None):\n", " h, w = np.int32(self.rgb.shape[:2])\n", " tiles = split2d((h, w), 768)\n", " for i in range(iter_n):\n", " x, y = np.random.randint(256, size=2)\n", " self.rgb[:] = np.roll(np.roll(self.rgb, x, 1), y, 0)\n", " for ry, rx in tiles:\n", " self.run_tile(rx, ry, step)\n", " self.rgb[:] = np.roll(np.roll(self.rgb,-x, 1),-y, 0)\n", " if iter_cb:\n", " iter_cb(self.rgb)\n", " print('\\r', end='')\n", " print(i, end='', flush=True)\n", " print('\\r', end='')\n", " \n", " def run_tile(self, range_x, range_y, step):\n", " M = np.float32([[ 0.07911157, 0.06922007, 0.0629285 ],\n", " [ 0.06922007, 0.07536791, 0.07108173],\n", " [ 0.0629285, 0.07108173, 0.08204902]])\n", " tile = slice(*range_y), slice(*range_x)\n", " rgb = self.rgb[tile]\n", " g_rgb = self.sess.run(self.t_rgb_grad, {self.t_rgb: rgb})\n", " g_rgb = g_rgb.dot(M+np.eye(3)*self.params['colorful'])\n", " g_rgb *= step / (g_rgb.std()+1e-5)\n", " rgb += g_rgb\n", " rgb[:] = np.clip(rgb, 0.0, 1.0)\n", " \n", " def upscale_image(self, img):\n", " scale = self.octave_scale\n", " img = cv2.resize(np.float32(img), None, None, scale, scale)\n", " return np.float16(img)\n", " \n", " def upscale(self, gamma=2.2):\n", " if gamma != 1.0:\n", " self.rgb **= 1.0/gamma\n", " self.rgb = self.upscale_image(self.rgb)\n", " if gamma != 1.0:\n", " self.rgb **= gamma\n", " \n", " def run(self, octave, iter_n, iter_cb=None):\n", " if octave > 0:\n", " self.upscale()\n", " step_reduction = max(0.2, self.octave_scale**(-octave))\n", " octave_step = self.base_step * step_reduction\n", " self.run_octave(iter_n, octave_step, iter_cb)\n", " print(octave, self.rgb.shape, octave_step)\n", "\n", "\n", "frames = []\n", "\n", "params = dict(\n", " layer_index=layer_index,\n", " selectivity=selectivity,\n", " colorful=colorful\n", ")\n", "dream = DeepDream(g_image, params)\n", "try:\n", " for octave in range(preview_octave_n):\n", " clear_output()\n", " show.image(dream.rgb, format='jpeg')\n", " dream.run(octave, preview_octave_steps, iter_cb=lambda a: frames.append(a.copy()))\n", "except KeyboardInterrupt:\n", " pass\n", "finally:\n", " clear_output()\n", " show_tiled_image(dream.rgb)" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "IAGrnM4w3Gmv", "colab_type": "code", "colab": {} }, "source": [ "" ], "execution_count": 0, "outputs": [] } ] }