{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Nonnegative matrix factorization\n", "\n", "A derivative work by Judson Wilson, 6/2/2014. \n", "Adapted from the CVX example of the same name, by Argyris Zymnis, Joelle Skaf and Stephen Boyd\n", "\n", "## Introduction\n", "\n", "We are given a matrix $A \\in \\mathbf{\\mbox{R}}^{m \\times n}$ and are interested in solving the problem:\n", " \\begin{array}{ll}\n", " \\mbox{minimize} & \\| A - YX \\|_F \\\\\n", " \\mbox{subject to} & Y \\succeq 0 \\\\\n", " & X \\succeq 0,\n", " \\end{array}\n", "where $Y \\in \\mathbf{\\mbox{R}}^{m \\times k}$ and $X \\in \\mathbf{\\mbox{R}}^{k \\times n}$.\n", "\n", "This example generates a random matrix $A$ and obtains an\n", "*approximate* solution to the above problem by first generating\n", "a random initial guess for $Y$ and then alternatively minimizing\n", "over $X$ and $Y$ for a fixed number of iterations.\n", "\n", "## Generate problem data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import cvxpy as cp\n", "import numpy as np\n", "\n", "# Ensure repeatably random problem data.\n", "np.random.seed(0)\n", "\n", "# Generate random data matrix A.\n", "m = 10\n", "n = 10\n", "k = 5\n", "A = np.random.rand(m, k).dot(np.random.rand(k, n))\n", "\n", "# Initialize Y randomly.\n", "Y_init = np.random.rand(m, k)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Perform alternating minimization" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration 1, residual norm 2.766300564135502\n", "Iteration 2, residual norm 0.5840356930600721\n", "Iteration 3, residual norm 0.3356679970549085\n", "Iteration 4, residual norm 0.18670276027770083\n", "Iteration 5, residual norm 0.12819921698143966\n", "Iteration 6, residual norm 0.09295501592922492\n", "Iteration 7, residual norm 0.06766021043574907\n", "Iteration 8, residual norm 0.04958204907945361\n", "Iteration 9, residual norm 0.03897402158866238\n", "Iteration 10, residual norm 0.02979328283505179\n", "Iteration 11, residual norm 0.022938564327729952\n", "Iteration 12, residual norm 0.021943924920767337\n", "Iteration 13, residual norm 0.01810297853945281\n", "Iteration 14, residual norm 0.014551161988556204\n", "Iteration 15, residual norm 0.014039687334395924\n", "Iteration 16, residual norm 0.009354606824469416\n", "Iteration 17, residual norm 0.008643141637584189\n", "Iteration 18, residual norm 0.007278100007476402\n", "Iteration 19, residual norm 0.008486679700021057\n", "Iteration 20, residual norm 0.008827511916396866\n", "Iteration 21, residual norm 0.008396764193205366\n", "Iteration 22, residual norm 0.005265185332845983\n", "Iteration 23, residual norm 0.006931929503816392\n", "Iteration 24, residual norm 0.007356156596477946\n", "Iteration 25, residual norm 0.0039053948996930054\n", "Iteration 26, residual norm 0.003989885269615319\n", "Iteration 27, residual norm 0.002920361405226024\n", "Iteration 28, residual norm 0.007779246694466739\n", "Iteration 29, residual norm 0.007339011292898449\n", "Iteration 30, residual norm 0.005008539285258121\n" ] } ], "source": [ "# Ensure same initial random Y, rather than generate new one\n", "# when executing this cell.\n", "Y = Y_init \n", "\n", "# Perform alternating minimization.\n", "MAX_ITERS = 30\n", "residual = np.zeros(MAX_ITERS)\n", "for iter_num in range(1, 1+MAX_ITERS):\n", " # At the beginning of an iteration, X and Y are NumPy\n", " # array types, NOT CVXPY variables.\n", "\n", " # For odd iterations, treat Y constant, optimize over X.\n", " if iter_num % 2 == 1:\n", " X = cp.Variable(shape=(k, n))\n", " constraint = [X >= 0]\n", " # For even iterations, treat X constant, optimize over Y.\n", " else:\n", " Y = cp.Variable(shape=(m, k))\n", " constraint = [Y >= 0]\n", " \n", " # Solve the problem.\n", " # increase max iters otherwise, a few iterations are \"OPTIMAL_INACCURATE\"\n", " # (eg a few of the entries in X or Y are negative beyond standard tolerances)\n", " obj = cp.Minimize(cp.norm(A - Y*X, 'fro'))\n", " prob = cp.Problem(obj, constraint)\n", " prob.solve(solver=cp.SCS, max_iters=10000)\n", "\n", " if prob.status != cp.OPTIMAL:\n", " raise Exception(\"Solver did not converge!\")\n", " \n", " print('Iteration {}, residual norm {}'.format(iter_num, prob.value))\n", " residual[iter_num-1] = prob.value\n", "\n", " # Convert variable to NumPy array constant for next iteration.\n", " if iter_num % 2 == 1:\n", " X = X.value\n", " else:\n", " Y = Y.value" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Output results" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Original matrix:\n", "[[1.323426 1.11061189 1.69137835 1.20020115 1.13216889 0.5980743\n", " 1.64965406 0.340611 1.69871738 0.78278448]\n", " [1.73721109 1.40464204 1.90898877 1.60774132 1.53717253 0.62647405\n", " 1.76242265 0.41151492 1.8048194 1.20313124]\n", " [1.4071438 1.10269406 1.75323063 1.18928983 1.23428169 0.60364688\n", " 1.63792853 0.40855006 1.57257432 1.17227344]\n", " [1.3905141 1.33367163 1.07723947 1.67735654 1.33039096 0.42003169\n", " 1.22641711 0.21470465 1.47350799 0.84931787]\n", " [1.42153652 1.13598552 2.00816457 1.11463462 1.17914429 0.69942578\n", " 1.90353699 0.45664487 1.81023916 1.09668578]\n", " [1.60813803 1.23214532 1.73741086 1.3148874 1.27589039 0.40755835\n", " 1.31904948 0.3469129 1.34256526 0.76924618]\n", " [0.90607895 0.6632877 1.25412229 0.81696721 0.87218892 0.50032884\n", " 1.245879 0.25079329 1.25017792 0.72155621]\n", " [1.5691922 1.47359672 1.76518996 1.66268312 1.43746574 0.72486628\n", " 1.97409333 0.39239642 2.09234807 1.16325748]\n", " [1.18723548 1.00282008 1.41532595 1.03836298 0.90382914 0.38460446\n", " 1.213473 0.23641422 1.32784402 0.27179726]\n", " [0.75789915 0.75119989 0.99502166 0.65444815 0.56073096 0.341146\n", " 1.02555143 0.24273668 1.01035919 0.49427978]]\n", "Left factor Y:\n", "[[ 7.56475742e-01 3.42102372e-01 8.40426641e-01 7.02845111e-01\n", " 4.38002833e-03]\n", " [ 6.36189366e-01 8.27831861e-01 5.28165827e-01 5.60609403e-01\n", " 3.34595403e-02]\n", " [ 5.54834858e-01 6.37954560e-01 8.01726231e-01 1.96879041e-01\n", " 3.74736667e-02]\n", " [ 2.72955779e-01 9.53749151e-01 6.14934798e-02 9.81276972e-01\n", " -4.26647247e-05]\n", " [ 7.93952558e-01 3.50946872e-01 1.18853643e+00 3.85961318e-01\n", " 2.96701863e-02]\n", " [ 7.26183347e-01 4.41639937e-01 2.71711699e-03 7.33393633e-01\n", " 4.55176129e-02]\n", " [ 4.89263105e-01 4.20725095e-01 7.56036398e-01 6.24033457e-02\n", " -5.38302416e-04]\n", " [ 6.09810836e-01 7.55780427e-01 1.03636918e+00 9.08549910e-01\n", " 1.91844947e-03]\n", " [ 8.31578328e-01 8.75528332e-05 2.93543168e-01 1.10037225e+00\n", " -2.65884776e-04]\n", " [ 4.26650967e-01 5.53761974e-02 6.52855369e-01 6.43132832e-01\n", " 1.47569255e-02]]\n", "Right factor X:\n", "[[ 1.07015116e+00 4.25961964e-01 1.59511553e+00 6.26808607e-01\n", " 8.98124301e-01 3.62801718e-01 9.53757673e-01 1.88661317e-01\n", " 9.64559055e-01 1.43675625e-01]\n", " [ 8.72908811e-01 7.03553498e-01 6.45229205e-01 1.10121868e+00\n", " 9.93621271e-01 3.12383803e-01 7.45085312e-01 1.25155585e-01\n", " 8.84272390e-01 7.94988511e-01]\n", " [ 1.41086863e-04 1.70049131e-01 2.73427259e-01 2.50933223e-02\n", " 8.38007474e-03 2.51575697e-01 5.99473425e-01 1.39362252e-01\n", " 5.06840502e-01 4.22844259e-01]\n", " [ 2.70906925e-01 5.46340550e-01 1.04256418e-02 4.63290841e-01\n", " 1.39889787e-01 7.65220031e-03 2.22742919e-01 3.60875098e-02\n", " 3.41601146e-01 2.72448408e-02]\n", " [ 5.44108256e+00 4.62667224e+00 6.26354249e+00 7.23656013e-01\n", " 1.81220987e+00 -2.57729003e-07 2.90739234e+00 2.81123997e+00\n", " -2.15606388e-06 6.43189790e+00]]\n", "Residual A - Y * X:\n", "[[ 9.02157264e-04 5.23117764e-04 -5.79950842e-04 -5.74317402e-04\n", " -4.61768644e-04 -5.28680186e-05 1.62394448e-04 2.76277321e-04\n", " 4.85227596e-04 -5.60481823e-04]\n", " [-2.33027425e-04 3.21455250e-04 2.17040399e-04 1.56606195e-04\n", " -2.41256203e-04 -1.01386736e-04 7.36342995e-05 -1.73587325e-05\n", " -5.22429324e-05 -2.04432888e-04]\n", " [-8.35846517e-04 2.46121871e-04 5.93720663e-04 5.38806481e-04\n", " -8.42363429e-05 -1.36215640e-04 2.31633730e-06 -1.52108618e-04\n", " -3.23620331e-04 -5.42078084e-06]\n", " [ 2.62860853e-04 1.83780003e-05 -3.20542830e-04 -1.49712163e-04\n", " -1.31334078e-04 8.78805144e-05 1.46798183e-04 -2.03546983e-05\n", " 4.79256197e-04 -5.81320754e-04]\n", " [-6.22557723e-04 6.31892711e-04 4.34719938e-04 4.01388769e-04\n", " -3.52745774e-04 -2.12014739e-04 8.42548761e-05 -4.17321003e-05\n", " -1.50760383e-04 -3.01455643e-04]\n", " [-8.46202248e-04 3.61714835e-04 6.15005890e-04 5.85452470e-04\n", " -2.39872783e-04 -1.59000367e-04 6.24749082e-05 -1.69461803e-04\n", " -3.16622183e-04 -8.20910778e-05]\n", " [ 1.15561552e-03 -1.28864368e-03 -1.77288000e-03 -5.10264071e-04\n", " 6.38713553e-04 7.17730381e-04 2.05892579e-04 -2.69449092e-04\n", " 1.71225020e-03 -1.13410340e-03]\n", " [ 1.57913703e-04 6.21168134e-04 -4.04695033e-05 -1.48187018e-04\n", " -4.38037868e-04 -1.45409129e-04 1.34145488e-04 1.47289692e-04\n", " 1.98184939e-04 -5.09549810e-04]\n", " [ 5.51365483e-04 -1.32683206e-03 -1.26345269e-03 6.01647636e-05\n", " 9.72529426e-04 6.10472383e-04 -1.48674297e-05 -3.54468161e-04\n", " 9.92202367e-04 -1.42249517e-04]\n", " [-1.63514531e-03 -1.59800828e-04 1.08957766e-03 1.01954949e-03\n", " 3.41048252e-04 -1.06257705e-04 -1.57094132e-04 -3.64204427e-04\n", " -7.26930797e-04 4.63755883e-04]]\n", "Residual after 30 iterations: 0.005008539285258121\n" ] } ], "source": [ "#\n", "# Plot residuals.\n", "#\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "# Show plot inline in ipython.\n", "%matplotlib inline\n", "\n", "# Set plot properties.\n", "plt.rc('text', usetex=True)\n", "plt.rc('font', family='serif')\n", "font = {'weight' : 'normal',\n", " 'size' : 16}\n", "plt.rc('font', **font)\n", "\n", "# Create the plot.\n", "plt.plot(residual)\n", "plt.xlabel('Iteration Number')\n", "plt.ylabel('Residual Norm')\n", "plt.show()\n", "\n", "#\n", "# Print results.\n", "#\n", "print('Original matrix:')\n", "print(A)\n", "print('Left factor Y:')\n", "print(Y)\n", "print('Right factor X:')\n", "print(X)\n", "print('Residual A - Y * X:')\n", "print(A - Y.dot(X))\n", "print('Residual after {} iterations: {}'.format(iter_num, prob.value))\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 1 }