/*
 * Decompiled with CFR 0.152.
 */
package librec.data;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Table;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.data.TensorEntry;
import librec.util.Logs;
import librec.util.Randoms;

public class SparseTensor
implements Iterable<TensorEntry>,
Serializable {
    private static final long serialVersionUID = 2487513413901432943L;
    private int numDimensions;
    private int[] dimensions;
    private List<Integer>[] ndKeys;
    private List<Double> values;
    private Multimap<Integer, Integer>[] keyIndices;
    private List<Integer> indexedDimensions;
    private int userDimension;
    private int itemDimension;

    public SparseTensor(int ... dims) {
        this(dims, null, null);
    }

    public SparseTensor(int[] dims, List<Integer>[] nds, List<Double> vals) {
        if (dims.length < 3) {
            throw new Error("The dimension of a tensor cannot be smaller than 3!");
        }
        this.numDimensions = dims.length;
        this.dimensions = new int[this.numDimensions];
        this.ndKeys = new List[this.numDimensions];
        this.keyIndices = new Multimap[this.numDimensions];
        for (int d = 0; d < this.numDimensions; ++d) {
            this.dimensions[d] = dims[d];
            this.ndKeys[d] = nds == null ? new ArrayList<Integer>() : new ArrayList<Integer>(nds[d]);
            this.keyIndices[d] = HashMultimap.create();
        }
        this.values = vals == null ? new ArrayList<Double>() : new ArrayList<Double>(vals);
        this.indexedDimensions = new ArrayList<Integer>(this.numDimensions);
    }

    public SparseTensor clone() {
        SparseTensor res = new SparseTensor(this.dimensions);
        for (int d = 0; d < this.numDimensions; ++d) {
            res.ndKeys[d].addAll(this.ndKeys[d]);
            res.keyIndices[d].putAll(this.keyIndices[d]);
        }
        res.values.addAll(this.values);
        res.indexedDimensions.addAll(this.indexedDimensions);
        res.userDimension = this.userDimension;
        res.itemDimension = this.itemDimension;
        return res;
    }

    public void add(double val, int ... keys) throws Exception {
        int index = this.findIndex(keys);
        if (index >= 0) {
            this.values.set(index, this.values.get(index) + val);
        } else {
            this.set(val, keys);
        }
    }

    public void set(double val, int ... keys) throws Exception {
        int index = this.findIndex(keys);
        if (index >= 0) {
            this.values.set(index, val);
            return;
        }
        for (int d = 0; d < this.numDimensions; ++d) {
            this.ndKeys[d].add(keys[d]);
            if (!this.isIndexed(d)) continue;
            this.keyIndices[d].put((Object)keys[d], (Object)(this.ndKeys[d].size() - 1));
        }
        this.values.add(val);
    }

    public boolean remove(int ... keys) throws Exception {
        int index = this.findIndex(keys);
        if (index < 0) {
            return false;
        }
        for (int d = 0; d < this.numDimensions; ++d) {
            this.ndKeys[d].remove(index);
            if (!this.isIndexed(d)) continue;
            this.buildIndex(d);
        }
        this.values.remove(index);
        return true;
    }

    public List<Integer> getIndices(int user, int item) {
        ArrayList<Integer> res = new ArrayList<Integer>();
        Collection<Integer> indices = this.getIndex(this.userDimension, user);
        for (int index : indices) {
            if (this.key(this.itemDimension, index) != item) continue;
            res.add(index);
        }
        return res;
    }

    private int findIndex(int ... keys) throws Exception {
        int d;
        Collection indices;
        if (keys.length != this.numDimensions) {
            throw new Exception("The given input does not match with the tensor dimension!");
        }
        if (this.values.size() == 0) {
            return -1;
        }
        if (this.indexedDimensions.size() == 0) {
            this.buildIndex(0);
        }
        if ((indices = this.keyIndices[d = this.indexedDimensions.get(0).intValue()].get((Object)keys[d])) == null || indices.size() == 0) {
            return -1;
        }
        Iterator iterator = indices.iterator();
        while (iterator.hasNext()) {
            int index = (Integer)iterator.next();
            boolean found = true;
            for (int dd = 0; dd < this.numDimensions; ++dd) {
                if (keys[dd] == this.key(dd, index)) continue;
                found = false;
                break;
            }
            if (!found) continue;
            return index;
        }
        return -1;
    }

    public SparseVector fiber(int dim, int ... keys) {
        if (keys.length != this.numDimensions - 1 || this.size() < 1) {
            throw new Error("The input indices do not match the fiber specification!");
        }
        int d = -1;
        if (this.indexedDimensions.size() == 0 || this.indexedDimensions.contains(dim) && this.indexedDimensions.size() == 1) {
            d = dim != 0 ? 0 : 1;
            this.buildIndex(d);
        } else {
            for (int dd : this.indexedDimensions) {
                if (dd == dim) continue;
                d = dd;
                break;
            }
        }
        SparseVector res = new SparseVector(this.dimensions[dim]);
        Collection indices = this.keyIndices[d].get((Object)keys[d < dim ? d : d - 1]);
        if (indices == null || indices.size() == 0) {
            return res;
        }
        Iterator iterator = indices.iterator();
        while (iterator.hasNext()) {
            int index = (Integer)iterator.next();
            boolean found = true;
            int ndi = 0;
            for (int dd = 0; dd < this.numDimensions; ++dd) {
                if (dd == dim || keys[ndi++] == this.key(dd, index)) continue;
                found = false;
                break;
            }
            if (!found) continue;
            res.set(this.key(dim, index), this.value(index));
        }
        return res;
    }

    public boolean contains(int ... keys) throws Exception {
        return this.findIndex(keys) >= 0;
    }

    public boolean isIndexed(int d) {
        return this.indexedDimensions.contains(d);
    }

    public boolean isCubical() {
        int dim = this.dimensions[0];
        for (int d = 1; d < this.numDimensions; ++d) {
            if (dim == this.dimensions[d]) continue;
            return false;
        }
        return true;
    }

    public boolean isDiagonal() {
        for (TensorEntry te : this) {
            double val = te.get();
            if (val == 0.0) continue;
            int i = te.key(0);
            for (int d = 0; d < this.numDimensions; ++d) {
                int j = te.key(d);
                if (i == j) continue;
                return false;
            }
        }
        return true;
    }

    public double get(int ... keys) throws Exception {
        assert (keys.length == this.numDimensions);
        int index = this.findIndex(keys);
        return index < 0 ? 0.0 : this.values.get(index);
    }

    public void shuffle() {
        int len = this.size();
        for (int i = 0; i < len; ++i) {
            int j = i + Randoms.uniform(len - i);
            double temp = this.values.get(i);
            this.values.set(i, this.values.get(j));
            this.values.set(j, temp);
            for (int d = 0; d < this.numDimensions; ++d) {
                int ikey = this.key(d, i);
                int jkey = this.key(d, j);
                this.ndKeys[d].set(i, jkey);
                this.ndKeys[d].set(j, ikey);
                if (!this.isIndexed(d)) continue;
                this.keyIndices[d].remove((Object)jkey, (Object)j);
                this.keyIndices[d].put((Object)jkey, (Object)i);
                this.keyIndices[d].remove((Object)ikey, (Object)i);
                this.keyIndices[d].put((Object)ikey, (Object)j);
            }
        }
    }

    public void buildIndex(int ... dims) {
        for (int d : dims) {
            this.keyIndices[d].clear();
            for (int index = 0; index < this.ndKeys[d].size(); ++index) {
                this.keyIndices[d].put((Object)this.key(d, index), (Object)index);
            }
            if (this.indexedDimensions.contains(d)) continue;
            this.indexedDimensions.add(d);
        }
    }

    public void buildIndices() {
        int d = 0;
        while (d < this.numDimensions) {
            this.buildIndex(d++);
        }
    }

    public Collection<Integer> getIndex(int d, int key) {
        if (!this.isIndexed(d)) {
            this.buildIndex(d);
        }
        return this.keyIndices[d].get((Object)key);
    }

    public int[] keys(int index) {
        int[] res = new int[this.numDimensions];
        for (int d = 0; d < this.numDimensions; ++d) {
            res[d] = this.key(d, index);
        }
        return res;
    }

    public int key(int d, int index) {
        return this.ndKeys[d].get(index);
    }

    public double value(int index) {
        return this.values.get(index);
    }

    public List<Integer> getRelevantKeys(int sd, int key, int td) {
        Collection<Integer> indices = this.getIndex(sd, key);
        ArrayList<Integer> res = null;
        if (indices != null) {
            res = new ArrayList<Integer>();
            for (int index : indices) {
                res.add(this.key(td, index));
            }
        }
        return res;
    }

    public int size() {
        return this.values.size();
    }

    public SparseMatrix slice(int rowDim, int colDim, int ... otherKeys) {
        Collection indices;
        boolean cond3;
        if (otherKeys.length != this.numDimensions - 2) {
            throw new Error("The input dimensions do not match the tensor specification!");
        }
        int d = -1;
        boolean cond1 = this.indexedDimensions.size() == 0;
        boolean cond2 = (this.indexedDimensions.contains(rowDim) || this.indexedDimensions.contains(colDim)) && this.indexedDimensions.size() == 1;
        boolean bl = cond3 = this.indexedDimensions.contains(rowDim) && this.indexedDimensions.contains(colDim) && this.indexedDimensions.size() == 2;
        if (cond1 || cond2 || cond3) {
            for (d = 0; d < this.numDimensions && (d == rowDim || d == colDim); ++d) {
            }
            this.buildIndex(d);
        } else {
            for (int dd : this.indexedDimensions) {
                if (dd == rowDim || dd == colDim) continue;
                d = dd;
                break;
            }
        }
        int key = -1;
        int i = 0;
        for (int dim = 0; dim < this.numDimensions; ++dim) {
            if (dim == rowDim || dim == colDim) continue;
            if (dim == d) {
                key = otherKeys[i];
                break;
            }
            ++i;
        }
        if ((indices = this.keyIndices[d].get((Object)key)) == null || indices.size() == 0) {
            return null;
        }
        HashBasedTable dataTable = HashBasedTable.create();
        HashMultimap colMap = HashMultimap.create();
        Iterator iterator = indices.iterator();
        while (iterator.hasNext()) {
            int index = (Integer)iterator.next();
            boolean found = true;
            int j = 0;
            for (int dd = 0; dd < this.numDimensions; ++dd) {
                if (dd == rowDim || dd == colDim || otherKeys[j++] == this.key(dd, index)) continue;
                found = false;
                break;
            }
            if (!found) continue;
            int row = this.ndKeys[rowDim].get(index);
            int col = this.ndKeys[colDim].get(index);
            double val = this.values.get(index);
            dataTable.put((Object)row, (Object)col, (Object)val);
            colMap.put((Object)col, (Object)row);
        }
        return new SparseMatrix(this.dimensions[rowDim], this.dimensions[colDim], (Table<Integer, Integer, ? extends Number>)dataTable, (Multimap<Integer, Integer>)colMap);
    }

    public SparseMatrix matricization(int n) {
        int numRows = this.dimensions[n];
        int numCols = 1;
        for (int d = 0; d < this.numDimensions; ++d) {
            if (d == n) continue;
            numCols *= this.dimensions[d];
        }
        HashBasedTable dataTable = HashBasedTable.create();
        HashMultimap colMap = HashMultimap.create();
        for (TensorEntry te : this) {
            int[] keys = te.keys();
            int i = keys[n];
            int j = 0;
            for (int k = 0; k < this.numDimensions; ++k) {
                if (k == n) continue;
                int ik = keys[k];
                int jk = 1;
                for (int m = 0; m < k; ++m) {
                    if (m == n) continue;
                    jk *= this.dimensions[m];
                }
                j += ik * jk;
            }
            dataTable.put((Object)i, (Object)j, (Object)te.get());
            colMap.put((Object)j, (Object)i);
        }
        return new SparseMatrix(numRows, numCols, (Table<Integer, Integer, ? extends Number>)dataTable, (Multimap<Integer, Integer>)colMap);
    }

    public SparseMatrix rateMatrix() {
        HashBasedTable dataTable = HashBasedTable.create();
        HashMultimap colMap = HashMultimap.create();
        for (TensorEntry te : this) {
            int u = te.key(this.userDimension);
            int i = te.key(this.itemDimension);
            dataTable.put((Object)u, (Object)i, (Object)te.get());
            colMap.put((Object)i, (Object)u);
        }
        return new SparseMatrix(this.dimensions[this.userDimension], this.dimensions[this.itemDimension], (Table<Integer, Integer, ? extends Number>)dataTable, (Multimap<Integer, Integer>)colMap);
    }

    @Override
    public Iterator<TensorEntry> iterator() {
        return new TensorIterator();
    }

    public double norm() {
        double res = 0.0;
        for (double val : this.values) {
            res += val * val;
        }
        return Math.sqrt(res);
    }

    public double innerProduct(SparseTensor st) throws Exception {
        if (!this.isDimMatch(st)) {
            throw new Exception("The dimensions of two sparse tensors do not match!");
        }
        double res = 0.0;
        for (TensorEntry te : this) {
            double v1 = te.get();
            double v2 = st.get(te.keys());
            res += v1 * v2;
        }
        return res;
    }

    public boolean isDimMatch(SparseTensor st) {
        if (this.numDimensions != st.numDimensions) {
            return false;
        }
        boolean match = true;
        for (int d = 0; d < this.numDimensions; ++d) {
            if (this.dimensions[d] == st.dimensions[d]) continue;
            match = false;
            break;
        }
        return match;
    }

    public int getUserDimension() {
        return this.userDimension;
    }

    public void setUserDimension(int userDimension) {
        this.userDimension = userDimension;
    }

    public int getItemDimension() {
        return this.itemDimension;
    }

    public void setItemDimension(int itemDimension) {
        this.itemDimension = itemDimension;
    }

    public int[] dimensions() {
        return this.dimensions;
    }

    public int numDimensions() {
        return this.numDimensions;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("N-Dimension: ").append(this.numDimensions).append(", Size: ").append(this.size()).append("\n");
        for (int index = 0; index < this.values.size(); ++index) {
            for (int d = 0; d < this.numDimensions; ++d) {
                sb.append(this.key(d, index)).append("\t");
            }
            sb.append(this.value(index)).append("\n");
        }
        return sb.toString();
    }

    public static void main(String[] args) throws Exception {
        SparseTensor st = new SparseTensor(4, 4, 6);
        st.set(1.0, 1, 0, 0);
        st.set(1.5, 1, 0, 0);
        st.set(2.0, 1, 1, 0);
        st.set(3.0, 2, 0, 0);
        st.set(4.0, 1, 3, 0);
        st.set(5.0, 1, 0, 5);
        st.set(6.0, 3, 1, 4);
        Logs.debug(st);
        Logs.debug("Keys (1, 0, 0) = {}", (Object)st.get(1, 0, 0));
        Logs.debug("Keys (1, 1, 0) = {}", (Object)st.get(1, 1, 0));
        Logs.debug("Keys (1, 2, 0) = {}", (Object)st.get(1, 2, 0));
        Logs.debug("Keys (2, 0, 0) = {}", (Object)st.get(2, 0, 0));
        Logs.debug("Keys (1, 0, 6) = {}", (Object)st.get(1, 0, 6));
        Logs.debug("Keys (3, 1, 4) = {}", (Object)st.get(3, 1, 4));
        Logs.debug("Index of dimension 0 key 1 = {}", st.getIndex(0, 1));
        Logs.debug("Index of dimension 1 key 3 = {}", st.getIndex(1, 3));
        Logs.debug("Index of dimension 2 key 1 = {}", st.getIndex(2, 1));
        Logs.debug("Index of dimension 2 key 6 = {}", st.getIndex(2, 6));
        st.set(4.5, 2, 1, 1);
        Logs.debug(st);
        Logs.debug("Index of dimension 2 key 1 = {}", st.getIndex(2, 1));
        st.remove(2, 1, 1);
        Logs.debug("Index of dimension 2 key 1 = {}", st.getIndex(2, 1));
        Logs.debug("Index of keys (1, 2, 0) = {}, value = {}", st.findIndex(1, 2, 0), st.get(1, 2, 0));
        Logs.debug("Index of keys (3, 1, 4) = {}, value = {}", st.findIndex(3, 1, 4), st.get(3, 1, 4));
        Logs.debug("Keys in dimension 2 associated with dimension 0 key 1 = {}", st.getRelevantKeys(0, 1, 2));
        Logs.debug("norm = {}", (Object)st.norm());
        SparseTensor st2 = st.clone();
        Logs.debug("make a clone = {}", (Object)st2);
        Logs.debug("inner with the clone = {}", (Object)st.innerProduct(st2));
        st.set(2.5, 1, 0, 0);
        st2.remove(1, 0, 0);
        Logs.debug("st1 = {}", (Object)st);
        Logs.debug("st2 = {}", (Object)st2);
        Logs.debug("fiber (0, 0, 0) = {}", (Object)st.fiber(0, 0, 0));
        Logs.debug("fiber (1, 1, 0) = {}", (Object)st.fiber(1, 1, 0));
        Logs.debug("fiber (2, 1, 0) = {}", (Object)st.fiber(2, 1, 0));
        Logs.debug("slice (0, 1, 0) = {}", (Object)st.slice(0, 1, 0));
        Logs.debug("slice (0, 2, 1) = {}", (Object)st.slice(0, 2, 1));
        Logs.debug("slice (1, 2, 1) = {}", (Object)st.slice(1, 2, 1));
        for (TensorEntry te : st) {
            te.set(te.get() + 0.588);
        }
        Logs.debug("Before shuffle: {}", (Object)st);
        st.shuffle();
        Logs.debug("After shuffle: {}", (Object)st);
        st = new SparseTensor(3, 4, 2);
        st.set(1.0, 0, 0, 0);
        st.set(4.0, 0, 1, 0);
        st.set(7.0, 0, 2, 0);
        st.set(10.0, 0, 3, 0);
        st.set(2.0, 1, 0, 0);
        st.set(5.0, 1, 1, 0);
        st.set(8.0, 1, 2, 0);
        st.set(11.0, 1, 3, 0);
        st.set(3.0, 2, 0, 0);
        st.set(6.0, 2, 1, 0);
        st.set(9.0, 2, 2, 0);
        st.set(12.0, 2, 3, 0);
        st.set(13.0, 0, 0, 1);
        st.set(16.0, 0, 1, 1);
        st.set(19.0, 0, 2, 1);
        st.set(22.0, 0, 3, 1);
        st.set(14.0, 1, 0, 1);
        st.set(17.0, 1, 1, 1);
        st.set(20.0, 1, 2, 1);
        st.set(23.0, 1, 3, 1);
        st.set(15.0, 2, 0, 1);
        st.set(18.0, 2, 1, 1);
        st.set(21.0, 2, 2, 1);
        st.set(24.0, 2, 3, 1);
        Logs.debug("A new tensor = {}", (Object)st);
        Logs.debug("Mode X0 unfoldings = {}", (Object)st.matricization(0));
        Logs.debug("Mode X1 unfoldings = {}", (Object)st.matricization(1));
        Logs.debug("Mode X2 unfoldings = {}", (Object)st.matricization(2));
    }

    private class SparseTensorEntry
    implements TensorEntry {
        private int index = -1;

        private SparseTensorEntry() {
        }

        public SparseTensorEntry update(int index) {
            this.index = index;
            return this;
        }

        @Override
        public int key(int d) {
            return (Integer)SparseTensor.this.ndKeys[d].get(this.index);
        }

        @Override
        public double get() {
            return (Double)SparseTensor.this.values.get(this.index);
        }

        @Override
        public void set(double value) {
            SparseTensor.this.values.set(this.index, value);
        }

        @Override
        public void remove() {
            for (int d = 0; d < SparseTensor.this.numDimensions; ++d) {
                if (SparseTensor.this.isIndexed(d)) {
                    SparseTensor.this.keyIndices[d].remove((Object)this.key(d), (Object)this.index);
                }
                SparseTensor.this.ndKeys[d].remove(this.index);
            }
            SparseTensor.this.values.remove(this.index);
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            for (int d = 0; d < SparseTensor.this.numDimensions; ++d) {
                sb.append(this.key(d)).append("\t");
            }
            sb.append(this.get());
            return sb.toString();
        }

        @Override
        public int[] keys() {
            int[] res = new int[SparseTensor.this.numDimensions];
            for (int d = 0; d < SparseTensor.this.numDimensions; ++d) {
                res[d] = this.key(d);
            }
            return res;
        }
    }

    private class TensorIterator
    implements Iterator<TensorEntry> {
        private int index = 0;
        private SparseTensorEntry entry = new SparseTensorEntry();

        private TensorIterator() {
        }

        @Override
        public boolean hasNext() {
            return this.index < SparseTensor.this.values.size();
        }

        @Override
        public TensorEntry next() {
            return this.entry.update(this.index++);
        }

        @Override
        public void remove() {
            this.entry.remove();
        }
    }
}

