/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import com.recalot.common.configuration.Configuration;
import com.recalot.common.configuration.ConfigurationItem;
import com.recalot.common.configuration.Configurations;
import java.util.HashMap;
import java.util.Map;
import librec.data.DenseMatrix;
import librec.data.MatrixEntry;
import librec.intf.SocialRecommender;

@Configurations(value={@Configuration(key="regC", requirement=ConfigurationItem.ConfigurationItemRequirementType.Required, type=ConfigurationItem.ConfigurationItemType.Double), @Configuration(key="regZ", requirement=ConfigurationItem.ConfigurationItemRequirementType.Required, type=ConfigurationItem.ConfigurationItemType.Double)})
public class SoRec
extends SocialRecommender {
    private DenseMatrix Z;
    public double regC;
    public double regZ;
    private Map<Integer, Integer> inDegrees;
    private Map<Integer, Integer> outDegrees;

    public SoRec() {
        this.initByNorm = false;
    }

    @Override
    public void initModel() throws Exception {
        super.initModel();
        this.Z = new DenseMatrix(this.numUsers, this.numFactors);
        this.Z.init();
        this.inDegrees = new HashMap<Integer, Integer>();
        this.outDegrees = new HashMap<Integer, Integer>();
        for (int u = 0; u < this.numUsers; ++u) {
            int in = socialMatrix.columnSize(u);
            int out = socialMatrix.rowSize(u);
            this.inDegrees.put(u, in);
            this.outDegrees.put(u, out);
        }
    }

    @Override
    public void buildModel() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            double pred;
            int u;
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix QS = new DenseMatrix(this.numItems, this.numFactors);
            DenseMatrix ZS = new DenseMatrix(this.numUsers, this.numFactors);
            for (MatrixEntry me : this.trainMatrix) {
                u = me.row();
                int j = me.column();
                double ruj = me.get();
                pred = this.predict(u, j, false);
                double euj = this.g(pred) - this.normalize(ruj);
                this.loss += euj * euj;
                for (int f = 0; f < this.numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qjf = this.Q.get(j, f);
                    PS.add(u, f, this.gd(pred) * euj * qjf + this.regU * puf);
                    QS.add(j, f, this.gd(pred) * euj * puf + this.regI * qjf);
                    this.loss += this.regU * puf * puf + this.regI * qjf * qjf;
                }
            }
            for (MatrixEntry me : socialMatrix) {
                u = me.row();
                int v = me.column();
                double tuv = me.get();
                if (tuv <= 0.0) continue;
                pred = DenseMatrix.rowMult(this.P, u, this.Z, v);
                int vminus = this.inDegrees.get(v);
                int uplus = this.outDegrees.get(u);
                double weight = Math.sqrt((double)vminus / ((double)(uplus + vminus) + 0.0));
                double euv = this.g(pred) - weight * tuv;
                this.loss += this.regC * euv * euv;
                for (int f = 0; f < this.numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double zvf = this.Z.get(v, f);
                    PS.add(u, f, this.regC * this.gd(pred) * euv * zvf);
                    ZS.add(v, f, this.regC * this.gd(pred) * euv * puf + this.regZ * zvf);
                    this.loss += this.regZ * zvf * zvf;
                }
            }
            this.P = this.P.add(PS.scale(-this.lRate));
            this.Q = this.Q.add(QS.scale(-this.lRate));
            this.Z = this.Z.add(ZS.scale(-this.lRate));
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    public double predict(int u, int j, boolean bounded) {
        double pred = DenseMatrix.rowMult(this.P, u, this.Q, j);
        if (bounded) {
            return this.denormalize(this.g(pred));
        }
        return pred;
    }
}

