/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import com.recalot.common.configuration.Configuration;
import com.recalot.common.configuration.ConfigurationItem;
import com.recalot.common.configuration.Configurations;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import librec.data.DenseVector;
import librec.data.SparseVector;
import librec.data.SymmMatrix;
import librec.intf.Recommender;
import librec.util.Lists;
import librec.util.Stats;

@Configurations(value={@Configuration(key="similarityMeasure", type=ConfigurationItem.ConfigurationItemType.Options, options={"cos", "cos-binary", "msd", "cpc", "exjaccard", "pcc"}, value="pcc", description="similarity measure"), @Configuration(key="similarityShrinkage", type=ConfigurationItem.ConfigurationItemType.Integer, description="similarity shrinkage", value="25"), @Configuration(key="knn", type=ConfigurationItem.ConfigurationItemType.Integer, description="number of nearest neighbors", value="80")})
public class ItemKNN
extends Recommender {
    private SymmMatrix itemCorrs;
    private DenseVector itemMeans;
    public int knn;
    public int similarityShrinkage;
    public String similarityMeasure;

    @Override
    public void initModel() throws Exception {
        this.itemCorrs = this.buildCorrs(false, this.similarityMeasure, this.similarityShrinkage);
        this.itemMeans = new DenseVector(this.numItems);
        for (int i = 0; i < this.numItems; ++i) {
            SparseVector vs = this.trainMatrix.column(i);
            this.itemMeans.set(i, vs.getCount() > 0 ? vs.mean() : this.globalMean);
        }
    }

    @Override
    public double predict(int u, int j) {
        HashMap<Integer, Double> nns = new HashMap<Integer, Double>();
        SparseVector dv = this.itemCorrs.row(j);
        for (int i : dv.getIndex()) {
            double sim = dv.get(i);
            double rate = this.trainMatrix.get(u, i);
            if (this.isRankingPred && rate > 0.0) {
                nns.put(i, sim);
                continue;
            }
            if (!(sim > 0.0) || !(rate > 0.0)) continue;
            nns.put(i, sim);
        }
        if (this.knn > 0 && this.knn < nns.size()) {
            List sorted = Lists.sortMap(nns, true);
            List subset = sorted.subList(0, this.knn);
            nns.clear();
            for (Map.Entry kv : subset) {
                nns.put((Integer)kv.getKey(), (Double)kv.getValue());
            }
        }
        if (nns.size() == 0) {
            return this.isRankingPred ? 0.0 : this.globalMean;
        }
        if (this.isRankingPred) {
            return Stats.sum(nns.values());
        }
        double sum = 0.0;
        double ws = 0.0;
        for (Map.Entry en : nns.entrySet()) {
            int i = (Integer)en.getKey();
            double sim = (Double)en.getValue();
            double rate = this.trainMatrix.get(u, i);
            sum += sim * (rate - this.itemMeans.get(i));
            ws += Math.abs(sim);
        }
        return ws > 0.0 ? this.itemMeans.get(j) + sum / ws : this.globalMean;
    }
}

