/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import java.util.Iterator;
import java.util.List;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.intf.SocialRecommender;

public class TrustSVD
extends SocialRecommender {
    private DenseMatrix W;
    private DenseMatrix Y;
    private DenseVector wlr_j;
    private DenseVector wlr_tc;
    private DenseVector wlr_tr;

    @Override
    public void initModel() throws Exception {
        int count;
        super.initModel();
        this.userBias = new DenseVector(this.numUsers);
        this.itemBias = new DenseVector(this.numItems);
        this.W = new DenseMatrix(this.numUsers, this.numFactors);
        this.Y = new DenseMatrix(this.numItems, this.numFactors);
        if (this.initByNorm) {
            this.userBias.init(this.initMean, this.initStd);
            this.itemBias.init(this.initMean, this.initStd);
            this.W.init(this.initMean, this.initStd);
            this.Y.init(this.initMean, this.initStd);
        } else {
            this.userBias.init();
            this.itemBias.init();
            this.W.init();
            this.Y.init();
        }
        this.wlr_tc = new DenseVector(this.numUsers);
        this.wlr_tr = new DenseVector(this.numUsers);
        this.wlr_j = new DenseVector(this.numItems);
        this.userItemsCache = this.trainMatrix.rowColumnsCache(this.guavaCacheSpec);
        this.userFriendsCache = socialMatrix.rowColumnsCache(this.guavaCacheSpec);
        for (int u = 0; u < this.numUsers; ++u) {
            count = socialMatrix.columnSize(u);
            this.wlr_tc.set(u, count > 0 ? 1.0 / Math.sqrt(count) : 1.0);
            count = socialMatrix.rowSize(u);
            this.wlr_tr.set(u, count > 0 ? 1.0 / Math.sqrt(count) : 1.0);
        }
        for (int j = 0; j < this.numItems; ++j) {
            count = this.trainMatrix.columnSize(j);
            this.wlr_j.set(j, count > 0 ? 1.0 / Math.sqrt(count) : 1.0);
        }
    }

    @Override
    public void buildModel() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            int u;
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix WS = new DenseMatrix(this.numUsers, this.numFactors);
            for (MatrixEntry me : this.trainMatrix) {
                int f;
                List tu;
                u = me.row();
                int j = me.column();
                double ruj = me.get();
                double bu = this.userBias.get(u);
                double bj = this.itemBias.get(j);
                double pred = this.globalMean + bu + bj + DenseMatrix.rowMult(this.P, u, this.Q, j);
                List nu = (List)this.userItemsCache.get((Object)u);
                if (nu.size() > 0) {
                    double sum = 0.0;
                    Iterator iterator = nu.iterator();
                    while (iterator.hasNext()) {
                        int i = (Integer)iterator.next();
                        sum += DenseMatrix.rowMult(this.Y, i, this.Q, j);
                    }
                    pred += sum / Math.sqrt(nu.size());
                }
                if ((tu = (List)this.userFriendsCache.get((Object)u)).size() > 0) {
                    double sum = 0.0;
                    Iterator i = tu.iterator();
                    while (i.hasNext()) {
                        int v = (Integer)i.next();
                        sum += DenseMatrix.rowMult(this.W, v, this.Q, j);
                    }
                    pred += sum / Math.sqrt(tu.size());
                }
                double euj = pred - ruj;
                this.loss += euj * euj;
                double w_nu = Math.sqrt(nu.size());
                double w_tu = Math.sqrt(tu.size());
                double reg_u = 1.0 / w_nu;
                double reg_j = this.wlr_j.get(j);
                double sgd = euj + this.regB * reg_u * bu;
                this.userBias.add(u, -this.lRate * sgd);
                sgd = euj + this.regB * reg_j * bj;
                this.itemBias.add(j, -this.lRate * sgd);
                this.loss += this.regB * reg_u * bu * bu;
                this.loss += this.regB * reg_j * bj * bj;
                double[] sum_ys = new double[this.numFactors];
                for (int f2 = 0; f2 < this.numFactors; ++f2) {
                    double sum = 0.0;
                    Iterator iterator = nu.iterator();
                    while (iterator.hasNext()) {
                        int i = (Integer)iterator.next();
                        sum += this.Y.get(i, f2);
                    }
                    sum_ys[f2] = w_nu > 0.0 ? sum / w_nu : sum;
                }
                double[] sum_ts = new double[this.numFactors];
                for (f = 0; f < this.numFactors; ++f) {
                    double sum = 0.0;
                    Iterator i = tu.iterator();
                    while (i.hasNext()) {
                        int v = (Integer)i.next();
                        sum += this.W.get(v, f);
                    }
                    sum_ts[f] = w_tu > 0.0 ? sum / w_tu : sum;
                }
                for (f = 0; f < this.numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    double delta_u = euj * qjf + this.regU * reg_u * puf;
                    double delta_j = euj * (puf + sum_ys[f] + sum_ts[f]) + this.regI * reg_j * qjf;
                    PS.add(u, f, delta_u);
                    this.Q.add(j, f, -this.lRate * delta_j);
                    this.loss += this.regU * reg_u * puf * puf + this.regI * reg_j * qjf * qjf;
                    Iterator iterator = nu.iterator();
                    while (iterator.hasNext()) {
                        int i = (Integer)iterator.next();
                        double yif = this.Y.get(i, f);
                        double reg_yi = this.wlr_j.get(i);
                        double delta_y = euj * qjf / w_nu + this.regI * reg_yi * yif;
                        this.Y.add(i, f, -this.lRate * delta_y);
                        this.loss += this.regI * reg_yi * yif * yif;
                    }
                    iterator = tu.iterator();
                    while (iterator.hasNext()) {
                        int v = (Integer)iterator.next();
                        double wvf = this.W.get(v, f);
                        double reg_v = this.wlr_tc.get(v);
                        double delta_t = euj * qjf / w_tu + this.regU * reg_v * wvf;
                        WS.add(v, f, delta_t);
                        this.loss += this.regU * reg_v * wvf * wvf;
                    }
                }
            }
            for (MatrixEntry me : socialMatrix) {
                u = me.row();
                int v = me.column();
                double tuv = me.get();
                if (tuv == 0.0) continue;
                double pred = DenseMatrix.rowMult(this.P, u, this.W, v);
                double eut = pred - tuv;
                this.loss += this.regS * eut * eut;
                double csgd = this.regS * eut;
                double reg_u = this.wlr_tr.get(u);
                for (int f = 0; f < this.numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double wvf = this.W.get(v, f);
                    PS.add(u, f, csgd * wvf + this.regS * reg_u * puf);
                    WS.add(v, f, csgd * puf);
                    this.loss += this.regS * reg_u * puf * puf;
                }
            }
            this.P = this.P.add(PS.scale(-this.lRate));
            this.W = this.W.add(WS.scale(-this.lRate));
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    public double predict(int u, int j) throws Exception {
        List tu;
        double pred = this.globalMean + this.userBias.get(u) + this.itemBias.get(j) + DenseMatrix.rowMult(this.P, u, this.Q, j);
        List nu = (List)this.userItemsCache.get((Object)u);
        if (nu.size() > 0) {
            double sum = 0.0;
            Iterator iterator = nu.iterator();
            while (iterator.hasNext()) {
                int i = (Integer)iterator.next();
                sum += DenseMatrix.rowMult(this.Y, i, this.Q, j);
            }
            pred += sum / Math.sqrt(nu.size());
        }
        if ((tu = (List)this.userFriendsCache.get((Object)u)).size() > 0) {
            double sum = 0.0;
            Iterator iterator = tu.iterator();
            while (iterator.hasNext()) {
                int v = (Integer)iterator.next();
                sum += DenseMatrix.rowMult(this.W, v, this.Q, j);
            }
            pred += sum / Math.sqrt(tu.size());
        }
        return pred;
    }
}

