% Construct MSCFF model using MatConvnet DagNN

% Author: Zhiwei Li, Wuhan University, Email: lizw@whu.edu.cn

% Cite the following paper:
% Li, Z., Shen, H., Cheng, Q., Liu, Y., You, S., He, Z., 2019. Deep learning based cloud detection for medium and high
% resolution remote sensing images of different sensors. ISPRS Journal of Photogrammetry and Remote Sensing. 150, 197C212.

% MatConvnet: http://www.vlfeat.org/matconvnet/

% opts: the setting for model training.
% imdb: images and labels for model training
% inputChannels: the number of channels of input image
% xx = 0; %set xx=0namely the group number of filters in encoder-decoder module is 64.

function [net, stats] = mscff(opts, imdb, inputChannels, xx)
%  define the number of channels of output according to channels of input image
if (inputChannels ==3)
    finalMaps =2;
else
    finalMaps =3;
end

% set filter numbers in model
filtersNum = [64, 128, 256, 512, 512, 512, 512, 512, 512, 256, 128, 64]/xx;
if xx == 0
    filtersNum(:) = 64;
end

%% construct MSCFF model
net = dagnn.DagNN();
%% Encoder
num = 1;
outconv = {''}; outconv_count = 1;
for lyn = 1 : 6
    if lyn==1
        firstChannels = inputChannels;
        firstInput = 'input';
    else
        firstChannels = filtersNum(lyn-1);
        firstInput = ['x', num2str(num)];
    end
    
    if lyn<=4
        dPix = [1, 1, 1]; % Dilation pixels
    else
        dPix = [2*(lyn-4), 2*(lyn-4), 2*(lyn-4)];
    end
    
    block = dagnn.Conv('size', [3 3 firstChannels filtersNum(lyn)], 'pad', dPix(1), 'stride', 1, 'dilate', dPix(1), 'hasBias', true);
    net.addLayer(['conv', num2str(lyn), '_1'], block, {firstInput}, {['x', num2str(num+1)]}, {['conv', num2str(lyn), '_1f'], ['conv', num2str(lyn), '_1b']});
    initializeLearningRate(net, ['conv', num2str(lyn), '_1']);
    net.addLayer(['bn', num2str(lyn), '_1'], dagnn.BatchNorm('numChannels', filtersNum(lyn)), {['x', num2str(num+1)]}, {['x', num2str(num+2)]}, {['bn', num2str(lyn), '_1f'], ['bn', num2str(lyn), '_1b'], ['bn', num2str(lyn), '_1m']});
    initializeLearningRate(net, ['bn', num2str(lyn), '_1']);
    net.addLayer(['relu', num2str(lyn), '_1'], dagnn.ReLU(), {['x', num2str(num+2)]}, {['x', num2str(num+3)]});
    
    block = dagnn.Conv('size', [3 3 filtersNum(lyn) filtersNum(lyn)], 'pad', dPix(2), 'stride', 1, 'dilate', dPix(2), 'hasBias', true);
    net.addLayer(['conv', num2str(lyn), '_2'], block, {['x', num2str(num+3)]}, {['x', num2str(num+4)]}, {['conv', num2str(lyn), '_2f'], ['conv', num2str(lyn), '_2b']});
    initializeLearningRate(net, ['conv', num2str(lyn), '_2']);
    net.addLayer(['bn', num2str(lyn), '_2'], dagnn.BatchNorm('numChannels', filtersNum(lyn)), {['x', num2str(num+4)]}, {['x', num2str(num+5)]}, {['bn', num2str(lyn), '_2f'], ['bn', num2str(lyn), '_2b'], ['bn', num2str(lyn), '_2m']});
    initializeLearningRate(net, ['bn', num2str(lyn), '_2']);
    net.addLayer(['relu', num2str(lyn), '_2'], dagnn.ReLU(), {['x', num2str(num+5)]}, {['x', num2str(num+6)]});
    
    block = dagnn.Conv('size', [3 3 filtersNum(lyn) filtersNum(lyn)], 'pad', dPix(3), 'stride', 1, 'dilate', dPix(3), 'hasBias', true);
    net.addLayer(['conv', num2str(lyn), '_3'], block, {['x', num2str(num+6)]}, {['x', num2str(num+7)]}, {['conv', num2str(lyn), '_3f'], ['conv', num2str(lyn), '_3b']});
    initializeLearningRate(net, ['conv', num2str(lyn), '_3']);
    net.addLayer(['bn', num2str(lyn), '_3'], dagnn.BatchNorm('numChannels', filtersNum(lyn)), {['x', num2str(num+7)]}, {['x', num2str(num+8)]}, {['bn', num2str(lyn), '_3f'], ['bn', num2str(lyn), '_3b'], ['bn', num2str(lyn), '_3m']});
    initializeLearningRate(net, ['bn', num2str(lyn), '_3']);
    net.addLayer(['relu', num2str(lyn), '_3'], dagnn.ReLU(), {['x', num2str(num+8)]}, {['x', num2str(num+9)]});
    net.addLayer(['sum', num2str(lyn), '_3'], dagnn.Sum(), {['x', num2str(num+3)], ['x', num2str(num+9)]}, ['x', num2str(num+10)]);
    
    outconv{outconv_count} = ['x', num2str(num+10)];  %1, 2, 4, 8, 8(1/16), 8(1/32)
    outconv_count = outconv_count+1;
    
    if lyn<=3
        block = dagnn.Pooling('poolSize', [2, 2], 'method', 'max',  'stride', 2, 'pad', 0);
        net.addLayer(['pool', num2str(lyn)] , block, {['x', num2str(num+10)]}, {['x', num2str(num+11)]});
        num = num+11;
    else
        num = num+10;
    end
    
