{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/Cellar/python/3.6.5/Frameworks/Python.framework/Versions/3.6/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n", " return f(*args, **kwds)\n" ] } ], "source": [ "import collections\n", "\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import sklearn.metrics\n", "import sklearn.preprocessing\n", "import tensorflow as tf\n", "\n", "from vae_bernouilli import VariationalMnist, train_vae_bernouilli, encoding\n", "\n", "#sns.set(font_scale=1.5, palette='colorblind')\n", "sns.reset_orig()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "vae = VariationalMnist(28, 3)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/srom/workspace/ds/env/lib/python3.6/site-packages/sklearn/utils/validation.py:475: DataConversionWarning: Data with input dtype uint8 was converted to float64 by StandardScaler.\n", " warnings.warn(msg, DataConversionWarning)\n", "/Users/srom/workspace/ds/env/lib/python3.6/site-packages/sklearn/utils/validation.py:475: DataConversionWarning: Data with input dtype uint8 was converted to float64 by StandardScaler.\n", " warnings.warn(msg, DataConversionWarning)\n", "/Users/srom/workspace/ds/env/lib/python3.6/site-packages/sklearn/utils/validation.py:475: DataConversionWarning: Data with input dtype uint8 was converted to float64 by StandardScaler.\n", " warnings.warn(msg, DataConversionWarning)\n" ] } ], "source": [ "scaler = sklearn.preprocessing.StandardScaler()\n", "\n", "inverse_scale = lambda x, scaler: scaler.inverse_transform(x.reshape(-1, 28*28)).reshape(-1, 28, 28)\n", "\n", "x_train_n = scaler.fit_transform(x_train.reshape(-1, 28*28)).reshape(-1, 28, 28)\n", "x_test_n = scaler.transform(x_test.reshape(-1, 28*28)).reshape(-1, 28, 28)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./vae_bernouilli.ckpt\n", "0 / 10000: latest = 534.1364041541162 | best = 534.1364041541162\n", "1000 / 10000: latest = 307.92270836335905 | best = 534.1364041541162\n", "2000 / 10000: latest = -1526.231314344567 | best = 534.1364041541162\n", "3000 / 10000: latest = -1504.1705758529235 | best = 534.1364041541162\n", "4000 / 10000: latest = -4932.884539762535 | best = 534.1364041541162\n", "5000 / 10000: latest = -13459.881697788984 | best = 534.1364041541162\n", "6000 / 10000: latest = -14056.31020361178 | best = 534.1364041541162\n", "7000 / 10000: latest = -15022.571205475098 | best = 534.1364041541162\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m save_path, best_elbo = train_vae_bernouilli(\n\u001b[0;32m----> 2\u001b[0;31m vae, x_train_n, x_test_n, 0.0001, 10000, 100, save_path='./vae_bernouilli.ckpt', best_elbo=532.865)\n\u001b[0m", "\u001b[0;32m~/workspace/ds/vae/vae_bernouilli.py\u001b[0m in \u001b[0;36mtrain_vae_bernouilli\u001b[0;34m(vae, x_train, x_test, learning_rate, n_epochs, batch_size, log_every, name, save_path, best_elbo)\u001b[0m\n\u001b[1;32m 118\u001b[0m feed_dict={\n\u001b[1;32m 119\u001b[0m \u001b[0mvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx_batch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 120\u001b[0;31m \u001b[0mvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlearning_rate\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 121\u001b[0m },\n\u001b[1;32m 122\u001b[0m )\n", "\u001b[0;32m~/workspace/ds/env/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 875\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 876\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 877\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 878\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 879\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/workspace/ds/env/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1099\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m-> 1100\u001b[0;31m feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[1;32m 1101\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1102\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/workspace/ds/env/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1252\u001b[0m \"\"\"\n\u001b[1;32m 1253\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1254\u001b[0;31m \u001b[0mfeeds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_as_tf_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1255\u001b[0m \u001b[0mfetches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_as_tf_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1256\u001b[0m \u001b[0mtargets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_c_op\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mop\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/workspace/ds/env/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 1252\u001b[0m \"\"\"\n\u001b[1;32m 1253\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1254\u001b[0;31m \u001b[0mfeeds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_as_tf_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1255\u001b[0m \u001b[0mfetches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_as_tf_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1256\u001b[0m \u001b[0mtargets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_c_op\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mop\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/workspace/ds/env/lib/python3.6/site-packages/tensorflow/python/framework/ops.py\u001b[0m in \u001b[0;36m_as_tf_output\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 569\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_as_tf_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 570\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 571\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mc_api_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtf_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_c_op\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 572\u001b[0m \u001b[0;31m# pylint: enable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/workspace/ds/env/lib/python3.6/site-packages/tensorflow/python/framework/c_api_util.py\u001b[0m in \u001b[0;36mtf_output\u001b[0;34m(c_op, index)\u001b[0m\n\u001b[1;32m 182\u001b[0m \u001b[0mWrapped\u001b[0m \u001b[0mTF_Output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 183\u001b[0m \"\"\"\n\u001b[0;32m--> 184\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mc_api\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_Output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 185\u001b[0m \u001b[0mret\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moper\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mc_op\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0mret\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/workspace/ds/env/lib/python3.6/site-packages/tensorflow/python/pywrap_tensorflow_internal.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 957\u001b[0m \u001b[0mthis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_pywrap_tensorflow_internal\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_TF_Output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 958\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 959\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mthis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mthis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 960\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 961\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mthis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mthis\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/workspace/ds/env/lib/python3.6/site-packages/tensorflow/python/pywrap_tensorflow_internal.py\u001b[0m in \u001b[0;36m\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 943\u001b[0m \u001b[0m__setattr__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0m_swig_setattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTF_Output\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 944\u001b[0m \u001b[0m__swig_getmethods__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 945\u001b[0;31m \u001b[0m__getattr__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0m_swig_getattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTF_Output\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 946\u001b[0m \u001b[0m__repr__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_swig_repr\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 947\u001b[0m \u001b[0m__swig_setmethods__\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"oper\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_pywrap_tensorflow_internal\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_Output_oper_set\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/workspace/ds/env/lib/python3.6/site-packages/tensorflow/python/pywrap_tensorflow_internal.py\u001b[0m in \u001b[0;36m_swig_getattr\u001b[0;34m(self, class_type, name)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_swig_getattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclass_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_swig_getattr_nondynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclass_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/workspace/ds/env/lib/python3.6/site-packages/tensorflow/python/pywrap_tensorflow_internal.py\u001b[0m in \u001b[0;36m_swig_getattr_nondynamic\u001b[0;34m(self, class_type, name, static)\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 62\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0m_swig_getattr_nondynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclass_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatic\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 63\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"thisown\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mthis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mown\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "save_path, best_elbo = train_vae_bernouilli(\n", " vae, x_train_n, x_test_n, 0.0001, 10000, 100, save_path='./vae_bernouilli.ckpt', best_elbo=532.865)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x_encoded = encoding(vae, x_test_n[:300], './vae_bernouilli.ckpt')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_encoding(X_code, y, figsize=(15, 8)):\n", " fig = plt.figure(figsize=figsize)\n", " ax = fig.add_subplot(111, projection='3d')\n", " \n", " vals = np.unique(y)\n", " \n", " data_per_class = collections.defaultdict(list)\n", " for i, v in enumerate(y):\n", " x_c = X_code[i]\n", " data_per_class[v].append(x_c)\n", " \n", " for v in sorted(data_per_class.keys()):\n", " data = data_per_class[v]\n", " x1 = [i for i, _, _ in data]\n", " x2 = [j for _, j, _ in data]\n", " x3 = [k for _, _, k in data]\n", " ax.scatter(x1, x2, x3, label=v)\n", " \n", " ax.legend()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_encoding(x_encoded, y_test[:300])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def generate(vae, num_samples, save_path):\n", " saver = tf.train.Saver()\n", " with tf.Session() as sess:\n", " saver.restore(sess, save_path)\n", " return sess.run(vae.samples, {vae.num_samples: num_samples})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "digits_n = generate(vae, 10, save_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "digits = inverse_scale(digits_n, scaler).reshape((-1, 28, 28))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.imshow(digits[1], cmap='Greys_r')\n", "plt.grid(False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "best_elbo" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }