/*
 * Decompiled with CFR 0.152.
 */
package librec.rating;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.recalot.common.configuration.Configuration;
import com.recalot.common.configuration.ConfigurationItem;
import com.recalot.common.configuration.Configurations;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.intf.IterativeRecommender;
import librec.util.Randoms;

@Configurations(value={@Configuration(key="beta", type=ConfigurationItem.ConfigurationItemType.Double, description="time decay factor"), @Configuration(key="numBins", type=ConfigurationItem.ConfigurationItemType.Integer, description="number of bins over all the items")})
public class TimeSVD
extends IterativeRecommender {
    private static int numDays;
    private DenseVector userMeanDate;
    public double beta;
    public int numBins;
    private DenseMatrix Y;
    private DenseMatrix Bit;
    private Table<Integer, Integer, Double> But;
    private DenseVector Alpha;
    private DenseMatrix Auk;
    private Map<Integer, Table<Integer, Integer, Double>> Pukt;
    private DenseVector Cu;
    private DenseMatrix Cut;

    @Override
    public void initModel() throws Exception {
        super.initModel();
        numDays = TimeSVD.days(this.maxTimestamp, this.minTimestamp) + 1;
        this.userBias = new DenseVector(this.numUsers);
        this.userBias.init();
        this.itemBias = new DenseVector(this.numItems);
        this.itemBias.init();
        this.Alpha = new DenseVector(this.numUsers);
        this.Alpha.init();
        this.Bit = new DenseMatrix(this.numItems, this.numBins);
        this.Bit.init();
        this.Y = new DenseMatrix(this.numItems, this.numFactors);
        this.Y.init();
        this.Auk = new DenseMatrix(this.numUsers, this.numFactors);
        this.Auk.init();
        this.But = HashBasedTable.create();
        this.Pukt = new HashMap<Integer, Table<Integer, Integer, Double>>();
        this.Cu = new DenseVector(this.numUsers);
        this.Cu.init();
        this.Cut = new DenseMatrix(this.numUsers, numDays);
        this.Cut.init();
        this.userItemsCache = this.trainMatrix.rowColumnsCache(this.guavaCacheSpec);
        double sum = 0.0;
        int cnt = 0;
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double rui = me.get();
            if (rui <= 0.0) continue;
            sum += (double)TimeSVD.days((long)this.timeMatrix.get(u, i), this.minTimestamp);
            ++cnt;
        }
        double globalMeanDate = sum / (double)cnt;
        this.userMeanDate = new DenseVector(this.numUsers);
        List Ru = null;
        for (int u = 0; u < this.numUsers; ++u) {
            sum = 0.0;
            Ru = (List)this.userItemsCache.get((Object)u);
            Iterator rui = Ru.iterator();
            while (rui.hasNext()) {
                int i = (Integer)rui.next();
                sum += (double)TimeSVD.days((long)this.timeMatrix.get(u, i), this.minTimestamp);
            }
            double mean = Ru.size() > 0 ? (sum + 0.0) / (double)Ru.size() : globalMeanDate;
            this.userMeanDate.set(u, mean);
        }
    }

    @Override
    public void buildModel() throws Exception {
        for (int iter = 1; iter <= this.numIters; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int i = me.column();
                double rui = me.get();
                long timestamp = (long)this.timeMatrix.get(u, i);
                int t = TimeSVD.days(timestamp, this.minTimestamp);
                int bin = this.bin(t);
                double dev_ut = this.dev(u, t);
                double bi = this.itemBias.get(i);
                double bit = this.Bit.get(i, bin);
                double bu = this.userBias.get(u);
                double cu = this.Cu.get(u);
                double cut = this.Cut.get(u, t);
                if (!this.But.contains((Object)u, (Object)t)) {
                    this.But.put((Object)u, (Object)t, (Object)Randoms.random());
                }
                double but = (Double)this.But.get((Object)u, (Object)t);
                double au = this.Alpha.get(u);
                double pui = this.globalMean + (bi + bit) * (cu + cut);
                pui += bu + au * dev_ut + but;
                List Ru = (List)this.userItemsCache.get((Object)u);
                double sum_y = 0.0;
                Iterator iterator = Ru.iterator();
                while (iterator.hasNext()) {
                    int j = (Integer)iterator.next();
                    sum_y += DenseMatrix.rowMult(this.Y, j, this.Q, i);
                }
                double wi = Ru.size() > 0 ? Math.pow(Ru.size(), -0.5) : 0.0;
                pui += sum_y * wi;
                if (!this.Pukt.containsKey(u)) {
                    HashBasedTable data = HashBasedTable.create();
                    this.Pukt.put(u, (Table<Integer, Integer, Double>)data);
                }
                Table<Integer, Integer, Double> Pkt = this.Pukt.get(u);
                for (int k = 0; k < this.numFactors; ++k) {
                    double qik = this.Q.get(i, k);
                    if (!Pkt.contains((Object)k, (Object)t)) {
                        Pkt.put((Object)k, (Object)t, (Object)Randoms.random());
                    }
                    double puk = this.P.get(u, k) + this.Auk.get(u, k) * dev_ut + (Double)Pkt.get((Object)k, (Object)t);
                    pui += puk * qik;
                }
                double eui = pui - rui;
                this.loss += eui * eui;
                double sgd = eui * (cu + cut) + this.regB * bi;
                this.itemBias.add(i, -this.lRate * sgd);
                this.loss += this.regB * bi * bi;
                sgd = eui * (cu + cut) + this.regB * bit;
                this.Bit.add(i, bin, -this.lRate * sgd);
                this.loss += this.regB * bit * bit;
                sgd = eui * (bi + bit) + this.regB * cu;
                this.Cu.add(u, -this.lRate * sgd);
                this.loss += this.regB * cu * cu;
                sgd = eui * (bi + bit) + this.regB * cut;
                this.Cut.add(u, t, -this.lRate * sgd);
                this.loss += this.regB * cut * cut;
                sgd = eui + this.regB * bu;
                this.userBias.add(u, -this.lRate * sgd);
                this.loss += this.regB * bu * bu;
                sgd = eui * dev_ut + this.regB * au;
                this.Alpha.add(u, -this.lRate * sgd);
                this.loss += this.regB * au * au;
                sgd = eui + this.regB * but;
                double delta = but - this.lRate * sgd;
                this.But.put((Object)u, (Object)t, (Object)delta);
                this.loss += this.regB * but * but;
                for (int k = 0; k < this.numFactors; ++k) {
                    int j;
                    double qik = this.Q.get(i, k);
                    double puk = this.P.get(u, k);
                    double auk = this.Auk.get(u, k);
                    double pkt = (Double)Pkt.get((Object)k, (Object)t);
                    double pukt = puk + auk * dev_ut + pkt;
                    double sum_yk = 0.0;
                    Iterator iterator2 = Ru.iterator();
                    while (iterator2.hasNext()) {
                        j = (Integer)iterator2.next();
                        sum_yk += this.Y.get(j, k);
                    }
                    sgd = eui * (pukt + wi * sum_yk) + this.regI * qik;
                    this.Q.add(i, k, -this.lRate * sgd);
                    this.loss += this.regI * qik * qik;
                    sgd = eui * qik + this.regU * puk;
                    this.P.add(u, k, -this.lRate * sgd);
                    this.loss += this.regU * puk * puk;
                    sgd = eui * qik * dev_ut + this.regU * auk;
                    this.Auk.add(u, k, -this.lRate * sgd);
                    this.loss += this.regU * auk * auk;
                    sgd = eui * qik + this.regU * pkt;
                    delta = pkt - this.lRate * sgd;
                    Pkt.put((Object)k, (Object)t, (Object)delta);
                    this.loss += this.regU * pkt * pkt;
                    iterator2 = Ru.iterator();
                    while (iterator2.hasNext()) {
                        j = (Integer)iterator2.next();
                        double yjk = this.Y.get(j, k);
                        sgd = eui * wi * qik + this.regI * yjk;
                        this.Y.add(j, k, -this.lRate * sgd);
                        this.loss += this.regI * yjk * yjk;
                    }
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
        }
    }

    @Override
    public double predict(int u, int i) throws Exception {
        long timestamp = (long)this.timeMatrix.get(u, i);
        int t = TimeSVD.days(timestamp, this.minTimestamp);
        int bin = this.bin(t);
        double dev_ut = this.dev(u, t);
        double pred = this.globalMean;
        pred += (this.itemBias.get(i) + this.Bit.get(i, bin)) * (this.Cu.get(u) + this.Cut.get(u, t));
        double but = this.But.contains((Object)u, (Object)t) ? (Double)this.But.get((Object)u, (Object)t) : 0.0;
        pred += this.userBias.get(u) + this.Alpha.get(u) * dev_ut + but;
        List Ru = (List)this.userItemsCache.get((Object)u);
        double sum_y = 0.0;
        Iterator iterator = Ru.iterator();
        while (iterator.hasNext()) {
            int j = (Integer)iterator.next();
            sum_y += DenseMatrix.rowMult(this.Y, j, this.Q, i);
        }
        double wi = Ru.size() > 0 ? Math.pow(Ru.size(), -0.5) : 0.0;
        pred += sum_y * wi;
        for (int k = 0; k < this.numFactors; ++k) {
            Table<Integer, Integer, Double> pkt;
            double qik = this.Q.get(i, k);
            double puk = this.P.get(u, k) + this.Auk.get(u, k) * dev_ut;
            if (this.Pukt.containsKey(u) && (pkt = this.Pukt.get(u)) != null) {
                puk += pkt.contains((Object)k, (Object)t) ? (Double)pkt.get((Object)k, (Object)t) : 0.0;
            }
            pred += puk * qik;
        }
        return pred;
    }

    protected double dev(int u, int t) {
        double tu = this.userMeanDate.get(u);
        double diff = (double)t - tu;
        return Math.signum(diff) * Math.pow(Math.abs(diff), this.beta);
    }

    protected int bin(int day) {
        return (int)((double)day / ((double)numDays + 0.0) * (double)this.numBins);
    }

    protected static int days(long diff) {
        return (int)TimeUnit.MILLISECONDS.toDays(diff);
    }

    protected static int days(long t1, long t2) {
        return TimeSVD.days(Math.abs(t1 - t2));
    }
}

