{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Background experiments for transparent feature vis", "version": "0.3.2", "views": {}, "default_view": {}, "provenance": [ { "file_id": "1mEhS8mZVKxO3HstKUhjWQ4fOleXOJUfk", "timestamp": 1530574212481 } ] }, "kernelspec": { "name": "python2", "display_name": "Python 2" }, "accelerator": "GPU" }, "cells": [ { "metadata": { "id": "JndnmDMp66FL", "colab_type": "text" }, "cell_type": "markdown", "source": [ "##### Copyright 2018 Google LLC.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "metadata": { "id": "hMqWDc_m6rUC", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "cellView": "both" }, "cell_type": "code", "source": [ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "GS4kqcEJuy5R", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# Background experiments for Generation of Semi-Transparent Patterns\n", "\n", "This notebook uses [**Lucid**](https://github.com/tensorflow/lucid) to produce semi-transparent feature visualizations that are similar in spirit to the feature visualizations in [Differentiable Image Parameterizations](https://distill.pub/2018/differentiable-parameterizations/#section-rgba). \n", "\n", "Here, we experimented with different backgrounds, and found no particularly interesting dependency: as long as the background was random, the resulting visualizations created the kind of alpha masks we expected. **In a sense, this is a null result, and we provide the code primarily for interested tinkerers.**\n", "\n", "This notebook doesn't introduce the abstractions behind lucid; you may wish to also read the [Lucid tutorial](https://colab.research.google.com/github/tensorflow/lucid/blob/master/notebooks/tutorial.ipynb).\n", "\n", "**Note**: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup. We recommend you enable a free GPU by going:\n", "\n", "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**" ] }, { "metadata": { "id": "FsFc1mE51tCd", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Install, Import, Load Model" ] }, { "metadata": { "id": "tavMPe3KQ8Cs", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } } }, "cell_type": "code", "source": [ "# Install Lucid\n", "\n", "!pip install --quiet lucid" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "RBr8QbboRAdU", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } } }, "cell_type": "code", "source": [ "# Imports\n", "\n", "import numpy as np\n", "import tensorflow as tf\n", "\n", "import lucid.modelzoo.vision_models as models\n", "from lucid.misc.io import show\n", "import lucid.optvis.objectives as objectives\n", "import lucid.optvis.param as param\n", "import lucid.optvis.render as render\n", "import lucid.optvis.transform as transform" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "yNALaA0QRJVT", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } } }, "cell_type": "code", "source": [ "# Let's import a model from the Lucid modelzoo!\n", "\n", "model = models.InceptionV1()\n", "model.load_graphdef()" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "zsmsB_DpWAhb", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# Preparation" ] }, { "metadata": { "id": "8MyNeseOYMHN", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Lucid's `lucid.optvis.param`(*eterizations*) provides `image_sample` to get a random background:" ] }, { "metadata": { "id": "FoaTlZs254M0", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 721 }, "outputId": "123ec376-01dd-4e2f-a328-23b57555fe5e", "executionInfo": { "status": "ok", "timestamp": 1530575238620, "user_tz": 420, "elapsed": 4446, "user": { "displayName": "Ludwig Schubert", "photoUrl": "//lh4.googleusercontent.com/-JSZvF3zetaM/AAAAAAAAAAI/AAAAAAAAAA0/DioB29jA0U0/s50-c-k-no/photo.jpg", "userId": "106277933620557364646" } } }, "cell_type": "code", "source": [ "with tf.Graph().as_default(), tf.Session():\n", " \n", " for decorrelate in [False, True]:\n", " print(\"\\nDecorrelate: {}\".format(decorrelate))\n", " \n", " for decay_power in [1, 1.5, 2]:\n", " print(\"Decay power: {}\".format(decay_power))\n", " decay_power_reg = 20**(-decay_power + 1)\n", " images = [param.image_sample([1, 64, 64, 3], sd=decay_power_reg*n/20., decay_power=decay_power, decorrelate=decorrelate).eval() for n in range(10)]\n", " show(images)" ], "execution_count": 10, "outputs": [ { "output_type": "stream", "text": [ "\n", "Decorrelate: False\n", "Decay power: 1\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
\n", " 8
\n", " \n", "
\n", " 9
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "Decay power: 1.5\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
\n", " 8
\n", " \n", "
\n", " 9
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "Decay power: 2\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
\n", " 8
\n", " \n", "
\n", " 9
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Decorrelate: True\n", "Decay power: 1\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
\n", " 8
\n", " \n", "
\n", " 9
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "Decay power: 1.5\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
\n", " 8
\n", " \n", "
\n", " 9
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "Decay power: 2\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
\n", " 8
\n", " \n", "
\n", " 9
\n", " \n", "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "QmsP8FUQmT4Q", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# Experiment 1\n", "\n", "Here we try out our alpha parameterization on different backgrounds. Some are merely random, but drawn from different distributions, while others are optimized—either supporting the objective or opposing it. A weak hypothesis we held when starting out was that an adversarial background would really help narrow down which parts of the foreground needed to be opague. While this does seem to work, in expectation it doesn't seem to work better than a random background—in fact it appears that an adversarial background may force the foreground to \"defend\" itself by making areas near the object opaque.\n", "\n", "## Setup" ] }, { "metadata": { "id": "vJsD1-aZVvrl", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } } }, "cell_type": "code", "source": [ "def make_bg_func(noise_ratio=0.5, decay_power=1.5, base_sd=0.8,\n", " decorrelate=True, var_mode=\"image\", var_image_jitter=0):\n", " def bg_func(w):\n", " if var_mode ==\"image\":\n", " var_img = param.image(w+var_image_jitter, decorrelate=decorrelate)\n", " var_img = transform.jitter(var_image_jitter)(var_img)\n", " elif var_mode ==\"color\":\n", " var_img = param.image(5, decorrelate=decorrelate)[:, 3:4, 3:4, :]\n", " noise = param.image_sample([1, w, w, 3], sd=base_sd*20**(-decay_power + 1), decay_power=decay_power, decorrelate=decorrelate)\n", " return (1-noise_ratio)*var_img + noise_ratio*noise\n", " return bg_func\n", "\n" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "Q67HH_m0f_Hb", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } } }, "cell_type": "code", "source": [ "def alpha_experiment(layer, neuron, bg_func, alpha_loss_coef=1e4, jitter_n=2, w=128, bg_adv=True, transforms=None):\n", "\n", " with tf.Graph().as_default(), tf.Session():\n", "\n", " t_img = param.image(w, alpha=True)\n", " t_bg = bg_func(w)\n", " \n", " # Compose bg and image -- this includes\n", " # * jittering rgb, alpha, and background relative to each other\n", " # * Blocking gradients so that we can optimize foreground/background separately\n", " t_rgb, t_alpha = t_img[..., :3], t_img[..., 3:4]\n", " jitter = transform.jitter(jitter_n)\n", " t_rgb_, t_alpha_, t_bg_ = jitter(t_rgb), jitter(t_alpha), jitter(t_bg)\n", " t_flat = t_rgb_*t_alpha_ + (1-t_alpha_)*tf.stop_gradient(t_bg_)\n", " t_flat_ = tf.stop_gradient(t_rgb_*t_alpha_) + tf.stop_gradient(1-t_alpha_)*t_bg_\n", " t_inp = tf.concat([t_flat, t_flat_], axis=0)\n", "\n", " # Create the objective\n", " t_alpha_mean = tf.reduce_mean(t_alpha)\n", " obj = objectives.channel(layer, neuron, batch=0)\n", " if bg_adv:\n", " obj -= objectives.channel(layer, neuron, batch=1)\n", " #obj -= objectives.neuron(layer, neuron, batch=1)\n", " else:\n", " obj += objectives.channel(layer, neuron, batch=1)\n", " obj += alpha_loss_coef*objectives.Objective(lambda T: -tf.square(t_alpha_mean-0.25) )\n", "\n", " # Optimize the visualization\n", " T = render.make_vis_T(model, obj, t_inp, transforms=transforms)\n", " tf.global_variables_initializer().run()\n", " for i in range(512):\n", " #if i%16 == 0: print \".\",\n", " T(\"vis_op\").run()\n", "\n", " # Show the visualization\n", " img = t_img.eval()\n", " rgb, alpha = img[..., :3], img[..., 3:4]\n", " #show([img, rgb*alpha + (1-alpha)*(0.5+0.5*rgb), rgb, 1-alpha, t_bg.eval(), t_flat.eval()])\n", " show(np.hstack([(rgb*alpha + 1-alpha)[0], t_bg.eval()[0]]))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "KF6aCRLwZqaH", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Let's make sure these two work together:" ] }, { "metadata": { "id": "8meyKA8bvxDL", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 277 }, "outputId": "375edbe1-c1ea-44f3-e0b4-1f9cae1ccdb0", "executionInfo": { "status": "ok", "timestamp": 1521824073860, "user_tz": 420, "elapsed": 19673, "user": { "displayName": "Christopher Olah", "photoUrl": "//lh6.googleusercontent.com/-BDHAgNAk34E/AAAAAAAAAAI/AAAAAAAAAMw/gTWZ3IeP8dY/s50-c-k-no/photo.jpg", "userId": "104989755527098071788" } } }, "cell_type": "code", "source": [ " alpha_experiment(\"mixed4d_3x3_bottleneck_pre_relu\", 139,\n", " bg_func=make_bg_func(noise_ratio=1.0, decay_power=1.5, base_sd=0.4, decorrelate=True),\n", " alpha_loss_coef=1e4,\n", " jitter_n=1,\n", " bg_adv=False,\n", " w=256,\n", " transforms=[transform.jitter(8), transform.jitter(8), transform.jitter(8), transform.random_scale([0.95, 0.98, 1.0, 1.02, 1.05]), transform.jitter(8)]\n", " )" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "1ksSMk1BZcAV", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Test" ] }, { "metadata": { "id": "arcvtNmmZ4ij", "colab_type": "text" }, "cell_type": "markdown", "source": [ "A small helper to run a list of different backgrounds on a single neuron :" ] }, { "metadata": { "id": "uSk7oxxGV4Lj", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } } }, "cell_type": "code", "source": [ "def test(layer, neuron, alpha_loss_coef=1e4, jitter_n=2):\n", " \n", " print \"Adverserial Background Image Full\"\n", " alpha_experiment(layer, neuron,\n", " bg_func=make_bg_func(noise_ratio=0.0, decay_power=1.5, base_sd=0.3, decorrelate=True, var_mode=\"image\", var_image_jitter=0),\n", " alpha_loss_coef=alpha_loss_coef,\n", " jitter_n=jitter_n,\n", " bg_adv=True\n", " )\n", " \n", "\n", " print \"\"\n", " print \"Adverserial Background Image Random Offset\"\n", " alpha_experiment(layer, neuron,\n", " bg_func=make_bg_func(noise_ratio=0.0, decay_power=1.5, base_sd=0.3, decorrelate=True, var_mode=\"image\", var_image_jitter=128),\n", " alpha_loss_coef=alpha_loss_coef,\n", " jitter_n=jitter_n,\n", " bg_adv=True\n", " )\n", " \n", "\n", " print \"\"\n", " print \"Adverserial Background Color\"\n", " alpha_experiment(layer, neuron,\n", " bg_func=make_bg_func(noise_ratio=0.0, decay_power=1.5, base_sd=0.3, decorrelate=True, var_mode=\"color\"),\n", " alpha_loss_coef=alpha_loss_coef,\n", " jitter_n=jitter_n,\n", " bg_adv=True\n", " )\n", " \n", "\n", " print \"\"\n", " print \"Noise Background\"\n", " alpha_experiment(layer, neuron,\n", " bg_func=make_bg_func(noise_ratio=1.0, decay_power=2, base_sd=0.1, decorrelate=True, var_mode=\"color\"),\n", " alpha_loss_coef=alpha_loss_coef,\n", " jitter_n=jitter_n,\n", " bg_adv=True\n", " )\n", " \n", "\n", " print \"\"\n", " print \"Allied Background Color\"\n", " alpha_experiment(layer, neuron,\n", " bg_func=make_bg_func(noise_ratio=0.0, decay_power=1.5, base_sd=0.3, decorrelate=True, var_mode=\"color\"),\n", " alpha_loss_coef=alpha_loss_coef,\n", " jitter_n=jitter_n,\n", " bg_adv=False\n", " )\n", " \n", "\n", " print \"\"\n", " print \"Allied Background Image Random Offset\"\n", " alpha_experiment(layer, neuron,\n", " bg_func=make_bg_func(noise_ratio=0.0, decay_power=1.5, base_sd=0.3, decorrelate=True, var_mode=\"image\", var_image_jitter=128),\n", " alpha_loss_coef=alpha_loss_coef,\n", " jitter_n=jitter_n,\n", " bg_adv=False\n", " )\n", " \n", " print \"\"\n", " print \"Allied Background Image Full\"\n", " alpha_experiment(layer, neuron,\n", " bg_func=make_bg_func(noise_ratio=0.0, decay_power=1.5, base_sd=0.3, decorrelate=True, var_mode=\"image\", var_image_jitter=0),\n", " alpha_loss_coef=alpha_loss_coef,\n", " jitter_n=jitter_n,\n", " bg_adv=False\n", " )\n", " \n", " \n", " \n", " \n", " " ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "pxt3CjC9aDmY", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Let's see this on three different neurons:" ] }, { "metadata": { "id": "s10iISNGtNV9", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 1169 }, "outputId": "c58f097c-1835-4e5f-8389-ad59d56141e2", "executionInfo": { "status": "ok", "timestamp": 1521823041229, "user_tz": 420, "elapsed": 75588, "user": { "displayName": "Christopher Olah", "photoUrl": "//lh6.googleusercontent.com/-BDHAgNAk34E/AAAAAAAAAAI/AAAAAAAAAMw/gTWZ3IeP8dY/s50-c-k-no/photo.jpg", "userId": "104989755527098071788" } } }, "cell_type": "code", "source": [ "test(\"mixed4d_3x3_bottleneck_pre_relu\", 139)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Adverserial Background Image Full\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Adverserial Background Image Random Offset\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Adverserial Background Color\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Noise Background\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Allied Background Color\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Allied Background Image Random Offset\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Allied Background Image Full\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "UjLv57W3gDXZ", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 1168 }, "outputId": "2b09fd4f-2638-4d82-af94-5bdbb3d11e35", "executionInfo": { "status": "ok", "timestamp": 1521585438313, "user_tz": 420, "elapsed": 67802, "user": { "displayName": "Christopher Olah", "photoUrl": "//lh6.googleusercontent.com/-BDHAgNAk34E/AAAAAAAAAAI/AAAAAAAAAMw/gTWZ3IeP8dY/s50-c-k-no/photo.jpg", "userId": "104989755527098071788" } } }, "cell_type": "code", "source": [ "test(\"mixed4b_pool_reduce_pre_relu\", 16)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Adverserial Background Image Full\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Adverserial Background Image Random Offset\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Adverserial Background Color\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Noise Background\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Allied Background Color\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Allied Background Image Random Offset\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Allied Background Image Full\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "855HuXaZmP6h", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 1168 }, "outputId": "02b9010e-3520-494f-fa47-0a2bd1c75160", "executionInfo": { "status": "ok", "timestamp": 1521586160588, "user_tz": 420, "elapsed": 82377, "user": { "displayName": "Christopher Olah", "photoUrl": "//lh6.googleusercontent.com/-BDHAgNAk34E/AAAAAAAAAAI/AAAAAAAAAMw/gTWZ3IeP8dY/s50-c-k-no/photo.jpg", "userId": "104989755527098071788" } } }, "cell_type": "code", "source": [ "test(\"mixed4d_pre_relu\", 426,)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Adverserial Background Image Full\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Adverserial Background Image Random Offset\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Adverserial Background Color\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Noise Background\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Allied Background Color\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Allied Background Image Random Offset\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n", "Allied Background Image Full\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "xI_moOKkmRDt", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# Experiment 2\n", "\n", "One factor that made comparing the different backgrounds hard in experiment 1 is the same problem that \"Aligned Feature Vis interpolation\" addresses—visual landmarks, such as ears, are not in the same position between those visualizations.\n", "\n", "Experiment 2 tried to address this by introducing a shared paremeterization for the image foregrounds.\n", "\n", "We didn't experiment with this very much, and it apperas that in our initial attempts the shared parameterization may have been too strong, causing the different visualizations to be essentially identical, instead of just aligned. **If one wanted to explore this direction further, that would be something to fiddle with.**\n", "\n", "## Setup" ] }, { "metadata": { "id": "YXsy9z5SVTpC", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } } }, "cell_type": "code", "source": [ "def alpha_experiment_2(layer, neuron, bg_list_func, bgs_adv=True, w=128, jitter_n=2, shared_param_coef=0.5):\n", " with tf.Graph().as_default(), tf.Session():\n", " \n", " bgs = bg_list_func(w)\n", " N = len(bgs)\n", " \n", " make_param = lambda: sum(param.lowres_tensor([w, w, 4], [w//k, w//k, 4]) for k in [1,2,4,8])/4.0 # <-- shared param is here\n", " shared_param = make_param()\n", " fgs = []\n", " for _ in range(N):\n", " fg_param = shared_param_coef*shared_param + make_param()\n", " rgb = param.to_valid_rgb(fg_param[..., :3], decorrelate=True)\n", " alpha = tf.nn.sigmoid(fg_param[..., 3:4])\n", " fgs.append(tf.concat([rgb, alpha], axis=-1))\n", " \n", " if isinstance(bgs_adv, bool):\n", " bgs_adv = N*[bgs_adv]\n", " \n", " flats = []\n", " obj = 0\n", " for n, (fg, bg, adv) in enumerate(zip(fgs, bgs, bgs_adv)):\n", " rgb, alpha = fg[..., :3], fg[..., 3:4]\n", " jitter = transform.jitter(jitter_n)\n", " rgb, alpha, bg = jitter(rgb), jitter(alpha), jitter(bg)\n", " flats.append(rgb*alpha + (1-alpha)*tf.stop_gradient(bg))\n", " flats.append(tf.stop_gradient(rgb*alpha) + tf.stop_gradient(1-alpha)*bg)\n", " obj += objectives.neuron(layer, neuron, batch=2*n)\n", " adv_sign = -1 if adv else 1\n", " obj += adv_sign * (objectives.neuron(layer, neuron, batch=2*n+1) + objectives.channel(layer, neuron, batch=2*n+1))\n", " t_alpha_mean = tf.reduce_mean(fg[..., 3:4])\n", " \n", " obj += 1e4*objectives.Objective(lambda T: -sum(tf.square( tf.reduce_mean(fg[..., 3:4]) -0.25) for fg in fgs ))\n", " \n", " # Optimize the visualization\n", " t_inp = tf.stack(flats)\n", " T = render.make_vis_T(model, obj, t_inp)\n", " tf.global_variables_initializer().run()\n", " for i in range(512):\n", " #if i%16 == 0: print \".\",\n", " T(\"vis_op\").run()\n", "\n", " show([t.eval() for t in flats[::2]])\n", " show([t.eval() for t in fgs])\n", " show([t.eval() for t in bgs])\n", " \n", " \n", " " ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "uU0t-ySSbm7X", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Once again we run a list of different backgrounds on the same neurons:" ] }, { "metadata": { "id": "iGpR0ZbFgGWY", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 1379 }, "outputId": "975940ff-7fb6-4120-ca8e-f5951364b34d", "executionInfo": { "status": "ok", "timestamp": 1521585862627, "user_tz": 420, "elapsed": 208056, "user": { "displayName": "Christopher Olah", "photoUrl": "//lh6.googleusercontent.com/-BDHAgNAk34E/AAAAAAAAAAI/AAAAAAAAAMw/gTWZ3IeP8dY/s50-c-k-no/photo.jpg", "userId": "104989755527098071788" } } }, "cell_type": "code", "source": [ "bg_list = [\n", " (True, make_bg_func(noise_ratio=0.0, decay_power=1.5, base_sd=0.3, decorrelate=True, var_mode=\"image\", var_image_jitter=2)),\n", " (True, make_bg_func(noise_ratio=0.0, decay_power=1.5, base_sd=0.3, decorrelate=True, var_mode=\"image\", var_image_jitter=128)),\n", " (True, make_bg_func(noise_ratio=0.0, decay_power=1.5, base_sd=0.3, decorrelate=True, var_mode=\"color\")),\n", " (True, make_bg_func(noise_ratio=1.0, decay_power=2, base_sd=0.4, decorrelate=True)),\n", " (True, make_bg_func(noise_ratio=1.0, decay_power=2, base_sd=0.1, decorrelate=True)),\n", " (False, make_bg_func(noise_ratio=0.0, decay_power=1.5, base_sd=0.3, decorrelate=True, var_mode=\"color\")),\n", " (False, make_bg_func(noise_ratio=0.0, decay_power=1.5, base_sd=0.3, decorrelate=True, var_mode=\"image\", var_image_jitter=128)),\n", " (False, make_bg_func(noise_ratio=0.0, decay_power=1.5, base_sd=0.3, decorrelate=True, var_mode=\"image\", var_image_jitter=2)),\n", " ]\n", "\n", "def bg_list_func(w):\n", " return [item[1](w)[0] for item in bg_list]\n", "\n", "adv_list = [item[0] for item in bg_list]\n", "\n", " \n", "alpha_experiment_2(\"mixed4b_pool_reduce_pre_relu\", 16, bg_list_func, bgs_adv=adv_list, jitter_n=2, shared_param_coef=0.0)\n", "alpha_experiment_2(\"mixed4b_pool_reduce_pre_relu\", 16, bg_list_func, bgs_adv=adv_list, jitter_n=2, shared_param_coef=0.5)\n", "\n", "alpha_experiment_2(\"mixed4d_pre_relu\", 426, bg_list_func, bgs_adv=adv_list, jitter_n=2)" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "1TCMU5pMjnoU", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 471 }, "outputId": "0c24bb81-9e1d-45c1-c125-b557b135b742", "executionInfo": { "status": "ok", "timestamp": 1521587230394, "user_tz": 420, "elapsed": 69412, "user": { "displayName": "Christopher Olah", "photoUrl": "//lh6.googleusercontent.com/-BDHAgNAk34E/AAAAAAAAAAI/AAAAAAAAAMw/gTWZ3IeP8dY/s50-c-k-no/photo.jpg", "userId": "104989755527098071788" } } }, "cell_type": "code", "source": [ "alpha_experiment_2(\"mixed4d_pre_relu\", 426, bg_list_func, bgs_adv=adv_list, jitter_n=2, shared_param_coef=0)" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "OE93Vl8UnEQS", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 471 }, "outputId": "581ea22e-42d8-4e03-99e7-89d8a000cafb", "executionInfo": { "status": "ok", "timestamp": 1521587065017, "user_tz": 420, "elapsed": 69916, "user": { "displayName": "Christopher Olah", "photoUrl": "//lh6.googleusercontent.com/-BDHAgNAk34E/AAAAAAAAAAI/AAAAAAAAAMw/gTWZ3IeP8dY/s50-c-k-no/photo.jpg", "userId": "104989755527098071788" } } }, "cell_type": "code", "source": [ "alpha_experiment_2(\"mixed4d_pre_relu\", 426, bg_list_func, bgs_adv=adv_list, jitter_n=2, shared_param_coef=1)" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " 0
\n", " \n", "
\n", " 1
\n", " \n", "
\n", " 2
\n", " \n", "
\n", " 3
\n", " \n", "
\n", " 4
\n", " \n", "
\n", " 5
\n", " \n", "
\n", " 6
\n", " \n", "
\n", " 7
\n", " \n", "
" ] }, "metadata": { "tags": [] } } ] } ] }