/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.recalot.common.configuration.Configuration;
import com.recalot.common.configuration.ConfigurationItem;
import com.recalot.common.configuration.Configurations;
import librec.data.DenseMatrix;
import librec.data.MatrixEntry;
import librec.data.SparseVector;
import librec.intf.SocialRecommender;

@Configurations(value={@Configuration(key="beta", requirement=ConfigurationItem.ConfigurationItemRequirementType.Required, type=ConfigurationItem.ConfigurationItemType.Double), @Configuration(key="similarityMeasure", type=ConfigurationItem.ConfigurationItemType.Options, options={"cos", "cos-binary", "msd", "cpc", "exjaccard", "pcc"}, value="pcc", description="similarity measure"), @Configuration(key="similarityShrinkage", type=ConfigurationItem.ConfigurationItemType.Integer, description="similarity shrinkage")})
public class SoReg
extends SocialRecommender {
    private Table<Integer, Integer, Double> userCorrs;
    public double beta;
    public String similarityMeasure;
    public int similarityShrinkage;

    public SoReg() {
        this.initByNorm = false;
    }

    @Override
    public void initModel() throws Exception {
        super.initModel();
        this.userCorrs = HashBasedTable.create();
    }

    protected double similarity(Integer u, Integer v) {
        SparseVector vv;
        SparseVector uv;
        if (this.userCorrs.contains((Object)u, (Object)v)) {
            return (Double)this.userCorrs.get((Object)u, (Object)v);
        }
        if (this.userCorrs.contains((Object)v, (Object)u)) {
            return (Double)this.userCorrs.get((Object)v, (Object)u);
        }
        double sim = Double.NaN;
        if (u < this.trainMatrix.numRows() && v < this.trainMatrix.numRows() && (uv = this.trainMatrix.row(u)).getCount() > 0 && !Double.isNaN(sim = this.correlation(uv, vv = this.trainMatrix.row(v), this.similarityMeasure, this.similarityShrinkage))) {
            sim = (1.0 + sim) / 2.0;
        }
        this.userCorrs.put((Object)u, (Object)v, (Object)sim);
        return sim;
    }

    @Override
    public void buildModel() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix QS = new DenseMatrix(this.numItems, this.numFactors);
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                double pred = this.predict(u, j);
                double euj = pred - ruj;
                this.loss += euj * euj;
                for (int f = 0; f < this.numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    PS.add(u, f, euj * qjf + this.regU * puf);
                    QS.add(j, f, euj * puf + this.regI * qjf);
                    this.loss += this.regU * puf * puf + this.regI * qjf * qjf;
                }
            }
            for (int u = 0; u < this.numUsers; ++u) {
                SparseVector uos = socialMatrix.row(u);
                for (int k : uos.getIndex()) {
                    double suk = this.similarity(u, k);
                    if (Double.isNaN(suk)) continue;
                    for (int f = 0; f < this.numFactors; ++f) {
                        double euk = this.P.get(u, f) - this.P.get(k, f);
                        PS.add(u, f, this.beta * suk * euk);
                        this.loss += this.beta * suk * euk * euk;
                    }
                }
                SparseVector uis = socialMatrix.column(u);
                for (int g : uis.getIndex()) {
                    double sug = this.similarity(u, g);
                    if (Double.isNaN(sug)) continue;
                    for (int f = 0; f < this.numFactors; ++f) {
                        double eug = this.P.get(u, f) - this.P.get(g, f);
                        PS.add(u, f, this.beta * sug * eug);
                    }
                }
            }
            this.P = this.P.add(PS.scale(-this.lRate));
            this.Q = this.Q.add(QS.scale(-this.lRate));
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }
}