end

%% Decoder
sumIDX = {''}; sum_count = 6;
for lyn = 7 : 12
    if lyn >= 10
        fPix = 4; uPix = 2; cPix = 1;
        block = dagnn.ConvTranspose('size', [fPix fPix filtersNum(lyn-1) filtersNum(lyn-1)], 'upsample', uPix, 'crop', cPix, 'hasBias', false);
        net.addLayer(['deconv', num2str(lyn)], block, {['x', num2str(num)]}, {['x', num2str(num+1)]}, {['deconv', num2str(lyn), '_f']});
        initializeLearningRate(net, ['deconv', num2str(lyn)]);
        idx = net.getParamIndex(['deconv', num2str(lyn), '_f']);  fsizeSave(lyn,:) = [idx, fPix, 1, filtersNum(lyn-1)];
        num = num+1;
    end
    
    if lyn<=10
        dPix = [1, 1, 1]; % Dilation pixels
    else
        dPix = [2*(lyn-10), 2*(lyn-10), 2*(lyn-10)];
    end
    
    % seeting in paper
    %     if lyn<=8
    %         dPix = [2*(9-lyn), 2*(9-lyn), 2*(9-lyn)];
    %     else
    %         dPix = [1, 1, 1]; % Dilation pixels
    %     end
    
    block = dagnn.Conv('size', [3 3 filtersNum(lyn-1) filtersNum(lyn)], 'pad', dPix(1), 'stride', 1, 'dilate', dPix(1), 'hasBias', true);
    net.addLayer(['conv', num2str(lyn), '_1'], block, {['x', num2str(num)]}, {['x', num2str(num+1)]}, {['conv', num2str(lyn), '_1f'], ['conv', num2str(lyn), '_1b']});
    initializeLearningRate(net, ['conv', num2str(lyn), '_1']);
    net.addLayer(['bn', num2str(lyn), '_1'], dagnn.BatchNorm('numChannels', filtersNum(lyn)), {['x', num2str(num+1)]}, {['x', num2str(num+2)]}, {['bn', num2str(lyn), '_1f'], ['bn', num2str(lyn), '_1b'], ['bn', num2str(lyn), '_1m']});
    initializeLearningRate(net, ['bn', num2str(lyn), '_1']);
    net.addLayer(['relu', num2str(lyn), '_1'], dagnn.ReLU(), {['x', num2str(num+2)]}, {['x', num2str(num+3)]});
    
    block = dagnn.Conv('size', [3 3 filtersNum(lyn) filtersNum(lyn)], 'pad', dPix(2), 'stride', 1, 'dilate', dPix(2), 'hasBias', true);
    net.addLayer(['conv', num2str(lyn), '_2'], block, {['x', num2str(num+3)]}, {['x', num2str(num+4)]}, {['conv', num2str(lyn), '_2f'], ['conv', num2str(lyn), '_2b']});
    initializeLearningRate(net, ['conv', num2str(lyn), '_2']);
    net.addLayer(['bn', num2str(lyn), '_2'], dagnn.BatchNorm('numChannels', filtersNum(lyn)), {['x', num2str(num+4)]}, {['x', num2str(num+5)]}, {['bn', num2str(lyn), '_2f'], ['bn', num2str(lyn), '_2b'], ['bn', num2str(lyn), '_2m']});
    initializeLearningRate(net, ['bn', num2str(lyn), '_2']);
    net.addLayer(['relu', num2str(lyn), '_2'], dagnn.ReLU(), {['x', num2str(num+5)]}, {['x', num2str(num+6)]});
    
    block = dagnn.Conv('size', [3 3 filtersNum(lyn) filtersNum(lyn)], 'pad', dPix(3), 'stride', 1, 'dilate', dPix(3), 'hasBias', true);
    net.addLayer(['conv', num2str(lyn), '_3'], block, {['x', num2str(num+6)]}, {['x', num2str(num+7)]}, {['conv', num2str(lyn), '_3f'], ['conv', num2str(lyn), '_3b']});
    initializeLearningRate(net, ['conv', num2str(lyn), '_3']);
    net.addLayer(['bn', num2str(lyn), '_3'], dagnn.BatchNorm('numChannels', filtersNum(lyn)), {['x', num2str(num+7)]}, {['x', num2str(num+8)]}, {['bn', num2str(lyn), '_3f'], ['bn', num2str(lyn), '_3b'], ['bn', num2str(lyn), '_3m']});
    initializeLearningRate(net, ['bn', num2str(lyn), '_3']);
    net.addLayer(['relu', num2str(lyn), '_3'], dagnn.ReLU(), {['x', num2str(num+8)]}, {['x', num2str(num+9)]});
    net.addLayer(['sum', num2str(lyn), '_3'], dagnn.Sum(), {['x', num2str(num+3)], ['x', num2str(num+9)]}, ['x', num2str(num+10)]);
    
    net.addLayer(['sum', num2str(lyn)], dagnn.Sum(), {outconv{sum_count}, ['x', num2str(num+10)]}, ['x', num2str(num+11)]);
    sumIDX{sum_count} = ['x', num2str(num+11)];
    sum_count = sum_count-1;
    
    num = num+11;
