{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "![MOSEK ApS](https://www.mosek.com/static/images/branding/webgraphmoseklogocolor.png )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Wasserstein Barycenters using Mosek" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Wasserstein Distance is a way to measure the distance between two probabilty distributions. It allows to summarize, compare, match and reduce the dimensionality of the emprical probability measures to carry out some machine learning fundamentals." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Wasserstein distance of order $p$ between two probabilty measures $\\mu$ and $\\upsilon$ in $P(\\Omega)$ is defined as:\n", "
\n", "
\n", "$$W_p(\\mu,\\upsilon) \\overset{\\underset{\\mathrm{def}}{}}{=} \\bigg( \\underset{\\pi \\in \\Pi{(\\mu, \\upsilon)}}{\\mbox{inf}} \\int_{\\Omega^2} D(X_i,Y_j)^p d\\pi(x,y)\\bigg)^{1/p}\n", "$$\n", "
\n", "where $\\Pi(\\mu, \\upsilon)$ is the set of all probability measures on $\\Omega^2$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If the distributions are discrete $W_p(\\mu,\\upsilon)$ is equilavent to the objective of the following LP model:\n", "
\n", "
\n", "$$ \\mbox{minimize} \\quad \\sum_{i=1}\\sum_{j=1} D(X_i,Y_j)^p\\pi_{ij}$$\n", "
\n", "$$ \\mbox{st.} \\quad \\sum_{j=1} \\pi_{ij} = \\mu_i , \\quad i = 1,2,..n $$\n", "
\n", "$$ \\quad \\sum_{i=1} \\pi_{ij} = \\upsilon_j, \\quad j = 1,2,..m $$\n", "
\n", "$$ \\pi_{ij} \\geq 0, \\quad \\forall_{i,j}$$\n", "
\n", "where $D(X_i,Y_j)$ is the distance function on $\\Omega$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are more efficient ways to approximate this metric but LP approach will be applied in order to compare the performance and modeling structure of Fusion, Pyomo and CVXPY." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Wasserstein Barycenter" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Wasserstein barycenter problem involves all Wasserstein distances from one to many measures. Given measures $\\upsilon_1,\\ldots,\\upsilon_N$ we want to find the measure which minimizes the sum of distances to the given measures, that is solve the problem:\n", "$$ \\mbox{minimize}_\\mu\\sum_{i=1} \\lambda_i W_p(\\mu,\\upsilon_i) $$\n", "for some fixed system $\\lambda_i $ of weights of distances to specific distributions that satisfies $$\\sum_{i=1}\\lambda_i = 1.$$\n", "For simplicity uniform weights are used in this problem. Then the barycenters problem becomes:\n", "$$ \\mbox{minimize}_\\mu \\frac1N \\sum_{i=1}^{N} W_p(\\mu,\\upsilon_i). $$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this problem, Wasserstein Barycenter of One's are visualized using images with size $28x28$ using $20$ handwriten '1' digits from MNIST database http://yann.lecun.com/exdb/mnist/. Computations are carried out by Intel(R) Xeon(R) CPU E5-2687W v4 @ 3.00GHz processor. Similar experiments are performed by Cuturi and Doucet in http://proceedings.mlr.press/v32/cuturi14.pdf." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAABIEAAAILCAYAAAB/zYxFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzs3Xu0nVV5L/5nJoQQrnIRiCGA3MQ7SExBWkul9ofUFmyPVHpq1dITa8VLtT1V2zGk9niK91MvRbHwC7ZWy69oRUqPUg7VU7lIQlNuEUEaJJKCgFyFkGTP3x9ZdgSy5srKus5kfj5jZOy932e/73z2Gvlm7/3kXWumnHMAAAAAsH2bNe0GAAAAABg/QyAAAACABhgCAQAAADTAEAgAAACgAYZAAAAAAA0wBAIAAABogCEQAAAAQAMMgQAAAAAaYAgEAAAA0IAdhjk5pXRSRPx5RMyOiL/MOZ/d6/N3THPzTrHLMEvCNuvxeDSeyGvTJNaSTeifbEKdZBPqJJtQp36zmXLOAy2QUpodEd+NiJdHxOqIuDYiTs8531w6Z/e0V/6pdOJA68G27pp8eTyU7x/7N0zZhK0jm1An2YQ6ySbUqd9sDvN0sMURcVvO+fac8xMR8cWIOGWI6wGjIZtQJ9mEOskm1Ek2YQyGGQItiIg7N/l4defYk6SUlqSUlqWUlq2LtUMsB/RJNqFOsgl1kk2ok2zCGAwzBOp2m9Fmzy3LOZ+bc16Uc140J+YOsRzQJ9mEOskm1Ek2oU6yCWMwzBBodUQs3OTjAyLiruHaAUZANqFOsgl1kk2ok2zCGAwzBLo2Ig5PKT0zpbRjRLwmIi4eTVvAEGQT6iSbUCfZhDrJJozBwFvE55zXp5TOjIivxcYt+87POd80ss6Agcgm1Ek2oU6yCXWSTRiPgYdAERE550sj4tIR9QKMiGxCnWQT6iSbUCfZhNEb5ulgAAAAAGwjDIEAAAAAGmAIBAAAANAAQyAAAACABhgCAQAAADTAEAgAAACgAYZAAAAAAA0wBAIAAABogCEQAAAAQAMMgQAAAAAaYAgEAAAA0ABDIAAAAIAGGAIBAAAANMAQCAAAAKABhkAAAAAADTAEAgAAAGiAIRAAAABAAwyBAAAAABqww7QboA7zvrFfsXbAzg8Ua7eftEuxtuG++4fqCQAApu22jx5brH331/6iWFv0/jOLtX3/4sqhegIYlDuBAAAAABpgCAQAAADQAEMgAAAAgAYYAgEAAAA0wBAIAAAAoAF2B2vI2le8uFi75LBPF2szMVOsHX/qW4u1vc+7qr/GoHGzn/usYu20i64o1l6/+z3F2o9nnijWXvE75d1Kdrrk28Ua0J8dDjm4WLv4/36pWJudyv83t/jdbyrW9rzA91sY1g77l3fK/fypnyrWyj8l2wEMRmHNO19SrC1/xyeKtRPfXP6+Oe/vB/t597FTF4/8mtMw1BAopbQqIh6OiA0RsT7nvGgUTQHDkU2ok2xCnWQT6iSbMHqjuBPo53LO947gOsBoySbUSTahTrIJdZJNGCGvCQQAAADQgGGHQDkivp5SWp5SWjKKhoCRkE2ok2xCnWQT6iSbMGLDPh3s+JzzXSmlfSPispTSd3LO39z0EzphXRIRsVPsPORyQJ9kE+okm1An2YQ6ySaM2FB3AuWc7+q8vScivhwRm71cds753Jzzopzzojkxd5jlgD7JJtRJNqFOsgl1kk0YvYHvBEop7RIRs3LOD3fe/4WIeN/IOmPkvv+Lo38JqEcXpGJt75GvRj9kc9vz4HP3LNb+625rirV1uXzNOWl2sbb658v/Fhx2SfmaDEc2G7JhQ7G0ZsOPi7X9Zs8r1o59y7Ji7da/3alYm3n88WKNjWSTiIgHlu5SrB3TY65w74bHxtANEbLZkjved1yx9rnX/nmxNtPjmvccXR51HPT3/XS1uZ3vfLRY6/FjeXWGeTrYfhHx5ZTST67zNznn/z2SroBhyCbUSTahTrIJdZJNGIOBh0A559sj4oUj7AUYAdmEOskm1Ek2oU6yCeNhi3gAAACABhgCAQAAADTAEAgAAACgAYZAAAAAAA0YZncwGvGFhxcUa4d89vZibf04mgGAyuXH1xZrX3v0sGLtN3f/QbH2kflXF2s///NvKtZ2uuTbxRq0ZoeDFhZr5x75+R5n7lisHH/FW4u1w+O6ftqC5u1+zL3F2gvL8evp0AvuKtYG/T01L79pwDPr4k4gAAAAgAYYAgEAAAA0wBAIAAAAoAGGQAAAAAANMAQCAAAAaIAhEAAAAEADbBG/HZr56aO6Hv+Lk5YOdL0/ueqXirUj1iwf6JoAsL1ad+SCYu03d//Hga65ev1jxdqcRwbd7Ba2PzvM379YO/JLPyjWjpgz2D7UB144e6DzoDUPvPa4Yu3vnv+hHmfOHWi99bevGui8FrgTCAAAAKABhkAAAAAADTAEAgAAAGiAIRAAAABAAwyBAAAAABpgCAQAAADQAFvEb4e+99vdt6r8uXmP9DjLPBCm5aHTHxr5NWdipljb8QF5h2HN2m23Ym3JZy8a+Xqn/ut/K9b2/+frRr4ebKvWHbxfsfbB/S8t1man8lbvh/7t7xRrh/3D1f01Bo275gPnFGvr8ryBrnnyd04t1mbFnQNdswV+EwAAAABogCEQAAAAQAMMgQAAAAAaYAgEAAAA0ABDIAAAAIAGGAIBAAAANMAW8QBT9l8OWTHyaz4480SxduBZV458PWjN99/y/GLtl3f554Guedu6tcXaPp/YeaBrQmv+/VXlrMxELtZueuLxYu3Ij64u1tb31xY0YfZhzyzW1uXlxdpMzAy03pw37VisbRjoim3Y4p1AKaXzU0r3pJRu3OTYXimly1JKt3be7jneNoGnkk2ok2xCnWQT6iSbMFn9PB1saUSc9JRj74qIy3POh0fE5Z2PgclaGrIJNVoasgk1WhqyCTVaGrIJE7PFIVDO+ZsRcf9TDp8SERd03r8gIk4dcV/AFsgm1Ek2oU6yCXWSTZisQV8Yer+c85qIiM7bfUufmFJaklJallJati7Kz3UHRkI2oU6yCXWSTaiTbMKYjH13sJzzuTnnRTnnRXNi7riXA/okm1An2YQ6ySbUSTZh6ww6BLo7pTQ/IqLz9p7RtQQMQTahTrIJdZJNqJNswpgMukX8xRHxuog4u/P2KyPriL7ssP9+xdrvvfifRrrWgV+aPdLrMVayCXWSzW3R4vI28Of+9idHvtxrVpxRrO1/eXlrXYYim9ugdb+wqFj71ukf7nHmTsXKr1z09mLt0Duv7qctRks2t0G3nPW0kV/z2Ze/sVh71vdvHvl6Lehni/gvRMRVEfGslNLqlNIZsTGML08p3RoRL+98DEyQbEKdZBPqJJtQJ9mEydrinUA559MLpRNH3AuwFWQT6iSbUCfZhDrJJkzW2F8YGgAAAIDpMwQCAAAAaIAhEAAAAEADDIEAAAAAGjDoFvFMWd5tl2JtydNuG+lau1xZvt6Gka4E27cfvum4rsfftc/He5xlVg/Tcvuv7lqsHTu3fN5Mj2vetm5tsbbPJ3buoytg1S/NLtb2njWvWJuJXKw963+tLtbW99cWNOOB13b/mfa6Ez7a46wdi5UrHit/vz3y7EeKtQ2PP95jPUr8dgEAAADQAEMgAAAAgAYYAgEAAAA0wBAIAAAAoAGGQAAAAAANMAQCAAAAaIAt4gEm5A1nXtr1+CzzeKjSea8+Z+TX/NXP/n6xtvDyK0e+HmzL1rzzJV2Pf/vUD/U4q7xF/BFffVOx9qz/WNFvW9C8x/ZNXY/vnMrbwM+K7udERPzxB36rWNv75qv6b4y++M0DAAAAoAGGQAAAAAANMAQCAAAAaIAhEAAAAEADDIEAAAAAGmB3sO3QIDsNzUmzx9AJsKnD5/7HxNY69uJ3lPuIaybWB2wLfvOWO7seP27uhh5nlXc5OeGGVxdrB354ebGWe6wG26s0p7yb0CG//L2ux/ecVd4B7Ftryz8HP/sj9xZrG9Y9UaxBi2bvvVex9l9ff1nX4zMxUzznwZlyxubdVz6P0XMnEAAAAEADDIEAAAAAGmAIBAAAANAAQyAAAACABhgCAQAAADTAEAgAAACgAbaI3w712pqvZJ19aWEk7v+t44q1F839VqGy00BrLV9brh358fuKtV6bXsP2avW7X1KsvWbXTxYq5W3g12z4cbG260m3F2u+3cKTPf7zLyzWvnrYp7se7/WT7nu/d0qxtuOt5WwCT7byQ4cUa1/e62tbfb0X/8PvFWtHfOmarb4eg9vinUAppfNTSveklG7c5NhZKaUfpJRWdP6cPN42gaeSTaiTbEKdZBPqJJswWf08HWxpRJzU5fjHcs5Hdf5cOtq2gD4sDdmEGi0N2YQaLQ3ZhBotDdmEidniECjn/M2IuH8CvQBbQTahTrIJdZJNqJNswmQN88LQZ6aUru/cvrdn6ZNSSktSSstSSsvWRY8XsABGRTahTrIJdZJNqJNswhgMOgQ6JyIOjYijImJNRHyk9Ik553NzzotyzovmxNwBlwP6JJtQJ9mEOskm1Ek2YUwGGgLlnO/OOW/IOc9ExGcjYvFo2wIGIZtQJ9mEOskm1Ek2YXwG2iI+pTQ/57ym8+GrIuLGXp/P6N133H4jvd4RX19Srv1oxUjXYnxkc/pe+ubyFpd7zhpsK/iS07/2pmLtiFu+PdK1GI5sTsaGE15UrF34xuJ/IsdM7LjVa5207I3F2oK4aauvx3TI5vTt9Ad3bfU59254rFhLH9inx5l3bPVaTIdsTt/T93twpNd79h/dVqxtGOlKbMkWh0AppS9ExAkRsU9KaXVEvDciTkgpHRUROSJWRUT5JyFgLGQT6iSbUCfZhDrJJkzWFodAOefTuxw+bwy9AFtBNqFOsgl1kk2ok2zCZA2zOxgAAAAA2whDIAAAAIAGGAIBAAAANMAQCAAAAKABA20Rz/Q9cspDI71eerTHX4UZm/bBph78jWOLtT/d7+M9zpy91WvNxEyxtteKrb8ebM/WvGVtsXbEnK3fBv7k75xarC34FdvAQ792OGhhsfbBQy7scWb33B5/xVuLZxz+T8v7bQuaN/tZhxVr3zrqiz3O7H4vyRvuOLF4xob77u+3LcbMnUAAAAAADTAEAgAAAGiAIRAAAABAAwyBAAAAABpgCAQAAADQAEMgAAAAgAbYIp6IiNhlla2moV/3vLhcm5NGm6X/ee9Rxdo+n7lqpGvBtmDDCS8q1i485lM9ztz6LeJXf6O8rfWBcedWXw9a9eBn5hRrR8zZ+mweeKGfW2EUbvmdfYq1mZgp1v51bfd7SVa///DiOXPj2v4bY6zcCQQAAADQAEMgAAAAgAYYAgEAAAA0wBAIAAAAoAGGQAAAAAANMAQCAAAAaIAt4is2a7fdirUDnvbgSNd6xoevHOn1gNH42od+pljbI66eYCdQhzVvWVusDbLVdETE5Y/t3PX4wR+7oXhOeeNcaFM+7oXF2jee///2ODMVKz97w3/penyXf7DVNIzCd077VLHW6/vcW24+vevxvWRzm+BOIAAAAIAGGAIBAAAANMAQCAAAAKABhkAAAAAADTAEAgAAAGiAIRAAAABAA2wRX7GZZx9crF185PkjXWuHgxYWa+vvuHOkawFP9jt3/myxtudXbirWbFFNi15ywKqRX/NDb/yNrsd3eHj5yNeC7dUDz9q5WJuJPNA1H79ov67Hd4nbB7oeMBoPXr931+N7TbgPBrPFO4FSSgtTSleklFamlG5KKb2tc3yvlNJlKaVbO2/3HH+7wE/IJtRJNqFOsgl1kk2YrH6eDrY+It6Zc352RBwbEW9OKT0nIt4VEZfnnA+PiMs7HwOTI5tQJ9mEOskm1Ek2YYK2OATKOa/JOV/Xef/hiFgZEQsi4pSIuKDzaRdExKnjahLYnGxCnWQT6iSbUCfZhMnaqheGTikdHBFHR8Q1EbFfznlNxMbgRsS+hXOWpJSWpZSWrYu1w3ULdCWbUCfZhDrJJtRJNmH8+h4CpZR2jYiLIuLtOeeH+j0v53xuznlRznnRnJg7SI9AD7IJdZJNqJNsQp1kEyajryFQSmlObAzk53POX+ocvjulNL9Tnx8R94ynRaBENqFOsgl1kk2ok2zC5Gxxi/iUUoqI8yJiZc75o5uULo6I10XE2Z23XxlLhw2b9cSGYu2u9eVbHZ+xw9ZPwF96yXeKtc99d3GxNuuaPcp9fPDKre6D/snm9uO+tbsUazMP3z3BThgF2Rzehp97UbH2nv0/3uPMeQOtt9Pt93Y9vn6gq1Er2RzerN12K9Yu+dMP9zhzp2JlzYbHirW9b3i0n7bYxsnmeN3xvuOKtVlxXbH25Uf3KdYO/8s1XY/7vrlt2OIQKCKOj4jXRsQNKaUVnWPviY1hvDCldEZEfD8iXj2eFoEC2YQ6ySbUSTahTrIJE7TFIVDO+V8iIhXKJ462HaBfsgl1kk2ok2xCnWQTJmurdgcDAAAAYNtkCAQAAADQAEMgAAAAgAYYAgEAAAA0oJ/dwZiSmRU3F2uvft8fFGuXntV9i849Zu1YPOcde5W3iH/D4hXF2nGPnVmsAf1518JLi7V3vObNxdpuX7x6HO3A1N11fHk76QN2GGwb+J/5t18r1p62+vsDXRNac8dbn1+s7T3rG8XaTORi7c71O5cXvPr6vvoCesil19zunc0/+tKvF2uH3H7VUC0xXe4EAgAAAGiAIRAAAABAAwyBAAAAABpgCAQAAADQAEMgAAAAgAYYAgEAAAA0wBbx26i9zytvy3fsz7y16/GVL//0QGv91D91v15ExBFvWD7QNWFbdsjfP1GsffeUcu2IOTt2PX7M3PJaL3/3/y3Wrv7inPKJsA075Of/feTX/NHD5W2o91hXzi0wXn/8piXF2o6xbIKdwPZpj9vK28D3PO97I26EargTCAAAAKABhkAAAAAADTAEAgAAAGiAIRAAAABAAwyBAAAAABpgd7Dt0OGv775j1y/Hiwe63hFhBzDY1OwrrivW3vpbZxZrnzz/E12PHzanvD3YX/3zzxRrh8fVxRpsyx74xIHl4sfLpU8/cEixdti7HizW1vfTFBAL339lsXby+1800DXtAAbj9bS/Ku8qHWdPrg/q4U4gAAAAgAYYAgEAAAA0wBAIAAAAoAGGQAAAAAANMAQCAAAAaIAhEAAAAEADbBEPMEI7/J/lxdrbD37JVl/PNvC0aJeLrinWXnnRMQNe9Y4BzwOA7dMrF5S/p+4dPbaWZ5u2xTuBUkoLU0pXpJRWppRuSim9rXP8rJTSD1JKKzp/Th5/u8BPyCbUSTahTrIJdZJNmKx+7gRaHxHvzDlfl1LaLSKWp5Qu69Q+lnP+8PjaA3qQTaiTbEKdZBPqJJswQVscAuWc10TEms77D6eUVkbEgnE3BvQmm1An2YQ6ySbUSTZhsrbqhaFTSgdHxNER8ZMn65+ZUro+pXR+SmnPwjlLUkrLUkrL1sXaoZoFupNNqJNsQp1kE+okmzB+fQ+BUkq7RsRFEfH2nPNDEXFORBwaEUfFxsntR7qdl3M+N+e8KOe8aE7MHUHLwKZkE+okm1An2YQ6ySZMRl9DoJTSnNgYyM/nnL8UEZFzvjvnvCHnPBMRn42IxeNrE+hGNqFOsgl1kk2ok2zC5PSzO1iKiPMiYmXO+aObHJ+/yae9KiJuHH17QIlsQp1kE+okm1An2YTJ6md3sOMj4rURcUNKaUXn2Hsi4vSU0lERkSNiVUS8cSwdAiWyCXWSTaiTbEKdZBMmqJ/dwf4lIlKX0qWjbwfol2xCnWQT6iSbUCfZhMnaqt3BAAAAANg2GQIBAAAANMAQCAAAAKABhkAAAAAADTAEAgAAAGiAIRAAAABAAwyBAAAAABpgCAQAAADQAEMgAAAAgAYYAgEAAAA0wBAIAAAAoAEp5zy5xVL6YUTc0flwn4i4d2KL91ZLL/rYXC29jKKPg3LOTx9FM6Mmm1ukj83V0otsTkctvehjc7X0IpuTV0sfEfX0UksfEfX0IpuTV0sfEfX0oo/NTSybEx0CPWnhlJblnBdNZfGnqKUXfWyull5q6WMSavpaa+lFH5urpZda+piEmr7WWnrRx+Zq6aWWPiahlq+1lj4i6umllj4i6umllj4moZavtZY+IurpRR+bm2Qvng4GAAAA0ABDIAAAAIAGTHMIdO4U136qWnrRx+Zq6aWWPiahpq+1ll70sblaeqmlj0mo6WutpRd9bK6WXmrpYxJq+Vpr6SOinl5q6SOinl5q6WMSavlaa+kjop5e9LG5ifUytdcEAgAAAGByPB0MAAAAoAGGQAAAAAANmMoQKKV0UkrplpTSbSmld02jh04fq1JKN6SUVqSUlk147fNTSveklG7c5NheKaXLUkq3dt7uOaU+zkop/aDzuKxIKZ08gT4WppSuSCmtTCndlFJ6W+f4NB6TUi8Tf1wmTTZls0sfVWSz5VxGyGZnbdl8ch+yWQHZlM0ufcjmlNWSy04vsimb/fYxscdk4q8JlFKaHRHfjYiXR8TqiLg2Ik7POd880UY29rIqIhblnO+dwtovjYhHIuJzOefndY59MCLuzzmf3fkHa8+c8x9OoY+zIuKRnPOHx7n2U/qYHxHzc87XpZR2i4jlEXFqRLw+Jv+YlHo5LSb8uEySbP7n2rL55D6qyGaruYyQzU3Wls0n9yGbUyab/7m2bD65D9mcoppy2elnVcimbPbXx8SyOY07gRZHxG0559tzzk9ExBcj4pQp9DFVOedvRsT9Tzl8SkRc0Hn/gtj4l2EafUxcznlNzvm6zvsPR8TKiFgQ03lMSr1s72QzZLNLH1Vks+FcRshmRMhmlz5kc/pkM2SzSx+yOV1y2SGbm/Uhmx3TGAItiIg7N/l4dUzvH6QcEV9PKS1PKS2ZUg+b2i/nvCZi41+OiNh3ir2cmVK6vnP73thvE9xUSungiDg6Iq6JKT8mT+klYoqPywTIZplsRj3ZbCyXEbLZi2yGbE6RbJbJZsjmlNSUywjZ7EU2p5TNaQyBUpdj09qn/vic84si4hUR8ebOrWpEnBMRh0bEURGxJiI+MqmFU0q7RsRFEfH2nPNDk1q3z16m9rhMiGzWr/lsNpjLCNncFsimbP6EbNZFNtvLZk25jJDNEtmcYjanMQRaHRELN/n4gIi4awp9RM75rs7beyLiy7Hx9sFpurvzHMGfPFfwnmk0kXO+O+e8Iec8ExGfjQk9LimlObExCJ/POX+pc3gqj0m3Xqb1uEyQbJbJZgXZbDSXEbLZi2zK5jTJZplsyua0VJPLCNkskc3pZnMaQ6BrI+LwlNIzU0o7RsRrIuLiSTeRUtql80JMkVLaJSJ+ISJu7H3W2F0cEa/rvP+6iPjKNJr4SQg6XhUTeFxSSikizouIlTnnj25SmvhjUuplGo/LhMlmmWxOOZsN5zJCNnuRTdmcJtksk03ZnJYqchkhm73I5pSzmXOe+J+IODk2vmr79yLij6bUwyER8W+dPzdNuo+I+EJsvM1rXWycWJ8REXtHxOURcWvn7V5T6uOvIuKGiLg+NoZi/gT6+OnYeKvm9RGxovPn5Ck9JqVeJv64TPqPbMpmlz6qyGbLuex8/bIpm0/tQzYr+CObstmlD9mc8p8actnpQzbLfcjmFLM58S3iAQAAAJi8aTwdDAAAAIAJMwQCAAAAaIAhEAAAAEADDIEAAAAAGmAIBAAAANAAQyAAAACABhgCAQAAADTAEAgAAACgAYZAAAAAAA0wBAIAAABogCEQAAAAQAMMgQAAAAAaYAgEAAAA0ABDIAAAAIAGGAIBAAAANMAQCAAAAKABhkAAAAAADTAEAgAAAGiAIRAAAABAAwyBAAAAABpgCAQAAADQAEMgAAAAgAYYAgEAAAA0wBAIAAAAoAGGQAAAAAANMAQCAAAAaIAhEAAAAEADDIEAAAAAGmAIBAAAANAAQyAAAACABhgCAQAAADTAEAgAAACgAYZAAAAAAA0wBAIAAABogCEQAAAAQAMMgQAAAAAaYAgEAAAA0ABDIAAAAIAGGAIBAAAANMAQCAAAAKABhkAAAAAADTAEAgAAAGiAIRAAAABAAwyBAAAAABpgCAQAAADQAEMgAAAAgAYYAgEAAAA0wBAIAAAAoAGGQAAAAAANMAQCAAAAaIAhEAAAAEADDIEAAAAAGmAIBAAAANAAQyAAAACABhgCAQAAADTAEAgAAACgAYZAAAAAAA0wBAIAAABogCEQAAAAQAMMgQAAAAAaYAgEAAAA0ABDIAAAAIAGGAIBAAAANMAQCAAAAKABhkAAAAAADTAEAgAAAGiAIRAAAABAAwyBAAAAABpgCAQAAADQAEMgAAAAgAYYAgEAAAA0wBAIAAAAoAGGQAAAAAANMAQCAAAAaIAhEAAAAEADDIEAAAAAGmAIBAAAANAAQyAAAACABhgCAQAAADTAEAgAAACgAYZAAAAAAA0wBAIAAABogCEQAAAAQAMMgQAAAAAaYAgEAAAA0ABDIAAAAIAGGAIBAAAANGCHSS62Y5qbd4pdJrkkVOPxeDSeyGvTtPvoRjZpmWxCnWQT6iSbUKd+sznUECildFJE/HlEzI6Iv8w5n93r83eKXeKn0onDLAnbrGvy5RNbSzahf7IJdZJNqJNsQp36zebATwdLKc2OiE9FxCsi4jkRcXpK6TmDXg8YDdmEOskm1Ek2oU6yCeMxzGsCLY6I23LOt+ecn4iIL0bEKaNpCxiCbEKdZBPqJJtQJ9mEMRhmCLQgIu7c5OPVnWNPklJaklJallJati7WDrEc0CfZhDrJJtRJNqFOsgljMMwQqNsLDuXNDuR8bs55Uc550ZyYO8RyQJ9kE+okm1An2YQ6ySaMwTBDoNURsXCTjw+IiLuGawcYAdmEOskm1Ek2oU6yCWMwzBDo2og4PKX0zJTSjhHxmoi4eDRtAUOQTaiTbEKdZBPqJJswBgNvEZ9zXp9SOjMivhYbt+w7P+d808g6AwYim1An2YQ6ySbUSTZhPAYeAkVE5JwvjYhLR9QLMCKyCXWSTaiTbEKdZBNGb5ingwEAAACwjTAEAgAAAGiAIRAAAABAAwyBAAAAABpgCASQi4KXAAAZnUlEQVQAAADQAEMgAAAAgAYYAgEAAAA0wBAIAAAAoAGGQAAAAAANMAQCAAAAaIAhEAAAAEADdph2AwAAwPZrhwXPKNbWHbxvsfbIATtt9Vrzfrhuq8+JiNjh/ywv1mYf9sxi7dJvfrlY+8WX/HKxtn7V9/trDGDE3AkEAAAA0ABDIAAAAIAGGAIBAAAANMAQCAAAAKABhkAAAAAADTAEAgAAAGiALeIBpuzpVz6tWPvrg/+5WFv87jcVa3tecNUwLQHAyNz6loOKtW/8+oeKtX1mz9vqtR6ZWVusbYhcrN34xG7F2n6zryxfM5e3sb//Jc8o1na3RTwwJe4EAgAAAGiAIRAAAABAAwyBAAAAABpgCAQAAADQAEMgAAAAgAbYHYwteuyUxcXar/3ZPxZrS/ZYVayd+LvlXY3mfeXbffUF24uZnIq1dXlDsfaKd3yzWLv6gjlD9QQtmbVT9919vnv2UcVzTjj2xmLte3/y7GJt7j9e239jsJ3Y9Y5y7SWXvqNYSzt1/x540nNvKp5z6+8dWaz90QWfK9aOnvtosbZzmlus9bLHdx8p1sr7lEF7Hjr92GLtmx/+VLH2ygXHjKOd7d5QQ6CU0qqIeDgiNkTE+pzzolE0BQxHNqFOsgl1kk2ok2zC6I3iTqCfyznfO4LrAKMlm1An2YQ6ySbUSTZhhLwmEAAAAEADhh0C5Yj4ekppeUppSbdPSCktSSktSyktWxdrh1wO6JNsQp1kE+okm1An2YQRG/bpYMfnnO9KKe0bEZellL6Tc37SK5XmnM+NiHMjInZPe3kNNJgM2YQ6ySbUSTahTrIJIzbUnUA557s6b++JiC9HRHkbKWBiZBPqJJtQJ9mEOskmjN7AdwKllHaJiFk554c77/9CRLxvZJ1RjfuPLP816bUN/EzMDHTNBV/pqy0KZLNes591WNfjJ+71LwNdb04qbx8fYYv42shmvVb99xd1Pf7dV3+yeM7sVP5/tOcsekGxtvAf+++LyZDN8Xv6OVeVawNc73s9arNiRbH2P3779cXaXS/ZqVj7t9/9RB9dbW72vQ8Va+sHumJbZLMdz3nbjcVar98p7zvjuGJt7/PK/+60bping+0XEV9OKf3kOn+Tc/7fI+kKGIZsQp1kE+okm1An2YQxGHgIlHO+PSJeOMJegBGQTaiTbEKdZBPqJJswHraIBwAAAGiAIRAAAABAAwyBAAAAABpgCAQAAADQgGF2B6MRs477UbkWqdeZxcqCKx4eoiPYNj30vL27Hv/N3X8w4U4AoD2zr7iuWHvs119crPX6efeFHz+zWFuw6sr+GoMG/OAPX1KsXbrwk8XaTI/fKW0DPxh3AgEAAAA0wBAIAAAAoAGGQAAAAAANMAQCAAAAaIAhEAAAAEADDIEAAAAAGmCLeCIi4r4zjivWLnnRh4q1mZhXrH3qgUPLC377hr76Asr++jvl7WwPChmDcbp3w6PF2kFffaBYmxlHM0Bfdth/v2LtYyd8oVibiVysLfiAbeChH48e/kSx1itjM75zjpw7gQAAAAAaYAgEAAAA0ABDIAAAAIAGGAIBAAAANMAQCAAAAKABhkAAAAAADbBFPBER8cCJjxVr82eXt4GfFalYu+Cck4u1fcN2mrTn3hfMHun1dv+HXUd6PWjV4wvWbfU5P849trNdcfMw7QBjcs/JhxRrPzX3omLt7g3j6Aa2P+tfdkyx9oWXfaZY6/U75ZI7X9ZjxYf7aYuncCcQAAAAQAMMgQAAAAAaYAgEAAAA0ABDIAAAAIAGGAIBAAAANMAQCAAAAKABW9wiPqV0fkS8MiLuyTk/r3Nsr4j424g4OCJWRcRpOecfja9NRmLx84ulvz72vGJtJmaKteVry3PE+Zf/sFiz0+bwZHPb87KTr5t2C0yAbG57Lnj5Z6fdAhMgm9y7uPwT6D6z5xVrR1745mLtsLh6qJ6Qze3JHb9dztjRc8u/U17b43fKO999eLE2O/xsPYh+7gRaGhEnPeXYuyLi8pzz4RFxeedjYLKWhmxCjZaGbEKNloZsQo2WhmzCxGxxCJRz/mZE3P+Uw6dExAWd9y+IiFNH3BewBbIJdZJNqJNsQp1kEyZr0NcE2i/nvCYiovN239G1BAxBNqFOsgl1kk2ok2zCmGzxNYGGlVJaEhFLIiJ2ip3HvRzQJ9mEOskm1Ek2oU6yCVtn0DuB7k4pzY+I6Ly9p/SJOedzc86Lcs6L5sTcAZcD+iSbUCfZhDrJJtRJNmFMBh0CXRwRr+u8/7qI+Mpo2gGGJJtQJ9mEOskm1Ek2YUz62SL+CxFxQkTsk1JaHRHvjYizI+LClNIZEfH9iHj1OJtkNI75zL8Vay+em4q1mR6zwtf+3ZnF2iErr+qvMQYim21YuW5dsbbrmnKN6ZFNqJNstiEff1SxduXJHy3WvvX47sXaYb9nG/hxks3tx+8ddXmxNqvH75T/885fLNZmX2Eb+FHb4hAo53x6oXTiiHsBtoJsQp1kE+okm1An2YTJGvTpYAAAAABsQwyBAAAAABpgCAQAAADQAEMgAAAAgAYYAgEAAAA0YIu7g7Htue+M47oe/529P1Q8ZybmFWsvvf60Yu2Q/24beNjUDgsPKNaO2PnGrb7eX93fPc8REXO+vmyrrwcA27NHDtipWNtndvnn3VvWjaMb2E4tfn7Xw0v2WFo8ZSZmirXbLz2kWFsQ/9F3W/THnUAAAAAADTAEAgAAAGiAIRAAAABAAwyBAAAAABpgCAQAAADQAEMgAAAAgAbYIn479IZ3XtL1+Pwe22LeveGxYm39/7dvj9W+129b0IT7f7q8Rfybn/aVCXYCAACjt+Y967senxWpeM7yteX7TxZ84Mqhe6J/7gQCAAAAaIAhEAAAAEADDIEAAAAAGmAIBAAAANAAQyAAAACABtgdbBt13xnHFWtL9vhk1+MzMVM85xXLlxRrzzj/qv4bAwCAKZl53b0DnXfh/Yt7VB8frBnYTr3yoJu6Hp+JXDznN6767WLt0PjXoXuif+4EAgAAAGiAIRAAAABAAwyBAAAAABpgCAQAAADQAEMgAAAAgAYYAgEAAAA0wBbx26gnfumBYm1WpGKl5Bl/NnvIjgBg2zQ7ZrofT+Xvm3PG1QywRbMPe2ax9q0XXtjjzNLPyBH/uPwFxdoR8e1+2oLtykOnH1us/cm+n+p6vPx7aMRuV80buidGY4t3AqWUzk8p3ZNSunGTY2ellH6QUlrR+XPyeNsEnko2oU6yCXWSTaiTbMJk9fN0sKURcVKX4x/LOR/V+XPpaNsC+rA0ZBNqtDRkE2q0NGQTarQ0ZBMmZotDoJzzNyPi/gn0AmwF2YQ6ySbUSTahTrIJkzXMC0OfmVK6vnP73p6lT0opLUkpLUspLVsXa4dYDuiTbEKdZBPqJJtQJ9mEMRh0CHRORBwaEUdFxJqI+EjpE3PO5+acF+WcF82JuQMuB/RJNqFOsgl1kk2ok2zCmAw0BMo5351z3pBznomIz0bE4tG2BQxCNqFOsgl1kk2ok2zC+Ay0RXxKaX7OeU3nw1dFxI29Pp8BLX5+sXTJi84p1mai+/Z7L73+tOI5u3/7hv77olqyCXWSzbptKPyf2Ibcfev4iIhfW/naYm1e/PvQPTEZsrltuuvk+QOdNxO5WNvz32YP2g5jIJvT95y3lR/ymej+/fHTDxxWPGf+58rX29B/W4zAFodAKaUvRMQJEbFPSml1RLw3Ik5IKR0VETkiVkXEG8fYI9CFbEKdZBPqJJtQJ9mEydriECjnfHqXw+eNoRdgK8gm1Ek2oU6yCXWSTZisYXYHAwAAAGAbYQgEAAAA0ABDIAAAAIAGGAIBAAAANGCgLeKZjIcO3aVYmz+7+zbwERGzInU9vssH9xi6J6C3B3/1kZFe76uXHFusHRRXjXQtoH9rVuxfrB1ii3gYq4eeNdiG0uc9eGCxtt9f99oOG7ZPOyw8oFg7d+HFxdpM4V6Sz3zuF4vnLHjoyv4bY6zcCQQAAADQAEMgAAAAgAYYAgEAAAA0wBAIAAAAoAGGQAAAAAANMAQCAAAAaIAt4iv2sfd/qlib6bFZ5fK13Wd7O/7Hw8VzBttoE9q0/mXHFGt/dcw5Pc6c3fXoj2YeL56x73LpBIBNveB5qwY67+51exRrMw+Xf06G7dXN792/WJuJ3KPW/XfRvW9eP3RPjJ87gQAAAAAaYAgEAAAA0ABDIAAAAIAGGAIBAAAANMAQCAAAAKABhkAAAAAADbBFfMVePDcVazM95ncXP3h01+MbVt46dE9AxH3PnVusvWDH7tvA9/JHd/1CsTbvK9/e6usB47f/1d23xwVGaPHzux5+ywF/M9DlXrXHdcXaYbcsKNb+5mU/Vayt/8FdA/UCk7L+ZccUa999xWeKtVlR/l108Z+9revxfb96Zf+NMTXuBAIAAABogCEQAAAAQAMMgQAAAAAaYAgEAAAA0ABDIAAAAIAGGAIBAAAANGCLW8SnlBZGxOciYv+ImImIc3POf55S2isi/jYiDo6IVRFxWs75R+NrtT0zkXvUylvTfnXV87oef0bcPHRP1EM2p+f0JZeN9HrfuPwFxdoz46qRrsX4yWYbdrv5vmJtwwT7oH/bYzbT3LnF2obFzynWnv2xG4u1X9lzebF2z4bdirWVj5W3WP/6miOLtYN2Kz/Uf3rAOV2PH7jDvOI50WNb62fPmVOsLZx9Z7H23refVqwd+ge2iB/W9pjNmtz33PK/E71+p1y+tny/yPzLf9j1uO9/24Z+7gRaHxHvzDk/OyKOjYg3p5SeExHviojLc86HR8TlnY+ByZFNqJNsQp1kE+okmzBBWxwC5ZzX5Jyv67z/cESsjIgFEXFKRFzQ+bQLIuLUcTUJbE42oU6yCXWSTaiTbMJkbdVrAqWUDo6IoyPimojYL+e8JmJjcCNi38I5S1JKy1JKy9bF2uG6BbqSTaiTbEKdZBPqJJswfn0PgVJKu0bERRHx9pzzQ/2el3M+N+e8KOe8aE6Un48IDEY2oU6yCXWSTaiTbMJk9DUESinNiY2B/HzO+Uudw3enlOZ36vMj4p7xtAiUyCbUSTahTrIJdZJNmJwtDoFSSikizouIlTnnj25SujgiXtd5/3UR8ZXRtweUyCbUSTahTrIJdZJNmKwtbhEfEcdHxGsj4oaU0orOsfdExNkRcWFK6YyI+H5EvHo8LbZrTppdrK0r7x4f+ZqnjaEbKiSb24n9lpW352SbJJsN2HDLbdNuga233WXz8Ze9oFi77C8/XazN6rGN+kz0+CEz7i9WXrVLufbH+5S3pO+9Xq+t4LfeTU+sL9Ze9/HfL9YO/diVI+2DzWx32Zy0HRYeUKytePdfFGvrcvmekDPff2axtvfKq/prjCptcQiUc/6XiOJ3ihNH2w7QL9mEOskm1Ek2oU6yCZO1VbuDAQAAALBtMgQCAAAAaIAhEAAAAEADDIEAAAAAGmAIBAAAANCAfraIZ4zuO+O4Ym1dXl6szUSPLaV77bQJDO3HMztOuwVghP7X6pd3PX78oV8rnvO9vzmqWDv011cUazBKOy9bVawddfVvFmvp2j2Ktcdf8ONiba89Hi3W7n9g12Jtp3lPFGu5x8+trzzkpq7HnztvdfGcDy49rVjb4/byz88LLv9OsbahWIE63PHrBxZr63L5b3Cv3yn3Ps828NsrdwIBAAAANMAQCAAAAKABhkAAAAAADTAEAgAAAGiAIRAAAABAAwyBAAAAABpgi/gpe/zpqVibk2YXa1c/Xr7mgg9cOUxLwBZcc8bRxdrlX7ytWLtn/W5dj+92ywPFc2xLC+N370ef2fX4Y58sb2t9y8+eX6wd9ftnFmvP+LDv0YzOhh/+sFg74FfLtXHYcwzXXFE8fkDxnAUxWMZ8v2WbVv6VsufvlIf83e8Wa4fHNcN0RMXcCQQAAADQAEMgAAAAgAYYAgEAAAA0wBAIAAAAoAGGQAAAAAANsDvYlO190/pibV0u71PwhgveUqwdOOCuCEB/8rIbi7WPHfbsAa54y+DNAEOb9/ff7nr8tH/+f4rnrPzAEeXrzRu6JQDoXy6XPvGjg4q1I/94ZbFmx7ztlzuBAAAAABpgCAQAAADQAEMgAAAAgAYYAgEAAAA0wBAIAAAAoAGGQAAAAAAN2OIW8SmlhRHxuYjYPyJmIuLcnPOfp5TOioj/FhE/7Hzqe3LOl46r0e3VTl/tvi1tRMQrv3pMsWYbeGQT6iSb248NDzxYrB3xxmsn2AmjIJtQJ9kc3oIPlH83vOQDe/Y486HRN0P1tjgEioj1EfHOnPN1KaXdImJ5SumyTu1jOecPj689oAfZhDrJJtRJNqFOsgkTtMUhUM55TUSs6bz/cEppZUQsGHdjQG+yCXWSTaiTbEKdZBMma6teEyildHBEHB0R13QOnZlSuj6ldH5Kqdd9ZsAYySbUSTahTrIJdZJNGL++h0AppV0j4qKIeHvO+aGIOCciDo2Io2Lj5PYjhfOWpJSWpZSWrYu1I2gZ2JRsQp1kE+okm1An2YTJ6GsIlFKaExsD+fmc85ciInLOd+ecN+ScZyLisxGxuNu5Oedzc86Lcs6L5sTcUfUNhGxCrWQT6iSbUCfZhMnZ4hAopZQi4ryIWJlz/ugmx+dv8mmviogbR98eUCKbUCfZhDrJJtRJNmGy+tkd7PiIeG1E3JBSWtE59p6IOD2ldFRE5IhYFRFvHEuHQIlsQp1kE+okm1An2YQJ6md3sH+JiNSldOno2wH6JZtQJ9mEOskm1Ek2YbK2ancwAAAAALZNhkAAAAAADTAEAgAAAGiAIRAAAABAAwyBAAAAABpgCAQAAADQAEMgAAAAgAYYAgEAAAA0wBAIAAAAoAGGQAAAAAANMAQCAAAAaIAhEAAAAEADUs55coul9MOIuKPz4T4Rce/EFu+tll70sblaehlFHwflnJ8+imZGTTa3SB+bq6UX2ZyOWnrRx+Zq6UU2J6+WPiLq6aWWPiLq6UU2J6+WPiLq6UUfm5tYNic6BHrSwiktyzkvmsriT1FLL/rYXC291NLHJNT0tdbSiz42V0svtfQxCTV9rbX0oo/N1dJLLX1MQi1fay19RNTTSy19RNTTSy19TEItX2stfUTU04s+NjfJXjwdDAAAAKABhkAAAAAADZjmEOjcKa79VLX0oo/N1dJLLX1MQk1fay296GNztfRSSx+TUNPXWksv+thcLb3U0sck1PK11tJHRD291NJHRD291NLHJNTytdbSR0Q9vehjcxPrZWqvCQQAAADA5Hg6GAAAAEADDIEAAAAAGjCVIVBK6aSU0i0ppdtSSu+aRg+dPlallG5IKa1IKS2b8Nrnp5TuSSnduMmxvVJKl6WUbu283XNKfZyVUvpB53FZkVI6eQJ9LEwpXZFSWplSuiml9LbO8Wk8JqVeJv64TJpsymaXPqrIZsu5jJDNztqy+eQ+ZLMCsimbXfqQzSmrJZedXmRTNvvtY2KPycRfEyilNDsivhsRL4+I1RFxbUScnnO+eaKNbOxlVUQsyjnfO4W1XxoRj0TE53LOz+sc+2BE3J9zPrvzD9aeOec/nEIfZ0XEIznnD49z7af0MT8i5uecr0sp7RYRyyPi1Ih4fUz+MSn1clpM+HGZJNn8z7Vl88l9VJHNVnMZIZubrC2bT+5DNqdMNv9zbdl8ch+yOUU15bLTz6qQTdnsr4+JZXMadwItjojbcs6355yfiIgvRsQpU+hjqnLO34yI+59y+JSIuKDz/gWx8S/DNPqYuJzzmpzzdZ33H46IlRGxIKbzmJR62d7JZshmlz6qyGbDuYyQzYiQzS59yOb0yWbIZpc+ZHO65LJDNjfrQzY7pjEEWhARd27y8eqY3j9IOSK+nlJanlJaMqUeNrVfznlNxMa/HBGx7xR7OTOldH3n9r2x3ya4qZTSwRFxdERcE1N+TJ7SS8QUH5cJkM0y2Yx6stlYLiNksxfZDNmcItksk82QzSmpKZcRstmLbE4pm9MYAqUux6a1T/3xOecXRcQrIuLNnVvViDgnIg6NiKMiYk1EfGRSC6eUdo2IiyLi7Tnnhya1bp+9TO1xmRDZrF/z2WwwlxGyuS2QTdn8Cdmsi2y2l82achkhmyWyOcVsTmMItDoiFm7y8QERcdcU+oic812dt/dExJdj4+2D03R35zmCP3mu4D3/f3t3jCJFEIZh+KtAk800MlTwFsYGm5mZbeAx9g5eQIxEzAQ39wQm7qoYyB5gz7BoGXQJItOTWX9DPw8UM8ww8E/BmxQ9PRVD9N5veu8/e++/krzKpH1prd3JEsLb3vv78XLJnhyapWpfJtLmOm1uoM2ddplo8xhtarOSNtdpU5tVNtNlos012qxts+IQ6FOSx621h621u0meJ7mYPURr7WTciCmttZMkT5N8Pf6p/+4iydl4fpbkQ8UQfyIYnmXCvrTWWpLXSb733l/+9db0PVmbpWJfJtPmOm0Wt7njLhNtHqNNbVbS5jptarPKJrpMtHmMNovb7L1PX0lOs9y1/TrJedEMj5JcjvVt9hxJ3mW5zOs2y4n1iyT3k3xM8mM83iua402SL0muskTxYMIcT7JcqnmV5PNYp0V7sjbL9H2ZvbSpzQNzbKLNPXc5vr82tfnvHNrcwNKmNg/Moc3itYUuxxzaXJ9Dm4VtTv+LeAAAAADmq/g5GAAAAACTOQQCAAAA2AGHQAAAAAA74BAIAAAAYAccAgEAAADsgEMgAAAAgB1wCAQAAACwA78BY010KJtAWJAAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import struct\n", "import numpy as np\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "#Define the number of images for the barycenter calculation\n", "n=20\n", "\n", "#Read the images from the file\n", "def read_idx(filename):\n", " with open(filename, 'rb') as f:\n", " zero, data_type, dims = struct.unpack('>HBB', f.read(4))\n", " shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))\n", " return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)\n", " \n", "data = read_idx('train-images.idx3-ubyte')\n", "labels = read_idx('train-labels.idx1-ubyte')\n", "#Select the images\n", "ones = data[labels == 1]\n", "train_1 = ones[:n]\n", "\n", "plt.figure(figsize=(20,10))\n", "for i in range(10):\n", " plt.subplot(2,5,i+1)\n", " plt.imshow(ones[np.random.randint(0,ones.shape[0])])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Barycenters using Mosek Fusion" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The final barycenter problem is as follows. We choose $p=2$.\n", "
\n", "
\n", "$$ \\mbox{minimize} \\quad \\frac1N \\sum_{i,j,k}^{N} D(X_i,Y_j)^2\\pi_{ij}^k$$\n", "
\n", "$$\\mbox{st.} \\quad \\sum_{j=1} \\pi_{ij}^{k} = \\mu_i, \\quad \\forall_{k,i} \\quad (1)$$\n", "
\n", "$$ \\quad \\sum_{i=1} \\pi_{ij}^{k} = \\upsilon_j^{k}, \\quad \\forall_{k,j} \\quad (2) $$\n", "
\n", "$$ \\pi_{ij}^{k} \\geq 0 \\quad \\forall_{k,i,j}$$\n", "
\n", "where $D(X_i,Y_j)$ is the Euclidean distance between pixels and $N$ is the number of samples." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from mosek.fusion import *\n", "import time\n", "import sys\n", "class Wasserstein_Fusion:\n", " \n", " def __init__(self):\n", " self.time = 0.0\n", " self.M = Model('Wasserstein')\n", " self.result = None\n", " \n", " \n", " def single_pmf(self, data = None, img=False):\n", " \n", " ''' Takes a image or array of images and extracts the probabilty mass function'''\n", " \n", " if not img:\n", " v=[]\n", " for image in data:\n", " arr = np.asarray(image).ravel(order='K')\n", " v.append(arr/np.sum(arr))\n", " else:\n", " v = np.asarray(data).ravel(order='K')\n", " v = v/np.sum(v)\n", " return v\n", " \n", " def ms_distance(self, m ,n, constant=False):\n", " \n", " ''' Squared Euclidean distance calculation between the pixels '''\n", " \n", " if constant:\n", " d = np.ones((m,m))\n", " else:\n", " d = np.empty((m,m))\n", " coor = []\n", " for i in range(n):\n", " for j in range(n):\n", " coor.append(np.array([i,j]))\n", " for i in range(m):\n", " for j in range(m):\n", " d[i][j] = np.linalg.norm(coor[i]-coor[j])**2\n", " return d\n", " \n", " def Wasserstein_Distance(self, bc ,data, img = False):\n", " \n", " ''' Calculation of wasserstein distance between a barycenter and an image by solving the minimization problem '''\n", " \n", " v = np.array(self.single_pmf(data, img))\n", " n = v.shape[0]\n", " d = self.ms_distance(n,data.shape[1])\n", " with Model('Wasserstein') as M:\n", " #Add variable\n", " pi = M.variable('pi',[n,n], Domain.greaterThan(0.0))\n", " \n", " #Add constraints\n", " M.constraint('c1' , Expr.sum(pi,0), Domain.equalsTo(v))\n", " M.constraint('c2' , Expr.sum(pi,1), Domain.equalsTo(bc))\n", " \n", " M.objective('Obj.' , ObjectiveSense.Minimize, Expr.dot(d, pi))\n", " \n", " M.solve()\n", " objective = M.primalObjValue()\n", " \n", " return objective\n", " \n", " def Wasserstein_BaryCenter(self,data):\n", "\n", " M = self.M\n", " start_time = time.time()\n", " k = data.shape[0]\n", " v = np.array(self.single_pmf(data))\n", " n = v.shape[1]\n", " d = self.ms_distance(n,data.shape[1])\n", "\n", " #Add variables \n", " mu = M.variable('Mu', n, Domain.greaterThan(0.0)) \n", " pi = (M.variable('Pi', [k,n,n] , Domain.greaterThan(0.0)))\n", "\n", " #Add constraints \n", "\n", " #Constraint (1)\n", " M.constraint('B', Expr.sub(Expr.sum(pi,1) , Var.repeat(mu,1,k).transpose()), Domain.equalsTo(0.0))\n", " #Constraint (2)\n", " M.constraint('C', Expr.sum(pi,2), Domain.equalsTo(v))\n", "\n", " M.objective('Obj' , ObjectiveSense.Minimize, Expr.sum(Expr.mul(Expr.mul(Expr.reshape(pi.asExpr(), k, n*n) , d.ravel()), 1/k)))\n", "\n", " M.setLogHandler(sys.stdout)\n", " M.solve()\n", " self.result = mu.level()\n", " M.selectedSolution(SolutionType.Interior)\n", " self.objective = M.primalObjValue()\n", " self.time = time.time() - start_time\n", "\n", " return mu.level()\n", "\n", " def reset(self):\n", " self.M = Model('Wasserstein')\n", " " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Problem\n", " Name : Wasserstein \n", " Objective sense : min \n", " Type : LO (linear optimization problem)\n", " Constraints : 31360 \n", " Cones : 0 \n", " Scalar variables : 12293905 \n", " Matrix variables : 0 \n", " Integer variables : 0 \n", "\n", "Optimizer started.\n", "Presolve started.\n", "Linear dependency checker started.\n", "Linear dependency checker terminated.\n", "Eliminator started.\n", "Freed constraints in eliminator : 0\n", "Eliminator terminated.\n", "Eliminator - tries : 1 time : 0.00 \n", "Lin. dep. - tries : 1 time : 0.35 \n", "Lin. dep. - number : 19 \n", "Presolve terminated. Time: 5.70 \n", "GP based matrix reordering started.\n", "GP based matrix reordering terminated.\n", "Problem\n", " Name : Wasserstein \n", " Objective sense : min \n", " Type : LO (linear optimization problem)\n", " Constraints : 31360 \n", " Cones : 0 \n", " Scalar variables : 12293905 \n", " Matrix variables : 0 \n", " Integer variables : 0 \n", "\n", "Optimizer - threads : 24 \n", "Optimizer - solved problem : the primal \n", "Optimizer - Constraints : 17440\n", "Optimizer - Cones : 0\n", "Optimizer - Scalar variables : 1395520 conic : 0 \n", "Optimizer - Semi-definite variables: 0 scalarized : 0 \n", "Factor - setup time : 16.44 dense det. time : 0.21 \n", "Factor - ML order time : 0.04 GP order time : 14.45 \n", "Factor - nonzeros before factor : 1.55e+06 after factor : 1.44e+07 \n", "Factor - dense dim. : 0 flops : 1.64e+10 \n", "Factor - GP saved nzs : 1.75e+06 GP saved flops : 3.15e+09 \n", "ITE PFEAS DFEAS GFEAS PRSTATUS POBJ DOBJ MU TIME \n", "0 5.5e+03 3.2e+02 8.3e+07 0.00e+00 1.182613040e+07 0.000000000e+00 4.9e+01 23.93 \n", "1 6.8e-01 4.0e-02 1.0e+04 -1.00e+00 1.120231068e+07 -1.638527125e+05 6.1e-03 24.96 \n", "2 5.8e-02 3.4e-03 8.8e+02 2.61e+01 1.558624710e+04 -2.985003979e+03 5.2e-04 26.57 \n", "3 4.5e-02 2.6e-03 6.8e+02 1.08e+01 3.256531514e+03 -6.820367762e+02 4.0e-04 27.67 \n", "4 4.2e-02 2.5e-03 6.4e+02 5.25e+00 2.407863248e+03 -4.995626420e+02 3.8e-04 28.62 \n", "5 3.8e-02 2.2e-03 5.7e+02 4.40e+00 1.571023848e+03 -3.171235705e+02 3.4e-04 29.52 \n", "6 3.3e-02 1.9e-03 5.0e+02 3.48e+00 1.104555995e+03 -2.125212567e+02 3.0e-04 30.36 \n", "7 2.9e-02 1.7e-03 4.3e+02 2.91e+00 7.928892307e+02 -1.415543053e+02 2.6e-04 31.30 \n", "8 1.2e-02 7.1e-04 1.8e+02 2.47e+00 2.252271180e+02 -1.976502285e+01 1.1e-04 32.37 \n", "9 5.2e-03 3.1e-04 7.9e+01 1.44e+00 9.137862868e+01 -2.802278257e+00 4.7e-05 33.33 \n", "10 1.2e-03 6.9e-05 1.8e+01 1.16e+00 2.223917710e+01 2.047778184e+00 1.1e-05 34.94 \n", "11 7.7e-04 4.5e-05 1.2e+01 1.04e+00 1.552925437e+01 2.419701859e+00 6.9e-06 35.99 \n", "12 3.9e-04 2.3e-05 5.9e+00 1.02e+00 9.411514139e+00 2.769092362e+00 3.5e-06 37.12 \n", "13 1.3e-04 8.1e-06 2.0e+00 1.01e+00 5.275173812e+00 3.000408169e+00 1.2e-06 38.69 \n", "14 5.5e-05 3.3e-06 8.2e-01 1.00e+00 3.990965725e+00 3.070294675e+00 4.9e-07 39.91 \n", "15 2.5e-05 1.5e-06 3.8e-01 1.00e+00 3.515474345e+00 3.095350472e+00 2.2e-07 40.99 \n", "16 1.2e-05 7.0e-07 1.8e-01 1.00e+00 3.304309669e+00 3.106260800e+00 1.1e-07 42.22 \n", "17 7.1e-06 4.3e-07 1.1e-01 1.00e+00 3.229689796e+00 3.109866447e+00 6.4e-08 43.16 \n", "18 3.1e-06 2.4e-07 4.8e-02 1.00e+00 3.165660320e+00 3.112133382e+00 2.8e-08 44.08 \n", "19 1.2e-06 9.5e-08 1.9e-02 1.00e+00 3.135045283e+00 3.113923931e+00 1.1e-08 45.09 \n", "20 2.5e-07 1.9e-08 3.9e-03 1.00e+00 3.119054866e+00 3.114731758e+00 2.3e-09 46.13 \n", "21 1.2e-09 9.1e-11 1.8e-05 1.00e+00 3.114878927e+00 3.114858639e+00 1.1e-11 47.44 \n", "22 3.4e-12 2.5e-13 4.8e-08 1.00e+00 3.114858825e+00 3.114858771e+00 2.8e-14 48.27 \n", "23 3.7e-13 5.7e-14 3.3e-13 1.00e+00 3.114858772e+00 3.114858772e+00 2.9e-18 49.09 \n", "Basis identification started.\n", "Basis identification terminated. Time: 0.51\n", "Optimizer terminated. Time: 53.52 \n", "\n", "\n", "Interior-point solution summary\n", " Problem status : PRIMAL_AND_DUAL_FEASIBLE\n", " Solution status : OPTIMAL\n", " Primal. obj: 3.1148587718e+00 nrm: 1e+00 Viol. con: 4e-13 var: 0e+00 \n", " Dual. obj: 3.1148587718e+00 nrm: 2e+02 Viol. con: 0e+00 var: 6e-14 \n", "\n", "Basic solution summary\n", " Problem status : PRIMAL_AND_DUAL_FEASIBLE\n", " Solution status : OPTIMAL\n", " Primal. obj: 3.1148587718e+00 nrm: 1e+00 Viol. con: 4e-17 var: 1e-06 \n", " Dual. obj: 3.1148587718e+00 nrm: 2e+02 Viol. con: 0e+00 var: 2e-09 \n", "\n", "Time Spent to solve problem with Fusion: \n", " 74.02755951881409\n", "Time Spent in solver: \n", " 53.52407097816467\n", "The average Wasserstein distance between digits and the barycenter: \n", " 3.1148587717997467\n" ] } ], "source": [ "fusion_model = Wasserstein_Fusion()\n", "f_bc = fusion_model.Wasserstein_BaryCenter(train_1)\n", "print('\\nTime Spent to solve problem with Fusion: \\n {0}'.format(fusion_model.time))\n", "print('Time Spent in solver: \\n {0}'.format(fusion_model.M.getSolverDoubleInfo(\"optimizerTime\")))\n", "print('The average Wasserstein distance between digits and the barycenter: \\n {0}'.format(fusion_model.objective))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEKNJREFUeJzt3X2wVPV9x/H3BwSMgCJREREBFY1PFTO32ESnkjgxmjQDcWoiSRQztmRqbOtM/qi1MxE7Teu0McbpWB1SmWASNTTRSltjNJgJTccHrsQokSQgQwC5ggQVULk83G//2IOzXu/+7nLv7p69/D6vmZ179nz37PnODh/O2fOwP0UEZpafYWU3YGblcPjNMuXwm2XK4TfLlMNvlimH3yxTDr9Zphz+NidpvaS3Je2S9Jqk/5E0uey+6iVpgaTvlt2HvZfDPzR8KiLGABOBLcC/HuwbSDqs4V21wFDteyhw+IeQiNgN/AA4E0DSJyX9QtIOSRslLTjwWklTJYWkayVtAJ4o9hr+svo9JT0vaU4xfZakxyVtl7RF0k3F/GGSbpT0kqTfS1oiaXyv9cyTtEHSNkl/V9QuBW4CPlvsufyymH+UpHskdUl6WdI/SBpe1K6R9H+Sbpe0HViANYXDP4RIOgL4LPBUMetN4GpgHPBJ4C8OBLnKRcAZwMeBxcAXqt7vXGAS8IikscBPgEeBE4BTgWXFS/8KmFO81wnAa8CdvdZzIXA6cDHwVUlnRMSjwD8C34+IMRFxbvHaxcC+Yh3nAZcAf1b1XucD64DjgK/V+fHYwYoIP9r4AawHdgGvUwnMZuCcGq/9JnB7MT0VCODkqvooYDswvXj+deDfium5wC9qvO9q4OKq5xOBvcBhVes5sar+DHBlMb0A+G5VbQLQDbyvat5c4KfF9DXAhrI/9xwe/j41NMyJiJ8Uu8azgZ9JOhOYAtwKnA2MpBLu/+i17MYDExHRLWkJ8AVJt1AJ3Z8W5cnASzXWPwV4SFJP1bz9VIJ8wCtV028BYxLvNQLoknRg3rDqPntNW5N4t38IiYj9EfEgleBdCNwHLAUmR8RRwN2Aei/W6/li4PNUds/fiogni/kbgVNqrHojcFlEjKt6HB4RL9fTdh/v1Q0cU/VeR0bEWYllrAkc/iFEFbOBo6nsio8FtkfEbkkzgc/19x5F2HuA24DvVJX+Gzhe0g2SRkkaK+n8onY38DVJU4o+ji36qMcWYKqkYcX6u4DHgNskHVkcTDxF0kV1vp81iMM/NPyXpF3ADioHwOZFxK+A64C/l7QT+CqwpM73uxc4B3jn/HtE7AQ+BnyKyi78GuAjRfkOKnsYjxXreorKQbl6HPga8ntJK4vpq6l8TXmRysHDH1A5jmAtpOIgi2VE0tXA/Ii4sOxerDze8memOF14HbCw7F6sXA5/RiR9HHiVyvfw+0pux0rm3X6zTHnLb5apll7kM1Kj4nBGt3KVZlnZzZvsie7e13r0aVDhL27cuAMYDvx7RNyaev3hjOZ8XTyYVZpZwtOxrP8XFQa8219canoncBmVu8zmFpecmtkQMJjv/DOBtRGxLiL2AA9Que7czIaAwYR/Eu++AWNTMe9dJM2X1Cmpcy/dg1idmTXSYMLf10GF95w3jIiFEdERER0jGDWI1ZlZIw0m/Juo3AZ6wIlU7jU3syFgMOFfAUyXNE3SSOBKKjd/mNkQMOBTfRGxT9L1wI+pnOpbVNxpZmZDwKDO80fEI8AjDerFzFrIl/eaZcrhN8uUw2+WKYffLFMOv1mmHH6zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmHH6zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmWjpEtx16Nv3th5P17vHvGcTpHaff1ZVcdt+69QNpyerkLb9Zphx+s0w5/GaZcvjNMuXwm2XK4TfLlMNvlimf57ek+NC5yfpFl69M1n+08pyatV/fcnRy2VOvWp+s2+AMKvyS1gM7gf3AvojoaERTZtZ8jdjyfyQitjXgfcyshfyd3yxTgw1/AI9JelbS/L5eIGm+pE5JnXvpHuTqzKxRBrvbf0FEbJZ0HPC4pF9HxPLqF0TEQmAhwJFK3OVhZi01qC1/RGwu/m4FHgJmNqIpM2u+AYdf0mhJYw9MA5cAqxrVmJk112B2+ycAD0k68D73RcSjDenKWmfY8GR53eVHJOsbfnResj7mnDdq1na/PTK5rDXXgMMfEeuA9BUgZta2fKrPLFMOv1mmHH6zTDn8Zply+M0y5Vt6Mzd8+rRkvWdk+qLMYdN2JesXnfhSzdqPn/hgcllrLm/5zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNM+Tx/7rZtT5aPmpr+JzJl3GvJ+oa3av8898jXlVzWmstbfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUz7Pn7k9fzA1Wf/QCS8m69370/+EOl+ZXLPW41/uLpW3/GaZcvjNMuXwm2XK4TfLlMNvlimH3yxTDr9ZpnyeP3Nbzzs8Wb967Npk/bHtZyfrPVH7nv1Jy3cnl7Xm6nfLL2mRpK2SVlXNGy/pcUlrir+1f7HBzNpSPbv93wYu7TXvRmBZREwHlhXPzWwI6Tf8EbEc6P1bT7OBxcX0YmBOg/sysyYb6AG/CRHRBVD8Pa7WCyXNl9QpqXMv3QNcnZk1WtOP9kfEwojoiIiOEYxq9urMrE4DDf8WSRMBir9bG9eSmbXCQMO/FJhXTM8DHm5MO2bWKv2e55d0PzALOEbSJuBm4FZgiaRrgQ3AFc1s0gbusKknJet7j0wvf8eajybrV538TLL+1JMfqFk7bNebyWUjWbXB6jf8ETG3RuniBvdiZi3ky3vNMuXwm2XK4TfLlMNvlimH3yxTvqX3ENcz5ohkvXta+rbaaWN2Jeu/fev4ZP39v6x9S69eWJNc1qf6mstbfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUz7Pf4jrPmFMsj7i8D3J+q496V9fWrElfcvw62fVrh29b19yWWsub/nNMuXwm2XK4TfLlMNvlimH3yxTDr9Zphx+s0z5PP8hbuRr6SHSPveBZ5P15a+emqxv2pIeoHl8YoTvYWNGJ5fd//obyboNjrf8Zply+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmfJ7/ELf3n9Lnyh94aFay/vnLn0jW33hgUrKunto1n8cvV79bfkmLJG2VtKpq3gJJL0t6rnh8orltmlmj1bPb/23g0j7m3x4RM4rHI41ty8yard/wR8RyYHsLejGzFhrMAb/rJT1ffC2oeYG3pPmSOiV17iV9nbmZtc5Aw38XcAowA+gCbqv1wohYGBEdEdExgvSPQZpZ6wwo/BGxJSL2R0QP8C1gZmPbMrNmG1D4JU2sevppYFWt15pZe+r3PL+k+4FZwDGSNgE3A7MkzaAyhPp64EtN7NH6Mfz02vfcb1h+XHLZ93/4lWT9nmcvSNbP/+JvkvXN/5L+PQArT7/hj4i5fcy+pwm9mFkL+fJes0w5/GaZcvjNMuXwm2XK4TfLlG/pPQTojZ01a7unjEsuu2f/8GR9zLi3k/UVz5yWrJ/6ylvJupXHW36zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFM+z38I2PWHU2rWTp3alVx27ZqJyfqEk9I/39izNb39GL6t9jUI+5NLWrN5y2+WKYffLFMOv1mmHH6zTDn8Zply+M0y5fCbZcrn+Q8Bu8fVvif/3KPSP829cXz6fv8TxqSH0d60+dhknTd2petWGm/5zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNM1TNE92TgXuB4oAdYGBF3SBoPfB+YSmWY7s9ExGvNazVfwyekh9ne+tE9NWs/25QeIvuLZzyVrN/95Kxk/eRNtdcNsP/VV5N1K089W/59wFci4gzgj4AvSzoTuBFYFhHTgWXFczMbIvoNf0R0RcTKYnonsBqYBMwGFhcvWwzMaVaTZtZ4B/WdX9JU4DzgaWBCRHRB5T8IIL1vamZtpe7wSxoD/BC4ISJ2HMRy8yV1SurcS/dAejSzJqgr/JJGUAn+9yLiwWL2FkkTi/pEYGtfy0bEwojoiIiOEYxqRM9m1gD9hl+SgHuA1RHxjarSUmBeMT0PeLjx7ZlZs9RzS+8FwFXAC5KeK+bdBNwKLJF0LbABuKI5LZpGjkzWTztpS83aqMP2JZd9cvvJyfr7NoxI1ketWJWs++e521e/4Y+InwOqUb64se2YWav4Cj+zTDn8Zply+M0y5fCbZcrhN8uUw2+WKf909xCw+7QJyfqfn/hgzdotz/9JctmI9LrHr+1J1vfvqPtKb2sz3vKbZcrhN8uUw2+WKYffLFMOv1mmHH6zTDn8Zpnyef4hYNTvtifrS7fNqFlbfcF3ksueeed1yfpRq9LrTl8FYO3MW36zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFOK/m7obqAjNT7Ol3/t26xZno5l7IjttX5q/1285TfLlMNvlimH3yxTDr9Zphx+s0w5/GaZcvjNMtVv+CVNlvRTSasl/UrSXxfzF0h6WdJzxeMTzW/XzBqlnh/z2Ad8JSJWShoLPCvp8aJ2e0R8vXntmVmz9Bv+iOgCuorpnZJWA5Oa3ZiZNddBfeeXNBU4D3i6mHW9pOclLZJ0dI1l5kvqlNS5l+5BNWtmjVN3+CWNAX4I3BARO4C7gFOAGVT2DG7ra7mIWBgRHRHRMYJRDWjZzBqhrvBLGkEl+N+LiAcBImJLROyPiB7gW8DM5rVpZo1Wz9F+AfcAqyPiG1XzJ1a97NPAqsa3Z2bNUs/R/guAq4AXJD1XzLsJmCtpBhDAeuBLTenQzJqinqP9Pwf6uj/4kca3Y2at4iv8zDLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/WaZaOkS3pFeB31XNOgbY1rIGDk679taufYF7G6hG9jYlIo6t54UtDf97Vi51RkRHaQ0ktGtv7doXuLeBKqs37/abZcrhN8tU2eFfWPL6U9q1t3btC9zbQJXSW6nf+c2sPGVv+c2sJA6/WaZKCb+kSyX9RtJaSTeW0UMtktZLeqEYdryz5F4WSdoqaVXVvPGSHpe0pvjb5xiJJfXWFsO2J4aVL/Wza7fh7lv+nV/ScOC3wMeATcAKYG5EvNjSRmqQtB7oiIjSLwiR9MfALuDeiDi7mPfPwPaIuLX4j/PoiPibNultAbCr7GHbi9GkJlYPKw/MAa6hxM8u0ddnKOFzK2PLPxNYGxHrImIP8AAwu4Q+2l5ELAe295o9G1hcTC+m8o+n5Wr01hYioisiVhbTO4EDw8qX+tkl+ipFGeGfBGyser6JEj+APgTwmKRnJc0vu5k+TIiILqj8YwKOK7mf3vodtr2Veg0r3zaf3UCGu2+0MsLf19Bf7XS+8YKI+CBwGfDlYvfW6lPXsO2t0sew8m1hoMPdN1oZ4d8ETK56fiKwuYQ++hQRm4u/W4GHaL+hx7ccGCG5+Lu15H7e0U7Dtvc1rDxt8Nm103D3ZYR/BTBd0jRJI4ErgaUl9PEekkYXB2KQNBq4hPYbenwpMK+Yngc8XGIv79Iuw7bXGlaekj+7dhvuvpQr/IpTGd8EhgOLIuJrLW+iD5JOprK1h8oIxveV2Zuk+4FZVG753ALcDPwnsAQ4CdgAXBERLT/wVqO3WVR2Xd8Ztv3Ad+wW93Yh8L/AC0BPMfsmKt+vS/vsEn3NpYTPzZf3mmXKV/iZZcrhN8uUw2+WKYffLFMOv1mmHH6zTDn8Zpn6f/UCLZWmyE1mAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fus_bc = np.reshape(f_bc,(28,28))\n", "plt.imshow(fus_bc)\n", "plt.title('Barycenter')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Modeling the same problem with Pyomo" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The same problem is formulated using Pyomo and solved using Mosek. Unlike Fusion Pyomo requires rules and summations to formulate the problem." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import pyomo.environ as pyo\n", "import time\n", "class Wasserstein_Pyomo:\n", " \n", " def __init__(self):\n", " self.time = 0.0\n", " self.result = None\n", " self._solver = 'mosek'\n", " self.M = pyo.ConcreteModel()\n", " \n", " \n", " def single_pmf(self, data = None, img=False):\n", " \n", " ''' Takes a image or array of images and extracts the probabilty mass function'''\n", " \n", " if not img:\n", " v=[]\n", " for image in data:\n", " arr = np.asarray(image).ravel(order='K')\n", " v.append(arr/np.sum(arr))\n", " else:\n", " v = np.asarray(data).ravel(order='K')\n", " v = v/np.sum(v)\n", " return v\n", " \n", " def ms_distance(self, m ,n, constant=False):\n", " \n", " ''' Squared Euclidean distance calculation between the pixels '''\n", " \n", " if constant:\n", " d = np.ones((m,m))\n", " else:\n", " d = np.empty((m,m))\n", " coor = []\n", " for i in range(n):\n", " for j in range(n):\n", " coor.append(np.array([i,j]))\n", " for i in range(m):\n", " for j in range(m):\n", " d[i][j] = np.linalg.norm(coor[i]-coor[j])**2\n", " return d\n", " \n", " def Wasserstein_BaryCenter(self,data):\n", " \n", " ''' Calculation of wasserstein barycenter of given images by solving the minimization problem '''\n", " \n", " M = self.M\n", " k = data.shape[0]\n", " v = np.array(self.single_pmf(data))\n", " n = v.shape[1]\n", " d = self.ms_distance(n,data.shape[1])\n", " \n", " #Define indices\n", " M.i = range(n)\n", " M.j = range(n)\n", " M.k = range(k)\n", " \n", " #Add variables\n", " M.pi = pyo.Var(M.k, M.i, M.j, domain = pyo.NonNegativeReals)\n", " M.mu = pyo.Var(M.i, domain = pyo.NonNegativeReals)\n", " M.t = pyo.Var(M.k, domain = pyo.NonNegativeReals)\n", " \n", " M.obj = pyo.Objective(expr = sum(M.t[k] for k in M.k)/k, sense= pyo.minimize)\n", " \n", " #Define constraint rules\n", " def c3_rule(model, k, j): #Rule for Constraint (3)\n", " return sum(model.pi[k,i,j] for i in model.i) == v[k][j]\n", " def c2_rule(model, k, i): #Rule for Constraint (2)\n", " return sum(model.pi[k,i,j] for j in model.j) == model.mu[i]\n", " def c1_rule(model, k): #Rule for Constraint (1)\n", " return sum(d[i][j]*model.pi[k,i,j] for i in model.i for j in model.j) <= model.t[k]\n", " \n", " # Add Constraints\n", " M.c3 = pyo.Constraint(M.k, M.j , rule = c3_rule)\n", " M.c2 = pyo.Constraint(M.k, M.i , rule = c2_rule)\n", " M.c1 = pyo.Constraint(M.k, rule = c1_rule)\n", " \n", " return M\n", " \n", " def run(self,data):\n", " start_time = time.time()\n", " model = self.Wasserstein_BaryCenter(data)\n", " opt = pyo.SolverFactory(self._solver)\n", " self.result = opt.solve(model, tee = True)\n", " self.time = time.time() - start_time\n", " bc = []\n", " [bc.append(model.mu[i]()) for i in range(data.shape[1]*data.shape[2])]\n", " return np.array(bc)\n", " \n", " def reset(self):\n", " self.M = pyo.ConcreteModel()\n", " " ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Problem\n", " Name : \n", " Objective sense : min \n", " Type : LO (linear optimization problem)\n", " Constraints : 31380 \n", " Cones : 0 \n", " Scalar variables : 12293924 \n", " Matrix variables : 0 \n", " Integer variables : 0 \n", "\n", "Optimizer started.\n", "Presolve started.\n", "Linear dependency checker started.\n", "Linear dependency checker terminated.\n", "Eliminator started.\n", "Freed constraints in eliminator : 0\n", "Eliminator terminated.\n", "Eliminator - tries : 1 time : 0.00 \n", "Lin. dep. - tries : 1 time : 0.39 \n", "Lin. dep. - number : 19 \n", "Presolve terminated. Time: 9.03 \n", "GP based matrix reordering started.\n", "GP based matrix reordering terminated.\n", "Problem\n", " Name : \n", " Objective sense : min \n", " Type : LO (linear optimization problem)\n", " Constraints : 31380 \n", " Cones : 0 \n", " Scalar variables : 12293924 \n", " Matrix variables : 0 \n", " Integer variables : 0 \n", "\n", "Optimizer - threads : 24 \n", "Optimizer - solved problem : the primal \n", "Optimizer - Constraints : 17440\n", "Optimizer - Cones : 0\n", "Optimizer - Scalar variables : 1395520 conic : 0 \n", "Optimizer - Semi-definite variables: 0 scalarized : 0 \n", "Factor - setup time : 16.13 dense det. time : 0.22 \n", "Factor - ML order time : 0.04 GP order time : 14.46 \n", "Factor - nonzeros before factor : 1.55e+06 after factor : 1.44e+07 \n", "Factor - dense dim. : 0 flops : 1.64e+10 \n", "Factor - GP saved nzs : 1.75e+06 GP saved flops : 3.15e+09 \n", "ITE PFEAS DFEAS GFEAS PRSTATUS POBJ DOBJ MU TIME \n", "0 5.5e+03 3.2e+02 8.3e+07 0.00e+00 1.182613040e+07 0.000000000e+00 4.9e+01 26.72 \n", "1 6.8e-01 4.0e-02 1.0e+04 -1.00e+00 1.120231111e+07 -1.638527136e+05 6.1e-03 27.80 \n", "2 5.8e-02 3.4e-03 8.8e+02 2.61e+01 1.558624799e+04 -2.985004062e+03 5.2e-04 29.37 \n", "3 4.5e-02 2.6e-03 6.8e+02 1.08e+01 3.256531589e+03 -6.820367758e+02 4.0e-04 30.39 \n", "4 4.2e-02 2.5e-03 6.4e+02 5.25e+00 2.407863322e+03 -4.995626462e+02 3.8e-04 31.33 \n", "5 3.8e-02 2.2e-03 5.7e+02 4.40e+00 1.571023931e+03 -3.171235815e+02 3.4e-04 32.18 \n", "6 3.3e-02 1.9e-03 5.0e+02 3.48e+00 1.104556065e+03 -2.125212675e+02 3.0e-04 33.06 \n", "7 2.9e-02 1.7e-03 4.3e+02 2.91e+00 7.928892926e+02 -1.415543160e+02 2.6e-04 34.05 \n", "8 1.2e-02 7.1e-04 1.8e+02 2.47e+00 2.252273003e+02 -1.976505585e+01 1.1e-04 35.13 \n", "9 5.2e-03 3.1e-04 7.9e+01 1.44e+00 9.137867985e+01 -2.802280352e+00 4.7e-05 36.09 \n", "10 1.2e-03 6.9e-05 1.8e+01 1.16e+00 2.223918431e+01 2.047776659e+00 1.1e-05 37.94 \n", "11 7.7e-04 4.5e-05 1.2e+01 1.04e+00 1.552926266e+01 2.419700435e+00 6.9e-06 38.99 \n", "12 3.9e-04 2.3e-05 5.9e+00 1.02e+00 9.413170641e+00 2.768991242e+00 3.5e-06 39.96 \n", "13 1.3e-04 8.1e-06 2.0e+00 1.01e+00 5.274879516e+00 3.000437008e+00 1.2e-06 41.68 \n", "14 5.5e-05 3.3e-06 8.2e-01 1.00e+00 3.991039163e+00 3.070294699e+00 4.9e-07 42.90 \n", "15 2.5e-05 1.5e-06 3.8e-01 1.00e+00 3.515553240e+00 3.095347848e+00 2.2e-07 44.03 \n", "16 1.2e-05 7.0e-07 1.8e-01 1.00e+00 3.304324846e+00 3.106260526e+00 1.1e-07 45.22 \n", "17 7.1e-06 4.3e-07 1.1e-01 1.00e+00 3.229734360e+00 3.109864486e+00 6.4e-08 46.19 \n", "18 3.1e-06 2.4e-07 4.8e-02 1.00e+00 3.165698228e+00 3.112133278e+00 2.8e-08 47.16 \n", "19 1.2e-06 9.4e-08 1.9e-02 1.00e+00 3.135045322e+00 3.113924772e+00 1.1e-08 48.29 \n", "20 2.5e-07 1.9e-08 3.9e-03 1.00e+00 3.119048189e+00 3.114732226e+00 2.3e-09 49.29 \n", "21 1.2e-09 9.1e-11 1.8e-05 1.00e+00 3.114878911e+00 3.114858640e+00 1.1e-11 50.64 \n", "22 3.6e-12 2.5e-13 4.7e-08 1.00e+00 3.114858824e+00 3.114858771e+00 2.8e-14 51.46 \n", "23 4.1e-13 7.1e-14 2.0e-12 1.00e+00 3.114858772e+00 3.114858772e+00 2.9e-18 52.27 \n", "Basis identification started.\n", "Basis identification terminated. Time: 0.42\n", "Optimizer terminated. Time: 58.87 \n", "\n", "\n", "Interior-point solution summary\n", " Problem status : PRIMAL_AND_DUAL_FEASIBLE\n", " Solution status : OPTIMAL\n", " Primal. obj: 3.1148587718e+00 nrm: 6e+00 Viol. con: 7e-13 var: 0e+00 \n", " Dual. obj: 3.1148587718e+00 nrm: 2e+02 Viol. con: 0e+00 var: 4e-14 \n", "\n", "Basic solution summary\n", " Problem status : PRIMAL_AND_DUAL_FEASIBLE\n", " Solution status : OPTIMAL\n", " Primal. obj: 3.1148587718e+00 nrm: 6e+00 Viol. con: 4e-17 var: 1e-06 \n", " Dual. obj: 3.1148587718e+00 nrm: 2e+02 Viol. con: 0e+00 var: 2e-09 \n", "\n", "Time spent to solve problem with Pyomo: \n", " 900.9914240837097\n", "Time spent in solver: \n", " 58.86694002151489\n", "The average Wasserstein distance between digits and the barycenter: \n", " 3.114858771795237\n" ] } ], "source": [ "pyomo_model = Wasserstein_Pyomo()\n", "p_bc = pyomo_model.run(train_1)\n", "print('\\nTime spent to solve problem with Pyomo: \\n {0}'.format(pyomo_model.time))\n", "print('Time spent in solver: \\n {0}'.format(pyomo_model.result.solver.wallclock_time))\n", "print('The average Wasserstein distance between digits and the barycenter: \\n {0}'.format(pyomo_model.M.obj()))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEKZJREFUeJzt3X2MHPV9x/H3x4efYhuweTCHMTYYt8GAgORqqoIaIxIekiI7anlwEzASlaMS2iLxRxGVglOVCkUhBEU0kVMsTBMglEBwWkQgLgpNCq7PxAGDaQHX2MYXH8YQP4DP9vnbP3aMFvv2d8fd7s6ef5+XtLrZ+e7MfLXyxzO7M7M/RQRmlp8RZTdgZuVw+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmHP4WJ2m9pA8k7ZT0rqR/lzS17L4GStIiST8ouw87lMM/PFweEeOBdmAL8J2PuwJJR9S9qyYYrn0PBw7/MBIRu4FHgFkAkr4g6deStkvaKGnRgddKmi4pJF0vaQPwH8VRw19Vr1PSi5LmFdNnSHpa0jZJWyTdWswfIekWSW9IekfSw5ImHbSdBZI2SNoq6e+K2qXArcBVxZHLb4r5R0m6V1KXpLck/YOktqJ2naRfSbpL0jZgEdYQDv8wIukTwFXA88WsXcC1wNHAF4C/PBDkKp8BTgcuAZYCX65a39nAFOAJSROAnwNPAicCpwHLi5f+NTCvWNeJwLvAPQdt5wLg94GLgK9JOj0ingT+EfhRRIyPiLOL1y4F9hXbOBe4GPiLqnWdB6wDjgduH+DbYx9XRPjRwg9gPbATeI9KYDYDZ9V47beBu4rp6UAAp1bVRwPbgJnF828C/1RMzwd+XWO9a4GLqp63A3uBI6q2c1JV/b+Bq4vpRcAPqmqTgR5gbNW8+cAzxfR1wIay3/ccHv48NTzMi4ifF4fGc4FfSJoFTAPuAM4ERlEJ978etOzGAxMR0SPpYeDLkr5OJXR/VpSnAm/U2P404DFJ+6vm9VIJ8gG/rZp+HxifWNdIoEvSgXkjqvs8aNoaxIf9w0hE9EbEo1SCdwHwALAMmBoRRwHfA3TwYgc9Xwp8icrh+fsR8VwxfyMwo8amNwKXRcTRVY8xEfHWQNruY109wLFV6zoyIs5ILGMN4PAPI6qYC0ykcig+AdgWEbslzQb+vL91FGHfD9wJ/EtV6d+AEyTdJGm0pAmSzitq3wNulzSt6OO4oo+B2AJMlzSi2H4X8BRwp6Qjiy8TZ0j6zADXZ3Xi8A8PP5W0E9hO5QuwBRHxMnAD8PeSdgBfAx4e4PruB84CPjz/HhE7gM8Bl1M5hH8NuLAo303lCOOpYlvPU/lSbiAOfAx5R9ILxfS1VD6mvELly8NHqHyPYE2k4ksWy4ika4GFEXFB2b1Yebznz0xxuvAGYHHZvVi5HP6MSLoEeJvK5/AHSm7HSubDfrNMec9vlqmmXuQzSqNjDOOauUmzrOxmF3ui5+BrPfo0pPAXN27cDbQB/xwRd6ReP4ZxnKeLhrJJM0tYEcv7f1Fh0If9xaWm9wCXUbnLbH5xyamZDQND+cw/G3g9ItZFxB7gISrXnZvZMDCU8E/hozdgbCrmfYSkhZI6JXXupWcImzOzehpK+Pv6UuGQ84YRsTgiOiKiYySjh7A5M6unoYR/E5XbQA84icq95mY2DAwl/CuBmZJOkTQKuJrKzR9mNgwM+lRfROyTdCPwMyqn+pYUd5qZ2TAwpPP8EfEE8ESdejGzJvLlvWaZcvjNMuXwm2XK4TfLlMNvlimH3yxTDr9Zphx+s0w5/GaZcvjNMuXwm2XK4TfLlMNvlimH3yxTDr9Zphx+s0w5/GaZcvjNMuXwm2XK4TfLlMNvlqmmDtFth5+um/8oWf9g8iGDOH3o9+7ZlFx235sbk3UbGu/5zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNM+Ty/pc0+K1meM39lsv7TlefWrK39+vHJZWde5/P8jTSk8EtaD+wAeoF9EdFRj6bMrPHqsee/MCK21mE9ZtZE/sxvlqmhhj+ApyStkrSwrxdIWiipU1LnXnqGuDkzq5ehHvafHxGbJR0PPC3p1Yh4tvoFEbEYWAxwpCbVvsvDzJpqSHv+iNhc/O0GHgNm16MpM2u8QYdf0jhJEw5MAxcDa+rVmJk11lAO+ycDj0k6sJ4HIuLJunRlzTOiLVle96fjk/WNP/mDZH3sOdtr1no+GJlc1hpr0OGPiHXA2XXsxcyayKf6zDLl8JtlyuE3y5TDb5Yph98sU76lN3Ntp01P1qOf3UPvWTuT9U+3v1WztvJXn0yv3BrKe36zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFM+z5+77vRvr46ZMSZZ/+RxW5L1d3aPq1kb/Y6Sy1pjec9vlimH3yxTDr9Zphx+s0w5/GaZcvjNMuXwm2XK5/kzt+fcGcn6p9pfT9b37U//9Pem946qWds/OrmoNZj3/GaZcvjNMuXwm2XK4TfLlMNvlimH3yxTDr9ZpnyeP3Nbz0rfr3/l0enz/E9tnZWs799fe/9y0i8+SC5rjdXvnl/SEkndktZUzZsk6WlJrxV/Jza2TTOrt4Ec9t8HXHrQvFuA5RExE1hePDezYaTf8EfEs8C2g2bPBZYW00uBeXXuy8wabLBf+E2OiC6A4u/xtV4oaaGkTkmde+kZ5ObMrN4a/m1/RCyOiI6I6BiJ7+QwaxWDDf8WSe0Axd/u+rVkZs0w2PAvAxYU0wuAx+vTjpk1S7/n+SU9CMwBjpW0CbgNuAN4WNL1wAbgikY2aYN3xLSpyfruY9LLf+fVOcn6/BmrkvXVK06rWWvbtTO5bCSrNlT9hj8i5tcoXVTnXsysiXx5r1mmHH6zTDn8Zply+M0y5fCbZcq39B7m9k+oPUQ2QM/UPcl6+/hdyfpr79e8shuAY1cniqtfTS5rjeU9v1mmHH6zTDn8Zply+M0y5fCbZcrhN8uUw2+WKZ/nP8z1nJA+z982Zl+yvqc3PQT3b7pPTNZ3nqGataPb0utmX7o3Gxrv+c0y5fCbZcrhN8uUw2+WKYffLFMOv1mmHH6zTPk8/2FuTFf657GvnvVKsv7c1lOS9c2bJiXrxyRG+B4xNj08eG+Ph3drJO/5zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNM+Tz/YW77nXuT9UeWXZCsX/4nzyfrv3toSrI+ord2rfe93yWXtcbqd88vaYmkbklrquYtkvSWpNXF4/ONbdPM6m0gh/33AZf2Mf+uiDineDxR37bMrNH6DX9EPAtsa0IvZtZEQ/nC70ZJLxYfCybWepGkhZI6JXXuxddqm7WKwYb/u8AM4BygC7iz1gsjYnFEdEREx0hGD3JzZlZvgwp/RGyJiN6I2A98H5hd37bMrNEGFX5J7VVPvwisqfVaM2tN/Z7nl/QgMAc4VtIm4DZgjqRzgADWA19pYI/Wj7aZp9asbV1xQnLZT5z7TrL+yKqOZP3sa9Yl69vumpasW3n6DX9EzO9j9r0N6MXMmsiX95plyuE3y5TDb5Yph98sUw6/WaZ8S+9hQDt21az1tKdv6R0btYfQBhg78YNk/eUVtU8zApz629q9Wbm85zfLlMNvlimH3yxTDr9Zphx+s0w5/GaZcvjNMuXz/IeBXZ8+uWZtysnpW3Y3v3Fcsj7x5HeT9bbu9HUCR3Rvr1nbl1zSGs17frNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUz7PfxjYfXRbzVrHMW8ll+1+d0Ky3j5hR3r5rcck65H4rQErl/f8Zply+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmBjJE91TgfuAEYD+wOCLuljQJ+BEwncow3VdGRPrmbxuUtuPS99xvubD2nfHPbDgtueyXZq1M1u977oJkfcabPcl679atybqVZyB7/n3AzRFxOvCHwFclzQJuAZZHxExgefHczIaJfsMfEV0R8UIxvQNYC0wB5gJLi5ctBeY1qkkzq7+P9Zlf0nTgXGAFMDkiuqDyHwRwfL2bM7PGGXD4JY0HfgzcFBG1f5jt0OUWSuqU1LmX9OdDM2ueAYVf0kgqwf9hRDxazN4iqb2otwPdfS0bEYsjoiMiOkYyuh49m1kd9Bt+SQLuBdZGxLeqSsuABcX0AuDx+rdnZo0ykFt6zweuAV6StLqYdytwB/CwpOuBDcAVjWnRNHpUsj592ts1a6NG9CaX/a+t6SG2x/1f+p/IqFWvJuu9Ecm6laff8EfEL4FaP85+UX3bMbNm8RV+Zply+M0y5fCbZcrhN8uUw2+WKYffLFP+6e5hoGfm5GT9qik/q1n7RuclyWXHjktfcn3khv3Jeu/2AV/pbS3Ge36zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFM+zz8MjNr0XrL+5Ntn1qyt++yS5LJn3n1Dsn70S9uS9fRVANbKvOc3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTKlaOLvqh+pSXGe/GvfZo2yIpazPbbV+qn9j/Ce3yxTDr9Zphx+s0w5/GaZcvjNMuXwm2XK4TfLVL/hlzRV0jOS1kp6WdLfFPMXSXpL0uri8fnGt2tm9TKQH/PYB9wcES9ImgCskvR0UbsrIr7ZuPbMrFH6DX9EdAFdxfQOSWuBKY1uzMwa62N95pc0HTgXWFHMulHSi5KWSJpYY5mFkjolde4lPTSUmTXPgMMvaTzwY+CmiNgOfBeYAZxD5cjgzr6Wi4jFEdERER0jGV2Hls2sHgYUfkkjqQT/hxHxKEBEbImI3ojYD3wfmN24Ns2s3gbybb+Ae4G1EfGtqvntVS/7IrCm/u2ZWaMM5Nv+84FrgJckrS7m3QrMl3QOEMB64CsN6dDMGmIg3/b/Eujr/uAn6t+OmTWLr/Azy5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmWrqEN2S3gberJp1LLC1aQ18PK3aW6v2Be5tsOrZ27SIOG4gL2xq+A/ZuNQZER2lNZDQqr21al/g3garrN582G+WKYffLFNlh39xydtPadXeWrUvcG+DVUpvpX7mN7PylL3nN7OSOPxmmSol/JIulfQ/kl6XdEsZPdQiab2kl4phxztL7mWJpG5Ja6rmTZL0tKTXir99jpFYUm8tMWx7Ylj5Ut+7Vhvuvumf+SW1Af8LfA7YBKwE5kfEK01tpAZJ64GOiCj9ghBJfwzsBO6PiDOLed8AtkXEHcV/nBMj4m9bpLdFwM6yh20vRpNqrx5WHpgHXEeJ712irysp4X0rY88/G3g9ItZFxB7gIWBuCX20vIh4Fth20Oy5wNJieimVfzxNV6O3lhARXRHxQjG9AzgwrHyp712ir1KUEf4pwMaq55so8Q3oQwBPSVolaWHZzfRhckR0QeUfE3B8yf0crN9h25vpoGHlW+a9G8xw9/VWRvj7Gvqrlc43nh8RnwIuA75aHN7awAxo2PZm6WNY+ZYw2OHu662M8G8CplY9PwnYXEIffYqIzcXfbuAxWm/o8S0HRkgu/naX3M+HWmnY9r6GlacF3rtWGu6+jPCvBGZKOkXSKOBqYFkJfRxC0rjiixgkjQMupvWGHl8GLCimFwCPl9jLR7TKsO21hpWn5Peu1Ya7L+UKv+JUxreBNmBJRNze9Cb6IOlUKnt7qIxg/ECZvUl6EJhD5ZbPLcBtwE+Ah4GTgQ3AFRHR9C/eavQ2h8qh64fDth/4jN3k3i4A/hN4CdhfzL6Vyufr0t67RF/zKeF98+W9ZpnyFX5mmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/Wab+H4aeLdESTBCYAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pyo_bc = np.reshape(p_bc, (28,28))\n", "#print('Visualization of the barycenter:')\n", "plt.imshow(pyo_bc)\n", "plt.title('Barycenter')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Modeling the same problem with CVXPY" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import cvxpy as cp\n", "import time\n", "class Wasserstein_CVXPY:\n", " \n", " def __init__(self):\n", " self.time = 0.0\n", " self.result = None\n", " self.prob = None\n", "\n", " def single_pmf(self, data = None, img=False):\n", " \n", " ''' Takes a image or array of images and extracts the probabilty mass function'''\n", " \n", " if not img:\n", " v=[]\n", " for image in data:\n", " arr = np.asarray(image).ravel(order='K')\n", " v.append(arr/np.sum(arr))\n", " else:\n", " v = np.asarray(data).ravel(order='K')\n", " v = v/np.sum(v)\n", " return v\n", " \n", " def ms_distance(self, m ,n, constant=False):\n", " \n", " ''' Squared Euclidean distance calculation between the pixels '''\n", " \n", " if constant:\n", " d = np.ones((m,m))\n", " else:\n", " d = np.empty((m,m))\n", " coor = []\n", " for i in range(n):\n", " for j in range(n):\n", " coor.append(np.array([i,j]))\n", " for i in range(m):\n", " for j in range(m):\n", " d[i][j] = np.linalg.norm(coor[i]-coor[j])**2\n", " return d\n", " \n", " def Wasserstein_Distance(self, bc ,data, img = False):\n", " \n", " ''' Calculation of wasserstein distance between a barycenter and an image by solving \n", " the minimization problem '''\n", " \n", " v = np.array(self.single_pmf(data, img))\n", " n = v.shape[0]\n", " d = self.ms_distance(n,data.shape[1])\n", " \n", " pi = cp.Variable((n,n), nonneg=True)\n", " obj = cp.Minimize((np.ones(n).T @ cp.multiply(d,pi) @ np.ones(n)))\n", " \n", " Cons=[]\n", " Cons.append((np.ones(n) @ pi).T == bc)\n", " Cons.append((pi @ np.ones(n)) == v)\n", " \n", " prob = cp.Problem(obj, constraints= Cons)\n", " \n", " return prob.solve(solver=cp.MOSEK, verbose = True)\n", " \n", " def Wasserstein_BaryCenter(self,data):\n", " \n", " ''' Calculation of wasserstein barycenter of given images by solving the minimization problem '''\n", " \n", " start_time = time.time()\n", " k = data.shape[0]\n", " v = np.array(self.single_pmf(data))\n", " n = v.shape[1]\n", " d = self.ms_distance(n,data.shape[1])\n", " \n", " #Add variables\n", " pi= []\n", " t= []\n", " mu = cp.Variable(n, nonneg = True)\n", " for i in range(k):\n", " pi.append(cp.Variable((n,n), nonneg = True))\n", " t.append(cp.Variable(nonneg = True))\n", " \n", " obj = cp.Minimize(np.sum(t)/k)\n", " \n", " #Add constraints\n", " Cons=[]\n", " for i in range(k):\n", " Cons.append( t[i] >= np.ones(n).T @ cp.multiply(d,pi[i]) @ np.ones(n) ) #Constraint (1)\n", " Cons.append( (np.ones(n) @ pi[i]).T == mu) #Constraint (2)\n", " Cons.append( (pi[i] @ np.ones(n)) == v[i]) #Constraint (3)\n", " \n", " self.prob = cp.Problem(obj, constraints= Cons)\n", " self.result = self.prob.solve(solver=cp.MOSEK,verbose = True)\n", " self.time = time.time() - start_time\n", " \n", " return mu.value\n", " \n", " def reset(self):\n", " self.prob = None\n", " self.result = None\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Problem\n", " Name : \n", " Objective sense : min \n", " Type : LO (linear optimization problem)\n", " Constraints : 12325304 \n", " Cones : 0 \n", " Scalar variables : 12293924 \n", " Matrix variables : 0 \n", " Integer variables : 0 \n", "\n", "Optimizer started.\n", "Presolve started.\n", "Linear dependency checker started.\n", "Linear dependency checker terminated.\n", "Eliminator started.\n", "Freed constraints in eliminator : 0\n", "Eliminator terminated.\n", "Eliminator - tries : 1 time : 0.00 \n", "Lin. dep. - tries : 1 time : 2.55 \n", "Lin. dep. - number : 19 \n", "Presolve terminated. Time: 15.54 \n", "GP based matrix reordering started.\n", "GP based matrix reordering terminated.\n", "Problem\n", " Name : \n", " Objective sense : min \n", " Type : LO (linear optimization problem)\n", " Constraints : 12325304 \n", " Cones : 0 \n", " Scalar variables : 12293924 \n", " Matrix variables : 0 \n", " Integer variables : 0 \n", "\n", "Optimizer - threads : 24 \n", "Optimizer - solved problem : the primal \n", "Optimizer - Constraints : 17440\n", "Optimizer - Cones : 0\n", "Optimizer - Scalar variables : 1395520 conic : 0 \n", "Optimizer - Semi-definite variables: 0 scalarized : 0 \n", "Factor - setup time : 16.43 dense det. time : 0.21 \n", "Factor - ML order time : 0.04 GP order time : 14.57 \n", "Factor - nonzeros before factor : 1.55e+06 after factor : 1.44e+07 \n", "Factor - dense dim. : 0 flops : 1.64e+10 \n", "Factor - GP saved nzs : 1.75e+06 GP saved flops : 3.15e+09 \n", "ITE PFEAS DFEAS GFEAS PRSTATUS POBJ DOBJ MU TIME \n", "0 5.5e+03 3.2e+02 8.3e+07 0.00e+00 1.182613040e+07 0.000000000e+00 4.9e+01 34.19 \n", "1 6.8e-01 4.0e-02 1.0e+04 -1.00e+00 1.120231112e+07 -1.638527133e+05 6.1e-03 35.19 \n", "2 5.8e-02 3.4e-03 8.8e+02 2.61e+01 1.558624808e+04 -2.985004070e+03 5.2e-04 36.64 \n", "3 4.5e-02 2.6e-03 6.8e+02 1.08e+01 3.256531589e+03 -6.820367740e+02 4.0e-04 37.60 \n", "4 4.2e-02 2.5e-03 6.4e+02 5.25e+00 2.407863324e+03 -4.995626454e+02 3.8e-04 38.57 \n", "5 3.8e-02 2.2e-03 5.7e+02 4.40e+00 1.571023937e+03 -3.171235819e+02 3.4e-04 39.50 \n", "6 3.3e-02 1.9e-03 5.0e+02 3.48e+00 1.104556071e+03 -2.125212684e+02 3.0e-04 40.49 \n", "7 2.9e-02 1.7e-03 4.3e+02 2.91e+00 7.928892985e+02 -1.415543170e+02 2.6e-04 41.34 \n", "8 1.2e-02 7.1e-04 1.8e+02 2.47e+00 2.252273014e+02 -1.976505604e+01 1.1e-04 42.35 \n", "9 5.2e-03 3.1e-04 7.9e+01 1.44e+00 9.137868017e+01 -2.802280320e+00 4.7e-05 43.28 \n", "10 1.2e-03 6.9e-05 1.8e+01 1.16e+00 2.223918434e+01 2.047776637e+00 1.1e-05 44.86 \n", "11 7.7e-04 4.5e-05 1.2e+01 1.04e+00 1.552926278e+01 2.419700413e+00 6.9e-06 45.89 \n", "12 3.9e-04 2.3e-05 5.9e+00 1.02e+00 9.413192190e+00 2.768989925e+00 3.5e-06 46.88 \n", "13 1.3e-04 8.1e-06 2.0e+00 1.01e+00 5.274875679e+00 3.000437376e+00 1.2e-06 48.55 \n", "14 5.5e-05 3.3e-06 8.2e-01 1.00e+00 3.991040086e+00 3.070294698e+00 4.9e-07 49.88 \n", "15 2.5e-05 1.5e-06 3.8e-01 1.00e+00 3.515554299e+00 3.095347811e+00 2.2e-07 50.96 \n", "16 1.2e-05 7.0e-07 1.8e-01 1.00e+00 3.304325086e+00 3.106260520e+00 1.1e-07 52.26 \n", "17 7.1e-06 4.3e-07 1.1e-01 1.00e+00 3.229734970e+00 3.109864458e+00 6.4e-08 53.26 \n", "18 3.1e-06 2.4e-07 4.8e-02 1.00e+00 3.165698749e+00 3.112133276e+00 2.8e-08 54.22 \n", "19 1.2e-06 9.4e-08 1.9e-02 1.00e+00 3.135045328e+00 3.113924783e+00 1.1e-08 55.44 \n", "20 2.5e-07 1.9e-08 3.9e-03 1.00e+00 3.119048079e+00 3.114732233e+00 2.3e-09 56.55 \n", "21 1.2e-09 9.1e-11 1.8e-05 1.00e+00 3.114878914e+00 3.114858640e+00 1.1e-11 57.87 \n", "22 3.5e-12 2.5e-13 4.7e-08 1.00e+00 3.114858824e+00 3.114858771e+00 2.8e-14 58.71 \n", "23 4.6e-13 5.7e-14 2.8e-14 1.00e+00 3.114858772e+00 3.114858772e+00 2.9e-18 59.58 \n", "Basis identification started.\n", "Basis identification terminated. Time: 0.40\n", "Optimizer terminated. Time: 68.77 \n", "\n", "\n", "Interior-point solution summary\n", " Problem status : PRIMAL_AND_DUAL_FEASIBLE\n", " Solution status : OPTIMAL\n", " Primal. obj: 3.1148587718e+00 nrm: 6e+00 Viol. con: 3e-12 var: 0e+00 \n", " Dual. obj: 3.1148587718e+00 nrm: 2e+02 Viol. con: 9e-16 var: 6e-14 \n", "\n", "Basic solution summary\n", " Problem status : PRIMAL_AND_DUAL_FEASIBLE\n", " Solution status : OPTIMAL\n", " Primal. obj: 3.1148587718e+00 nrm: 6e+00 Viol. con: 1e-06 var: 0e+00 \n", " Dual. obj: 3.1148587718e+00 nrm: 2e+02 Viol. con: 2e-09 var: 2e-09 \n", "\n", "Time Spent to solve problem with CVXPY: \n", " 122.6916720867157\n", "Time Spent in solver: \n", " 68.76518487930298\n", "The average Wasserstein distance between digits and the barycenter: \n", " 3.1148587717983913\n" ] } ], "source": [ "cvxpy_model = Wasserstein_CVXPY()\n", "result = cvxpy_model.Wasserstein_BaryCenter(train_1)\n", "print('\\nTime Spent to solve problem with CVXPY: \\n {0}'.format(cvxpy_model.time))\n", "print('Time Spent in solver: \\n {0}'.format(cvxpy_model.prob.solver_stats.solve_time))\n", "print('The average Wasserstein distance between digits and the barycenter: \\n {0}'.format(cvxpy_model.result))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEKNJREFUeJzt3X2wVPV9x/H3BwSMgCJREREBFY1PFTO32ESnkjgxmjQDcWoiSRQztmRqbOtM/qi1MxE7Teu0McbpWB1SmWASNTTRSltjNJgJTccHrsQokSQgQwC5ggQVULk83G//2IOzXu/+7nLv7p69/D6vmZ179nz37PnODh/O2fOwP0UEZpafYWU3YGblcPjNMuXwm2XK4TfLlMNvlimH3yxTDr9Zphz+NidpvaS3Je2S9Jqk/5E0uey+6iVpgaTvlt2HvZfDPzR8KiLGABOBLcC/HuwbSDqs4V21wFDteyhw+IeQiNgN/AA4E0DSJyX9QtIOSRslLTjwWklTJYWkayVtAJ4o9hr+svo9JT0vaU4xfZakxyVtl7RF0k3F/GGSbpT0kqTfS1oiaXyv9cyTtEHSNkl/V9QuBW4CPlvsufyymH+UpHskdUl6WdI/SBpe1K6R9H+Sbpe0HViANYXDP4RIOgL4LPBUMetN4GpgHPBJ4C8OBLnKRcAZwMeBxcAXqt7vXGAS8IikscBPgEeBE4BTgWXFS/8KmFO81wnAa8CdvdZzIXA6cDHwVUlnRMSjwD8C34+IMRFxbvHaxcC+Yh3nAZcAf1b1XucD64DjgK/V+fHYwYoIP9r4AawHdgGvUwnMZuCcGq/9JnB7MT0VCODkqvooYDswvXj+deDfium5wC9qvO9q4OKq5xOBvcBhVes5sar+DHBlMb0A+G5VbQLQDbyvat5c4KfF9DXAhrI/9xwe/j41NMyJiJ8Uu8azgZ9JOhOYAtwKnA2MpBLu/+i17MYDExHRLWkJ8AVJt1AJ3Z8W5cnASzXWPwV4SFJP1bz9VIJ8wCtV028BYxLvNQLoknRg3rDqPntNW5N4t38IiYj9EfEgleBdCNwHLAUmR8RRwN2Aei/W6/li4PNUds/fiogni/kbgVNqrHojcFlEjKt6HB4RL9fTdh/v1Q0cU/VeR0bEWYllrAkc/iFEFbOBo6nsio8FtkfEbkkzgc/19x5F2HuA24DvVJX+Gzhe0g2SRkkaK+n8onY38DVJU4o+ji36qMcWYKqkYcX6u4DHgNskHVkcTDxF0kV1vp81iMM/NPyXpF3ADioHwOZFxK+A64C/l7QT+CqwpM73uxc4B3jn/HtE7AQ+BnyKyi78GuAjRfkOKnsYjxXreorKQbl6HPga8ntJK4vpq6l8TXmRysHDH1A5jmAtpOIgi2VE0tXA/Ii4sOxerDze8memOF14HbCw7F6sXA5/RiR9HHiVyvfw+0pux0rm3X6zTHnLb5apll7kM1Kj4nBGt3KVZlnZzZvsie7e13r0aVDhL27cuAMYDvx7RNyaev3hjOZ8XTyYVZpZwtOxrP8XFQa8219canoncBmVu8zmFpecmtkQMJjv/DOBtRGxLiL2AA9Que7czIaAwYR/Eu++AWNTMe9dJM2X1Cmpcy/dg1idmTXSYMLf10GF95w3jIiFEdERER0jGDWI1ZlZIw0m/Juo3AZ6wIlU7jU3syFgMOFfAUyXNE3SSOBKKjd/mNkQMOBTfRGxT9L1wI+pnOpbVNxpZmZDwKDO80fEI8AjDerFzFrIl/eaZcrhN8uUw2+WKYffLFMOv1mmHH6zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmHH6zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmWjpEtx16Nv3th5P17vHvGcTpHaff1ZVcdt+69QNpyerkLb9Zphx+s0w5/GaZcvjNMuXwm2XK4TfLlMNvlimf57ek+NC5yfpFl69M1n+08pyatV/fcnRy2VOvWp+s2+AMKvyS1gM7gf3AvojoaERTZtZ8jdjyfyQitjXgfcyshfyd3yxTgw1/AI9JelbS/L5eIGm+pE5JnXvpHuTqzKxRBrvbf0FEbJZ0HPC4pF9HxPLqF0TEQmAhwJFK3OVhZi01qC1/RGwu/m4FHgJmNqIpM2u+AYdf0mhJYw9MA5cAqxrVmJk112B2+ycAD0k68D73RcSjDenKWmfY8GR53eVHJOsbfnResj7mnDdq1na/PTK5rDXXgMMfEeuA9BUgZta2fKrPLFMOv1mmHH6zTDn8Zply+M0y5Vt6Mzd8+rRkvWdk+qLMYdN2JesXnfhSzdqPn/hgcllrLm/5zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNM+Tx/7rZtT5aPmpr+JzJl3GvJ+oa3av8898jXlVzWmstbfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUz7Pn7k9fzA1Wf/QCS8m69370/+EOl+ZXLPW41/uLpW3/GaZcvjNMuXwm2XK4TfLlMNvlimH3yxTDr9ZpnyeP3Nbzzs8Wb967Npk/bHtZyfrPVH7nv1Jy3cnl7Xm6nfLL2mRpK2SVlXNGy/pcUlrir+1f7HBzNpSPbv93wYu7TXvRmBZREwHlhXPzWwI6Tf8EbEc6P1bT7OBxcX0YmBOg/sysyYb6AG/CRHRBVD8Pa7WCyXNl9QpqXMv3QNcnZk1WtOP9kfEwojoiIiOEYxq9urMrE4DDf8WSRMBir9bG9eSmbXCQMO/FJhXTM8DHm5MO2bWKv2e55d0PzALOEbSJuBm4FZgiaRrgQ3AFc1s0gbusKknJet7j0wvf8eajybrV538TLL+1JMfqFk7bNebyWUjWbXB6jf8ETG3RuniBvdiZi3ky3vNMuXwm2XK4TfLlMNvlimH3yxTvqX3ENcz5ohkvXta+rbaaWN2Jeu/fev4ZP39v6x9S69eWJNc1qf6mstbfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUz7Pf4jrPmFMsj7i8D3J+q496V9fWrElfcvw62fVrh29b19yWWsub/nNMuXwm2XK4TfLlMNvlimH3yxTDr9Zphx+s0z5PP8hbuRr6SHSPveBZ5P15a+emqxv2pIeoHl8YoTvYWNGJ5fd//obyboNjrf8Zply+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmfJ7/ELf3n9Lnyh94aFay/vnLn0jW33hgUrKunto1n8cvV79bfkmLJG2VtKpq3gJJL0t6rnh8orltmlmj1bPb/23g0j7m3x4RM4rHI41ty8yard/wR8RyYHsLejGzFhrMAb/rJT1ffC2oeYG3pPmSOiV17iV9nbmZtc5Aw38XcAowA+gCbqv1wohYGBEdEdExgvSPQZpZ6wwo/BGxJSL2R0QP8C1gZmPbMrNmG1D4JU2sevppYFWt15pZe+r3PL+k+4FZwDGSNgE3A7MkzaAyhPp64EtN7NH6Mfz02vfcb1h+XHLZ93/4lWT9nmcvSNbP/+JvkvXN/5L+PQArT7/hj4i5fcy+pwm9mFkL+fJes0w5/GaZcvjNMuXwm2XK4TfLlG/pPQTojZ01a7unjEsuu2f/8GR9zLi3k/UVz5yWrJ/6ylvJupXHW36zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFM+z38I2PWHU2rWTp3alVx27ZqJyfqEk9I/39izNb39GL6t9jUI+5NLWrN5y2+WKYffLFMOv1mmHH6zTDn8Zply+M0y5fCbZcrn+Q8Bu8fVvif/3KPSP829cXz6fv8TxqSH0d60+dhknTd2petWGm/5zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNM1TNE92TgXuB4oAdYGBF3SBoPfB+YSmWY7s9ExGvNazVfwyekh9ne+tE9NWs/25QeIvuLZzyVrN/95Kxk/eRNtdcNsP/VV5N1K089W/59wFci4gzgj4AvSzoTuBFYFhHTgWXFczMbIvoNf0R0RcTKYnonsBqYBMwGFhcvWwzMaVaTZtZ4B/WdX9JU4DzgaWBCRHRB5T8IIL1vamZtpe7wSxoD/BC4ISJ2HMRy8yV1SurcS/dAejSzJqgr/JJGUAn+9yLiwWL2FkkTi/pEYGtfy0bEwojoiIiOEYxqRM9m1gD9hl+SgHuA1RHxjarSUmBeMT0PeLjx7ZlZs9RzS+8FwFXAC5KeK+bdBNwKLJF0LbABuKI5LZpGjkzWTztpS83aqMP2JZd9cvvJyfr7NoxI1ketWJWs++e521e/4Y+InwOqUb64se2YWav4Cj+zTDn8Zply+M0y5fCbZcrhN8uUw2+WKf909xCw+7QJyfqfn/hgzdotz/9JctmI9LrHr+1J1vfvqPtKb2sz3vKbZcrhN8uUw2+WKYffLFMOv1mmHH6zTDn8Zpnyef4hYNTvtifrS7fNqFlbfcF3ksueeed1yfpRq9LrTl8FYO3MW36zTDn8Zply+M0y5fCbZcrhN8uUw2+WKYffLFOK/m7obqAjNT7Ol3/t26xZno5l7IjttX5q/1285TfLlMNvlimH3yxTDr9Zphx+s0w5/GaZcvjNMtVv+CVNlvRTSasl/UrSXxfzF0h6WdJzxeMTzW/XzBqlnh/z2Ad8JSJWShoLPCvp8aJ2e0R8vXntmVmz9Bv+iOgCuorpnZJWA5Oa3ZiZNddBfeeXNBU4D3i6mHW9pOclLZJ0dI1l5kvqlNS5l+5BNWtmjVN3+CWNAX4I3BARO4C7gFOAGVT2DG7ra7mIWBgRHRHRMYJRDWjZzBqhrvBLGkEl+N+LiAcBImJLROyPiB7gW8DM5rVpZo1Wz9F+AfcAqyPiG1XzJ1a97NPAqsa3Z2bNUs/R/guAq4AXJD1XzLsJmCtpBhDAeuBLTenQzJqinqP9Pwf6uj/4kca3Y2at4iv8zDLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/WaZaOkS3pFeB31XNOgbY1rIGDk679taufYF7G6hG9jYlIo6t54UtDf97Vi51RkRHaQ0ktGtv7doXuLeBKqs37/abZcrhN8tU2eFfWPL6U9q1t3btC9zbQJXSW6nf+c2sPGVv+c2sJA6/WaZKCb+kSyX9RtJaSTeW0UMtktZLeqEYdryz5F4WSdoqaVXVvPGSHpe0pvjb5xiJJfXWFsO2J4aVL/Wza7fh7lv+nV/ScOC3wMeATcAKYG5EvNjSRmqQtB7oiIjSLwiR9MfALuDeiDi7mPfPwPaIuLX4j/PoiPibNultAbCr7GHbi9GkJlYPKw/MAa6hxM8u0ddnKOFzK2PLPxNYGxHrImIP8AAwu4Q+2l5ELAe295o9G1hcTC+m8o+n5Wr01hYioisiVhbTO4EDw8qX+tkl+ipFGeGfBGyser6JEj+APgTwmKRnJc0vu5k+TIiILqj8YwKOK7mf3vodtr2Veg0r3zaf3UCGu2+0MsLf19Bf7XS+8YKI+CBwGfDlYvfW6lPXsO2t0sew8m1hoMPdN1oZ4d8ETK56fiKwuYQ++hQRm4u/W4GHaL+hx7ccGCG5+Lu15H7e0U7Dtvc1rDxt8Nm103D3ZYR/BTBd0jRJI4ErgaUl9PEekkYXB2KQNBq4hPYbenwpMK+Yngc8XGIv79Iuw7bXGlaekj+7dhvuvpQr/IpTGd8EhgOLIuJrLW+iD5JOprK1h8oIxveV2Zuk+4FZVG753ALcDPwnsAQ4CdgAXBERLT/wVqO3WVR2Xd8Ztv3Ad+wW93Yh8L/AC0BPMfsmKt+vS/vsEn3NpYTPzZf3mmXKV/iZZcrhN8uUw2+WKYffLFMOv1mmHH6zTDn8Zpn6f/UCLZWmyE1mAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.imshow(np.reshape(result.squeeze(), (28,28)))\n", "plt.title('Barycenter')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Also, when the log data from Mosek is investigated the number of variables are same for each modeling language. However, while Pyomo and Fusion having same number of constraints CVXPY has a lot of extra constraints that resulted from nonnegativity of the variables." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Comparison of Modeling Languages" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Time to solve the problem" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total Time:\n", "Fusion: 74.02755951881409\n", "Pyomo : 900.9914240837097\n", "CVXPY : 122.6916720867157\n", "\n", "Time spent in the solver:\n", "Fusion: 53.52407097816467\n", "Pyomo : 58.86694002151489\n", "CVXPY : 68.76518487930298\n" ] } ], "source": [ "print('Total Time:')\n", "print('Fusion: {0}'.format(fusion_model.time))\n", "print('Pyomo : {0}'.format(pyomo_model.time))\n", "print('CVXPY : {0}'.format(cvxpy_model.time))\n", "print('\\nTime spent in the solver:')\n", "print('Fusion: {0}'.format(fusion_model.M.getSolverDoubleInfo(\"optimizerTime\")))\n", "print('Pyomo : {0}'.format(pyomo_model.result.solver.wallclock_time))\n", "print('CVXPY : {0}'.format(cvxpy_model.prob.solver_stats.solve_time))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.style.use('seaborn')\n", "plt.figure(figsize=(10,4))\n", "\n", "total_t = [fusion_model.time, pyomo_model.time, cvxpy_model.time]\n", "solver_t = [fusion_model.M.getSolverDoubleInfo(\"optimizerTime\"), pyomo_model.result.solver.wallclock_time, cvxpy_model.prob.solver_stats.solve_time]\n", "\n", "#Total time plot\n", "plt.subplot(1,2,1)\n", "plt.bar(['Fusion', 'Pyomo', 'CVXPY'], height= total_t,\n", " width=0.5, color=(0.3, 0.6, 0.2, 0.5))\n", "plt.ylabel(\"Total Time (s)\")\n", "plt.title(\"Comparison of Total Time\")\n", "\n", "#Solver time plot\n", "plt.subplot(1,2,2)\n", "plt.bar(['Fusion', 'Pyomo', 'CVXPY'], height=solver_t,\n", " width=0.5, color=(0.5, 0.6, 0.9, 0.8))\n", "plt.ylabel(\"Solver Time (s)\")\n", "plt.title(\"Comparison of Solver Time\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Discussion\n", "\n", "Apparently, **Fusion passes the model data to Mosek faster** than the other ones. CVXPY is close to Fusion but its solver time is longer, mainly due to presolve (for example CVXPY enters variable bounds as constraints). In terms of total time Pyomo is behind the others because all the transformations are made in Python, as opposed to Fusion and CVXPY which call a C library. However, this is a huge model with 31 thousand constraints and 12 million variables and the difference will not be that big for normal-sized models.\n", "\n", "On the other hand, Fusion and CVXPY allow you to express model in vectorized form (using **matrix, vector** notation). Pyomo is mainly based on **sum expressions** that can be defined by **rule functions** which makes modelling much easier. Therefore the **time and effort spent on constructing the model in Pyomo was much smaller** than with the other languages.\n", "\n", "The decision of which one to use depends on the problem and the preferences of the user." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## CVXPY and Fusion on huge models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " $\\quad$The plots above show the performance of the modeling languages for 20 images. It is also reasonable to test the behaivor of solving times as the model gets larger and larger. In order to make this test same problem is solved with 50 images by using CVXPY and Fusion." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "n = 50\n", "train_1 = ones[:n]\n", "cvxpy_model.reset()\n", "fusion_model.reset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### CVXPY" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Problem\n", " Name : \n", " Objective sense : min \n", " Type : LO (linear optimization problem)\n", " Constraints : 30812084 \n", " Cones : 0 \n", " Scalar variables : 30733634 \n", " Matrix variables : 0 \n", " Integer variables : 0 \n", "\n", "Optimizer started.\n", "Presolve started.\n", "Linear dependency checker started.\n", "Linear dependency checker terminated.\n", "Eliminator started.\n", "Freed constraints in eliminator : 0\n", "Eliminator terminated.\n", "Eliminator - tries : 1 time : 0.00 \n", "Lin. dep. - tries : 1 time : 6.01 \n", "Lin. dep. - number : 49 \n", "Presolve terminated. Time: 37.73 \n", "GP based matrix reordering started.\n", "GP based matrix reordering terminated.\n", "Problem\n", " Name : \n", " Objective sense : min \n", " Type : LO (linear optimization problem)\n", " Constraints : 30812084 \n", " Cones : 0 \n", " Scalar variables : 30733634 \n", " Matrix variables : 0 \n", " Integer variables : 0 \n", "\n", "Optimizer - threads : 24 \n", "Optimizer - solved problem : the primal \n", "Optimizer - Constraints : 43692\n", "Optimizer - Cones : 0\n", "Optimizer - Scalar variables : 3560928 conic : 0 \n", "Optimizer - Semi-definite variables: 0 scalarized : 0 \n", "Factor - setup time : 152.40 dense det. time : 0.58 \n", "Factor - ML order time : 77.89 GP order time : 70.64 \n", "Factor - nonzeros before factor : 4.53e+06 after factor : 9.72e+07 \n", "Factor - dense dim. : 0 flops : 3.04e+11 \n", "Factor - GP saved nzs : 3.57e+07 GP saved flops : 5.09e+11 \n", "ITE PFEAS DFEAS GFEAS PRSTATUS POBJ DOBJ MU TIME \n", "0 3.8e+03 1.0e+02 5.9e+07 0.00e+00 1.207747296e+07 0.000000000e+00 2.4e+01 195.47\n", "1 9.7e-01 2.6e-02 1.5e+04 -1.00e+00 1.138202629e+07 -1.932435859e+05 6.1e-03 199.68\n", "2 1.7e-01 4.7e-03 2.7e+03 4.69e+01 2.749170572e+04 -2.705073018e+03 1.1e-03 206.20\n", "3 1.2e-01 3.3e-03 1.9e+03 8.18e+00 4.793007222e+03 -5.773295566e+02 7.6e-04 211.01\n", "4 1.1e-01 3.1e-03 1.7e+03 3.34e+00 3.771396588e+03 -4.640287740e+02 7.1e-04 214.99\n", "5 1.1e-01 2.9e-03 1.6e+03 3.09e+00 3.173273039e+03 -3.951025274e+02 6.7e-04 219.11\n", "6 9.7e-02 2.6e-03 1.5e+03 2.94e+00 2.470985631e+03 -3.117937667e+02 6.1e-04 223.02\n", "7 6.5e-02 1.8e-03 1.0e+03 2.74e+00 1.021523834e+03 -1.260039800e+02 4.1e-04 227.61\n", "8 8.4e-03 2.3e-04 1.3e+02 2.13e+00 8.372742526e+01 -2.798064226e+00 5.2e-05 234.39\n", "9 4.0e-03 1.1e-04 6.1e+01 1.18e+00 3.888935256e+01 -6.475636806e-01 2.5e-05 240.88\n", "10 2.0e-03 5.4e-05 3.1e+01 1.09e+00 2.102550501e+01 1.461108737e+00 1.2e-05 245.43\n", "11 1.6e-03 4.5e-05 2.5e+01 1.04e+00 1.781785581e+01 1.824394084e+00 1.0e-05 249.44\n", "12 1.4e-03 4.0e-05 2.2e+01 1.03e+00 1.560289396e+01 1.976506458e+00 8.7e-06 253.43\n", "13 1.2e-03 3.1e-05 1.9e+01 1.03e+00 1.404619009e+01 2.298815530e+00 7.5e-06 257.53\n", "14 1.1e-03 2.8e-05 1.7e+01 1.02e+00 1.287129495e+01 2.405788065e+00 6.7e-06 261.85\n", "15 8.5e-04 2.2e-05 1.3e+01 1.02e+00 1.082240990e+01 2.583882531e+00 5.3e-06 267.28\n", "16 7.5e-04 1.9e-05 1.1e+01 1.01e+00 9.871522333e+00 2.665118594e+00 4.6e-06 271.43\n", "17 6.3e-04 1.6e-05 9.6e+00 1.01e+00 8.789904345e+00 2.753524954e+00 3.9e-06 275.72\n", "18 2.6e-04 6.0e-06 3.9e+00 1.01e+00 5.486142029e+00 3.037439240e+00 1.6e-06 280.57\n", "19 2.1e-04 4.8e-06 3.2e+00 1.00e+00 5.049911898e+00 3.067269696e+00 1.3e-06 285.09\n", "20 1.1e-04 2.8e-06 1.7e+00 1.00e+00 4.192997697e+00 3.120001209e+00 6.9e-07 289.72\n", "21 8.3e-05 2.1e-06 1.3e+00 1.00e+00 3.937217141e+00 3.136827758e+00 5.2e-07 293.98\n", "22 4.3e-05 1.4e-06 6.7e-01 1.00e+00 3.575235178e+00 3.153092018e+00 2.7e-07 298.06\n", "23 2.6e-05 8.4e-07 4.0e-01 1.00e+00 3.417790137e+00 3.165912215e+00 1.6e-07 302.48\n", "24 1.3e-05 4.4e-07 2.1e-01 1.00e+00 3.306492542e+00 3.174555301e+00 8.5e-08 306.79\n", "25 5.5e-06 1.8e-07 8.5e-02 1.00e+00 3.233672573e+00 3.179897390e+00 3.5e-08 311.42\n", "26 7.7e-07 1.9e-08 1.2e-02 1.00e+00 3.190109552e+00 3.182620943e+00 4.8e-09 316.45\n", "27 3.3e-07 8.2e-09 5.1e-03 1.00e+00 3.185934904e+00 3.182730244e+00 2.1e-09 320.91\n", "28 2.7e-08 6.7e-10 4.1e-04 1.00e+00 3.183064915e+00 3.182803646e+00 1.7e-10 326.93\n", "29 6.5e-10 1.5e-11 9.3e-06 1.00e+00 3.182810905e+00 3.182805007e+00 3.8e-12 331.90\n", "30 3.4e-11 1.4e-13 3.8e-08 1.00e+00 3.182805061e+00 3.182805037e+00 1.5e-14 335.98\n", "Basis identification started.\n", "Basis identification terminated. Time: 1.86\n", "Optimizer terminated. Time: 361.28 \n", "\n", "\n", "Interior-point solution summary\n", " Problem status : PRIMAL_AND_DUAL_FEASIBLE\n", " Solution status : OPTIMAL\n", " Primal. obj: 3.1828050614e+00 nrm: 2e+01 Viol. con: 1e-10 var: 0e+00 \n", " Dual. obj: 3.1828050373e+00 nrm: 2e+02 Viol. con: 4e-16 var: 9e-14 \n", "\n", "Basic solution summary\n", " Problem status : PRIMAL_AND_DUAL_FEASIBLE\n", " Solution status : OPTIMAL\n", " Primal. obj: 3.1828050375e+00 nrm: 2e+01 Viol. con: 8e-07 var: 0e+00 \n", " Dual. obj: 3.1828050375e+00 nrm: 2e+02 Viol. con: 7e-15 var: 9e-14 \n", "\n", "Time Spent to solve problem with CVXPY: \n", " 512.632221698761\n", "Time Spent in solver: \n", " 361.275288105011\n", "The average Wasserstein distance between digits and the barycenter: \n", " 3.182805061402046\n" ] } ], "source": [ "res_cvx = cvxpy_model.Wasserstein_BaryCenter(train_1)\n", "print('\\nTime Spent to solve problem with CVXPY: \\n {0}'.format(cvxpy_model.time))\n", "print('Time Spent in solver: \\n {0}'.format(cvxpy_model.prob.solver_stats.solve_time))\n", "print('The average Wasserstein distance between digits and the barycenter: \\n {0}'.format(cvxpy_model.result))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.style.use(\"default\")\n", "plt.figure(figsize=(3.3,3.3))\n", "plt.imshow(np.reshape(res_cvx.squeeze(), (28,28)))\n", "plt.title('Barycenter')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fusion" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Problem\n", " Name : Wasserstein \n", " Objective sense : min \n", " Type : LO (linear optimization problem)\n", " Constraints : 78400 \n", " Cones : 0 \n", " Scalar variables : 30733585 \n", " Matrix variables : 0 \n", " Integer variables : 0 \n", "\n", "Optimizer started.\n", "Presolve started.\n", "Linear dependency checker started.\n", "Linear dependency checker terminated.\n", "Eliminator started.\n", "Freed constraints in eliminator : 0\n", "Eliminator terminated.\n", "Eliminator - tries : 1 time : 0.00 \n", "Lin. dep. - tries : 1 time : 0.89 \n", "Lin. dep. - number : 49 \n", "Presolve terminated. Time: 13.31 \n", "GP based matrix reordering started.\n", "GP based matrix reordering terminated.\n", "Problem\n", " Name : Wasserstein \n", " Objective sense : min \n", " Type : LO (linear optimization problem)\n", " Constraints : 78400 \n", " Cones : 0 \n", " Scalar variables : 30733585 \n", " Matrix variables : 0 \n", " Integer variables : 0 \n", "\n", "Optimizer - threads : 24 \n", "Optimizer - solved problem : the primal \n", "Optimizer - Constraints : 43692\n", "Optimizer - Cones : 0\n", "Optimizer - Scalar variables : 3560928 conic : 0 \n", "Optimizer - Semi-definite variables: 0 scalarized : 0 \n", "Factor - setup time : 148.91 dense det. time : 0.60 \n", "Factor - ML order time : 77.43 GP order time : 67.30 \n", "Factor - nonzeros before factor : 4.53e+06 after factor : 9.72e+07 \n", "Factor - dense dim. : 0 flops : 3.04e+11 \n", "Factor - GP saved nzs : 3.57e+07 GP saved flops : 5.09e+11 \n", "ITE PFEAS DFEAS GFEAS PRSTATUS POBJ DOBJ MU TIME \n", "0 3.8e+03 1.0e+02 5.9e+07 0.00e+00 1.207747296e+07 0.000000000e+00 2.4e+01 166.33\n", "1 9.7e-01 2.6e-02 1.5e+04 -1.00e+00 1.138202628e+07 -1.932435836e+05 6.1e-03 170.72\n", "2 1.7e-01 4.7e-03 2.7e+03 4.69e+01 2.749170586e+04 -2.705073033e+03 1.1e-03 176.96\n", "3 1.2e-01 3.3e-03 1.9e+03 8.18e+00 4.793007234e+03 -5.773295562e+02 7.6e-04 181.92\n", "4 1.1e-01 3.1e-03 1.7e+03 3.34e+00 3.771396617e+03 -4.640287759e+02 7.1e-04 186.01\n", "5 1.1e-01 2.9e-03 1.6e+03 3.09e+00 3.173273061e+03 -3.951025289e+02 6.7e-04 189.91\n", "6 9.7e-02 2.6e-03 1.5e+03 2.94e+00 2.470985651e+03 -3.117937685e+02 6.1e-04 193.88\n", "7 6.5e-02 1.8e-03 1.0e+03 2.74e+00 1.021523849e+03 -1.260039818e+02 4.1e-04 198.83\n", "8 8.4e-03 2.3e-04 1.3e+02 2.13e+00 8.372815147e+01 -2.798139621e+00 5.2e-05 205.42\n", "9 4.0e-03 1.1e-04 6.1e+01 1.18e+00 3.888892297e+01 -6.475319201e-01 2.5e-05 211.68\n", "10 2.0e-03 5.4e-05 3.1e+01 1.09e+00 2.102558681e+01 1.461095266e+00 1.2e-05 216.47\n", "11 1.6e-03 4.5e-05 2.5e+01 1.04e+00 1.781792814e+01 1.824383152e+00 1.0e-05 221.00\n", "12 1.4e-03 4.0e-05 2.2e+01 1.03e+00 1.560302313e+01 1.976495303e+00 8.7e-06 225.04\n", "13 1.2e-03 3.1e-05 1.9e+01 1.03e+00 1.404624228e+01 2.298817130e+00 7.5e-06 228.87\n", "14 1.1e-03 2.8e-05 1.7e+01 1.02e+00 1.287134161e+01 2.405789410e+00 6.7e-06 233.13\n", "15 8.5e-04 2.2e-05 1.3e+01 1.02e+00 1.082241807e+01 2.583885971e+00 5.3e-06 238.64\n", "16 7.5e-04 1.9e-05 1.1e+01 1.01e+00 9.871538265e+00 2.665120797e+00 4.6e-06 242.84\n", "17 6.3e-04 1.6e-05 9.6e+00 1.01e+00 8.789936046e+00 2.753525214e+00 3.9e-06 247.20\n", "18 2.6e-04 6.0e-06 3.9e+00 1.01e+00 5.486120490e+00 3.037438811e+00 1.6e-06 252.18\n", "19 2.1e-04 4.8e-06 3.2e+00 1.00e+00 5.049889518e+00 3.067269697e+00 1.3e-06 256.42\n", "20 1.1e-04 2.8e-06 1.7e+00 1.00e+00 4.193063274e+00 3.120006971e+00 6.9e-07 261.06\n", "21 8.3e-05 2.1e-06 1.3e+00 1.00e+00 3.937318985e+00 3.136828389e+00 5.2e-07 265.26\n", "22 4.3e-05 1.4e-06 6.7e-01 1.00e+00 3.575228557e+00 3.153094207e+00 2.7e-07 269.78\n", "23 2.6e-05 8.4e-07 4.0e-01 1.00e+00 3.417767887e+00 3.165915025e+00 1.6e-07 274.64\n", "24 1.3e-05 4.4e-07 2.1e-01 1.00e+00 3.306549094e+00 3.174551508e+00 8.5e-08 278.83\n", "25 5.5e-06 1.8e-07 8.5e-02 1.00e+00 3.233604776e+00 3.179902683e+00 3.5e-08 283.63\n", "26 7.7e-07 1.3e-08 1.2e-02 1.00e+00 3.190062404e+00 3.182728989e+00 4.7e-09 289.08\n", "27 3.4e-07 5.7e-09 5.1e-03 1.00e+00 3.186012714e+00 3.182772597e+00 2.1e-09 293.89\n", "28 2.9e-08 4.9e-10 4.4e-04 1.00e+00 3.183083026e+00 3.182803490e+00 1.8e-10 299.39\n", "29 1.1e-09 1.5e-11 1.6e-05 1.00e+00 3.182815207e+00 3.182804980e+00 6.6e-12 303.47\n", "30 1.9e-11 1.7e-13 2.4e-08 1.00e+00 3.182805053e+00 3.182805037e+00 1.0e-14 307.66\n", "31 4.4e-12 1.4e-13 2.1e-12 1.00e+00 3.182805038e+00 3.182805038e+00 1.3e-18 312.21\n", "Basis identification started.\n", "Basis identification terminated. Time: 1.87\n", "Optimizer terminated. Time: 323.71 \n", "\n", "\n", "Interior-point solution summary\n", " Problem status : PRIMAL_AND_DUAL_FEASIBLE\n", " Solution status : OPTIMAL\n", " Primal. obj: 3.1828050375e+00 nrm: 1e+00 Viol. con: 7e-12 var: 0e+00 \n", " Dual. obj: 3.1828050375e+00 nrm: 2e+02 Viol. con: 0e+00 var: 9e-14 \n", "\n", "Basic solution summary\n", " Problem status : PRIMAL_AND_DUAL_FEASIBLE\n", " Solution status : OPTIMAL\n", " Primal. obj: 3.1828050375e+00 nrm: 1e+00 Viol. con: 4e-17 var: 8e-07 \n", " Dual. obj: 3.1828050375e+00 nrm: 2e+02 Viol. con: 0e+00 var: 7e-12 \n", "\n", "Time Spent to solve problem with Fusion: \n", " 369.8690595626831\n", "Time Spent in solver: \n", " 323.7116389274597\n", "The average Wasserstein distance between digits and the barycenter: \n", " 3.182805037502097\n" ] } ], "source": [ "res_f = fusion_model.Wasserstein_BaryCenter(train_1)\n", "print('\\nTime Spent to solve problem with Fusion: \\n {0}'.format(fusion_model.time))\n", "print('Time Spent in solver: \\n {0}'.format(fusion_model.M.getSolverDoubleInfo(\"optimizerTime\")))\n", "print('The average Wasserstein distance between digits and the barycenter: \\n {0}'.format(fusion_model.objective))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(3.3,3.3))\n", "plt.imshow(np.reshape(res_f,(28,28)))\n", "plt.title('Barycenter')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(8,3.5))\n", "\n", "total_t50 = [fusion_model.time, cvxpy_model.time]\n", "solver_t50 = [fusion_model.M.getSolverDoubleInfo(\"optimizerTime\"), cvxpy_model.prob.solver_stats.solve_time]\n", "\n", "#Total time plot\n", "plt.subplot(1,2,1)\n", "plt.bar(['Fusion', 'CVXPY'], height= total_t50,\n", " width=0.4, color=(0.3, 0.6, 0.2, 0.5))\n", "plt.ylabel(\"Total Time (s)\")\n", "plt.title(\"Comparison of Total Time\")\n", "\n", "#Solver time plot\n", "plt.subplot(1,2,2)\n", "plt.bar(['Fusion','CVXPY'], height=solver_t50,\n", " width=0.4, color=(0.5, 0.6, 0.9, 0.8))\n", "plt.ylabel(\"Solver Time (s)\")\n", "plt.title(\"Comparison of Solver Time\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"Creative
This work is licensed under a Creative Commons Attribution 4.0 International License. The **MOSEK** logo and name are trademarks of Mosek ApS. The code is provided as-is. Compatibility with future release of **MOSEK** or the `Fusion API` are not guaranteed. For more information contact our [support](mailto:support@mosek.com). " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }