/*
 * Decompiled with CFR 0.152.
 */
package librec.ranking;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.recalot.common.configuration.Configuration;
import com.recalot.common.configuration.ConfigurationItem;
import com.recalot.common.configuration.Configurations;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import librec.data.DenseMatrix;
import librec.data.SparseVector;
import librec.data.SymmMatrix;
import librec.data.VectorEntry;
import librec.intf.IterativeRecommender;
import librec.util.Lists;

@Configurations(value={@Configuration(key="similarityMeasure", type=ConfigurationItem.ConfigurationItemType.Options, options={"cos", "cos-binary", "msd", "cpc", "exjaccard", "pcc"}, value="pcc", description="similarity measure"), @Configuration(key="similarityShrinkage", type=ConfigurationItem.ConfigurationItemType.Integer, description="similarity shrinkage"), @Configuration(key="alpha", type=ConfigurationItem.ConfigurationItemType.Double, description="similarity filter"), @Configuration(key="knn", type=ConfigurationItem.ConfigurationItemType.Integer, description="number of nearest neighbors"), @Configuration(key="regL1", type=ConfigurationItem.ConfigurationItemType.Double, description="regularization parameters for the L1 term"), @Configuration(key="regL2", type=ConfigurationItem.ConfigurationItemType.Double, description="regularization parameters for the L2 term")})
public class SLIM
extends IterativeRecommender {
    private DenseMatrix W;
    private Multimap<Integer, Integer> itemNNs;
    private List<Integer> allItems;
    public double regL1;
    public double regL2;
    public int knn;
    public int similarityShrinkage;
    public String similarityMeasure;

    public SLIM() {
        this.isRankingPred = true;
    }

    @Override
    public void initModel() throws Exception {
        this.W = new DenseMatrix(this.numItems, this.numItems);
        this.W.init();
        this.userCache = this.trainMatrix.rowCache(this.guavaCacheSpec);
        if (this.knn > 0) {
            SymmMatrix itemCorrs = this.buildCorrs(false, this.similarityMeasure, this.similarityShrinkage);
            this.itemNNs = HashMultimap.create();
            for (int j = 0; j < this.numItems; ++j) {
                this.W.set(j, j, 0.0);
                Map<Integer, Double> nns = itemCorrs.row(j).toMap();
                if (this.knn > 0 && this.knn < nns.size()) {
                    List<Map.Entry<Integer, Double>> sorted = Lists.sortMap(nns, true);
                    List<Map.Entry<Integer, Double>> subset = sorted.subList(0, this.knn);
                    nns.clear();
                    for (Map.Entry<Integer, Double> kv : subset) {
                        nns.put(kv.getKey(), kv.getValue());
                    }
                }
                for (Map.Entry<Integer, Double> en : nns.entrySet()) {
                    this.itemNNs.put((Object)j, (Object)en.getKey());
                }
            }
        } else {
            this.allItems = this.trainMatrix.columns();
            for (int j = 0; j < this.numItems; ++j) {
                this.W.set(j, j, 0.0);
            }
        }
    }

    @Override
    public void buildModel() throws Exception {
        this.last_loss = 0.0;
        for (int iter = 1; iter <= this.numIters; ++iter) {
            this.loss = 0.0;
            for (int j = 0; j < this.numItems; ++j) {
                List<Integer> nns = this.knn > 0 ? this.itemNNs.get((Object)j) : this.allItems;
                for (Integer i : nns) {
                    double gradSum = 0.0;
                    double rateSum = 0.0;
                    double errs = 0.0;
                    SparseVector Ri = this.trainMatrix.column(i);
                    int N = Ri.getCount();
                    for (VectorEntry ve : Ri) {
                        int u = ve.index();
                        double rui = ve.get();
                        double ruj = this.trainMatrix.get(u, j);
                        double euj = ruj - this.predict(u, j, i);
                        gradSum += rui * euj;
                        rateSum += rui * rui;
                        errs += euj * euj;
                    }
                    gradSum /= (double)N;
                    rateSum /= (double)N;
                    double wij = this.W.get(i, j);
                    this.loss += (errs /= (double)N) + 0.5 * this.regL2 * wij * wij + this.regL1 * wij;
                    if (this.regL1 < Math.abs(gradSum)) {
                        if (gradSum > 0.0) {
                            double update = (gradSum - this.regL1) / (this.regL2 + rateSum);
                            this.W.set(i, j, update);
                            continue;
                        }
                        double update = (gradSum + this.regL1) / (this.regL2 + rateSum);
                        this.W.set(i, j, update);
                        continue;
                    }
                    this.W.set(i, j, 0.0);
                }
            }
            if (this.isConverged(iter)) break;
        }
    }

    protected double predict(int u, int j, int excluded_item) throws Exception {
        List<Integer> nns = this.knn > 0 ? this.itemNNs.get((Object)j) : this.allItems;
        SparseVector Ru = (SparseVector)this.userCache.get((Object)u);
        double pred = 0.0;
        Iterator iterator = nns.iterator();
        while (iterator.hasNext()) {
            int k = (Integer)iterator.next();
            if (!Ru.contains(k) || k == excluded_item) continue;
            double ruk = Ru.get(k);
            pred += ruk * this.W.get(k, j);
        }
        return pred;
    }

    @Override
    public double predict(int u, int j) throws Exception {
        return this.predict(u, j, -1);
    }

    @Override
    protected boolean isConverged(int iter) {
        double delta_loss = this.last_loss - this.loss;
        this.last_loss = this.loss;
        return iter > 1 ? delta_loss < 1.0E-5 : false;
    }
}

