/*
 * Decompiled with CFR 0.152.
 */
package librec.ext;

import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.intf.IterativeRecommender;

public class NMF
extends IterativeRecommender {
    protected DenseMatrix W;
    protected DenseMatrix H;
    protected SparseMatrix V;

    public NMF() {
        this.lRate = -1.0;
    }

    @Override
    public void initModel() throws Exception {
        this.W = new DenseMatrix(this.numUsers, this.numFactors);
        this.H = new DenseMatrix(this.numFactors, this.numItems);
        this.W.init(0.01);
        this.H.init(0.01);
        this.V = this.trainMatrix;
    }

    @Override
    public void buildModel() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            for (int u = 0; u < this.W.numRows(); ++u) {
                SparseVector uv = this.V.row(u);
                if (uv.getCount() <= 0) continue;
                SparseVector euv = new SparseVector(this.V.numColumns());
                for (int j : uv.getIndex()) {
                    euv.set(j, this.predict(u, j));
                }
                for (int f = 0; f < this.W.numColumns(); ++f) {
                    DenseVector fv = this.H.row(f, false);
                    double real = fv.inner(uv);
                    double estm = fv.inner(euv) + 1.0E-9;
                    this.W.set(u, f, this.W.get(u, f) * (real / estm));
                }
            }
            DenseMatrix trW = this.W.transpose();
            for (int j = 0; j < this.H.numColumns(); ++j) {
                SparseVector jv = this.V.column(j);
                if (jv.getCount() <= 0) continue;
                SparseVector ejv = new SparseVector(this.V.numRows());
                for (int u : jv.getIndex()) {
                    ejv.set(u, this.predict(u, j));
                }
                for (int f = 0; f < this.H.numRows(); ++f) {
                    DenseVector fv = trW.row(f, false);
                    double real = fv.inner(jv);
                    double estm = fv.inner(ejv) + 1.0E-9;
                    this.H.set(f, j, this.H.get(f, j) * (real / estm));
                }
            }
            this.loss = 0.0;
            for (MatrixEntry me : this.V) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                if (!(ruj > 0.0)) continue;
                double euj = this.predict(u, j) - ruj;
                this.loss += euj * euj;
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    public double predict(int u, int j) {
        return DenseMatrix.product(this.W, u, this.H, j);
    }
}

