function classhypo = KNN(features_train,features_test,labels_train,k,weighted) % function classhypo = KNN(features_train,features_test,labels_train,k,weighted) % % Performs KNN classification to a set of data points using a given % training data and the associated class labels. % % Inputs: % features_train: N x F matrix containing the training data, % where N is the number of samples % and F is the number of features. % features_test: M x F matrix containing the classification data, % where M is the number of samples % and F is the number of features. % labels_train: N x 1 vector of integer class labels for training % samples % k: number of nearest neighbors used in the % classification (can be a vector) % weighted: weight nearest neighbor counts by the inverse of % class frequencies estimated from the training % data? (0/1, default: 1). % % Authors: Okko Rasanen, Jouni Pohjalainen, June 2014 persistent pdist2_exists; if(size(features_train,1) ~= length(labels_train)) error('Different number of training samples and class labels'); end if ~exist('weighted','var') || isempty(weighted) weighted = 1; end % Compute the distribution of samples per each class in the training data N_classes = max(labels_train); if(weighted == 1) w = zeros(N_classes,1); for c = 1:N_classes w(c) = sum(labels_train == c); end end % Check if pdist2 function exists for distance computations (faster but % absent from older MATLAB versions). if isempty(pdist2_exists) pdist2_exists = exist('pdist2','file'); end if pdist2_exists D = pdist2(features_test,features_train,'euclidean'); % Compute distances between train and dev vectors else D = zeros(size(features_test,1),size(features_train,1)); for j=1:size(features_test,1) D(j,:) = sqrt(sum((repmat(features_test(j,:),size(features_train,1),1)-features_train).^2,2))'; end end [tmp,orderi] = sort(D,2,'ascend'); % Sort distances into ascending order classhypo = zeros(size(features_test,1),length(k)); for i1=1:length(k) % Get classes of k nearest neighbors for each dev sample nearest = labels_train(orderi(:,1:k(i1))); % Get class hypothesis for each sample by (weighted) majority voting a = zeros(size(nearest,1),N_classes); for c = 1:N_classes if weighted a(:,c) = sum(nearest==c,2)/w(c); else a(:,c) = sum(nearest==c,2); end end [tmp,classhypo(:,i1)] = max(a,[],2); end