function demo1
% Simple example code for the use of Gaussian Mixture Models (GMMs) to
% learn and reproduce movements represented as a combination of linear
% systems with a velocity dx computed iteratively as dx = sum_i h_i (A_i x
% + b_i). h_i is a weight defined by GMM which is suitably rescaled in
% order to allow motion commands which are present only in the regions of
% the task demonstrations and fade away outside these (according to the
% task constraints, that is, local variability). Thus, interaction
% capabilities with human users are achieved. A_i and b_i form a matrix and
% vector associated with motion primitive of state i of the GMM.
%
% The main components of this program are: parameter setting, learning,
% reproduction and result plotting.
%
% Authors:  Antonio Pistillo, Sylvain Calinon, 2011
%           http://programming-by-demonstration.org/
%
% This source code is given for free! However, we would be grateful if
% you refer to the following paper in any academic publication that 
% uses this code or part of it: 
%
% @inproceedings{Pistillo11IROS,
%  author="Pistillo, A. and Calinon, S. and Caldwell, D. G.",
%  title="Bilateral Physical Interaction with a Robot Manipulator through a
%  Weighted Combination of Flow Fields",
%  booktitle="Proc. {IEEE/RSJ} Intl Conf. on Intelligent Robots and Systems ({IROS})",
%  year="2011",
%  month="September",
%  address="San Francisco, CA, USA",
%  pages="3047-3052"
% }

clc
clear all
close all

%% Parameters
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

nbStates = 18; % number of Gaussian components of the GMM to be learnt
nbData = 200; % number of data points for each sample
nbVar = 3; % dimensionality of each data point in the demonstrations (only pos data)
posId=[1:nbVar]; velId=[nbVar+1:2*nbVar]; % indexes for pos and vel data in data matrices to be created later

folderTask = 'data'; % directory path for demo data files to be loaded
listSamples = [1 2 3 4]; % to load files: data001.txt, files data002.txt, files data003.txt, files data004.txt. 
% Must keep equal no of samples for each task! 1 and 2 for task A, 3 and 4 for task B. 
nbSamples = length(listSamples); % number of demo samples (in total and not for each task)

task = 1; % selected task to be reproduced (task "A") 
dt = 0.0109; % sampling period associated to demo data set

%% Create dataset 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
disp('Load data...');
Data = [];

for n=1:nbSamples  
  ni=listSamples(n);
  mTmp = importdata([folderTask '/data' num2str(ni,'%.3d') '.txt'], ' ', 1);
  motion = mTmp.data(:,1:nbVar)';

  %Resampling
  nbDataTmp = size(motion,2);      
  xx = linspace(1,nbDataTmp,nbData);
  d(n).Data = spline(1:nbDataTmp, motion, xx);
  
  %Smooth position
  for j=1:nbVar
    d(n).Data(posId(j),:) = smooth(d(n).Data(posId(j),:),30);
  end
  
  %Compute velocity
  d(n).Data(velId,:) = ([d(n).Data(posId,2:end) d(n).Data(posId,end)] - d(n).Data(posId,:)) ./ dt;
   
  Data = [Data d(n).Data];
end

%% Learning (the structure "m" will be filled with the learned parameters)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
disp('Learning...');

m.dt = dt;
DataTiming=[];
% Add time infornation to initialize EM with ordered states
for n=1:nbSamples/2    
  DataTiming = [ DataTiming [(1:nbSamples/2*nbData)*10;...
                            [d(n).Data(posId,:) d(n+nbSamples/2).Data(posId,:)]]];
end

% k-means and EM on DataTiming 
[PriorsTmp, MuTmp, SigmaTmp] = EM_init_kmeans(DataTiming(1:nbVar+1,:), nbStates); % k-means to init GMM parameters
[PriorsTmp, MuTmp, SigmaTmp] = EM(DataTiming(1:nbVar+1,:), PriorsTmp, MuTmp, SigmaTmp); % EM for GMMs
m.Priors = PriorsTmp; m.Mu=MuTmp(2:nbVar+1,:); m.Sigma=SigmaTmp(2:nbVar+1,2:nbVar+1,:); % discard learned time parameters 

% Learning of motion primitives parameters (A_i,b_i) to compute dx = sum_i h_i (A_i x + b_i)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% Weights based on standard GMM (weight for each time stamp of demo data sets)
for i=1:nbStates
  m.H(:,i) = gaussPDF(Data(posId,:),m.Mu(posId,i),m.Sigma(posId,posId,i));
end
m.H = m.H./repmat(sum(m.H,2),1,nbStates); % see Eq. (1) in the paper     