end

%%  Multi-scale feature fusion
Res = {''}; Res_count = 1;
for lyn = 13 : 18
    idx = lyn-12;
    if lyn<=16
        up = 2^(idx-1);
    else
        up = 8;
    end
    
    block = dagnn.Conv('size', [3 3 filtersNum(idx) finalMaps], 'pad', 1, 'stride', 1, 'hasBias', true);
    net.addLayer(['conv', num2str(lyn)], block, {sumIDX{idx}}, {['x', num2str(num+1)]}, {['conv', num2str(lyn), '_f'], ['conv', num2str(lyn), '_b']});
    initializeLearningRate(net, ['conv', num2str(lyn)]);
    net.addLayer(['bn', num2str(lyn)], dagnn.BatchNorm('numChannels', finalMaps), {['x', num2str(num+1)]}, {['x', num2str(num+2)]}, {['bn', num2str(lyn), '_f'], ['bn', num2str(lyn), '_b'], ['bn', num2str(lyn), '_m']});
    initializeLearningRate(net, ['bn', num2str(lyn)]);
    net.addLayer(['relu', num2str(lyn)], dagnn.ReLU(), {['x', num2str(num+2)]}, {['x', num2str(num+3)]});
    
    if up~=1
        % Bilinear upsampling
        net.addLayer(['interp', num2str(idx)], dagnn.Interp('zoomFactor', up, 'shrinkFactor', 1), {['x', num2str(num+3)]}, {['x', num2str(num+4)]});
        num = num+4;
    else
        num = num+3;
    end
    
    %     if up~=1
    %         block = dagnn.ConvTranspose('size', [up*2 up*2 finalMaps finalMaps], 'upsample', up, 'crop', up/2, 'hasBias', false);
    %         net.addLayer(['deconv', num2str(lyn)], block, {['x', num2str(num+3)]}, {['x', num2str(num+4)]}, {['deconv', num2str(lyn), '_f']});
    %         initializeLearningRate(net, ['deconv', num2str(lyn)]);
    %         fsize = [up*2 1 finalMaps];
    %         idx = net.getParamIndex(['deconv', num2str(lyn), '_f']);  fsizeSave(lyn,:) = [idx, fsize];
    %         num = num+4;
    %     else
    %         %         block = dagnn.ConvTranspose('size', [1 1 finalMaps finalMaps], 'upsample', 1, 'crop', 0, 'hasBias', false);
    %         %         net.addLayer(['deconv', num2str(lyn)], block, {['x', num2str(num+3)]}, {['x', num2str(num+4)]}, {['deconv', num2str(lyn), '_f']});
    %         %         initializeLearningRate(net, ['deconv', num2str(lyn)]);
    %         %         fsize = [1 1 finalMaps];
    %         %         idx = net.getParamIndex(['deconv', num2str(lyn), '_f']);  fsizeSave(lyn,:) = [idx, fsize];
    %         num = num+3;
    %     end
    
    Res{Res_count} = ['x', num2str(num)];
    Res_count = Res_count+1;
