{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "xy2rgb.ipynb", "version": "0.3.2", "provenance": [], "collapsed_sections": [ "JndnmDMp66FL", "EOCfOPdCkmXW", "UjboH7xgp6pt", "6x-kPRJzisK7" ], "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "[View in Colaboratory](https://colab.research.google.com/github/tensorflow/lucid/blob/master/notebooks/differentiable-parameterizations/xy2rgb.ipynb)" ] }, { "metadata": { "id": "JndnmDMp66FL", "colab_type": "text" }, "cell_type": "markdown", "source": [ "##### Copyright 2018 Google LLC.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "metadata": { "id": "hMqWDc_m6rUC", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "GS4kqcEJuy5R", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# Compositional Pattern Producing Networks for Feature Visualization\n", "\n", "This notebook uses [**Lucid**](https://github.com/tensorflow/lucid) to produce aesthetically pleasing feature visualizations using a [Differentiable Image Parameterization](https://distill.pub/2018/differentiable-parameterizations/#section-xy2rgb) called a **Compositional Pattern Producing Network** (CPPN).\n", "\n", "![](https://storage.googleapis.com/tensorflow-lucid/notebooks/xy2rgb/cppn-header.jpg)\n", "\n", "This notebook additionally demonstrates:\n", "\n", "* rendering videos of the training process of the CPPN generating the visualizations,\n", "* rendering videos of interpolating between sets of learned CPPN parameters\n", "* rendering high resolution visualizations from a set of CPPN parameters.\n", "\n", "\n", "This notebook doesn't introduce the abstractions behind lucid; you may wish to also read the [Lucid tutorial](https://colab.research.google.com/github/tensorflow/lucid/blob/master/notebooks/tutorial.ipynb).\n", "\n", "**Note**: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup. We recommend you enable a free GPU by going:\n", "\n", "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**" ] }, { "metadata": { "id": "EOCfOPdCkmXW", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Install, Import, and load a model" ] }, { "metadata": { "id": "b-IocmWWvb_I", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!pip install -q lucid>=0.2.3" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "bYd1LKCCxeTd", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 145 }, "outputId": "2047001c-97fa-4a7f-9f4e-fa3ca53516c2" }, "cell_type": "code", "source": [ "# For video rendering\n", "\n", "!pip install -q moviepy\n", "!imageio_download_bin ffmpeg" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Ascertaining binaries for: ffmpeg.\r\n", "Imageio: 'ffmpeg-linux64-v3.3.1' was not found on your computer; downloading it now.\n", "Try 1. Download from https://github.com/imageio/imageio-binaries/raw/master/ffmpeg/ffmpeg-linux64-v3.3.1 (43.8 MB)\n", "Downloading: 8192/45929032 bytes (0.0%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b663552/45929032 bytes (1.4%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b1867776/45929032 bytes (4.1%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b3637248/45929032 bytes (7.9%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b6168576/45929032 bytes (13.4%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b9486336/45929032 bytes (20.7%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b12820480/45929032 bytes (27.9%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b16310272/45929032 bytes (35.5%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b19849216/45929032 bytes (43.2%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b23289856/45929032 bytes (50.7%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b26828800/45929032 bytes (58.4%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b30261248/45929032 bytes (65.9%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b33767424/45929032 bytes (73.5%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b37216256/45929032 bytes (81.0%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b40755200/45929032 bytes (88.7%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b44318720/45929032 bytes (96.5%)\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b45929032/45929032 bytes (100.0%)\n", " Done\n", "File saved as /content/.imageio/ffmpeg/ffmpeg-linux64-v3.3.1.\n" ], "name": "stdout" } ] }, { "metadata": { "id": "JEzrNW5lBUxr", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 17 }, "outputId": "d3b9bc79-4527-4f28-f218-e8514b9e8ab7" }, "cell_type": "code", "source": [ "from __future__ import print_function\n", "import io\n", "import string\n", "import numpy as np\n", "import PIL\n", "import base64\n", "from glob import glob\n", "\n", "import matplotlib.pylab as pl\n", "\n", "import tensorflow as tf\n", "from tensorflow.contrib import slim\n", "\n", "from IPython.display import clear_output, Image, display, HTML\n", "\n", "import moviepy.editor as mpy\n", "from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter\n", "\n", "\n", "from google.colab import files" ], "execution_count": 1, "outputs": [] }, { "metadata": { "id": "mUnwK7mYFFns", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 17 }, "outputId": "d3b2b3be-a564-4ad7-e6ba-c09b36e6877a" }, "cell_type": "code", "source": [ "from lucid.modelzoo import vision_models\n", "from lucid.misc.io import show, save, load\n", "from lucid.optvis import objectives\n", "from lucid.optvis import render\n", "from lucid.misc.tfutil import create_session" ], "execution_count": 2, "outputs": [] }, { "metadata": { "id": "pRnDsiHa7ZK1", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 17 }, "outputId": "1edbbcf9-b7c0-4252-b422-a68d49d9b7fb" }, "cell_type": "code", "source": [ "model = vision_models.InceptionV1()\n", "model.load_graphdef()" ], "execution_count": 3, "outputs": [] }, { "metadata": { "id": "h1_Kw_kNkzFu", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Setting up the CPPN " ] }, { "metadata": { "id": "f5b3ShqcGNET", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 17 }, "outputId": "c66dc068-23e8-49f7-a85e-306512851d65" }, "cell_type": "code", "source": [ "def composite_activation(x):\n", " x = tf.atan(x)\n", " # Coefficients computed by:\n", " # def rms(x):\n", " # return np.sqrt((x*x).mean())\n", " # a = np.arctan(np.random.normal(0.0, 1.0, 10**6))\n", " # print(rms(a), rms(a*a))\n", " return tf.concat([x/0.67, (x*x)/0.6], -1)\n", "\n", "\n", "def composite_activation_unbiased(x):\n", " x = tf.atan(x)\n", " # Coefficients computed by:\n", " # a = np.arctan(np.random.normal(0.0, 1.0, 10**6))\n", " # aa = a*a\n", " # print(a.std(), aa.mean(), aa.std())\n", " return tf.concat([x/0.67, (x*x-0.45)/0.396], -1)\n", "\n", "\n", "def relu_normalized(x):\n", " x = tf.nn.relu(x)\n", " # Coefficients computed by:\n", " # a = np.random.normal(0.0, 1.0, 10**6)\n", " # a = np.maximum(a, 0.0)\n", " # print(a.mean(), a.std())\n", " return (x-0.40)/0.58\n", "\n", "\n", "def image_cppn(\n", " size,\n", " num_output_channels=3,\n", " num_hidden_channels=24,\n", " num_layers=8,\n", " activation_fn=composite_activation,\n", " normalize=False):\n", " r = 3.0**0.5 # std(coord_range) == 1.0\n", " coord_range = tf.linspace(-r, r, size)\n", " y, x = tf.meshgrid(coord_range, coord_range, indexing='ij')\n", " net = tf.expand_dims(tf.stack([x, y], -1), 0) # add batch dimension\n", "\n", " with slim.arg_scope([slim.conv2d], kernel_size=1, activation_fn=None):\n", " for i in range(num_layers):\n", " in_n = int(net.shape[-1])\n", " net = slim.conv2d(\n", " net, num_hidden_channels,\n", " # this is untruncated version of tf.variance_scaling_initializer\n", " weights_initializer=tf.random_normal_initializer(0.0, np.sqrt(1.0/in_n)),\n", " )\n", " if normalize:\n", " net = slim.instance_norm(net)\n", " net = activation_fn(net)\n", "\n", " rgb = slim.conv2d(net, num_output_channels, activation_fn=tf.nn.sigmoid,\n", " weights_initializer=tf.zeros_initializer())\n", " return rgb" ], "execution_count": 4, "outputs": [] }, { "metadata": { "id": "Y46FuT_jTymC", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Estimating the number of parameters of CPPN." ] }, { "metadata": { "id": "MmNfy9ItSZl6", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "outputId": "f2561b6c-edf9-4f6a-98e7-fb08d8c8fcab" }, "cell_type": "code", "source": [ "with tf.Graph().as_default():\n", " image_cppn(224)\n", " variables = tf.get_collection('variables')\n", " param_n = sum([v.shape.num_elements() for v in variables])\n", " print('CPPN parameter count:', param_n)" ], "execution_count": 36, "outputs": [ { "output_type": "stream", "text": [ "CPPN parameter count: 8451\n" ], "name": "stdout" } ] }, { "metadata": { "id": "NyoBLiF0BohH", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Let's quickly sanity check that this CPPN can learn to produce an image with the properties we expect it to.\n", "As a simplistic test we try to fit the XOR function by imposing a loss on four corner points of the image." ] }, { "metadata": { "id": "3q4yq_O97m_Z", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 105 }, "outputId": "98948003-2952-484a-84a1-3915661480ad" }, "cell_type": "code", "source": [ "cppn_f = lambda: image_cppn(64)\n", "optimizer = tf.train.AdamOptimizer(0.01)\n", "\n", "def xor_objective(T):\n", " a = T('input')[0]\n", " return -(tf.square(a[0, 0]) + tf.square(a[-1, -1]) + \n", " tf.square(1.0-a[-1, 0]) + tf.square(1.0-a[0, -1]))\n", "\n", "vis = render.render_vis(model, xor_objective, param_f=cppn_f, optimizer=optimizer, transforms=[], thresholds=range(10), verbose=False)\n", "show(vis)" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
\n", " 8
\n", " \n", "
\n", " 9
\n", " \n", "
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "P_S6xacVTYTZ", "colab_type": "text" }, "cell_type": "markdown", "source": [ "That looks reasonable enough!\n", "Let's move on to our original goal: Feature Visualizations" ] }, { "metadata": { "id": "7E9-h7CETPOo", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# Feature Visualization\n", "\n", "Let's use our new CPPN to produce one of the feature visualizations similar to those in the header image:" ] }, { "metadata": { "id": "Xijb4mGhSqie", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 285 }, "outputId": "89bea7e5-9487-4588-e885-855d2b790acb" }, "cell_type": "code", "source": [ "def render_feature(\n", " cppn_f = lambda: image_cppn(224),\n", " optimizer = tf.train.AdamOptimizer(0.005),\n", " objective = objectives.channel(\"mixed4b_3x3_pre_relu\", 77)):\n", " vis = render.render_vis(model, objective, param_f=cppn_f, optimizer=optimizer, transforms=[], thresholds=[2**i for i in range(5,10)], verbose=False)\n", " show(vis)\n", "\n", "render_feature()" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "UjboH7xgp6pt", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### Varying the activation function" ] }, { "metadata": { "id": "Edk42JM5RCn9", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 285 }, "outputId": "25a8a6d1-4cd4-43cd-cd63-798593c37bfe" }, "cell_type": "code", "source": [ "render_feature(\n", " cppn_f=lambda: image_cppn(224, activation_fn=composite_activation_unbiased))" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "28EP60C7RSZt", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 285 }, "outputId": "33de8a12-ea7e-43d3-fdf1-e668ca86a11f" }, "cell_type": "code", "source": [ "render_feature(\n", " cppn_f=lambda: image_cppn(224, activation_fn=relu_normalized))" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "WhA1fq9fSGx4", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 285 }, "outputId": "3fee825b-9ce4-44af-b1e4-d40a0bdc2b89" }, "cell_type": "code", "source": [ "render_feature(\n", " cppn_f=lambda: image_cppn(224, activation_fn=tf.abs))" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "aKgw-z8TSTfc", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# Video of the training process" ] }, { "metadata": { "id": "EGI2TEncO55Z", "colab_type": "text" }, "cell_type": "markdown", "source": [ "The following `render_story` function accomplishes a bunch of things: it sets up the optimization problem, saves out frames to a video at each step of the optimization, and finally saves out the weights and the final optimization result." ] }, { "metadata": { "id": "yt88aEudVclq", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "from lucid.misc.io.serialize_array import _normalize_array\n", "\n", "def render_story(obj_str, lr=0.004, step_n=512,\n", " normalize=False,\n", " activation_fn=composite_activation):\n", " sess = create_session()\n", "\n", " # Set up optimization problem\n", " size = 224\n", " t_size = tf.placeholder_with_default(size, [])\n", " T = render.make_vis_T(\n", " model, obj_str, \n", " param_f=lambda: image_cppn(\n", " t_size, normalize=normalize, activation_fn=activation_fn),\n", " transforms=[],\n", " optimizer=tf.train.AdamOptimizer(lr),\n", " )\n", " tf.global_variables_initializer().run()\n", "\n", " # Prepare video writer and filenames\n", " subst = {ord(':'):'_', ord('/'):'_'}\n", " out_name = 'xy2rgb_' + obj_str.translate(subst)\n", " video_fn = out_name + '.mp4'\n", " writer = FFMPEG_VideoWriter(video_fn, (size, size), 60.0)\n", "\n", " # Optimization loop\n", " try:\n", " for i in range(step_n):\n", " _, loss, img = sess.run([T(\"vis_op\"), T(\"loss\"), T(\"input\")])\n", " writer.write_frame(_normalize_array(img))\n", " if i > 0 and i % 50 == 0:\n", " clear_output()\n", " print(\"%d / %d score: %f\"%(i, step_n, loss))\n", " show(img)\n", " except KeyboardInterrupt:\n", " pass\n", " finally:\n", " writer.close()\n", "\n", " # Show the resulting video\n", " clear_output()\n", " display(mpy.ipython_display(video_fn, height=400))\n", "\n", " # Save trained variables\n", " train_vars = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)\n", " params = np.array(sess.run(train_vars), object)\n", " save(params, out_name + '.npy')\n", "\n", " # Save final image\n", " final_img = T(\"input\").eval({t_size: 400})\n", " save(final_img, out_name+'.jpg', quality=90)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "Opq53mH1v8TF", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 420 }, "outputId": "90e466e0-8999-4b83-80f2-1d2aa15041d4" }, "cell_type": "code", "source": [ "render_story('mixed4b_pool_reduce_pre_relu:16', )" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "7kFImFAtiOzw", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 420 }, "outputId": "36a11182-379a-47ab-c104-45f31ca2f5ed" }, "cell_type": "code", "source": [ "render_story('mixed4a_pool_reduce_pre_relu:52')" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "6x-kPRJzisK7", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### More interesting patterns to try" ] }, { "metadata": { "id": "vvBmiE0u2tjh", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "render_story('mixed4c_pool_reduce_pre_relu:5')" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "JicpKD-R3Z4D", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "render_story('softmax0_pre_activation/matmul:316')" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "_SoUHUXz7qya", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "render_story('mixed4b_3x3_pre_relu:77')" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "f6M6pdsUiHYy", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "render_story('mixed4d_3x3_bottleneck_pre_relu:114')" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "TTxCQAM48-Lw", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "render_story('mixed4e_3x3_pre_relu:120')" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "e2Yd6GXw2IsJ", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "render_story('head0_bottleneck_pre_relu:0')" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "mZPmtOBMCPe3", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "render_story('mixed4d_3x3_bottleneck_pre_relu:139')" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "w_RY_2rskGDj", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# Arbitrary resolution images" ] }, { "metadata": { "id": "xZA1me_XhNPu", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 91 }, "outputId": "ed62713f-c4cf-4db3-dc23-972927bcd0ba" }, "cell_type": "code", "source": [ "sess = create_session()\n", "t_size = tf.placeholder_with_default(224, [])\n", "t_image = image_cppn(t_size)\n", "\n", "train_vars = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)\n", "\n", "def render_params(params, size=224):\n", " feed_dict = dict(zip(train_vars, params))\n", " feed_dict[t_size] = size\n", " return sess.run(t_image, feed_dict)[0]" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "WARNING:py.warnings:/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py:1714: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n", " warnings.warn('An interactive session is already active. This can '\n", "\n" ], "name": "stderr" } ] }, { "metadata": { "id": "-tr5-Svfnfx1", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 820 }, "outputId": "546e59ca-2262-45f3-fc8e-040a222896e1" }, "cell_type": "code", "source": [ "params = load('xy2rgb_mixed4b_pool_reduce_pre_relu_16.npy')\n", "vis = render_params(params, 800)\n", "show(vis)" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "6iXpbL5knhYq", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 820 }, "outputId": "db430114-fb75-402a-ee3a-465d6423a5e4" }, "cell_type": "code", "source": [ "params = load('xy2rgb_mixed4a_pool_reduce_pre_relu_52.npy')\n", "vis = render_params(params, 800)\n", "show(vis)" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "FiJ872uviwXP", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# Interpolating CPPN parameters\n", "\n", "This requires you to have run the `render_story` method from the \"Video of the training process\" section—it saves out the learned parameters which we are interpolating between here." ] }, { "metadata": { "id": "wDKPiUvz4ETE", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "def interpolate_params(param1, param2, duration=5.0, size=224):\n", "\n", " def frame(t):\n", " t = t / duration\n", " t = (1.0-np.cos(2.0*np.pi*t))/2.0 # looping & easing\n", " params = param1*(1.0-t) + param2*t # blending\n", " params *= 1.0 + t*(1.0-t) # exaggerating\n", " img = render_params(params, size=size)\n", " return _normalize_array(img)\n", "\n", " clip = mpy.VideoClip(frame, duration=duration)\n", " clip.write_videofile('tmp.mp4', fps=30.0)\n", " display(mpy.ipython_display('tmp.mp4', height=400))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "SXHgySvfY4j6", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 528 }, "outputId": "2abefba8-f5d7-4e43-8345-37b059532210" }, "cell_type": "code", "source": [ "interpolate_params(\n", " load('xy2rgb_mixed4b_pool_reduce_pre_relu_16.npy'),\n", " load('xy2rgb_mixed4a_pool_reduce_pre_relu_52.npy'),\n", ")" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "[MoviePy] >>>> Building video tmp.mp4\n", "[MoviePy] Writing video tmp.mp4\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ " 99%|█████████▉| 150/151 [00:01<00:00, 83.44it/s]\n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "[MoviePy] Done.\n", "[MoviePy] >>>> Video ready: tmp.mp4 \n", "\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] } ] }