{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "# Sparse energy auto-encoders" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* The definition of the algortihm behind our sparse energy auto-encoder model.\n", "* It is an unsupervised feature extraction tool which tries to find a good sparse representation in an efficient manner.\n", "* This notebook is meant to be imported by other notebooks for applications to image or audio data.\n", "* Modeled after sklearn Estimator class so that it can be integrated into an sklearn Pipeline. Note that matrix dimensions are inverted (code vs math) to follow sklearn conventions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "General problem: \n", "* given $X \\in R^{n \\times N}$,\n", "* solve $\\min\\limits_{Z \\in R^{m \\times N}, D \\in R^{n \\times m}, E \\in R^{m \\times n}} \\frac{\\lambda_d}{2} \\|X - DZ\\|_F^2 + \\frac{\\lambda_e}{2} \\|Z - EX\\|_F^2 + \\lambda_s \\|Z\\|_1 + \\frac{\\lambda_g}{2} \\text{tr}(Z^TLZ)$\n", "* s.t. $\\|d_i\\|_2 \\leq 1$, $\\|e_k\\|_2 \\leq 1$, $i = 1, \\ldots, m$, $k = 1, \\ldots, n$\n", "\n", "which can be reduced to sparse coding with dictionary learning: \n", "* given $X \\in R^{n \\times N}$,\n", "* solve $\\min\\limits_{Z \\in R^{m \\times N}, D \\in R^{n \\times m}} \\frac{\\lambda_d}{2} \\|X - DZ\\|_F^2 + \\lambda_s \\|Z\\|_1$\n", "* s.t. $\\|d_i\\|_2 \\leq 1$, $i = 1, \\ldots, m$\n", "\n", "Observations:\n", "* Almost ten times faster (on comparison_xavier) using optimized linear algebra subroutines:\n", " * None: 9916s\n", " * ATLAS: 1335s (is memory bandwith limited)\n", " * OpenBLAS: 1371s (seems more CPU intensive than ATLAS)\n", "\n", "Open questions:\n", "* First optimize for Z (last impl) or first for D/E (new impl) ?\n", " * Seem to converge much faster if Z optimized last (see comparison_xavier).\n", " * But two times slower.\n", " * In fit we optimize for parameters D, E so it makes sense to optimize them last.\n", " * Need to optimize for Z first if we initialize it with zeros.\n", "* Fast evaluation of la.norm(Z.T.dot(Z)). Cumulative to save memory ?\n", "* Consider adding an option for $E = D^T$\n", "* Use single precision, i.e. float32 ? Yes, it saves memory and speed up computation due to reduced memory bandwidth." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n", "import numpy.linalg as la\n", "from pyunlocbox import functions, solvers\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "class auto_encoder():\n", " \"\"\"Sparse energy auto-encoder.\"\"\"\n", " \n", " def __init__(self, m=100, ls=None, ld=None, le=None, lg=None,\n", " rtol=1e-3, xtol=None, N_inner=100, N_outer=15):\n", " \"\"\"\n", " Model hyper-parameters and solver stopping criteria.\n", " \n", " Model hyper-parameters:\n", " m: number of atoms in the dictionary, sparse code length\n", " ld: weigth of the dictionary l2 penalty\n", " le: weigth of the encoder l2 penalty\n", " lg: weight of the graph smoothness\n", " \n", " Stopping criteria::\n", " rtol: objective function convergence\n", " xtol: model parameters convergence\n", " N_inner: hard limit of inner iterations\n", " N_outer: hard limit of outer iterations\n", " \"\"\"\n", " self.m = m\n", " self.ls = ls\n", " self.ld = ld\n", " self.le = le\n", " self.lg = lg\n", " self.N_outer = N_outer\n", " \n", " # Solver common parameters.\n", " self.params = {'rtol': rtol,\n", " 'xtol': xtol,\n", " 'maxit': N_inner,\n", " 'verbosity': 'NONE'}\n", "\n", " def _convex_functions(self, X, L, Z):\n", " \"\"\"Define convex functions.\"\"\"\n", " \n", " f = functions.proj_b2()\n", " self.f = functions.func()\n", " self.f._eval = lambda X: 0\n", " self.f._prox = lambda X,_: f._prox(X.T, 1).T\n", " #self.f._prox = lambda X,_: _normalize(X)\n", " \n", " if self.ld is not None:\n", " self.g_d = functions.norm_l2(lambda_=self.ld/2., A=Z, y=X, tight=False)\n", " self.g_z = functions.norm_l2(lambda_=self.ld/2., A=self.D.T, y=X.T, tight=False)\n", " else:\n", " self.g_z = functions.dummy()\n", "\n", " if self.le is not None:\n", " self.h_e = functions.norm_l2(lambda_=self.le/2., A=X, y=Z, tight=False)\n", " self.h_z = functions.norm_l2(lambda_=self.le/2., y=lambda: X.dot(self.E).T, tight=True)\n", " else:\n", " self.h_z = functions.dummy()\n", "\n", " if self.lg is not None:\n", " self.j_z = functions.func()\n", " # tr(A*B) = sum(A.*B^T).\n", " #self.j_z._eval = lambda Z: self.lg/2. * np.trace(Z.dot(L.dot(Z.T)))\n", " #self.j_z._eval = lambda Z: self.lg/2. * np.multiply(L.dot(Z.T), Z.T).sum()\n", " self.j_z._eval = lambda Z: self.lg/2. * np.einsum('ij,ji->', L.dot(Z.T), Z)\n", " self.j_z._grad = lambda Z: self.lg * L.dot(Z.T).T\n", " else:\n", " self.j_z = functions.dummy()\n", "\n", " self.ghj_z = functions.func()\n", " self.ghj_z._eval = lambda Z: self.j_z._eval(Z) + self.g_z._eval(Z) + self.h_z._eval(Z)\n", " self.ghj_z._grad = lambda Z: self.j_z._grad(Z) + self.g_z._grad(Z) + self.h_z._grad(Z)\n", " \n", " if self.ls is not None:\n", " self.i_z = functions.norm_l1(lambda_=self.ls)\n", " else:\n", " self.i_z = functions.dummy()\n", "\n", " def _minD(self, X, Z):\n", " \"\"\"Convex minimization for D.\"\"\"\n", " \n", " # Lipschitz continuous gradient. Faster if larger dim is 'inside'.\n", " B = self.ld * la.norm(Z.T.dot(Z))\n", " \n", " solver = solvers.forward_backward(step=1./B, method='FISTA')\n", " ret = solvers.solve([self.g_d, self.f], self.D, solver, **self.params)\n", " \n", " self.objective_d.extend(ret['objective'])\n", " self.objective_z.extend([[0,0]] * len(ret['objective']))\n", " self.objective_e.extend([[0,0]] * len(ret['objective']))\n", " \n", " def _minE(self, X, Z):\n", " \"\"\"Convex minimization for E.\"\"\"\n", " \n", " # Lipschitz continuous gradient. Faster if larger dim is 'inside'.\n", " B = self.le * la.norm(X.T.dot(X))\n", " \n", " solver = solvers.forward_backward(step=1./B, method='FISTA')\n", " ret = solvers.solve([self.h_e, self.f], self.E, solver, **self.params)\n", " \n", " self.objective_e.extend(ret['objective'])\n", " self.objective_z.extend([[0,0]] * len(ret['objective']))\n", " self.objective_d.extend([[0,0]] * len(ret['objective']))\n", " \n", " def _minZ(self, X, L, Z):\n", " \"\"\"Convex minimization for Z.\"\"\"\n", " \n", " B_e = self.le if self.le is not None else 0\n", " B_d = self.ld * la.norm(self.D.T.dot(self.D)) if self.ld is not None else 0\n", " B_g = self.lg * np.sqrt((L.data**2).sum()) if self.lg is not None else 0\n", " B = B_d + B_e + B_g\n", " \n", " solver = solvers.forward_backward(step=1./B, method='FISTA')\n", " ret = solvers.solve([self.ghj_z, self.i_z], Z.T, solver, **self.params)\n", " \n", " self.objective_z.extend(ret['objective'])\n", " self.objective_d.extend([[0,0]] * len(ret['objective']))\n", " self.objective_e.extend([[0,0]] * len(ret['objective']))\n", " \n", " def fit_transform(self, X, L):\n", " \"\"\"\n", " Fit the model parameters (dictionary, encoder and graph)\n", " given training data.\n", " \n", " Parameters\n", " ----------\n", " X : ndarray, shape (N, n)\n", " Training vectors, where N is the number of samples\n", " and n is the number of features.\n", " L : scipy.sparse, shape (N, N)\n", " The Laplacian matrix of the graph.\n", " \n", " Returns\n", " -------\n", " Z : ndarray, shape (N, m)\n", " Sparse codes (a by-product of training), where N\n", " is the number of samples and m is the number of atoms.\n", " \"\"\"\n", " N, n = X.shape\n", " \n", " def _normalize(X, axis=1):\n", " \"\"\"Normalize the selected axis of an ndarray to unit norm.\"\"\"\n", " return X / np.sqrt(np.sum(X**2, axis))[:,np.newaxis]\n", " \n", " # Model parameters initialization.\n", " if self.ld is not None:\n", " self.D = _normalize(np.random.uniform(size=(self.m, n)).astype(X.dtype))\n", " if self.le is not None:\n", " self.E = _normalize(np.random.uniform(size=(n, self.m)).astype(X.dtype))\n", " \n", " # Initial predictions.\n", " #Z = np.random.uniform(size=(N, self.m)).astype(X.dtype)\n", " Z = np.zeros(shape=(N, self.m), dtype=X.dtype)\n", " \n", " # Initialize convex functions.\n", " self._convex_functions(X, L, Z)\n", " \n", " # Objective functions.\n", " self.objective = []\n", " self.objective_g = []\n", " self.objective_h = []\n", " self.objective_i = []\n", " self.objective_j = []\n", " self.objective_z = []\n", " self.objective_d = []\n", " self.objective_e = []\n", " \n", " # Stopping criteria.\n", " crit = None\n", " niter = 0\n", " last = np.nan\n", " \n", " # Multi-variate non-convex optimization (outer loop).\n", " while not crit:\n", " niter += 1\n", "\n", " self._minZ(X, L, Z)\n", "\n", " if self.ld is not None:\n", " self._minD(X, Z)\n", "\n", " if self.le is not None:\n", " self._minE(X, Z)\n", "\n", " # Global objectives.\n", " self.objective_g.append(self.g_z.eval(Z.T))\n", " self.objective_h.append(self.h_z.eval(Z.T))\n", " self.objective_i.append(self.i_z.eval(Z.T))\n", " self.objective_j.append(self.j_z.eval(Z.T))\n", " \n", " if self.params['rtol'] is not None:\n", " current = 0\n", " for func in ['g', 'h', 'i', 'j']:\n", " current += getattr(self, 'objective_'+func)[-1]\n", " relative = np.abs((current - last) / current)\n", " last = current\n", " if relative < self.params['rtol']:\n", " crit = 'RTOL'\n", "\n", " if self.N_outer is not None and niter >= self.N_outer:\n", " crit = 'MAXIT'\n", "\n", " return Z\n", " \n", " def fit(self, X, L):\n", " \"\"\"Fit to data without returning the transformed data.\"\"\"\n", " self.fit_transform(X, L)\n", " \n", " def transform(self, X, L):\n", " \"\"\"Predict sparse codes for each sample in X.\"\"\"\n", " return self._transform_exact(X, L)\n", " \n", " def _transform_exact(self, X, L):\n", " \"\"\"Most accurate but slowest prediction.\"\"\"\n", " N = X.shape[0]\n", " Z = np.random.uniform(size=(N, self.m)).astype(X.dtype)\n", " self._convex_functions(X, L, Z)\n", " self._minZ(X, L, Z)\n", " return Z\n", " \n", " def _transform_approx(self, X, L):\n", " \"\"\"Much faster approximation using only the encoder.\"\"\"\n", " raise NotImplementedError('Not yet implemented')\n", " \n", " def inverse_transform(self, Z):\n", " \"\"\"\n", " Return the data corresponding to the given sparse codes using\n", " the learned dictionary.\n", " \"\"\"\n", " raise NotImplementedError('Not yet implemented')\n", " \n", " def plot_objective(self):\n", " \"\"\"Plot the objective (cost, loss, energy) functions.\"\"\"\n", " plt.figure(figsize=(8,5))\n", " plt.semilogy(np.asarray(self.objective_z)[:, 0], label='Z: data term')\n", " plt.semilogy(np.asarray(self.objective_z)[:, 1], label='Z: prior term')\n", " #plt.semilogy(np.sum(objective[:,0:2], axis=1), label='Z: sum')\n", " if self.ld is not None:\n", " plt.semilogy(np.asarray(self.objective_d)[:, 0], label='D: data term')\n", " if self.le is not None:\n", " plt.semilogy(np.asarray(self.objective_e)[:, 0], label='E: data term')\n", " iterations_inner = np.shape(self.objective_z)[0]\n", " plt.xlim(0, iterations_inner-1)\n", " plt.title('Sub-problems convergence')\n", " plt.xlabel('Iteration number (inner loops)')\n", " plt.ylabel('Objective function value')\n", " plt.grid(True); plt.legend(); plt.show()\n", " print('Inner loop: {} iterations'.format(iterations_inner))\n", "\n", " plt.figure(figsize=(8,5))\n", " def rdiff(a, b):\n", " print('rdiff: {}'.format(abs(a - b) / a))\n", " if self.ld is not None:\n", " name = 'g(Z) = ||X-DZ||_2^2'\n", " plt.semilogy(self.objective_g, '.-', label=name)\n", " print(name + ' = {:e}'.format(self.objective_g[-1]))\n", " rdiff(self.objective_g[-1], self.g_d.eval(self.D))\n", " if self.le is not None:\n", " name = 'h(Z) = ||Z-EX||_2^2'\n", " plt.semilogy(self.objective_h, '.-', label=name)\n", " print(name + ' = {:e}'.format(self.objective_h[-1]))\n", " rdiff(self.objective_h[-1], self.h_e.eval(self.E))\n", " name = 'i(Z) = ||Z||_1'\n", " plt.semilogy(self.objective_i, '.-', label=name)\n", " print(name + ' = {:e}'.format(self.objective_i[-1]))\n", " if self.lg is not None:\n", " name = 'j(Z) = tr(Z^TLZ)'\n", " plt.semilogy(self.objective_j, '.-', label=name)\n", " print(name + ' = {:e}'.format(self.objective_j[-1]))\n", " iterations_outer = len(self.objective_i)\n", " plt.xlim(0, iterations_outer-1)\n", " plt.title('Objectives convergence')\n", " plt.xlabel('Iteration number (outer loop)')\n", " plt.ylabel('Objective function value')\n", " plt.grid(True); plt.legend(loc='best'); plt.show()\n", " \n", " plt.figure(figsize=(8,5))\n", " objective = np.zeros((iterations_outer))\n", " for obj in ['g', 'h', 'i', 'j']:\n", " objective += np.asarray(getattr(self, 'objective_' + obj))\n", " print('Global objective: {:e}'.format(objective[-1]))\n", " plt.plot(objective, '.-')\n", " plt.xlim(0, iterations_outer-1)\n", " plt.title('Global convergence')\n", " plt.xlabel('Iteration number (outer loop)')\n", " plt.ylabel('Objective function value')\n", " plt.grid(True); plt.show()\n", " print('Outer loop: {} iterations\\n'.format(iterations_outer))\n", " \n", " return (iterations_inner, iterations_outer,\n", " self.objective_g[-1], self.objective_h[-1],\n", " self.objective_i[-1], self.objective_j[-1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tools for solution analysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tools to show model parameters, sparse codes and objective function. The *auto_encoder* class solely contains the core algorithm (and a visualization of the convergence)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def sparse_codes(Z, tol=0):\n", " \"\"\"Show the sparsity of the sparse codes.\"\"\"\n", " N, m = Z.shape\n", " \n", " print('Z in [{}, {}]'.format(np.min(Z), np.max(Z)))\n", " \n", " if tol is 0:\n", " nnz = np.count_nonzero(Z)\n", " else:\n", " nnz = np.sum(np.abs(Z) > tol)\n", " sparsity = 100.*nnz/Z.size\n", " print('Sparsity of Z: {:,} non-zero entries out of {:,} entries, '\n", " 'i.e. {:.1f}%.'.format(nnz, Z.size, sparsity))\n", "\n", " try:\n", " plt.figure(figsize=(8,5))\n", " plt.spy(Z.T, precision=tol, aspect='auto')\n", " plt.xlabel('N = {} samples'.format(N))\n", " plt.ylabel('m = {} atoms'.format(m))\n", " plt.show()\n", " except MemoryError:\n", " pass\n", " \n", " return sparsity\n", " \n", "def dictenc(D, tol=1e-5, enc=False):\n", " \"\"\"Show the norms and sparsity of the learned dictionary or encoder.\"\"\"\n", " m, n = D.shape\n", " name = 'D' if not enc else 'E'\n", " \n", " print('{} in [{}, {}]'.format(name, np.min(D), np.max(D)))\n", " \n", " d = np.sqrt(np.sum(D**2, axis=1))\n", " print('{} in [{}, {}]'.format(name.lower(), np.min(d), np.max(d)))\n", " print('Constraints on {}: {}'.format(name, np.alltrue(d <= 1+tol)))\n", " \n", " plt.figure(figsize=(8,5))\n", " plt.plot(d, 'b.')\n", " #plt.ylim(0.5, 1.5)\n", " plt.xlim(0, m-1)\n", " if not enc:\n", " plt.title('Dictionary atom norms')\n", " plt.xlabel('Atom [1,m={}]'.format(m))\n", " else:\n", " plt.title('Encoder column norms')\n", " plt.xlabel('Column [1,n={}]'.format(m))\n", " plt.ylabel('Norm [0,1]')\n", " plt.grid(True); plt.show()\n", " plt.show()\n", "\n", " plt.figure(figsize=(8,5))\n", " plt.spy(D.T, precision=1e-2, aspect='auto')\n", " if not enc:\n", " plt.xlabel('m = {} atoms'.format(m))\n", " plt.ylabel('data dimensionality of n = {}'.format(n))\n", " else:\n", " plt.xlabel('n = {} columns'.format(m))\n", " plt.ylabel('data dimensionality of m = {}'.format(n))\n", " \n", " plt.show()\n", " \n", " #plt.scatter to show intensity\n", " \n", "def atoms(D, Np=None):\n", " \"\"\"\n", " Show dictionary or encoder atoms.\n", " \n", " 2D atoms if Np is not None, else 1D atoms.\n", " \"\"\"\n", " m, n = D.shape\n", " \n", " fig = plt.figure(figsize=(8,8))\n", " Nx = np.ceil(np.sqrt(m))\n", " Ny = np.ceil(m / float(Nx))\n", " for k in np.arange(m):\n", " ax = fig.add_subplot(Ny, Nx, k)\n", " if Np is not None:\n", " img = D[k,:].reshape(Np, Np)\n", " ax.imshow(img, cmap='gray') # vmin=0, vmax=1 to disable normalization.\n", " ax.axis('off')\n", " else:\n", " ax.plot(D[k,:])\n", " ax.set_xlim(0, n-1)\n", " ax.set_ylim(-1, 1)\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " return fig" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Unit tests" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Test the auto-encoder class and tools." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "if False:\n", " # ldd numpy/core/_dotblas.so\n", " try:\n", " import numpy.core._dotblas\n", " print 'fast BLAS'\n", " except ImportError:\n", " print 'slow BLAS'\n", "\n", " print np.__version__\n", " np.__config__.show()\n", "\n", "if False:\n", "#if __name__ is '__main__':\n", " import time\n", " import scipy.sparse\n", " \n", " # Data.\n", " N, n = 25, 16\n", " X = np.random.normal(size=(N, n))\n", " \n", " # Graph.\n", " W = np.random.uniform(size=(N, N)) # W in [0,1].\n", " W = np.maximum(W, W.T) # Symmetric weight matrix, i.e. undirected graph.\n", " D = np.diag(W.sum(axis=0)) # Diagonal degree matrix.\n", " L = D - W # Symmetric and positive Laplacian.\n", " L = scipy.sparse.csr_matrix(L)\n", "\n", " # Algorithm.\n", " auto_encoder(m=20, ls=1, le=1, rtol=1e-5, xtol=None).fit(X, L)\n", " auto_encoder(m=20, ld=1, rtol=1e-5, xtol=None).fit(X, L)\n", " auto_encoder(m=20, lg=1, rtol=None, xtol=None).fit(X, L)\n", " auto_encoder(m=20, lg=1, ld=1, rtol=1e-5, xtol=1e-5).fit(X, L)\n", " ae = auto_encoder(m=20, ls=5, ld=10, le=100, lg=1, rtol=1e-5, N_outer=20)\n", " tstart = time.time()\n", " Z = ae.fit_transform(X, L)\n", " print('Elapsed time: {:.3f} seconds'.format(time.time() - tstart))\n", " ret = ae.plot_objective()\n", " iterations_inner, iterations_outer = ret[:2]\n", " objective_g, objective_h, objective_i, objective_j = ret[2:]\n", " \n", " # Reproducable results (min_Z given X, L, D, E is convex).\n", " err = la.norm(Z - ae.transform(X, L)) / np.sqrt(Z.size) #< 1e-3\n", " print('Error: {}'.format(err))\n", "\n", " # Results visualization.\n", " sparse_codes(Z)\n", " dictenc(ae.D)\n", " dictenc(ae.E, enc=True)\n", " atoms(ae.D, 4) # 2D atoms.\n", " atoms(ae.D) # 1D atoms.\n", " atoms(ae.E)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.3" } }, "nbformat": 4, "nbformat_minor": 0 }