/*
 * Decompiled with CFR 0.152.
 */
package librec.ranking;

import com.recalot.common.configuration.Configuration;
import com.recalot.common.configuration.ConfigurationItem;
import com.recalot.common.configuration.Configurations;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.intf.SocialRecommender;
import librec.util.Randoms;

@Configurations(value={@Configuration(key="rho", requirement=ConfigurationItem.ConfigurationItemRequirementType.Required, type=ConfigurationItem.ConfigurationItemType.Double), @Configuration(key="gSize", requirement=ConfigurationItem.ConfigurationItemRequirementType.Required, type=ConfigurationItem.ConfigurationItemType.Integer)})
public class GBPR
extends SocialRecommender {
    public double rho;
    public int gSize;

    public GBPR() {
        this.isRankingPred = true;
        this.initByNorm = false;
    }

    @Override
    public void initModel() throws Exception {
        super.initModel();
        this.itemBias = new DenseVector(this.numItems);
        this.itemBias.init();
        this.userItemsCache = this.trainMatrix.rowColumnsCache(this.guavaCacheSpec);
        this.itemUsersCache = this.trainMatrix.columnRowsCache(this.guavaCacheSpec);
    }

    @Override
    public void buildModel() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix QS = new DenseMatrix(this.numItems, this.numFactors);
            int smax = this.numUsers * 100;
            for (int s = 0; s < smax; ++s) {
                int u = 0;
                int i = 0;
                int j = 0;
                List ratedItems = null;
                while ((ratedItems = (List)this.userItemsCache.get((Object)(u = Randoms.uniform(this.trainMatrix.numRows())))).size() == 0) {
                }
                i = (Integer)Randoms.random(ratedItems);
                List ws = (List)this.itemUsersCache.get((Object)i);
                ArrayList<Integer> g = new ArrayList<Integer>();
                if (ws.size() <= this.gSize) {
                    g.addAll(ws);
                } else {
                    g.add(u);
                    while (g.size() < this.gSize) {
                        Integer w = (Integer)Randoms.random(ws);
                        if (g.contains(w)) continue;
                        g.add(w);
                    }
                }
                double pgui = this.predict(u, i, g);
                while (ratedItems.contains(j = Randoms.uniform(this.numItems))) {
                }
                double puj = this.predict(u, j);
                double pgij = pgui - puj;
                double vals = -Math.log(this.g(pgij));
                this.loss += vals;
                double cmg = this.g(-pgij);
                double bi = this.itemBias.get(i);
                this.itemBias.add(i, this.lRate * (cmg - this.regB * bi));
                this.loss += this.regB * bi * bi;
                double bj = this.itemBias.get(j);
                this.itemBias.add(j, this.lRate * (-cmg - this.regB * bj));
                this.loss += this.regB * bj * bj;
                double n = 1.0 / (double)g.size();
                double[] sum_w = new double[this.numFactors];
                Iterator iterator = g.iterator();
                while (iterator.hasNext()) {
                    int w = (Integer)iterator.next();
                    double delta = w == u ? 1.0 : 0.0;
                    int f = 0;
                    while (f < this.numFactors) {
                        double pwf = this.P.get(w, f);
                        double qif = this.Q.get(i, f);
                        double qjf = this.Q.get(j, f);
                        double delta_pwf = this.rho * n * qif + (1.0 - this.rho) * delta * qif - delta * qjf;
                        PS.add(w, f, this.lRate * (cmg * delta_pwf - this.regU * pwf));
                        this.loss += this.regU * pwf * pwf;
                        int n2 = f++;
                        sum_w[n2] = sum_w[n2] + pwf;
                    }
                }
                for (int f = 0; f < this.numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qif = this.Q.get(i, f);
                    double qjf = this.Q.get(j, f);
                    double delta_qif = this.rho * n * sum_w[f] + (1.0 - this.rho) * puf;
                    QS.add(i, f, this.lRate * (cmg * delta_qif - this.regI * qif));
                    this.loss += this.regI * qif * qif;
                    double delta_qjf = -puf;
                    QS.add(j, f, this.lRate * (cmg * delta_qjf - this.regI * qjf));
                    this.loss += this.regI * qjf * qjf;
                }
            }
            this.P = this.P.add(PS);
            this.Q = this.Q.add(QS);
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    public double predict(int u, int j) {
        return this.itemBias.get(j) + DenseMatrix.rowMult(this.P, u, this.Q, j);
    }

    protected double predict(int u, int j, List<Integer> g) {
        double ruj = this.predict(u, j);
        double sum = 0.0;
        for (int w : g) {
            sum += DenseMatrix.rowMult(this.P, w, this.Q, j);
        }
        double rgj = sum / (double)g.size() + this.itemBias.get(j);
        return this.rho * rgj + (1.0 - this.rho) * ruj;
    }
}

