/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import librec.data.DenseMatrix;
import librec.data.MatrixEntry;
import librec.data.SparseVector;
import librec.intf.SocialRecommender;

public class SocialMF
extends SocialRecommender {
    public SocialMF() {
        this.initByNorm = false;
    }

    @Override
    public void buildModel() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix QS = new DenseMatrix(this.numItems, this.numFactors);
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                double pred = this.predict(u, j, false);
                double euj = this.g(pred) - this.normalize(ruj);
                this.loss += euj * euj;
                double csgd = this.gd(pred) * euj;
                for (int f = 0; f < this.numFactors; ++f) {
                    PS.add(u, f, csgd * this.Q.get(j, f) + this.regU * this.P.get(u, f));
                    QS.add(j, f, csgd * this.P.get(u, f) + this.regI * this.Q.get(j, f));
                    this.loss += this.regU * this.P.get(u, f) * this.P.get(u, f);
                    this.loss += this.regI * this.Q.get(j, f) * this.Q.get(j, f);
                }
            }
            for (int u = 0; u < this.numUsers; ++u) {
                SparseVector uv = socialMatrix.row(u);
                int numConns = uv.getCount();
                if (numConns == 0) continue;
                double[] sumNNs = new double[this.numFactors];
                for (int v : uv.getIndex()) {
                    for (int f = 0; f < this.numFactors; ++f) {
                        int n = f;
                        sumNNs[n] = sumNNs[n] + socialMatrix.get(u, v) * this.P.get(v, f);
                    }
                }
                for (int f = 0; f < this.numFactors; ++f) {
                    double diff = this.P.get(u, f) - sumNNs[f] / (double)numConns;
                    PS.add(u, f, this.regS * diff);
                    this.loss += this.regS * diff * diff;
                }
                SparseVector iuv = socialMatrix.column(u);
                int numVs = iuv.getCount();
                for (int v : iuv.getIndex()) {
                    double tvu = socialMatrix.get(v, u);
                    SparseVector vv = socialMatrix.row(v);
                    double[] sumDiffs = new double[this.numFactors];
                    for (int w : vv.getIndex()) {
                        for (int f = 0; f < this.numFactors; ++f) {
                            int n = f;
                            sumDiffs[n] = sumDiffs[n] + socialMatrix.get(v, w) * this.P.get(w, f);
                        }
                    }
                    numConns = vv.getCount();
                    if (numConns <= 0) continue;
                    for (int f = 0; f < this.numFactors; ++f) {
                        PS.add(u, f, -this.regS * (tvu / (double)numVs) * (this.P.get(v, f) - sumDiffs[f] / (double)numConns));
                    }
                }
            }
            this.P = this.P.add(PS.scale(-this.lRate));
            this.Q = this.Q.add(QS.scale(-this.lRate));
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    public double predict(int u, int j, boolean bounded) {
        double pred = DenseMatrix.rowMult(this.P, u, this.Q, j);
        if (bounded) {
            return this.denormalize(this.g(pred));
        }
        return pred;
    }
}

