/*
 * Decompiled with CFR 0.152.
 */
package librec.baseline;

import java.math.BigDecimal;
import java.math.RoundingMode;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.GraphicRecommender;
import librec.util.Randoms;

public class ItemCluster
extends GraphicRecommender {
    private DenseMatrix Pkr;
    private DenseVector Pi;
    private DenseMatrix Gamma;
    private DenseMatrix Nir;
    private DenseVector Ni;

    @Override
    public void initModel() throws Exception {
        this.Pkr = new DenseMatrix(this.numFactors, this.numLevels);
        for (int k = 0; k < this.numFactors; ++k) {
            double[] probs = Randoms.randProbs(this.numLevels);
            for (int r = 0; r < this.numLevels; ++r) {
                this.Pkr.set(k, r, probs[r]);
            }
        }
        this.Pi = new DenseVector(Randoms.randProbs(this.numFactors));
        this.Gamma = new DenseMatrix(this.numItems, this.numFactors);
        this.Nir = new DenseMatrix(this.numItems, this.numLevels);
        this.Ni = new DenseVector(this.numItems);
        for (int i = 0; i < this.numItems; ++i) {
            SparseVector ri = this.trainMatrix.column(i);
            for (VectorEntry ve : ri) {
                double rui = ve.get();
                int r = ratingScale.indexOf(rui);
                this.Nir.add(i, r, 1.0);
            }
            this.Ni.set(i, ri.size());
        }
    }

    @Override
    public void buildModel() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            int k;
            for (int i = 0; i < this.numItems; ++i) {
                int k2;
                BigDecimal sum_i = BigDecimal.ZERO;
                SparseVector ri = this.trainMatrix.column(i);
                BigDecimal[] sum_ik = new BigDecimal[this.numFactors];
                for (k2 = 0; k2 < this.numFactors; ++k2) {
                    BigDecimal pik = new BigDecimal(this.Pi.get(k2));
                    for (VectorEntry ve : ri) {
                        double rui = ve.get();
                        int r = ratingScale.indexOf(rui);
                        BigDecimal pkr = new BigDecimal(this.Pkr.get(k2, r));
                        pik = pik.multiply(pkr);
                    }
                    sum_ik[k2] = pik;
                    sum_i = sum_i.add(pik);
                }
                for (k2 = 0; k2 < this.numFactors; ++k2) {
                    double zik = sum_ik[k2].divide(sum_i, 6, RoundingMode.HALF_UP).doubleValue();
                    this.Gamma.set(i, k2, zik);
                }
            }
            double[] sum_ik = new double[this.numFactors];
            double sum = 0.0;
            for (k = 0; k < this.numFactors; ++k) {
                for (int r = 0; r < this.numLevels; ++r) {
                    double numerator = 0.0;
                    double denorminator = 0.0;
                    for (int i = 0; i < this.numItems; ++i) {
                        double ruk = this.Gamma.get(i, k);
                        numerator += ruk * this.Nir.get(i, r);
                        denorminator += ruk * this.Ni.get(i);
                    }
                    this.Pkr.set(k, r, numerator / denorminator);
                }
                double sum_i = 0.0;
                for (int i = 0; i < this.numItems; ++i) {
                    double rik = this.Gamma.get(i, k);
                    sum_i += rik;
                }
                sum_ik[k] = sum_i;
                sum += sum_i;
            }
            for (k = 0; k < this.numFactors; ++k) {
                this.Pi.set(k, sum_ik[k] / sum);
            }
            this.loss = 0.0;
            for (int i = 0; i < this.numItems; ++i) {
                for (int k3 = 0; k3 < this.numFactors; ++k3) {
                    double rik = this.Gamma.get(i, k3);
                    double pi_k = this.Pi.get(k3);
                    double sum_nl = 0.0;
                    for (int r = 0; r < this.numLevels; ++r) {
                        double nur = this.Nir.get(i, r);
                        double pkr = this.Pkr.get(k3, r);
                        sum_nl += nur * Math.log(pkr);
                    }
                    this.loss += rik * (Math.log(pi_k) + sum_nl);
                }
            }
            this.loss = -this.loss;
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    protected boolean isConverged(int iter) throws Exception {
        float deltaLoss = (float)(this.loss - this.lastLoss);
        if (iter > 1 && (deltaLoss > 0.0f || Double.isNaN(deltaLoss))) {
            return true;
        }
        this.lastLoss = this.loss;
        return false;
    }

    @Override
    public double predict(int u, int j, boolean bound) throws Exception {
        double pred = 0.0;
        for (int k = 0; k < this.numFactors; ++k) {
            double pj_k = this.Gamma.get(j, k);
            double pred_k = 0.0;
            for (int r = 0; r < this.numLevels; ++r) {
                double rui = (Double)ratingScale.get(r);
                double pkr = this.Pkr.get(k, r);
                pred_k += rui * pkr;
            }
            pred += pj_k * pred_k;
        }
        return pred;
    }
}