end

% fusion
lyn = lyn+1;
net.addLayer(['cat', num2str(lyn)], dagnn.Concat('dim', 3), {Res{1}, Res{2}, Res{3}, Res{4}, Res{5}, Res{6}}, {'ResCat'}) ;
block = dagnn.Conv('size', [3 3 6*finalMaps finalMaps], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer(['conv', num2str(lyn)], block, {'ResCat'}, {'prediction'}, {['conv', num2str(lyn), '_f'], ['conv', num2str(lyn), '_b']});
initializeOutputLayerLearningRate(net, ['conv', num2str(lyn)]);
% net.addLayer('softmax', dagnn.SoftMax(), {'out'}, {'prediction'});

%% loss
net.addLayer('pdist', dagnn.my_Loss_pdist('p', 2, 'opts', {'noRoot', true}), {'prediction', 'label'}, {'objective'});
% net.addLayer('softmaxloss', dagnn.Loss('loss', 'softmaxlog'), {'prediction', 'label'}, 'objective');

%% parameter initialization
net.initParams();
% initialize the parameters in Deconv layers through bilinear kernel
for lyn = [10:12] %, 14:18
    net.params(fsizeSave(lyn, 1)).value = single(bilinear_u(fsizeSave(lyn, 2), fsizeSave(lyn, 3), fsizeSave(lyn, 4)));
    %net.params(fsizeSave(lyn, 1)).learningRate = 0; net.params(fsizeSave(lyn, 1)).weightDecay = 1;
end

%% model training
% net.print;
[net, stats] = cnn_train_dag_step(net, imdb, getBatch(opts, inputChannels), opts.train);
end

function initializeLearningRate(net, str_layer)
ii = net.getLayerIndex(str_layer);
net.params(net.layers(ii).paramIndexes(1)).learningRate = 1;
net.params(net.layers(ii).paramIndexes(1)).weightDecay = 1;
if size(net.layers(ii).paramIndexes, 2)>=2
    net.params(net.layers(ii).paramIndexes(2)).learningRate = 0.1;
    net.params(net.layers(ii).paramIndexes(2)).weightDecay = 0;
end
end

function initializeOutputLayerLearningRate(net, str_layer)
ii = net.getLayerIndex(str_layer);
net.params(net.layers(ii).paramIndexes(1)).learningRate = 0.1;
net.params(net.layers(ii).paramIndexes(1)).weightDecay = 1;
net.params(net.layers(ii).paramIndexes(2)).learningRate = 0.01;
net.params(net.layers(ii).paramIndexes(2)).weightDecay = 0;
end

function f = bilinear_u(k, numGroups, numClasses)
%BILINEAR_U  Create bilinear interpolation filters
%   BILINEAR_U(K, NUMGROUPS, NUMCLASSES) compute a square bilinear filter
%   of size k for deconv layer of depth numClasses and number of groups
%   numGroups

factor = floor((k+1)/2) ;
if rem(k,2)==1
    center = factor ;
else
    center = factor + 0.5 ;
end
C = 1:k ;
if numGroups ~= numClasses
    f = zeros(k,k,numGroups,numClasses) ;
else
    f = zeros(k,k,1,numClasses) ;
end

for i =1:numClasses
    if numGroups ~= numClasses
        index = i ;
    else
        index = 1 ;
    end
    f(:,:,index,i) = (ones(1,k) - abs(C-center)./factor)'*(ones(1,k) - abs(C-center)./(factor));
end
end

