/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import com.recalot.common.configuration.Configuration;
import com.recalot.common.configuration.ConfigurationItem;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.intf.SocialRecommender;

@Configuration(key="model", type=ConfigurationItem.ConfigurationItemType.Options, options={"Tr", "Te", "T"}, value="T", description="trust models")
public class TrustMF
extends SocialRecommender {
    protected DenseMatrix Br;
    protected DenseMatrix Wr;
    protected DenseMatrix Vr;
    protected DenseMatrix Be;
    protected DenseMatrix We;
    protected DenseMatrix Ve;
    public String model;

    protected void initTr() {
        this.Vr = new DenseMatrix(this.numItems, this.numFactors);
        this.Br = new DenseMatrix(this.numUsers, this.numFactors);
        this.Wr = new DenseMatrix(this.numUsers, this.numFactors);
        this.Vr.init();
        this.Br.init();
        this.Wr.init();
    }

    protected void initTe() {
        this.Ve = new DenseMatrix(this.numItems, this.numFactors);
        this.Be = new DenseMatrix(this.numUsers, this.numFactors);
        this.We = new DenseMatrix(this.numUsers, this.numFactors);
        this.Ve.init();
        this.Be.init();
        this.We.init();
    }

    @Override
    public void initModel() throws Exception {
        switch (this.model) {
            case "Tr": {
                this.initTr();
                break;
            }
            case "Te": {
                this.initTe();
                break;
            }
            default: {
                this.initTr();
                this.initTe();
            }
        }
    }

    @Override
    public void buildModel() throws Exception {
        switch (this.model) {
            case "Tr": {
                this.TrusterMF();
                break;
            }
            case "Te": {
                this.TrusteeMF();
                break;
            }
            default: {
                this.TrusterMF();
                this.TrusteeMF();
            }
        }
    }

    protected void TrusterMF() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            int f;
            double csgd;
            double euj;
            double pred;
            int u;
            this.loss = 0.0;
            DenseMatrix BS = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix WS = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix VS = new DenseMatrix(this.numItems, this.numFactors);
            for (MatrixEntry me : this.trainMatrix) {
                u = me.row();
                int j = me.column();
                double ruj = me.get();
                pred = this.predict(u, j, false);
                euj = this.g(pred) - this.normalize(ruj);
                this.loss += euj * euj;
                csgd = this.gd(pred) * euj;
                for (f = 0; f < this.numFactors; ++f) {
                    BS.add(u, f, csgd * this.Vr.get(j, f) + this.regU * this.Br.get(u, f));
                    VS.add(j, f, csgd * this.Br.get(u, f) + this.regI * this.Vr.get(j, f));
                    this.loss += this.regU * this.Br.get(u, f) * this.Br.get(u, f);
                    this.loss += this.regI * this.Vr.get(j, f) * this.Vr.get(j, f);
                }
            }
            for (MatrixEntry me : socialMatrix) {
                u = me.row();
                int k = me.column();
                double tuk = me.get();
                if (!(tuk > 0.0)) continue;
                pred = DenseMatrix.rowMult(this.Br, u, this.Wr, k);
                euj = this.g(pred) - tuk;
                this.loss += this.regS * euj * euj;
                csgd = this.gd(pred) * euj;
                for (f = 0; f < this.numFactors; ++f) {
                    BS.add(u, f, this.regS * csgd * this.Wr.get(k, f) + this.regU * this.Br.get(u, f));
                    WS.add(k, f, this.regS * csgd * this.Br.get(u, f) + this.regU * this.Wr.get(k, f));
                    this.loss += this.regU * this.Br.get(u, f) * this.Br.get(u, f);
                    this.loss += this.regU * this.Wr.get(u, f) * this.Wr.get(u, f);
                }
            }
            this.Br = this.Br.add(BS.scale(-this.lRate));
            this.Vr = this.Vr.add(VS.scale(-this.lRate));
            this.Wr = this.Wr.add(WS.scale(-this.lRate));
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }

    protected void TrusteeMF() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            int f;
            double csgd;
            double euj;
            double pred;
            this.loss = 0.0;
            DenseMatrix BS = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix WS = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix VS = new DenseMatrix(this.numItems, this.numFactors);
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                pred = this.predict(u, j, false);
                euj = this.g(pred) - this.normalize(ruj);
                this.loss += euj * euj;
                csgd = this.gd(pred) * euj;
                for (f = 0; f < this.numFactors; ++f) {
                    WS.add(u, f, csgd * this.Ve.get(j, f) + this.regU * this.We.get(u, f));
                    VS.add(j, f, csgd * this.We.get(u, f) + this.regI * this.Ve.get(j, f));
                    this.loss += this.regU * this.We.get(u, f) * this.We.get(u, f);
                    this.loss += this.regI * this.Ve.get(j, f) * this.Ve.get(j, f);
                }
            }
            for (MatrixEntry me : socialMatrix) {
                int k = me.row();
                int u = me.column();
                double tku = me.get();
                if (!(tku > 0.0)) continue;
                pred = DenseMatrix.rowMult(this.Be, k, this.We, u);
                euj = this.g(pred) - tku;
                this.loss += this.regS * euj * euj;
                csgd = this.gd(pred) * euj;
                for (f = 0; f < this.numFactors; ++f) {
                    WS.add(u, f, this.regS * csgd * this.Be.get(k, f) + this.regU * this.We.get(u, f));
                    BS.add(k, f, this.regS * csgd * this.We.get(u, f) + this.regU * this.Be.get(k, f));
                    this.loss += this.regU * this.We.get(u, f) * this.We.get(u, f);
                    this.loss += this.regU * this.Be.get(k, f) * this.Be.get(k, f);
                }
            }
            this.Be = this.Be.add(BS.scale(-this.lRate));
            this.Ve = this.Ve.add(VS.scale(-this.lRate));
            this.We = this.We.add(WS.scale(-this.lRate));
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    protected void updateLRate(int iter) {
        if (iter == 10) {
            this.lRate *= 0.6;
        } else if (iter == 30) {
            this.lRate *= 0.333;
        } else if (iter == 100) {
            this.lRate *= 0.5;
        }
    }

    @Override
    public double predict(int u, int j, boolean bounded) {
        double pred = 0.0;
        switch (this.model) {
            case "Tr": {
                pred = DenseMatrix.rowMult(this.Br, u, this.Vr, j);
                break;
            }
            case "Te": {
                pred = DenseMatrix.rowMult(this.We, u, this.Ve, j);
                break;
            }
            default: {
                DenseVector uv = this.Br.row(u).add(this.We.row(u, false));
                DenseVector jv = this.Vr.row(j).add(this.Ve.row(j, false));
                pred = uv.scale(0.5).inner(jv.scale(0.5));
            }
        }
        if (bounded) {
            return this.denormalize(this.g(pred));
        }
        return pred;
    }
}