options = optimset('Algorithm','active-set','display','notify');
%dx=Ax+b -> dx=[Ab]*[x;1] (see Eq. (4) in the paper)
X = [Data(posId,:); ones(1,size(Data,2))];
Y = [Data(velId,:)];
for i=1:nbStates
  % Initialization through weighted least-squares regression
  % (constrained optimization can alternatively be used here)
  m.s(i).Ab = [pinv(X * diag(m.H(:,i).^2) * X') * X * diag(m.H(:,i).^2) * Y']';
end

% Precompute inverse matrices and determinants to be used in Gaussian computation
for i=1:nbStates
  m.Sigma(:,:,i) = m.Sigma(:,:,i) + 1E-3.*diag(ones(nbVar,1));  % add variance artificially to avoid numerical issues
  m.invSigma(:,:,i) = inv(m.Sigma(:,:,i));
  m.detSigma(i) = det(m.Sigma(:,:,i));
end
%Compute max density values for rescaling
for i=1:nbStates
  m.hMax(i) = gaussPDF(m.Mu(:,i),m.Mu(:,i),m.Sigma(:,:,i));
end
m.nbData = nbData;
m.nbStates = nbStates;


%% Reproduction
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
disp('Reproduction...');
   
for i=1:nbStates
  %Modify variance articially with a scalar factor (optional)
  c = 1.0;
  m.Sigma(:,:,i) = m.Sigma(:,:,i) * c; % see Eq. (4) in the paper
  m.hMax(i) = gaussPDF(m.Mu(:,i),m.Mu(:,i),m.Sigma(:,:,i)); % re-compute max density values for rescaling
end    
  
startPos = d(task).Data(posId,1) + (rand(nbVar,1)-.5) .* 1E-3; % initial position for repro with noise addition to evaluate generalization

r = reproduction(startPos,m,d,posId); % call to reproduction function (see below)


%% Plot params
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
colMat=[0 0 0; .8 0 0; 0 .8 0; .8 .5 0; 0 .5 .8; .8 0 .5; .5 0 .8]';
colMat2 = jet(nbStates)';

%% Plot 3D repro
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
figure('position',[20 20 1000 800]);
box on; hold on; box on;

plot3(d(1).Data(1,1),d(1).Data(2,1),d(1).Data(3,1),'-','linewidth',2,'color',[.6 .6 .6]);
plot3(r(1).Data(1,1),r(1).Data(2,1),r(1).Data(3,1),'-','linewidth',3,'color',colMat(:,2));
[AUX1 AUX2] = find(r.Data(end,:));
plot3(r.Data(1,AUX2(1)),r.Data(2,AUX2(1)),r.Data(3,AUX2(1)),'-b','linewidth',3)
for k=1:nbStates
  plotGMM3D(m.Mu(:,k),m.Sigma(:,:,k),colMat2(:,k),.1);
end
legend('demo','repro','perturb.','1','2','3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18','Location','EastOutside')
for k=1:nbSamples
  plot3(d(k).Data(1,:),d(k).Data(2,:),d(k).Data(3,:),'-','linewidth',2,'color',[.6 .6 .6]);%,'Xlim',[-1 1],'Ylim',[-1 1],'Zlim',[-1 1]);
  plot3(d(k).Data(1,1),d(k).Data(2,1),d(k).Data(3,1),'.','markersize',50,'color',[.6 .6 .6]);%,'Xlim',[-1 1],'Ylim',[-1 1],'Zlim',[-1 1]);
end
plot3(r.Data(1,:),r.Data(2,:),r.Data(3,:),'-','linewidth',3,'color',colMat(:,2));%,'Xlim',[-1 1],'Ylim',[-1 1],'Zlim',[-1 1]);
plot3(r.Data(1,1),r.Data(2,1),r.Data(3,1),'.','markersize',50,'color',colMat(:,2));%,'Xlim',[-1 1],'Ylim',[-1 1],'Zlim',[-1 1]);
plot3(r.Data(1,AUX2),r.Data(2,AUX2),r.Data(3,AUX2),'-b','linewidth',3)
plot3(r.Data(1,AUX2(1)),r.Data(2,AUX2(1)),r.Data(3,AUX2(1)),'.b','markersize',50)
plot3(r.Data(1,AUX2(end)),r.Data(2,AUX2(end)),r.Data(3,AUX2(end)),'.b','markersize',50)
title('3D plot of pos components: Demo - Repro - GMM')
set(gca,'fontsize',10);
xlabel('x_1'); ylabel('x_2'); zlabel('x_3'); 
view(0,0); 

%% Repro time plot 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%Plot weights
figure('position',[200 20 1000 800]);
hold on; box on;

area(r.Data(end,:),'FaceColor',[.8 .8 .8]);
for i=1:nbStates
  plot(r.H(:,i),'-','linewidth',2,'color',colMat2(:,i));
end
legend('perturb.','repro h_1','r. h_2','r. h_3','r. h_4','r. h_5','r. h_6','r. h_7','r. h_8','r. h_9','r. h_1_0',...
       'r. h_1_1','r. h_1_2','r. h_1_3','r. h_1_4','r. h_1_5','r. h_1_6','r. h_1_7','r. h_1_8');
