/*
 * Decompiled with CFR 0.152.
 */
package librec.ranking;

import com.google.common.collect.HashBasedTable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.RatingContext;
import librec.intf.GraphicRecommender;
import librec.util.Gamma;

public class ItemBigram
extends GraphicRecommender {
    private Map<Integer, List<Integer>> userItemsMap;
    private int[][][] Nkji;
    private DenseMatrix Nkj;
    private double[][][] Pkji;
    private double[][][] PkjiSum;
    private DenseMatrix beta;

    public ItemBigram() {
        this.isRankingPred = true;
    }

    @Override
    public void initModel() throws Exception {
        this.userItemsMap = new HashMap<Integer, List<Integer>>();
        for (int u = 0; u < this.numUsers; ++u) {
            List<Integer> unsortedItems = this.trainMatrix.getColumns(u);
            int size = unsortedItems.size();
            ArrayList<RatingContext> rcs = new ArrayList<RatingContext>(size);
            for (Integer n : unsortedItems) {
                rcs.add(new RatingContext(u, n, (long)this.timeMatrix.get(u, n)));
            }
            Collections.sort(rcs);
            ArrayList<Integer> sortedItems = new ArrayList<Integer>(size);
            for (RatingContext rc : rcs) {
                sortedItems.add(rc.getItem());
            }
            this.userItemsMap.put(u, sortedItems);
        }
        this.Nuk = new DenseMatrix(this.numUsers, this.numFactors);
        this.Nu = new DenseVector(this.numUsers);
        this.Nkji = new int[this.numFactors][this.numItems + 1][this.numItems];
        this.Nkj = new DenseMatrix(this.numFactors, this.numItems + 1);
        this.PukSum = new DenseMatrix(this.numUsers, this.numFactors);
        this.PkjiSum = new double[this.numFactors][this.numItems + 1][this.numItems];
        this.Pkji = new double[this.numFactors][this.numItems + 1][this.numItems];
        this.alpha = new DenseVector(this.numFactors);
        this.alpha.setAll(this.initAlpha);
        this.beta = new DenseMatrix(this.numFactors, this.numItems + 1);
        this.beta.setAll(this.initBeta);
        this.z = HashBasedTable.create();
        for (Map.Entry<Integer, List<Integer>> en : this.userItemsMap.entrySet()) {
            int u = en.getKey();
            List<Integer> items = en.getValue();
            for (int m = 0; m < items.size(); ++m) {
                int n = items.get(m);
                int k = (int)(Math.random() * (double)this.numFactors);
                this.z.put((Object)u, (Object)n, (Object)k);
                this.Nuk.add(u, k, 1.0);
                this.Nu.add(u, 1.0);
                int j = m > 0 ? items.get(m - 1) : this.numItems;
                int[] nArray = this.Nkji[k][j];
                int n2 = n;
                nArray[n2] = nArray[n2] + 1;
                this.Nkj.add(k, j, 1.0);
            }
        }
    }

    @Override
    protected void eStep() {
        double sumAlpha = this.alpha.sum();
        for (Map.Entry<Integer, List<Integer>> en : this.userItemsMap.entrySet()) {
            int u = en.getKey();
            List<Integer> items = en.getValue();
            for (int m = 0; m < items.size(); ++m) {
                int t;
                int i = items.get(m);
                int k = (Integer)this.z.get((Object)u, (Object)i);
                this.Nuk.add(u, k, -1.0);
                this.Nu.add(u, -1.0);
                int j = m > 0 ? items.get(m - 1) : this.numItems;
                int[] nArray = this.Nkji[k][j];
                int n = i;
                nArray[n] = nArray[n] - 1;
                this.Nkj.add(k, j, -1.0);
                double[] Pk = new double[this.numFactors];
                for (t = 0; t < this.numFactors; ++t) {
                    double v1 = (this.Nuk.get(u, t) + this.alpha.get(t)) / (this.Nu.get(u) + sumAlpha);
                    double v2 = ((double)this.Nkji[t][j][i] + this.beta.get(t, j)) / (this.Nkj.get(t, j) + this.beta.sumOfRow(t));
                    Pk[t] = v1 * v2;
                }
                for (t = 1; t < this.numFactors; ++t) {
                    int n2 = t;
                    Pk[n2] = Pk[n2] + Pk[t - 1];
                }
                double rand = Math.random() * Pk[this.numFactors - 1];
                for (k = 0; k < this.numFactors && !(rand < Pk[k]); ++k) {
                }
                this.z.put((Object)u, (Object)i, (Object)k);
                this.Nuk.add(u, k, 1.0);
                this.Nu.add(u, 1.0);
                int[] nArray2 = this.Nkji[k][j];
                int n3 = i;
                nArray2[n3] = nArray2[n3] + 1;
                this.Nkj.add(k, j, 1.0);
            }
        }
    }

    @Override
    protected void mStep() {
        int k;
        double sumAlpha = this.alpha.sum();
        for (k = 0; k < this.numFactors; ++k) {
            double ak = this.alpha.get(k);
            double numerator = 0.0;
            double denominator = 0.0;
            for (int u = 0; u < this.numUsers; ++u) {
                numerator += Gamma.digamma(this.Nuk.get(u, k) + ak) - Gamma.digamma(ak);
                denominator += Gamma.digamma(this.Nu.get(u) + sumAlpha) - Gamma.digamma(sumAlpha);
            }
            if (numerator == 0.0) continue;
            this.alpha.set(k, ak * (numerator / denominator));
        }
        for (k = 0; k < this.numFactors; ++k) {
            double bk = this.beta.sumOfRow(k);
            for (int j = 0; j < this.numItems + 1; ++j) {
                double bkj = this.beta.get(k, j);
                double numerator = 0.0;
                double denominator = 0.0;
                for (int i = 0; i < this.numItems; ++i) {
                    numerator += Gamma.digamma((double)this.Nkji[k][j][i] + bkj) - Gamma.digamma(bkj);
                    denominator += Gamma.digamma(this.Nkj.get(k, j) + bk) - Gamma.digamma(bk);
                }
                if (numerator == 0.0) continue;
                this.beta.set(k, j, bkj * (numerator / denominator));
            }
        }
    }

    @Override
    protected void readoutParams() {
        double val = 0.0;
        double sumAlpha = this.alpha.sum();
        for (int u = 0; u < this.numFactors; ++u) {
            for (int k = 0; k < this.numFactors; ++k) {
                val = (this.Nuk.get(u, k) + this.alpha.get(k)) / (this.Nu.get(u) + sumAlpha);
                this.PukSum.add(u, k, val);
            }
        }
        for (int k = 0; k < this.numFactors; ++k) {
            double bk = this.beta.sumOfRow(k);
            for (int j = 0; j < this.numItems + 1; ++j) {
                int i = 0;
                while (i < this.numItems) {
                    val = ((double)this.Nkji[k][j][i] + this.beta.get(k, j)) / (this.Nkj.get(k, j) + bk);
                    double[] dArray = this.PkjiSum[k][j];
                    int n = i++;
                    dArray[n] = dArray[n] + val;
                }
            }
        }
        ++this.numStats;
    }

    @Override
    protected void estimateParams() {
        this.Puk = this.PukSum.scale(1.0 / (double)this.numStats);
        for (int k = 0; k < this.numFactors; ++k) {
            for (int j = 0; j < this.numItems + 1; ++j) {
                for (int i = 0; i < this.numItems; ++i) {
                    this.Pkji[k][j][i] = this.PkjiSum[k][j][i] / (double)this.numStats;
                }
            }
        }
    }

    @Override
    public double ranking(int u, int i) throws Exception {
        List<Integer> items = this.userItemsMap.get(u);
        int j = items.size() < 1 ? this.numItems : items.get(items.size() - 1);
        double rank = 0.0;
        for (int k = 0; k < this.numFactors; ++k) {
            rank += this.Puk.get(u, k) * this.Pkji[k][j][i];
        }
        return rank;
    }
}

