{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "![MOSEK ApS](https://www.mosek.com/static/images/branding/webgraphmoseklogocolor.png )\n", "\n", "# Implementing Regularised Wasserstein Barycenter problem using JuMP in Julia.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The goal of this notebook is to implement a model to calculate the Wasserstein barycenter by solving an entropy regularised minimization problem in Julia with JuMP and then solve it using MosekTools. For additional info about the data used, theoretical explanation of the calculation of barycenters, references and for more insight in construction of the model, please consult the corresponding [Python notebook](https://nbviewer.jupyter.org/github/MOSEK/Tutorials/blob/master/wasserstein/wasserstein-bary-reg.ipynb). Data files can be found in http://yann.lecun.com/exdb/mnist/." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "using LinearAlgebra\n", "using Plots\n", "pyplot()\n", "\n", "using JuMP\n", "\n", "using Mosek\n", "using MosekTools\n", "\n", "using Statistics" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "PyPlot.Figure(PyObject )" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#Define the number of images for the barycenter calculation.\n", "n = 2\n", "\n", "#Read the images from the file.\n", "function read_idx(filename)\n", " f = open(filename,\"r\")\n", " data_layout = zeros(UInt8,4)\n", " readbytes!(f,data_layout,4)\n", " data_zero = reinterpret(UInt16,data_layout[1:2])\n", " data_type,data_dimensions = reinterpret(UInt8,data_layout[3:4])\n", " data_shape = Int32[]\n", " for i = 1:data_layout[4]\n", " s = zeros(UInt8,4)\n", " readbytes!(f,s,4)\n", " s = map(hton,reinterpret(Int32,s))\n", " push!(data_shape,s[1])\n", " end\n", " idx_data = zeros(UInt8,cumprod(data_shape)[length(data_shape)])\n", " read!(f,idx_data)\n", " idx_data = reshape(idx_data,Tuple(reverse(data_shape)))\n", " return(idx_data)\n", "end\n", "\n", "data = read_idx(\"train-images-idx3-ubyte\")\n", "labels = read_idx(\"train-labels-idx1-ubyte\")\n", "\n", "#Select the images.\n", "mask = labels .== 3\n", "train_ones = data[:,:,mask]\n", "train = train_ones[:,:,1:n]\n", "\n", "x = [i for i=1:28]\n", "y = reverse(x)\n", "f,ax = PyPlot.plt.subplots(2,5,sharey=true,sharex=true,figsize=(10,5))\n", "PyPlot.plt.xticks([5,10,15,20,25])\n", "\n", "for i = 1:10\n", " rand_pick = rand(1:size(train_ones)[3])\n", " ax[i].pcolormesh(x,y,transpose(train_ones[:,:,rand_pick]))\n", "end" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "reg_wasserstein_barycenter (generic function with 1 method)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function single_pmf(data)\n", " #Takes a list of images and extracts the probability mass function.\n", " v = vec(data[:,:,1])\n", " v = v./cumsum(v)[length(v)]\n", " for im_k in 2:size(data)[3]\n", " image = data[:,:,im_k]\n", " arr = vec(image)\n", " v_size = size(arr)[1]\n", " v = hcat(v, arr./cumsum(arr)[length(arr)])\n", " end\n", " return v,size(v)[1]\n", "end\n", "\n", "function ms_distance(m,n)\n", " #Squared Euclidean distance calculation between the pixels.\n", " d = ones(m,m)\n", " coor_I = []\n", " for c_i in 1:n\n", " append!(coor_I,ones(Int,n).*c_i)\n", " end \n", " coor_J = repeat(1:n,n)\n", " coor = hcat(coor_I,coor_J)\n", " for i in 1:m\n", " for j in 1:m\n", " d[i,j] = norm(coor[i,:]-coor[j,:]).^2\n", " end\n", " end\n", " return d\n", "end\n", "\n", "function reg_wasserstein_barycenter(data,lambda,relgap)\n", " #Calculation of wasserstein barycenter by solving an entropy regularised minimization problem.\n", " #Direct mode model\n", " #M = direct_model(Mosek.Optimizer(MSK_DPAR_INTPNT_CO_TOL_REL_GAP=relgap))\n", " \n", " #Automatic mode model\n", " M = Model(with_optimizer(Mosek.Optimizer,MSK_DPAR_INTPNT_CO_TOL_REL_GAP=relgap))\n", " \n", " if length(size(data))==3\n", " K = size(data)[3]\n", " else\n", " K = 1\n", " end\n", " v,N = single_pmf(data)\n", " d = ms_distance(N,size(data)[2])\n", " \n", " if lambda==nothing\n", " lambda = 60/median(vec(d))\n", " end\n", " \n", " #Define indices\n", " M_i = 1:N\n", " M_j = 1:N\n", " M_k = 1:K\n", "\n", " #Adding variables\n", " M_pi = @variable(M, M_pi[i = M_i, j = M_j, k = M_k] >= 0.0)\n", " M_mu = @variable(M, M_mu[i = M_i] >= 0.0)\n", "\n", " #Auxiliary variable for the conic constraint\n", " M_aux = @variable(M,M_aux[i = M_i, j = M_j, k = M_k])\n", " \n", " #Adding constraints\n", " @constraint(M, c3_expr[k = M_k, j = M_j], sum(M_pi[:,j,k]) == v[j,k])\n", " @constraint(M, c2_expr[k = M_k, i = M_i], sum(M_pi[i,:,k]) == M_mu[i])\n", " \n", " #Adding conic constraint\n", " @constraint(M,cExp_cone[i=M_i, j=M_j, k=M_k],[M_aux[i,j,k],M_pi[i,j,k],1] in MOI.ExponentialCone())\n", " \n", " #Non-linear objective in the case of Regularized barycenters.\n", " W_obj = @objective(M, Min,(sum(d[i,j]*M_pi[i,j,k] for i=M_i,j=M_j,k=M_k) - \n", " sum(M_aux[i,j,k] for i=M_i,j=M_j,k=M_k)/lambda)/K)\n", " \n", " return M,M_mu\n", "end" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "show_barycenter (generic function with 1 method)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function run_regularised_model(data,lambda=nothing,relgap=1e-7)\n", " @time begin\n", " M,M_mu = reg_wasserstein_barycenter(data,lambda,relgap)\n", " optimize!(M)\n", " end\n", " println(\"Solution status = \",termination_status(M))\n", " println(\"Primal objective value = \",objective_value(M))\n", " mu_level = value.(M_mu)\n", " return mu_level\n", "end\n", "\n", "function show_barycenter(bary_center)\n", " bary_center = reshape(bary_center,(28,28))\n", " x = [i for i=1:28]\n", " y = reverse(x)\n", " PyPlot.plt.pcolormesh(x,y,transpose(bary_center))\n", " PyPlot.plt.title(\"Regularized Barycenter\")\n", " PyPlot.plt.show()\n", "end" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Problem\n", " Name : \n", " Objective sense : min \n", " Type : CONIC (conic optimization problem)\n", " Constraints : 3691072 \n", " Cones : 1229312 \n", " Scalar variables : 6147344 \n", " Matrix variables : 0 \n", " Integer variables : 0 \n", "\n", "Optimizer started.\n", "Presolve started.\n", "Linear dependency checker started.\n", "Linear dependency checker terminated.\n", "Eliminator started.\n", "Freed constraints in eliminator : 0\n", "Eliminator terminated.\n", "Eliminator - tries : 1 time : 0.00 \n", "Lin. dep. - tries : 1 time : 1.02 \n", "Lin. dep. - number : 1 \n", "Presolve terminated. Time: 7.65 \n", "Problem\n", " Name : \n", " Objective sense : min \n", " Type : CONIC (conic optimization problem)\n", " Constraints : 3691072 \n", " Cones : 1229312 \n", " Scalar variables : 6147344 \n", " Matrix variables : 0 \n", " Integer variables : 0 \n", "\n", "Optimizer - threads : 20 \n", "Optimizer - solved problem : the primal \n", "Optimizer - Constraints : 1138\n", "Optimizer - Cones : 1229312\n", "Optimizer - Scalar variables : 3687936 conic : 3687936 \n", "Optimizer - Semi-definite variables: 0 scalarized : 0 \n", "Factor - setup time : 2.62 dense det. time : 0.00 \n", "Factor - ML order time : 0.01 GP order time : 0.00 \n", "Factor - nonzeros before factor : 2.79e+05 after factor : 3.41e+05 \n", "Factor - dense dim. : 0 flops : 1.17e+08 \n", "ITE PFEAS DFEAS GFEAS PRSTATUS POBJ DOBJ MU TIME \n", "0 6.3e+02 5.1e+02 2.4e+07 0.00e+00 2.213290906e+07 -1.586952925e+06 1.0e+00 16.82 \n", "1 2.2e+02 1.8e+02 1.4e+07 -9.87e-01 1.997448319e+07 -3.251325127e+06 3.5e-01 21.07 \n", "2 1.4e+02 1.1e+02 1.0e+07 -8.50e-01 1.706335688e+07 -4.132584961e+06 2.2e-01 26.16 \n", "3 8.6e+01 6.9e+01 6.1e+06 -5.35e-01 1.255536087e+07 -4.350508428e+06 1.4e-01 33.89 \n", "4 3.9e+01 3.1e+01 2.3e+06 -5.91e-02 6.634369363e+06 -3.290261645e+06 6.2e-02 38.48 \n", "5 1.6e+01 1.3e+01 6.8e+05 4.56e-01 3.047700570e+06 -1.738022157e+06 2.5e-02 42.26 \n", "6 9.2e+00 7.4e+00 3.2e+05 6.93e-01 1.871313131e+06 -1.107859520e+06 1.5e-02 45.87 \n", "7 3.8e+00 3.0e+00 9.0e+04 7.82e-01 8.078383522e+05 -5.102774510e+05 5.9e-03 50.28 \n", "8 1.0e+00 8.1e-01 1.4e+04 8.51e-01 2.027574293e+05 -1.783130565e+05 1.6e-03 54.31 \n", "9 4.2e-01 3.4e-01 4.0e+03 8.87e-01 8.245554128e+04 -8.500164850e+04 6.6e-04 58.08 \n", "10 1.2e-01 9.9e-02 6.9e+02 9.00e-01 2.254804899e+04 -2.962342498e+04 2.0e-04 61.72 \n", "11 3.6e-02 2.9e-02 1.2e+02 9.02e-01 5.918580033e+03 -1.024328330e+04 5.7e-05 65.07 \n", "12 1.1e-02 8.5e-03 2.0e+01 9.00e-01 1.600621502e+03 -3.421588743e+03 1.7e-05 68.39 \n", "13 3.3e-03 2.6e-03 3.7e+00 9.03e-01 4.338656334e+02 -1.203318166e+03 5.2e-06 71.91 \n", "14 9.2e-04 7.4e-04 5.9e-01 9.15e-01 9.099381179e+01 -3.977761008e+02 1.5e-06 75.28 \n", "15 2.5e-04 2.0e-04 8.5e-02 9.22e-01 1.264148878e+00 -1.351942993e+02 3.9e-07 78.65 \n", "16 7.1e-05 5.7e-05 1.4e-02 9.31e-01 -1.815177735e+01 -5.892407032e+01 1.1e-07 81.93 \n", "17 1.6e-05 1.3e-05 1.6e-03 9.35e-01 -2.349891579e+01 -3.326506346e+01 2.6e-08 85.27 \n", "18 3.2e-06 2.6e-06 1.5e-04 9.38e-01 -2.441821029e+01 -2.642392235e+01 5.1e-09 88.70 \n", "19 7.5e-07 6.0e-07 1.7e-05 9.44e-01 -2.452958618e+01 -2.501499156e+01 1.2e-09 92.04 \n", "20 1.8e-07 1.4e-07 2.0e-06 9.49e-01 -2.454255521e+01 -2.466071465e+01 2.8e-10 95.32 \n", "21 3.6e-08 2.9e-08 1.9e-07 9.48e-01 -2.454079676e+01 -2.456580776e+01 5.7e-11 99.05 \n", "22 9.3e-09 7.5e-09 2.6e-08 9.50e-01 -2.453852209e+01 -2.454516580e+01 1.5e-11 102.62\n", "23 2.3e-09 1.8e-09 3.3e-09 9.51e-01 -2.453774799e+01 -2.453943385e+01 3.6e-12 106.06\n", "24 8.2e-10 5.0e-10 4.7e-10 9.51e-01 -2.453754484e+01 -2.453801397e+01 9.8e-13 109.44\n", "25 8.0e-10 4.1e-10 3.6e-10 9.54e-01 -2.453752652e+01 -2.453792006e+01 8.1e-13 120.02\n", "26 1.3e-09 3.7e-10 3.1e-10 9.54e-01 -2.453751783e+01 -2.453787528e+01 7.4e-13 135.61\n", "27 1.2e-09 3.6e-10 2.9e-10 9.54e-01 -2.453751378e+01 -2.453785435e+01 7.0e-13 149.61\n", "28 1.4e-09 3.4e-10 2.7e-10 9.54e-01 -2.453750992e+01 -2.453783439e+01 6.7e-13 161.23\n", "29 1.4e-09 3.1e-10 2.3e-10 9.54e-01 -2.453750254e+01 -2.453779623e+01 6.0e-13 173.52\n", "30 1.4e-09 2.9e-10 2.1e-10 9.54e-01 -2.453749919e+01 -2.453777893e+01 5.7e-13 186.35\n", "31 1.4e-09 2.8e-10 2.1e-10 9.54e-01 -2.453749759e+01 -2.453777069e+01 5.6e-13 199.37\n", "32 1.4e-09 2.8e-10 2.1e-10 9.54e-01 -2.453749749e+01 -2.453777019e+01 5.6e-13 215.66\n", "33 1.4e-09 2.8e-10 2.1e-10 9.54e-01 -2.453749746e+01 -2.453777006e+01 5.6e-13 233.18\n", "34 1.4e-09 2.8e-10 2.0e-10 9.54e-01 -2.453749669e+01 -2.453776604e+01 5.5e-13 246.21\n", "35 1.4e-09 2.8e-10 2.0e-10 9.54e-01 -2.453749649e+01 -2.453776505e+01 5.5e-13 261.34\n", "36 1.5e-09 2.8e-10 2.0e-10 9.54e-01 -2.453749644e+01 -2.453776480e+01 5.5e-13 279.47\n", "37 1.5e-09 2.8e-10 2.0e-10 9.54e-01 -2.453749643e+01 -2.453776474e+01 5.5e-13 298.25\n", "38 1.5e-09 2.8e-10 2.0e-10 9.54e-01 -2.453749641e+01 -2.453776462e+01 5.5e-13 315.76\n", "39 1.5e-09 2.8e-10 2.0e-10 9.54e-01 -2.453749640e+01 -2.453776460e+01 5.5e-13 335.01\n", "40 1.5e-09 2.8e-10 2.0e-10 9.54e-01 -2.453749640e+01 -2.453776457e+01 5.5e-13 354.66\n", "41 1.5e-09 2.8e-10 2.0e-10 9.54e-01 -2.453749640e+01 -2.453776457e+01 5.5e-13 373.38\n", "42 1.5e-09 2.8e-10 2.0e-10 9.54e-01 -2.453749640e+01 -2.453776457e+01 5.5e-13 392.86\n", "Optimizer terminated. Time: 419.08 \n", "\n", "649.126111 seconds (508.56 M allocations: 33.802 GiB, 22.71% gc time)\n", "Solution status = SLOW_PROGRESS\n", "Primal objective value = -24.5374963989306\n", "******\n" ] } ], "source": [ "bary_center = run_regularised_model(train)\n", "println(\"******\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "PyPlot.Figure(PyObject )" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_barycenter(bary_center)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"Creative
This work is licensed under a Creative Commons Attribution 4.0 International License. The **MOSEK** logo and name are trademarks of Mosek ApS. The code is provided as-is. Compatibility with future release of **MOSEK** is not guaranteed. For more information contact our [support](mailto:support@mosek.com). " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Julia 1.0.3", "language": "julia", "name": "julia-1.0" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.0.3" } }, "nbformat": 4, "nbformat_minor": 2 }