%for i=1:nbStates
%  plot(m.H(1:nbData,i),'-','linewidth',1,'color',[.6 .6 .6]);
%end
title('time plot of rescaledGMM components weights: Perturbation - Demo weights h - Repro Weights h')
xlabel('t'); ylabel('h_i'); 


% % Plot x
% figure('position',[400 20 1000 800]);
% hold on; box on;
% 
% Max = 1;
% Min = - Max;
% area(r.Data(end,:)*Max,'FaceColor',[.8 .8 .8]);
% 
% for i=1:3
%   plot(d(1).Data(i,:),'-','linewidth',1,'color',colMat(:,i));
% end
% for i=1:3
%   plot(d(nbSamples/2+1).Data(i,:),':','linewidth',1,'color',colMat(:,i));
% end
% for i=1:3
%   plot(r.Data(i,:),'-','linewidth',2,'color',colMat(:,i)*.5);
% end
% legend('perturb.','demo x1 A','demo x2 A','demo x3 A','demo x1 B','demo x2 B','demo x3 B','repro x1','repro x2','repro x3')
% area(r.Data(end,:)*Min,'FaceColor',[.8 .8 .8]);
% for i=1:3
%   plot(r.Data(i,:),'-','linewidth',2,'color',colMat(:,i)*.5);
%   for k=1:nbSamples/2
%     plot(d(k).Data(i,:),'-','linewidth',1,'color',colMat(:,i));
%   end
%   for k=1+nbSamples/2:nbSamples
%     plot(d(k).Data(i,:),':','linewidth',1,'color',colMat(:,i));
%   end
% end
% ylim([-1 1])
% title('time plot of pos. components: Perturbation - Demo A - Demo B - Repro')
% xlabel('$t$','Interpreter','latex','fontsize',14); 
% ylabel('$x$','Interpreter','latex','fontsize',14); 


%% Reproduction function
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function r = reproduction(currPos,m,d,posId)

nbDataRepro = m.nbData;
nbVar = size(currPos,1);
nbStates = size(m.Sigma,3);
currVel = zeros(nbVar,1); % initial velocity command
pert = zeros(nbDataRepro,1); % variable indicating if perturbations are occuring (1) or are absent (0)

%%%%% Definition of perturbation parameters. The perturbation phase has three sub-phases : 
%%%%% 1) After end of task A, the robot is moved towards StopPos (with linear interpolation with duration of 50 time instants)
%%%%% 2) Once reached StopPos, the robot remains in StopPos for 50 time instants
%%%%% 3) After end of stopping perturbation, the robot is moved towards starting position of task B (with linear interpolation with duration of 50 time instants)

%%% Perturbation position parameters
LastPosA = d(2).Data(posId,end); % Final position of task A = location where perturbation starts
FirstPosB = d(4).Data(posId,1); % Initial position of task B = location where perturbation ends
StopPos = [-0.1 currPos(2) 0.1]'; % Location where the perturbations becomes a pure stopping perturbation
%%% Perturbation timing parameters
TimeEndA = nbDataRepro; % Time instant corresponding to the final position of task A = when perturbation starts
TimeInitStopPos = nbDataRepro + 50; % Time instant corresponding to the beginning of the pure stopping perturbation
TimeEndStopPos = nbDataRepro + 100; % Time instant corresponding to the end of the pure stopping perturbation
TimeStartB = nbDataRepro + 150; % Time instant corresponding to the initial position of task A = when perturbation ends

%Repro loop
for n=1:nbDataRepro * 2 + 150     
  pert(n) = 0; % just a safety reset

  %Compute velocity command of each motion primitive
  for i=1:nbStates
    velTmp(:,i) = m.s(i).Ab * [currPos;1];
  end
  
  %Apply stopping perturbation (this can be commented if perturbations are not needed)
  if n == TimeEndA  % update LastPosA after nbDataRepro iterations of the simulation  
    LastPosA = currPos;
  end
  if and( n >= TimeEndA , n < TimeInitStopPos ) % Apply phase 1 of perturbation
    velTmp = zeros(size(velTmp));
    currPos = LastPosA + ( n - TimeEndA ) * ( StopPos - LastPosA ) / 50;
    pert(n) = 1;
  elseif and( n >= TimeInitStopPos , n < TimeEndStopPos ) % Apply phase 2 of  perturbation
    velTmp = zeros(size(velTmp));
    pert(n) = 1;
  elseif and( n >= TimeEndStopPos , n < TimeStartB ) % Apply phase 3 of  perturbation
    velTmp = zeros(size(velTmp));
    currPos = StopPos + ( n - TimeEndStopPos ) * ( FirstPosB - StopPos ) / 50;
    pert(n) = 1;
  end
  
  % rescaled GMM weights at each repro step
  h = rescaledGMM_step(currPos,m);
  
  %Update velocity and position
  currVel = velTmp * h;
  currPos = currPos + currVel .* m.dt;
  
  %Keep a trace of data
  r.Data(:,n)=[currPos; currVel; pert(n)];
  r.H(n,:)=h;
  
end