% thresholding : compute F1(i) = median of F_x(i) in u>=0.7
%                        F2(i) = median of F_x(i) in u< 0.3

% function fast_loc_hist_seg

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Load the input image
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


[Im0,map] = imread('cheetah2.bmp'); f = 2.; f2 = 1.1;

[Ny,Nx] = size(Im0);

Im0 = double(Im0);
u_bound=max(Im0(:));
Im0 = Im0/ max(Im0(:));



cpt_fig = 1;
figure(cpt_fig); clf;
imagesc(Im0);axis image;axis off; colormap gray; hold on;
%pause
       

% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% % compute local histograms (START)
% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

bin = 256;
bin_size = 256/bin;
F=zeros(Ny,Nx,bin);
P=zeros(Ny,Nx,bin);
% mdn=zeros(m,n);
rad=3;

for k=1+rad:Ny+rad
    for l=1+rad:Nx+rad
        Fi = zeros(Ny+2*rad,Nx+2*rad);
        for i=k-rad:k+rad
            for j=l-rad:l+rad
                Fi(i,j)=1;
            end
        end
        ndx = find(Fi(1+rad:Ny+rad,1+rad:Nx+rad));
        temp=reshape(u_bound*Im0(ndx),size(ndx));
%         mdn(k-2,l-2)=median(temp);
        P1=hist(temp,[bin_size:bin_size:256]);  
        F(k-rad,l-rad,:) = cumsum(P1(1:bin));
        P(k-rad,l-rad,:) = P1(1:bin)/F(k-rad,l-rad,bin);
        F(k-rad,l-rad,:) = F(k-rad,l-rad,:)/F(k-rad,l-rad,bin);
    end
end
    
F1 = zeros(1,256);
F2 = F1;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% compute local histograms (END)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

lambda = 100; %1/theta in the paper
gamma = 1/5;%;lambda in the paper
dt = 1/8; %
gb = ones(size(Im0));


Zeros = zeros(size(Im0));
pxu = Zeros;
pyu = Zeros;
%u = Zeros;
%u(50:150,50:150) = 1;
u = Im0;
v = u;

alpha = gamma/lambda;
mu1 = 0.7;%0.5;%0.7;
mu2 = 0.3;%0.5;%0.3;
nb_iters = 1000;%50%200;%1000;

for cpt=1:nb_iters
    
    cpt
    
    A = zeros(Ny,Nx);
    B = A;
%    cpt


   % Compute u_new
   divp = ( backwardx(pxu,Ny,Nx) + backwardy(pyu,Ny,Nx) );
   Term = divp - lambda* v;
   term1 = forwardx(Term,Ny,Nx);
   term2 = forwardy(Term,Ny,Nx);
   Norm = sqrt(term1.^2 + term2.^2);
   denom = 1. + dt*Norm./ gb;
   pxu = (pxu+dt*term1)./denom;
   pyu = (pyu+dt*term2)./denom;
   u = v - divp/lambda;
   in=find(double(u>=mu1));
   out=find(double(u<mu2));
   

   for i = 1:256
        F_temp=squeeze(F(:,:,i));
        F1(i) = median(reshape(F_temp(in),1,sum(sum(u>=mu1))));
        F2(i) = median(reshape(F_temp(out),1,sum(sum(u<mu2))));
   end

   for i = 1:256
        A = A + abs(F1(i) - F(:,:,i));
        B = B + abs(F2(i) - F(:,:,i));
   end

   regionTerm = A-B;
   
   v = u - alpha* regionTerm;
   
   v(v>1) = 1;
   v(v<0) = 0;
   

   if rem(cpt,50)==0
       ct = 0.5;
       cpt_fig = 10;
       figure(cpt_fig); clf;
       imagesc(u);axis image;axis off; colormap gray; hold on;
       %figure(cpt_fig+1); clf;
       %imagesc(v);axis image;axis off; colormap gray; hold on;
       figure(cpt_fig+2); clf;
       imagesc(Im0);axis image;axis off; colormap gray; hold on;
       contour(u, [ct ct],'b','linewidth',1.5);
       pause(0.1)
   end

end