/*
 * Decompiled with CFR 0.152.
 */
package librec.baseline;

import java.math.BigDecimal;
import java.math.RoundingMode;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.GraphicRecommender;
import librec.util.Randoms;

public class UserCluster
extends GraphicRecommender {
    private DenseMatrix Pkr;
    private DenseVector Pi;
    private DenseMatrix Gamma;
    private DenseMatrix Nur;
    private DenseVector Nu;

    @Override
    public void initModel() throws Exception {
        this.Pkr = new DenseMatrix(this.numFactors, this.numLevels);
        for (int k = 0; k < this.numFactors; ++k) {
            double[] probs = Randoms.randProbs(this.numLevels);
            for (int r = 0; r < this.numLevels; ++r) {
                this.Pkr.set(k, r, probs[r]);
            }
        }
        this.Pi = new DenseVector(Randoms.randProbs(this.numFactors));
        this.Gamma = new DenseMatrix(this.numUsers, this.numFactors);
        this.Nur = new DenseMatrix(this.numUsers, this.numLevels);
        this.Nu = new DenseVector(this.numUsers);
        for (int u = 0; u < this.numUsers; ++u) {
            SparseVector ru = this.trainMatrix.row(u);
            for (VectorEntry ve : ru) {
                double rui = ve.get();
                int r = ratingScale.indexOf(rui);
                this.Nur.add(u, r, 1.0);
            }
            this.Nu.set(u, ru.size());
        }
    }

    @Override
    public void buildModel() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            int k;
            for (int u = 0; u < this.numUsers; ++u) {
                int k2;
                BigDecimal sum_u = BigDecimal.ZERO;
                SparseVector ru = this.trainMatrix.row(u);
                BigDecimal[] sum_uk = new BigDecimal[this.numFactors];
                for (k2 = 0; k2 < this.numFactors; ++k2) {
                    BigDecimal puk = new BigDecimal(this.Pi.get(k2));
                    for (VectorEntry ve : ru) {
                        double rui = ve.get();
                        int r = ratingScale.indexOf(rui);
                        BigDecimal pkr = new BigDecimal(this.Pkr.get(k2, r));
                        puk = puk.multiply(pkr);
                    }
                    sum_uk[k2] = puk;
                    sum_u = sum_u.add(puk);
                }
                for (k2 = 0; k2 < this.numFactors; ++k2) {
                    double zuk = sum_uk[k2].divide(sum_u, 6, RoundingMode.HALF_UP).doubleValue();
                    this.Gamma.set(u, k2, zuk);
                }
            }
            double[] sum_uk = new double[this.numFactors];
            double sum = 0.0;
            for (k = 0; k < this.numFactors; ++k) {
                for (int r = 0; r < this.numLevels; ++r) {
                    double numerator = 0.0;
                    double denorminator = 0.0;
                    for (int u = 0; u < this.numUsers; ++u) {
                        double ruk = this.Gamma.get(u, k);
                        numerator += ruk * this.Nur.get(u, r);
                        denorminator += ruk * this.Nu.get(u);
                    }
                    this.Pkr.set(k, r, numerator / denorminator);
                }
                double sum_u = 0.0;
                for (int u = 0; u < this.numUsers; ++u) {
                    double ruk = this.Gamma.get(u, k);
                    sum_u += ruk;
                }
                sum_uk[k] = sum_u;
                sum += sum_u;
            }
            for (k = 0; k < this.numFactors; ++k) {
                this.Pi.set(k, sum_uk[k] / sum);
            }
            this.loss = 0.0;
            for (int u = 0; u < this.numUsers; ++u) {
                for (int k3 = 0; k3 < this.numFactors; ++k3) {
                    double ruk = this.Gamma.get(u, k3);
                    double pi_k = this.Pi.get(k3);
                    double sum_nl = 0.0;
                    for (int r = 0; r < this.numLevels; ++r) {
                        double nur = this.Nur.get(u, r);
                        double pkr = this.Pkr.get(k3, r);
                        sum_nl += nur * Math.log(pkr);
                    }
                    this.loss += ruk * (Math.log(pi_k) + sum_nl);
                }
            }
            this.loss = -this.loss;
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    protected boolean isConverged(int iter) throws Exception {
        float deltaLoss = (float)(this.loss - this.lastLoss);
        if (iter > 1 && (deltaLoss > 0.0f || Double.isNaN(deltaLoss))) {
            return true;
        }
        this.lastLoss = this.loss;
        return false;
    }

    @Override
    public double predict(int u, int j, boolean bound) throws Exception {
        double pred = 0.0;
        for (int k = 0; k < this.numFactors; ++k) {
            double pu_k = this.Gamma.get(u, k);
            double pred_k = 0.0;
            for (int r = 0; r < this.numLevels; ++r) {
                double rui = (Double)ratingScale.get(r);
                double pkr = this.Pkr.get(k, r);
                pred_k += rui * pkr;
            }
            pred += pu_k * pred_k;
        }
        return pred;
    }
}

