{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Energy auto-encoder: comparison with Xavier Primal-Dual matlab implementation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [], "source": [ "import numpy as np\n", "import numpy.linalg as la\n", "import scipy.io\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "from pyunlocbox import functions, solvers\n", "import time\n", "import scipy, matplotlib, pyunlocbox # For versions only.\n", "\n", "# Import auto-encoder definition.\n", "%run auto_encoder.ipynb\n", "\n", "print('Software versions:')\n", "for pkg in [np, matplotlib, scipy, pyunlocbox]:\n", " print(' %s: %s' % (pkg.__name__, pkg.__version__))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyper-parameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* The $\\lambda$ are the relative importance of each term in the composite objective function." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "ld = 10 # Xavier sets the weight of the L1 regularization to 1e-1." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* The set of data vectors $X \\in R^{n \\times N}$ is given by patches extracted from a grayscale image.\n", "* There is as many patches as pixels in the image.\n", "* The saved patches already have zero mean." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "mat = scipy.io.loadmat('data/xavier_X.mat')\n", "X = mat['X'].T\n", "\n", "N, n = X.shape\n", "Np = np.sqrt(n)\n", "print('N = %d samples with dimensionality n = %d (patches of %dx%d).' % (N, n, Np, Np))\n", "\n", "plt.figure(figsize=(8,5))\n", "patches = [24, 1000, 2004, 10782]\n", "for k in range(len(patches)):\n", " patch = patches[k]\n", " img = np.reshape(X[patch, :], (Np, Np))\n", " plt.subplot(1, 4, k+1)\n", " plt.imshow(img, cmap='gray')\n", " plt.title('Patch %d' % patch)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initial conditions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* $Z$ is drawn from a uniform distribution in ]0,1[.\n", "* Same for $D$. Its columns were then normalized to unit L2 norm.\n", "* The sparse code dimensionality $m$ should be greater than $n$ for an overcomplete representation but much smaller than $N$ to avoid over-fitting.\n", "* The dictionary and sparse codes are initialized this way in the auto_encoder class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "if False:\n", " mat = scipy.io.loadmat('data/xavier_initZD.mat')\n", " Zinit = mat['Zinit']\n", " Dinit = mat['Dinit']\n", "\n", " m, N = Zinit.shape\n", " n, m = Dinit.shape\n", " print('Sparse code dimensionality m = %d --> %s dictionary' % (m, 'overcomplete' if m > n else 'undercomplete'))\n", "\n", " print('mean(Z) = %f' % np.mean(Zinit))\n", "\n", " d = np.sqrt(np.sum(Dinit*Dinit, axis=0))\n", " print('Constraints on D: %s' % np.alltrue(d <= 1+1e-15))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given $X \\in R^{n \\times N}$, solve $\\min\\limits_{Z \\in R^{m \\times N}, D \\in R^{n \\times m}} \\frac{\\lambda_d}{2} \\|X - DZ\\|_F^2 + \\|Z\\|_1$ s.t. $\\|d_i\\|_2 \\leq 1$, $i = 1, \\ldots, m$" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ae = auto_encoder(m=64, ld=ld, rtol=1e-5, xtol=None, N_inner=100, N_outer=30)\n", "#ae = auto_encoder(m=64, ld=ld, rtol=None, xtol=1e-6, N_inner=100, N_outer=15)\n", "tstart = time.time()\n", "Z = ae.fit_transform(X)\n", "print('Elapsed time: {:.0f} seconds'.format(time.time() - tstart))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "if False:\n", " tstart = time.time()\n", " Z = ae.transform(X)\n", " print('Elapsed time: {:.0f} seconds'.format(time.time() - tstart))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Convergence" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ae.plot_convergence()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solution analysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Solution from Xavier" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "mat = scipy.io.loadmat('data/xavier_ZD.mat')\n", "Zxavier = mat['Z'].T\n", "Dxavier = mat['D'].T\n", "print('Elapsed time: %d seconds' % mat['exectime'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Objective" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "objective(X, Zxavier, Dxavier, ld)\n", "objective(X, Z, ae.D, ld)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sparse codes" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sparse_codes(Zxavier)\n", "sparse_codes(Z)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dictionary" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "dictionary(Dxavier)\n", "dictionary(ae.D)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "atoms(Dxavier, Np)\n", "atoms(ae.D, Np)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.3" } }, "nbformat": 4, "nbformat_minor": 0 }