/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.recalot.common.configuration.Configuration;
import com.recalot.common.configuration.ConfigurationItem;
import com.recalot.common.configuration.Configurations;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.GraphicRecommender;
import librec.util.Gaussian;
import librec.util.Randoms;
import librec.util.Stats;

@Configurations(value={@Configuration(key="q", requirement=ConfigurationItem.ConfigurationItemRequirementType.Required, type=ConfigurationItem.ConfigurationItemType.Double, description="smoothing weight", value="10"), @Configuration(key="b", requirement=ConfigurationItem.ConfigurationItemRequirementType.Required, type=ConfigurationItem.ConfigurationItemType.Double, description="tempered EM parameter beta, suggested by Wu Bin", value="1.0"), @Configuration(key="burnIn", requirement=ConfigurationItem.ConfigurationItemRequirementType.Required, type=ConfigurationItem.ConfigurationItemType.Double, description="burn-in period", value="1400"), @Configuration(key="sampleLag", requirement=ConfigurationItem.ConfigurationItemRequirementType.Required, type=ConfigurationItem.ConfigurationItemType.Integer, value="10", description="sample lag (if -1 only one sample taken)"), @Configuration(key="numFactors", requirement=ConfigurationItem.ConfigurationItemRequirementType.Required, type=ConfigurationItem.ConfigurationItemType.Integer, value="10"), @Configuration(key="numIters", requirement=ConfigurationItem.ConfigurationItemRequirementType.Required, type=ConfigurationItem.ConfigurationItemType.Integer, value="100")})
public class GPLSA
extends GraphicRecommender {
    private Table<Integer, Integer, Map<Integer, Double>> Q;
    private DenseMatrix Mu;
    private DenseMatrix Sigma;
    private DenseVector mu;
    private DenseVector sigma;
    public double q;
    public double b;

    @Override
    public void initModel() throws Exception {
        double sum;
        this.Puk = new DenseMatrix(this.numUsers, this.numFactors);
        for (int u = 0; u < this.numUsers; ++u) {
            double[] probs = Randoms.randProbs(this.numFactors);
            for (int k = 0; k < this.numFactors; ++k) {
                this.Puk.set(u, k, probs[k]);
            }
        }
        double mean = this.globalMean;
        double sd = Stats.sd(this.trainMatrix.getData(), mean);
        this.mu = new DenseVector(this.numUsers);
        this.sigma = new DenseVector(this.numUsers);
        for (int u = 0; u < this.numUsers; ++u) {
            SparseVector ru = this.trainMatrix.row(u);
            int Nu = ru.size();
            if (Nu < 1) continue;
            double mu_u = (ru.sum() + this.q * mean) / ((double)Nu + this.q);
            this.mu.set(u, mu_u);
            sum = 0.0;
            for (VectorEntry ve : ru) {
                sum += Math.pow(ve.get() - mu_u, 2.0);
            }
            double sigma_u = Math.sqrt((sum += this.q * Math.pow(sd, 2.0)) / ((double)Nu + this.q));
            this.sigma.set(u, sigma_u);
        }
        this.Q = HashBasedTable.create();
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double rate = me.get();
            double r = (rate - this.mu.get(u)) / this.sigma.get(u);
            me.set(r);
            this.Q.put((Object)u, (Object)i, new HashMap());
        }
        this.Mu = new DenseMatrix(this.numItems, this.numFactors);
        this.Sigma = new DenseMatrix(this.numItems, this.numFactors);
        for (int i = 0; i < this.numItems; ++i) {
            SparseVector ci = this.trainMatrix.column(i);
            int Ni = ci.size();
            if (Ni < 1) continue;
            double mu_i = ci.mean();
            sum = 0.0;
            for (VectorEntry ve : ci) {
                sum += Math.pow(ve.get() - mu_i, 2.0);
            }
            double sd_i = Math.sqrt(sum / (double)Ni);
            for (int z = 0; z < this.numFactors; ++z) {
                this.Mu.set(i, z, mu_i + this.smallValue * Math.random());
                this.Sigma.set(i, z, sd_i + this.smallValue * Math.random());
            }
        }
    }

    @Override
    protected void eStep() {
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double r = me.get();
            double denominator = 0.0;
            double[] numerator = new double[this.numFactors];
            for (int z = 0; z < this.numFactors; ++z) {
                double val;
                double pdf = Gaussian.pdf(r, this.Mu.get(i, z), this.Sigma.get(i, z));
                numerator[z] = val = Math.pow(this.Puk.get(u, z) * pdf, this.b);
                denominator += val;
            }
            Map factorProbs = (Map)this.Q.get((Object)u, (Object)i);
            for (int z = 0; z < this.numFactors; ++z) {
                double prob = denominator > 0.0 ? numerator[z] / denominator : 0.0;
                factorProbs.put(z, prob);
            }
        }
    }

    @Override
    protected void mStep() {
        for (int u = 0; u < this.numUsers; ++u) {
            int z;
            List<Integer> items = this.trainMatrix.getColumns(u);
            if (items.size() < 1) continue;
            double[] numerator = new double[this.numFactors];
            double denominator = 0.0;
            for (z = 0; z < this.numFactors; ++z) {
                for (int i : items) {
                    int n = z;
                    numerator[n] = numerator[n] + (Double)((Map)this.Q.get((Object)u, (Object)i)).get(z);
                }
                denominator += numerator[z];
            }
            for (z = 0; z < this.numFactors; ++z) {
                this.Puk.set(u, z, numerator[z] / denominator);
            }
        }
        for (int i = 0; i < this.numItems; ++i) {
            List<Integer> users = this.trainMatrix.getRows(i);
            if (users.size() < 1) continue;
            for (int z = 0; z < this.numFactors; ++z) {
                double numerator = 0.0;
                double denominator = 0.0;
                for (int u : users) {
                    double r = this.trainMatrix.get(u, i);
                    double prob = (Double)((Map)this.Q.get((Object)u, (Object)i)).get(z);
                    numerator += r * prob;
                    denominator += prob;
                }
                double mu = denominator > 0.0 ? numerator / denominator : 0.0;
                this.Mu.set(i, z, mu);
                numerator = 0.0;
                denominator = 0.0;
                for (int u : users) {
                    double r = this.trainMatrix.get(u, i);
                    double prob = (Double)((Map)this.Q.get((Object)u, (Object)i)).get(z);
                    numerator += Math.pow(r - mu, 2.0) * prob;
                    denominator += prob;
                }
                double sigma = denominator > 0.0 ? Math.sqrt(numerator / denominator) : 0.0;
                this.Sigma.set(i, z, sigma);
            }
        }
    }

    @Override
    public double predict(int u, int i) throws Exception {
        double sum = 0.0;
        for (int z = 0; z < this.numFactors; ++z) {
            sum += this.Puk.get(u, z) * this.Mu.get(i, z);
        }
        return this.mu.get(u) + this.sigma.get(u) * sum;
    }
}

