/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import com.recalot.common.configuration.Configuration;
import com.recalot.common.configuration.ConfigurationItem;
import librec.data.DenseMatrix;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.SocialRecommender;

@Configuration(key="alpha", requirement=ConfigurationItem.ConfigurationItemRequirementType.Required, type=ConfigurationItem.ConfigurationItemType.Double)
public class RSTE
extends SocialRecommender {
    public double alpha;

    public RSTE() {
        this.initByNorm = false;
    }

    @Override
    public void buildModel() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            double sum;
            double pred1;
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix QS = new DenseMatrix(this.numItems, this.numFactors);
            for (int u : this.trainMatrix.rows()) {
                SparseVector tu = socialMatrix.row(u);
                int[] tks = tu.getIndex();
                double ws = 0.0;
                for (int k : tks) {
                    ws += tu.get(k);
                }
                double[] sum_us = new double[this.numFactors];
                for (int f = 0; f < this.numFactors; ++f) {
                    int k;
                    int[] nArray = tks;
                    k = nArray.length;
                    for (int i = 0; i < k; ++i) {
                        int k2 = nArray[i];
                        int n = f;
                        sum_us[n] = sum_us[n] + tu.get(k2) * this.P.get(k2, f);
                    }
                }
                for (VectorEntry ve : this.trainMatrix.row(u)) {
                    int j = ve.index();
                    double rate = ve.get();
                    double ruj = this.normalize(rate);
                    pred1 = DenseMatrix.rowMult(this.P, u, this.Q, j);
                    sum = 0.0;
                    for (int k : tks) {
                        sum += tu.get(k) * DenseMatrix.rowMult(this.P, k, this.Q, j);
                    }
                    double pred2 = ws > 0.0 ? sum / ws : 0.0;
                    double pred = this.alpha * pred1 + (1.0 - this.alpha) * pred2;
                    double euj = this.g(pred) - ruj;
                    this.loss += euj * euj;
                    double csgd = this.gd(pred) * euj;
                    for (int f = 0; f < this.numFactors; ++f) {
                        double puf = this.P.get(u, f);
                        double qjf = this.Q.get(j, f);
                        double usgd = this.alpha * csgd * qjf + this.regU * puf;
                        double jd = ws > 0.0 ? sum_us[f] / ws : 0.0;
                        double jsgd = csgd * (this.alpha * puf + (1.0 - this.alpha) * jd) + this.regI * qjf;
                        PS.add(u, f, usgd);
                        QS.add(j, f, jsgd);
                        this.loss += this.regU * puf * puf + this.regI * qjf * qjf;
                    }
                }
            }
            for (int u : socialMatrix.columns()) {
                SparseVector bu = socialMatrix.column(u);
                for (int p : bu.getIndex()) {
                    if (p >= this.trainMatrix.numRows()) continue;
                    SparseVector pp = this.trainMatrix.row(p);
                    SparseVector tp = socialMatrix.row(p);
                    int[] tps = tp.getIndex();
                    for (int j : pp.getIndex()) {
                        pred1 = DenseMatrix.rowMult(this.P, p, this.Q, j);
                        sum = 0.0;
                        double ws = 0.0;
                        for (int k : tps) {
                            double tuk = tp.get(k);
                            sum += tuk * DenseMatrix.rowMult(this.P, k, this.Q, j);
                            ws += tuk;
                        }
                        double pred2 = ws > 0.0 ? sum / ws : 0.0;
                        double pred = this.alpha * pred1 + (1.0 - this.alpha) * pred2;
                        double epj = this.g(pred) - this.normalize(pp.get(j));
                        double csgd = this.gd(pred) * epj * bu.get(p);
                        for (int f = 0; f < this.numFactors; ++f) {
                            PS.add(u, f, (1.0 - this.alpha) * csgd * this.Q.get(j, f));
                        }
                    }
                }
            }
            this.loss *= 0.5;
            this.P = this.P.add(PS.scale(-this.lRate));
            this.Q = this.Q.add(QS.scale(-this.lRate));
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    public double predict(int u, int j, boolean bound) {
        double pred1 = DenseMatrix.rowMult(this.P, u, this.Q, j);
        double sum = 0.0;
        double ws = 0.0;
        SparseVector tu = socialMatrix.row(u);
        for (int k : tu.getIndex()) {
            double tuk = tu.get(k);
            sum += tuk * DenseMatrix.rowMult(this.P, k, this.Q, j);
            ws += tuk;
        }
        double pred2 = ws > 0.0 ? sum / ws : 0.0;
        double pred = this.alpha * pred1 + (1.0 - this.alpha) * pred2;
        if (bound) {
            return this.denormalize(this.g(pred));
        }
        return pred;
    }
}

