{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Variational Auto-Encoder\n", "\n", "Work in progress." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n", "ConX, version 3.7.3\n" ] } ], "source": [ "import conx as cx\n", "import keras.backend as K" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need a function to use as the activation function for the Sampler layer:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "LENGTH = 5 # latent size" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def sampler(inputs):\n", " ## inputs is a merged concat\n", " mean, stddev = inputs[:, :LENGTH], inputs[:, LENGTH:]\n", " # we sample from the standard normal a matrix of batch_size * latent_size (taking into account minibatches)\n", " std_norm = K.random_normal(shape=(K.shape(mean)[0], LENGTH), mean=0, stddev=1)\n", " # sampling from Z~N(μ, σ^2) is the same as sampling from μ + σX, X~N(0,1)\n", " return mean + K.exp(stddev) * std_norm" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "CAPACITY = 5 # size of encoded bank\n", "BETA = 1.5\n", "\n", "def bvae_loss(tensor):\n", " print(\"HERE!\", tensor.shape)\n", " LENGTH = tensor.shape[1]\n", " if LENGTH == 10:\n", " mean, stddev = tensor[:,:LENGTH//2], tensor[:,LENGTH//2:]\n", " else:\n", " mean, stddev = tensor, tensor\n", " # kl divergence:\n", " latent_loss = -0.5 * K.mean(1 + stddev\n", " - K.square(mean)\n", " - K.exp(stddev), axis=-1)\n", " # use beta to force less usage of vector space:\n", " # also try to use dimensions of the space:\n", " latent_loss = BETA * K.abs(latent_loss - CAPACITY/LENGTH)\n", " return K.sum(latent_loss)\n", "\n", "def vae_loss(tensor):\n", " print(\"HERE!\", tensor.shape)\n", " LENGTH = tensor.shape[1]\n", " if LENGTH == 10:\n", " mean, stddev = tensor[:,:LENGTH//2], tensor[:,LENGTH//2:]\n", " else:\n", " mean, stddev = tensor, tensor\n", " # kl divergence:\n", " latent_loss = -0.5 * K.mean(1 + stddev\n", " - K.square(mean)\n", " - K.exp(stddev), axis=-1)\n", " return K.sum(latent_loss)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "net = cx.Network(\"vae\")\n", "net.add(cx.Layer(\"input\", 2),\n", " cx.Layer(\"mean\", LENGTH, activation=\"sigmoid\"),\n", " cx.Layer(\"stddev\", LENGTH, activation=\"sigmoid\"),\n", " cx.LambdaLayer(\"encode\", 5, sampler), # function, that takes input layer's output\n", " cx.Layer(\"output\", 1, activation=\"sigmoid\"));\n", "#net.additional_output_banks = [\"encode\"]" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "net.connect(\"input\", \"mean\")\n", "net.connect(\"input\", \"stddev\")\n", "net.connect(\"mean\", \"encode\")\n", "net.connect(\"stddev\", \"encode\")\n", "net.connect(\"encode\", \"output\")" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "net.build_model()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "HERE! (?, 5)\n" ] } ], "source": [ "net.add_loss(\"encode\", vae_loss)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "net.compile(loss=\"mse\", optimizer=\"adam\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To allow an additional error function, we need to declare \"encode\" (an internal bank) as an output:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And then we can provide a dictionary of error functions by name:" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " console.log(\"received!\")\n", " console.log(msg)\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"xlink:href\"]) {\n", " var xlinkns=\"http://www.w3.org/1999/xlink\";\n", " images[i].setAttributeNS(xlinkns, \"href\", data[\"xlink:href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " Layer: output (output)\n", " output range: (0, 1)\n", " shape = (1,)\n", " Keras class = Dense\n", " activation = sigmoidoutputWeights from encode to output\n", " output_2/kernel:0 has shape (5, 1)\n", " output_2/bias:0 has shape (1,)Layer: encode (hidden)\n", " output range: (-Infinity, +Infinity)\n", " shape = (5,)\n", " Keras class = Lambda\n", " function = <function sampler at 0x7f30c800bb70>encodeWeights from mean to encodeLayer: mean (hidden)\n", " output range: (0, 1)\n", " shape = (5,)\n", " Keras class = Dense\n", " activation = sigmoidmeanWeights from stddev to encodeLayer: stddev (hidden)\n", " output range: (0, 1)\n", " shape = (5,)\n", " Keras class = Dense\n", " activation = sigmoidstddevWeights from input to mean\n", " mean_2/kernel:0 has shape (2, 5)\n", " mean_2/bias:0 has shape (5,)Weights from input to stddev\n", " stddev_2/kernel:0 has shape (2, 5)\n", " stddev_2/bias:0 has shape (5,)Layer: input (input)\n", " output range: (-1, 1)\n", " shape = (2,)\n", " Keras class = Inputinputvae" ], "text/plain": [ "" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.picture([1,-1], hspace=200, scale=1.0)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "net.dataset.load([\n", " [[0, 0], [0]],\n", " [[0, 1], [1]],\n", " [[1, 0], [1]],\n", " [[1, 1], [0]],\n", "])" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "========================================================\n", "Testing validation dataset with tolerance 0.1...\n", "Total count: 4\n", " correct: 0\n", " incorrect: 4\n", "Total percentage correct: 0.0\n" ] } ], "source": [ "net.evaluate(show=True)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "vae Dataset:\n", "Patterns Shape Range \n", "=================================================================\n", "inputs (2,) (0.0, 1.0) \n", "targets (1,) (0.0, 1.0) \n", "=================================================================\n", "Total patterns: 4\n", " Training patterns: 4\n", " Testing patterns: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "net.dataset.summary()" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-0.5848038 , 0.86490923, -0.04118258, -0.6218531 , -2.0275912 ],\n", " [ 1.4013709 , -0.40050203, 2.3741612 , 0.33983442, 1.589536 ],\n", " [-1.9229012 , 3.6163 , 1.7998898 , 0.44848615, 1.546372 ],\n", " [ 2.0491967 , -2.4284263 , -0.8186321 , 1.8166237 , 0.07840455]],\n", " dtype=float32)" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.propagate_to(\"encode\", net.dataset.inputs, sequence=True)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.7198903 , 0.86922455, 2.8350933 , 3.4182205 , 1.9641373 ],\n", " [ 0.5212698 , 0.49677068, 2.7393699 , 0.24070835, 1.8382547 ],\n", " [ 1.5064052 , -0.4074353 , 0.6014143 , 0.78827447, -2.2850318 ],\n", " [-1.8501136 , 0.9037933 , 0.26733384, -0.2333259 , 1.3512714 ]],\n", " dtype=float32)" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.propagate_to(\"encode\", net.dataset.inputs, sequence=True)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0.9892443418502808]" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.propagate(net.dataset.inputs[0])" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.9027333 ],\n", " [0.8923163 ],\n", " [0.36847237],\n", " [0.4038974 ]], dtype=float32)" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.propagate(net.dataset._inputs[0], sequence=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "CAPACITY used to break input down to a set number of basis.\n", "\n", "BETA (> 1) used for latent regularizer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#net.reset()" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "========================================================\n", " | Training | Training \n", "Epochs | Error | Accuracy \n", "------ | --------- | --------- \n", "#45000 | 3.95330 | 0.00000 \n" ] } ], "source": [ "net.train(epochs=30000, report_rate=100)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "net.plot_activation_map(to_layer=\"encode\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }