#include <stdlib.h> #include <math.h> #include <stdio.h> #include <string.h> #include <time.h> #include <fftw3.h> #include "mex.h" double compute_l2_ot(double *mu, double *nu, double *phi, double *dual, double totalMass, double sigma, int maxIters, int n1, int n2); void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){ double *mu=mxGetPr(prhs[0]); double *nu=mxGetPr(prhs[1]); int maxIters=(int) mxGetScalar(prhs[2]); double sigma =(double) mxGetScalar(prhs[3]); int n1=mxGetM(prhs[0]); int n2=mxGetN(prhs[0]); int pcount=n1*n2; plhs[0] = mxCreateDoubleMatrix(n1,n2,mxREAL); plhs[1] = mxCreateDoubleMatrix(n1,n2,mxREAL); double *phi=mxGetPr(plhs[0]); double *psi=mxGetPr(plhs[1]); double sum=0; for(int i=0;i<pcount;i++){ if (mu[i]<0){ mexErrMsgTxt("Initial density contains negative values"); } if (nu[i]<0){ mexErrMsgTxt("Final density contains negative values"); } sum+=mu[i]; } double totalMass=sum/(n1*n2*1.0); double value=compute_l2_ot(mu, nu, phi, psi, totalMass, sigma, maxIters, n1, n2); for(int i=0;i<n2;i++){ for(int j=0;j<n1;j++){ double x=(j+.5)/(n1*1.0); double y=(i+.5)/(n2*1.0); phi[i*n1+j]=.5*(x*x+y*y)-phi[i*n1+j]; psi[i*n1+j]=.5*(x*x+y*y)-psi[i*n1+j]; } } } typedef struct{ fftw_plan dctIn; fftw_plan dctOut; double *kernel; double *workspace; }poisson_solver; double *create_negative_laplace_kernel(int n1, int n2){ double *kernel=calloc(n1*n2,sizeof(double)); for(int i=0;i<n2;i++){ for(int j=0;j<n1;j++){ double x=M_PI*j/(n1*1.0); double y=M_PI*i/(n2*1.0); double negativeLaplacian=2*n1*n1*(1-cos(x))+2*n2*n2*(1-cos(y)); kernel[i*n1+j]=negativeLaplacian; } } return kernel; } poisson_solver create_poisson_solver_workspace(int n1, int n2){ clock_t b,e; b=clock(); poisson_solver fftps; fftps.workspace=calloc(n1*n2,sizeof(double)); fftps.kernel=create_negative_laplace_kernel(n1,n2); fftps.dctIn=fftw_plan_r2r_2d(n2, n1, fftps.workspace, fftps.workspace, FFTW_REDFT10, FFTW_REDFT10, FFTW_MEASURE); fftps.dctOut=fftw_plan_r2r_2d(n2, n1, fftps.workspace, fftps.workspace, FFTW_REDFT01, FFTW_REDFT01, FFTW_MEASURE); e=clock(); mexPrintf("FFT setup time: %.2fs\n", (e-b)/(CLOCKS_PER_SEC*1.0)); mexEvalString("pause(.001);"); return fftps; } void destroy_poisson_solver(poisson_solver fftps){ free(fftps.kernel); free(fftps.workspace); fftw_destroy_plan(fftps.dctIn); fftw_destroy_plan(fftps.dctOut); } typedef struct{ int *indices; int hullCount; }convex_hull; int sgn(double x){ int truth=(x>0)-(x<0); return truth; } void init_hull(convex_hull *hull, int n){ hull->indices=calloc(n,sizeof(double)); hull->hullCount=0; } void destroy_hull(convex_hull *hull){ free(hull->indices); } void transpose_doubles(double *transpose, double *data, int n1, int n2){ for(int i=0;i<n2;i++){ for(int j=0;j<n1;j++){ transpose[j*n2+i]=data[i*n1+j]; } } } double interpolate_function(double *function, double x, double y, int n1, int n2){ int xIndex=fmin(fmax(x*n1-.5 ,0),n1-1); int yIndex=fmin(fmax(y*n2-.5 ,0),n2-1); double xfrac=x*n1-xIndex-.5; double yfrac=y*n2-yIndex-.5; int xOther=xIndex+sgn(xfrac); int yOther=yIndex+sgn(yfrac); xOther=fmax(fmin(xOther, n1-1),0); yOther=fmax(fmin(yOther, n2-1),0); double v1=(1-fabs(xfrac))*(1-fabs(yfrac))*function[yIndex*n1+xIndex]; double v2=fabs(xfrac)*(1-fabs(yfrac))*function[yIndex*n1+xOther]; double v3=(1-fabs(xfrac))*fabs(yfrac)*function[yOther*n1+xIndex]; double v4=fabs(xfrac)*fabs(yfrac)*function[yOther*n1+xOther]; double v=v1+v2+v3+v4; return v; } void add_point(double *u, convex_hull *hull, int i){ if(hull->hullCount<2){ hull->indices[1]=i; hull->hullCount++; }else{ int hc=hull->hullCount; int ic1=hull->indices[hc-1]; int ic2=hull->indices[hc-2]; double oldSlope=(u[ic1]-u[ic2])/(ic1-ic2); double slope=(u[i]-u[ic1])/(i-ic1); if(slope>=oldSlope){ int hc=hull->hullCount; hull->indices[hc]=i; hull->hullCount++; }else{ hull->hullCount--; add_point(u, hull, i); } } } void get_convex_hull(double *u, convex_hull *hull, int n){ hull->indices[0]=0; hull->indices[1]=1; hull->hullCount=2; for(int i=2;i<n;i++){ add_point(u, hull, i); } } void compute_dual_indices(int *dualIndicies, double *u, convex_hull *hull, int n){ int counter=1; int hc=hull->hullCount; for(int i=0;i<n;i++){ double s=(i+.5)/(n*1.0); int ic1=hull->indices[counter]; int ic2=hull->indices[counter-1]; double slope=n*(u[ic1]-u[ic2])/(ic1-ic2); while(s>slope&&counter<hc-1){ counter++; ic1=hull->indices[counter]; ic2=hull->indices[counter-1]; slope=n*(u[ic1]-u[ic2])/(ic1-ic2); } dualIndicies[i]=hull->indices[counter-1]; } } void compute_dual(double *dual, double *u, int *dualIndicies, convex_hull *hull, int n){ get_convex_hull(u, hull, n); compute_dual_indices(dualIndicies, u, hull, n); for(int i=0;i<n;i++){ double s=(i+.5)/(n*1.0); int index=dualIndicies[i]; double x=(index+.5)/(n*1.0); double v1=s*x-u[dualIndicies[i]]; double v2=s*(n-.5)/(n*1.0)-u[n-1]; if(v1>v2){ dual[i]=v1; }else{ dualIndicies[i]=n-1; dual[i]=v2; } } } void compute_2d_dual(double *dual, double *u, convex_hull *hull, int n1, int n2){ int pcount=n1*n2; int n=fmax(n1,n2); int *argmin=calloc(n,sizeof(int)); double *temp=calloc(pcount,sizeof(double)); memcpy(temp, u, pcount*sizeof(double)); for(int i=0;i<n2;i++){ compute_dual(&dual[i*n1], &temp[i*n1], argmin, hull, n1); } transpose_doubles(temp, dual, n1, n2); for(int i=0;i<n1*n2;i++){ dual[i]=-temp[i]; } for(int j=0;j<n1;j++){ compute_dual(&temp[j*n2], &dual[j*n2], argmin, hull, n2); } transpose_doubles(dual, temp, n2, n1); free(temp); free(argmin); } void convexify(double *phi, double *dual, convex_hull *hull, int n1, int n2){ compute_2d_dual(dual, phi, hull, n1, n2); compute_2d_dual(phi, dual, hull, n1, n2); } void calc_pushforward_map(double *xMap, double *yMap, double *dual, int n1, int n2){ double xStep=1.0/n1; double yStep=1.0/n2; for(int i=0;i<n2+1;i++){ for(int j=0;j<n1+1;j++){ double x=j/(n1*1.0); double y=i/(n2*1.0); double dualxp=interpolate_function(dual, x+xStep, y, n1, n2); double dualxm=interpolate_function(dual, x-xStep, y, n1, n2); double dualyp=interpolate_function(dual, x, y+yStep, n1, n2); double dualym=interpolate_function(dual, x, y-yStep, n1, n2); xMap[i*(n1+1)+j]=.5*n1*(dualxp-dualxm); yMap[i*(n1+1)+j]=.5*n2*(dualyp-dualym); } } } void sampling_pushforward(double *rho, double *mu, double totalMass, double *xMap, double *yMap, int n1, int n2){ int pcount=n1*n2; memset(rho,0,pcount*sizeof(double)); double xCut=pow(1.0/n1,1.0/3); double yCut=pow(1.0/n2,1.0/3); for(int i=0;i<n2;i++){ for(int j=0;j<n1;j++){ double mass=mu[i*n1+j]; if(mass>0){ double xStretch0=fabs(xMap[i*(n1+1)+j+1]-xMap[i*(n1+1)+j]); double xStretch1=fabs(xMap[(i+1)*(n1+1)+j+1]-xMap[(i+1)*(n1+1)+j]); double yStretch0=fabs(yMap[(i+1)*(n1+1)+j]-yMap[i*(n1+1)+j]); double yStretch1=fabs(yMap[(i+1)*(n1+1)+j+1]-yMap[i*(n1+1)+j+1]); double xStretch=fmax(xStretch0, xStretch1); double yStretch=fmax(yStretch0, yStretch1); int xSamples=2*fmax(n1*xStretch,1); int ySamples=2*fmax(n2*yStretch,1); if(xStretch<xCut&&yStretch<yCut){ double factor=1/(xSamples*ySamples*1.0); for(int l=0;l<ySamples;l++){ for(int k=0;k<xSamples;k++){ double a=(k+.5)/(xSamples*1.0); double b=(l+.5)/(ySamples*1.0); double xPoint=(1-b)*(1-a)*xMap[i*(n1+1)+j]+(1-b)*a*xMap[i*(n1+1)+j+1]+b*(1-a)*xMap[(i+1)*(n1+1)+j]+a*b*xMap[i*(n1+1)+j+1]; double yPoint=(1-b)*(1-a)*yMap[i*(n1+1)+j]+(1-b)*a*yMap[i*(n1+1)+j+1]+b*(1-a)*yMap[(i+1)*(n1+1)+j]+a*b*yMap[i*(n1+1)+j+1]; double X=xPoint*n1-.5; double Y=yPoint*n2-.5; int xIndex=X; int yIndex=Y; double xFrac=X-xIndex; double yFrac=Y-yIndex; int xOther=xIndex+1; int yOther=yIndex+1; xIndex=fmin(fmax(xIndex,0),n1-1); xOther=fmin(fmax(xOther,0),n1-1); yIndex=fmin(fmax(yIndex,0),n2-1); yOther=fmin(fmax(yOther,0),n2-1); rho[yIndex*n1+xIndex]+=(1-xFrac)*(1-yFrac)*mass*factor; rho[yOther*n1+xIndex]+=(1-xFrac)*yFrac*mass*factor; rho[yIndex*n1+xOther]+=xFrac*(1-yFrac)*mass*factor; rho[yOther*n1+xOther]+=xFrac*yFrac*mass*factor; } } } } } } double sum=0; for(int i=0;i<pcount;i++){ sum+=rho[i]/pcount; } for(int i=0;i<pcount;i++){ rho[i]*=totalMass/sum; } } double update_potential(poisson_solver fftps, double *phi, double *rho, double *nu, double sigma, int n1, int n2){ int pcount=n1*n2; double h1=0; for(int i=0;i<pcount;i++){ fftps.workspace[i]=rho[i]-nu[i]; } fftw_execute(fftps.dctIn); fftps.workspace[0]=0; for(int i=1;i<pcount;i++){ fftps.workspace[i]/=4*pcount*fftps.kernel[i]; } fftw_execute(fftps.dctOut); for(int i=0;i<pcount;i++){ phi[i]+=sigma*fftps.workspace[i]; h1+=fftps.workspace[i]*(rho[i]-nu[i]); } h1/=pcount; return h1; } double compute_w2(double *phi, double *dual, double *mu, double *nu, int n1, int n2){ int pcount=n1*n2; double value=0; for(int i=0;i<n2;i++){ for(int j=0;j<n1;j++){ double x=(j+.5)/(n1*1.0); double y=(i+.5)/(n2*1.0); value+=.5*(x*x+y*y)*(mu[i*n1+j]+nu[i*n1+j])-nu[i*n1+j]*phi[i*n1+j]-mu[i*n1+j]*dual[i*n1+j]; } } value/=pcount; return value; } double step_update(double sigma, double value, double oldValue, double gradSq, double scaleUp, double scaleDown, double upper, double lower){ double diff=value-oldValue; if(diff>gradSq*sigma*upper){ return sigma*scaleUp; }else if(diff<gradSq*sigma*lower){ return sigma*scaleDown; }else{ return sigma; } } double compute_l2_ot(double *mu, double *nu, double *phi, double *dual, double totalMass, double sigma, int maxIters, int n1, int n2){ int pcount=n1*n2; poisson_solver fftps=create_poisson_solver_workspace(n1,n2); double *xMap=calloc((n1+1)*(n2+1),sizeof(double)); double *yMap=calloc((n1+1)*(n2+1),sizeof(double)); int n=fmax(n1,n2); convex_hull hull; init_hull(&hull, n); for(int i=0;i<n2+1;i++){ for(int j=0;j<n1+1;j++){ double x=j/(n1*1.0); double y=i/(n2*1.0); xMap[i*n1+j]=x; yMap[i*n1+j]=y; } } for(int i=0;i<n2;i++){ for(int j=0;j<n1;j++){ double x=(j+.5)/(n1*1.0); double y=(i+.5)/(n2*1.0); phi[i*n1+j]=.5*(x*x+y*y); dual[i*n1+j]=.5*(x*x+y*y); } } double *rho=calloc(pcount,sizeof(double)); memcpy(rho,mu,pcount*sizeof(double)); double oldValue=compute_w2(phi, dual, mu, nu, n1, n2); double scaleDown=.8; double scaleUp=1/scaleDown; double upper=.75; double lower=.25; int numDigitsIter=floor(log10(maxIters) + 1); for(int i=0;i<maxIters+1;i++){ double gradSq=update_potential(fftps, phi, rho, nu, sigma, n1, n2); convexify(phi, dual, &hull, n1, n2); double value=compute_w2(phi, dual, mu, nu, n1, n2); sigma=step_update(sigma, value, oldValue, gradSq , scaleUp, scaleDown, upper, lower); oldValue=value; calc_pushforward_map(xMap, yMap, phi, n1, n2); sampling_pushforward(rho, nu, totalMass, xMap, yMap, n1, n2); gradSq=update_potential(fftps, dual, rho, mu, sigma, n1, n2); convexify(dual, phi, &hull, n1, n2); calc_pushforward_map(xMap, yMap, dual, n1, n2); sampling_pushforward(rho, mu, totalMass, xMap, yMap, n1, n2); value=compute_w2(phi, dual, mu, nu, n1, n2); sigma=step_update(sigma, value, oldValue, gradSq , scaleUp, scaleDown, upper, lower); oldValue=value; sigma=fmax(sigma,.05); if(i%5==0){ mexPrintf("iter %*d, W2 value: %5e\n", numDigitsIter, i, value); mexEvalString("pause(.001);"); } } destroy_hull(&hull); free(rho); free(xMap); free(yMap); destroy_poisson_solver(fftps); return oldValue; }