/*
 * Decompiled with CFR 0.152.
 */
package librec.ranking;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.SocialRecommender;
import librec.util.Randoms;

public class SBPR
extends SocialRecommender {
    private Map<Integer, List<Integer>> SP;

    public SBPR() {
        this.isRankingPred = true;
        this.initByNorm = false;
    }

    @Override
    public void initModel() throws Exception {
        super.initModel();
        this.itemBias = new DenseVector(this.numItems);
        this.itemBias.init();
        this.userItemsCache = this.trainMatrix.rowColumnsCache(this.guavaCacheSpec);
        this.SP = new HashMap<Integer, List<Integer>>();
        int um = this.trainMatrix.numRows();
        for (int u = 0; u < um; ++u) {
            List uRatedItems = (List)this.userItemsCache.get((Object)u);
            if (uRatedItems.size() == 0) continue;
            List<Integer> trustedUsers = socialMatrix.getColumns(u);
            ArrayList<Integer> items = new ArrayList<Integer>();
            for (int v : trustedUsers) {
                if (v >= um) continue;
                List vRatedItems = (List)this.userItemsCache.get((Object)v);
                Iterator iterator = vRatedItems.iterator();
                while (iterator.hasNext()) {
                    int j = (Integer)iterator.next();
                    if (uRatedItems.contains(j) || items.contains(j)) continue;
                    items.add(j);
                }
            }
            this.SP.put(u, items);
        }
    }

    @Override
    public void postModel() throws Exception {
        this.SP = null;
    }

    @Override
    public void buildModel() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            this.loss = 0.0;
            int smax = this.numUsers * 100;
            for (int s = 0; s < smax; ++s) {
                int u = 0;
                int i = 0;
                int j = 0;
                List ratedItems = null;
                while ((ratedItems = (List)this.userItemsCache.get((Object)(u = Randoms.uniform(this.trainMatrix.numRows())))).size() == 0) {
                }
                i = (Integer)Randoms.random(ratedItems);
                double xui = this.predict(u, i);
                List<Integer> SPu = this.SP.get(u);
                while (ratedItems.contains(j = Randoms.uniform(this.numItems)) || SPu.contains(j)) {
                }
                double xuj = this.predict(u, j);
                if (SPu.size() > 0) {
                    int k = Randoms.random(SPu);
                    double xuk = this.predict(u, k);
                    SparseVector Tu = socialMatrix.row(u);
                    double suk = 0.0;
                    for (VectorEntry ve : Tu) {
                        double rvk;
                        int v = ve.index();
                        if (v >= this.trainMatrix.numRows() || !((rvk = this.trainMatrix.get(v, k)) > 0.0)) continue;
                        suk += 1.0;
                    }
                    double xuik = (xui - xuk) / (1.0 + suk);
                    double xukj = xuk - xuj;
                    double vals = -Math.log(this.g(xuik)) - Math.log(this.g(xukj));
                    this.loss += vals;
                    double cik = this.g(-xuik);
                    double ckj = this.g(-xukj);
                    double bi = this.itemBias.get(i);
                    this.itemBias.add(i, this.lRate * (cik / (1.0 + suk) - this.regB * bi));
                    this.loss += this.regB * bi * bi;
                    double bk = this.itemBias.get(k);
                    this.itemBias.add(k, this.lRate * (-cik / (1.0 + suk) + ckj - this.regB * bk));
                    this.loss += this.regB * bk * bk;
                    double bj = this.itemBias.get(j);
                    this.itemBias.add(j, this.lRate * (-ckj - this.regB * bj));
                    this.loss += this.regB * bj * bj;
                    for (int f = 0; f < this.numFactors; ++f) {
                        double puf = this.P.get(u, f);
                        double qif = this.Q.get(i, f);
                        double qkf = this.Q.get(k, f);
                        double qjf = this.Q.get(j, f);
                        double delta_puf = cik * (qif - qkf) / (1.0 + suk) + ckj * (qkf - qjf);
                        this.P.add(u, f, this.lRate * (delta_puf - this.regU * puf));
                        this.Q.add(i, f, this.lRate * (cik * puf / (1.0 + suk) - this.regI * qif));
                        double delta_qkf = cik * (-puf / (1.0 + suk)) + ckj * puf;
                        this.Q.add(k, f, this.lRate * (delta_qkf - this.regI * qkf));
                        this.Q.add(j, f, this.lRate * (ckj * -puf - this.regI * qjf));
                        this.loss += this.regU * puf * puf + this.regI * qif * qif;
                        this.loss += this.regI * qkf * qkf + this.regI * qjf * qjf;
                    }
                    continue;
                }
                double xuij = xui - xuj;
                double vals = -Math.log(this.g(xuij));
                this.loss += vals;
                double cij = this.g(-xuij);
                double bi = this.itemBias.get(i);
                this.itemBias.add(i, this.lRate * (cij - this.regB * bi));
                this.loss += this.regB * bi * bi;
                double bj = this.itemBias.get(j);
                this.itemBias.add(j, this.lRate * (-cij - this.regB * bj));
                this.loss += this.regB * bj * bj;
                for (int f = 0; f < this.numFactors; ++f) {
                    double puf = this.P.get(u, f);
                    double qif = this.Q.get(i, f);
                    double qjf = this.Q.get(j, f);
                    this.P.add(u, f, this.lRate * (cij * (qif - qjf) - this.regU * puf));
                    this.Q.add(i, f, this.lRate * (cij * puf - this.regI * qif));
                    this.Q.add(j, f, this.lRate * (cij * -puf - this.regI * qjf));
                    this.loss += this.regU * puf * puf + this.regI * qif * qif + this.regI * qjf * qjf;
                }
            }
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    public double predict(int u, int j) {
        return this.itemBias.get(j) + DenseMatrix.rowMult(this.P, u, this.Q, j);
    }
}

