/*
 * Decompiled with CFR 0.152.
 */
package contraband.coalescent;

import beast.base.core.Description;
import beast.base.core.Input;
import beast.base.evolution.tree.Node;
import beast.base.evolution.tree.Tree;
import beast.base.inference.CalculationNode;
import beast.base.inference.parameter.RealParameter;
import beast.base.inference.util.InputUtil;
import contraband.coalescent.CoalUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@Description(value="Coalescent correction for continuous trait on tree")
public class CoalCorrection
extends CalculationNode {
    public final Input<Tree> treeInput = new Input("tree", "Tree object containing tree.", Input.Validate.REQUIRED);
    public final Input<RealParameter> popSizesInput = new Input("popSizes", "Temporary input containing the pop size of each node in the tree.", Input.Validate.REQUIRED);
    private Tree tree;
    private double[][] nLineageDistAtEnd;
    private double[][] correctedPhyloTMat;
    double[] nullarray;
    Double[] popSizes;
    double[] popSizesd;
    double[][][] ratesCache;
    double[][][] cisCache;
    int[][] commonAncestor;
    List<Integer>[] descendantsList;
    static double[] expy;
    static double[] intexp;
    static int N;
    boolean hasDirt = true;
    int[] leftX;
    int[] rightX;
    int[] parentIdxs;
    double[] allBrLengthsInCU;
    double[] storedPopSizesD;
    double[][] storedCorrectedPhyloMat;
    int[][] storedCommonAncestor;
    int[] storedLeftX;
    int[] storedRightX;
    int[] storedParentIdxs;
    double[] storedAllBrLengthsInCU;
    Map<X, Double> GijCache = new HashMap<X, Double>();

    public void initAndValidate() {
        int n;
        this.initExp();
        this.tree = (Tree)this.treeInput.get();
        int n2 = this.tree.getLeafNodeCount();
        int n3 = this.tree.getNodeCount();
        this.correctedPhyloTMat = new double[n2][n2];
        this.popSizes = new Double[((RealParameter)this.popSizesInput.get()).getDimension()];
        this.popSizesd = new double[this.popSizes.length];
        this.nLineageDistAtEnd = new double[n3][];
        for (n = 0; n < n2; ++n) {
            this.nLineageDistAtEnd[n] = new double[]{1.0};
        }
        for (n = n2; n < n3; ++n) {
            this.nLineageDistAtEnd[n] = new double[n3];
        }
        this.nullarray = new double[n3];
        this.ratesCache = new double[n3][n3][];
        this.cisCache = new double[n3][n3][];
        this.commonAncestor = new int[n2][n2];
        this.descendantsList = new List[n3];
        for (n = 0; n < n3; ++n) {
            this.descendantsList[n] = new ArrayList<Integer>();
        }
        for (n = 0; n < n2; ++n) {
            this.descendantsList[n].add(n);
        }
        this.leftX = new int[n3];
        this.rightX = new int[n3];
        this.allBrLengthsInCU = new double[n3];
        this.parentIdxs = new int[n3];
        if (n3 != ((Double[])((RealParameter)this.popSizesInput.get()).getValues()).length) {
            throw new RuntimeException("The number of population sizes in popSizes is different from the number of nodes in the tree.");
        }
        this.storedPopSizesD = new double[this.popSizes.length];
        this.storedCorrectedPhyloMat = new double[n2][n2];
        this.storedCommonAncestor = new int[n2][n2];
        this.storedLeftX = new int[n3];
        this.storedRightX = new int[n3];
        this.storedParentIdxs = new int[n3];
        this.storedAllBrLengthsInCU = new double[n3];
    }

    private int fillNLineageDistInPlace(Node node) {
        if (node.isLeaf()) {
            if (node.isDirectAncestor()) {
                return 0;
            }
            int n = node.getParent().getNr();
            if (this.parentIdxs[node.getNr()] != n) {
                this.parentIdxs[node.getNr()] = n;
                return -1;
            }
            return 1;
        }
        boolean bl = false;
        int n = 0;
        int n2 = node.getNr();
        double d = this.popSizesd[n2];
        Node node2 = node.getChild(0);
        int n3 = node2.getNr();
        int n4 = this.fillNLineageDistInPlace(node2);
        if (n4 < 0) {
            bl = true;
            n4 = -n4;
        }
        Node node3 = node.getChild(1);
        int n5 = node3.getNr();
        int n6 = this.fillNLineageDistInPlace(node3);
        if (n6 < 0) {
            bl = true;
            n6 = -n6;
        }
        int n7 = node.isRoot() ? -1 : node.getParent().getNr();
        double d2 = -node.getLength() / d;
        if (this.parentIdxs[n2] != n7 || this.leftX[n2] != n4 || this.rightX[n2] != n6 || this.allBrLengthsInCU[n2] != d2) {
            bl = true;
        }
        this.parentIdxs[n2] = n7;
        this.leftX[n2] = n4;
        this.rightX[n2] = n6;
        this.allBrLengthsInCU[n2] = d2;
        n = n4 + n6;
        double[] dArray = this.nLineageDistAtEnd[n3];
        double[] dArray2 = this.nLineageDistAtEnd[n5];
        double[] dArray3 = this.nLineageDistAtEnd[n2];
        System.arraycopy(this.nullarray, 0, dArray3, 0, n);
        if (n4 == 0) {
            for (int i = 1; i <= n; ++i) {
                int n8;
                double d3 = 0.0;
                for (n8 = 1; n8 <= n6; ++n8) {
                    d3 += dArray2[n8 - 1];
                }
                if (node.isRoot()) {
                    dArray3[i - 1] = d3;
                    continue;
                }
                for (n8 = 1; n8 <= i; ++n8) {
                    int n9 = n8 - 1;
                    dArray3[n9] = dArray3[n9] + d3 * this.getHeledGij(i, n8, d2);
                }
            }
        } else if (n6 == 0) {
            for (int i = 1; i <= n; ++i) {
                int n10;
                double d4 = 0.0;
                for (n10 = 1; n10 <= n4; ++n10) {
                    d4 += dArray[n10 - 1];
                }
                if (node.isRoot()) {
                    dArray3[i - 1] = d4;
                    continue;
                }
                for (n10 = 1; n10 <= i; ++n10) {
                    int n11 = n10 - 1;
                    dArray3[n11] = dArray3[n11] + d4 * this.getHeledGij(i, n10, d2);
                }
            }
        } else {
            for (int i = 2; i <= n; ++i) {
                int n12;
                double d5 = 0.0;
                for (n12 = 1; n12 <= n4; ++n12) {
                    int n13 = i - n12;
                    if (n13 <= 0 || n13 > n6) continue;
                    d5 += dArray[n12 - 1] * dArray2[n13 - 1];
                }
                if (node.isRoot()) {
                    dArray3[i - 1] = d5;
                    continue;
                }
                for (n12 = 1; n12 <= i; ++n12) {
                    int n14 = n12 - 1;
                    dArray3[n14] = dArray3[n14] + d5 * this.getHeledGij(i, n12, d2);
                }
            }
        }
        if (bl) {
            return -n;
        }
        return n;
    }

    private double getExpGenealHeightAtRoot(Tree tree) {
        Node node = tree.getRoot();
        int n = node.getNr();
        double d = 0.0;
        double d2 = Math.abs(this.fillNLineageDistInPlace(node));
        this.hasDirt = false;
        int n2 = 1;
        while ((double)n2 <= d2) {
            double d3 = this.nLineageDistAtEnd[n][n2 - 1];
            d += d3 * CoalUtils.getMeanRootHeight(n2, this.popSizesd[n]);
            ++n2;
        }
        return d;
    }

    private double getExpCoalTimePair(Node node, Double[] doubleArray) {
        int n = node.getNr();
        double d = node.getHeight();
        double d2 = 0.0;
        double d3 = 1.0;
        while (true) {
            double d4 = node.getLength();
            double d5 = doubleArray[node.getNr()];
            double d6 = 1.0 / d5;
            double d7 = 0.0;
            double d8 = 0.0;
            if (node.isRoot()) {
                d7 = d5;
                d8 = 0.0;
            } else {
                d7 = (1.0 - (d4 * d6 + 1.0) * Math.exp(-d4 * d6)) / d6;
                d8 = this.getHeledGij(2, 2, -d4 / d5);
            }
            d2 += (d + d7 / (1.0 - d8)) * (1.0 - d8) * d3;
            if (node.isRoot()) break;
            d3 *= d8;
            d += d4;
            node = node.getParent();
        }
        return d2;
    }

    private void fillPhyloTMatInPlace(String[] stringArray) {
        int n;
        this.tree = (Tree)this.treeInput.get();
        double d = this.tree.getRoot().getHeight();
        ((RealParameter)this.popSizesInput.get()).getValues((Object[])this.popSizes);
        for (int i = 0; i < this.popSizes.length; ++i) {
            this.popSizesd[i] = this.popSizes[i];
        }
        double d2 = this.getExpGenealHeightAtRoot(this.tree);
        int n2 = this.tree.getLeafNodeCount();
        this.calcCommonAncestors(this.tree.getRoot());
        Node[] nodeArray = this.tree.getNodesAsArray();
        double[] dArray = new double[nodeArray.length];
        for (n = this.tree.getLeafNodeCount(); n < nodeArray.length; ++n) {
            dArray[n] = this.getExpCoalTimePair(nodeArray[n], this.popSizes);
        }
        for (n = 0; n < n2; ++n) {
            stringArray[n] = nodeArray[n].getID();
            for (int i = n; i < n2; ++i) {
                if (n == i) {
                    this.correctedPhyloTMat[n][i] = d2 + (d - nodeArray[n].getHeight());
                    continue;
                }
                this.correctedPhyloTMat[n][i] = d + d2 - dArray[this.commonAncestor[n][i]];
                this.correctedPhyloTMat[i][n] = this.correctedPhyloTMat[n][i];
            }
        }
    }

    public double[][] getCorrectedPhyloTMat(String[] stringArray) {
        if (this.tree.getRoot().getNr() != this.tree.getNodeCount() - 1) {
            int n = 3;
            ++n;
        }
        this.fillPhyloTMatInPlace(stringArray);
        return this.correctedPhyloTMat;
    }

    private double getHeledGij(int n, int n2, double d) {
        int n3;
        double[] dArray;
        if (this.ratesCache[n][n2] == null) {
            dArray = new double[n - n2 + 1];
            int n4 = 0;
            for (int i = n2; i < n + 1; ++i) {
                dArray[n4] = (double)i * ((double)i - 1.0) / 2.0;
                ++n4;
            }
            double[] dArray2 = new double[n - n2 + 1];
            dArray2[0] = 1.0;
            int n5 = 1;
            for (n3 = 1; n3 < n - n2 + 1; ++n3) {
                double d2 = dArray[n3];
                for (int i = 0; i < n5; ++i) {
                    int n6 = i;
                    dArray2[n6] = dArray2[n6] * (d2 / (d2 - dArray[i]));
                }
                double d3 = 0.0;
                for (double d4 : dArray2) {
                    d3 += d4;
                }
                dArray2[n5] = -d3;
                ++n5;
            }
            this.ratesCache[n][n2] = dArray;
            this.cisCache[n][n2] = dArray2;
        }
        dArray = this.ratesCache[n][n2];
        double[] dArray3 = this.cisCache[n][n2];
        double d5 = 0.0;
        for (n3 = 0; n3 < dArray.length; ++n3) {
            d5 += dArray3[n3] * this.exp(dArray[n3] * d);
        }
        return d5;
    }

    protected List<Integer> calcCommonAncestors(Node node) {
        int n = node.getNr();
        if (node.isLeaf()) {
            return this.descendantsList[n];
        }
        List<Integer> list = this.calcCommonAncestors(node.getLeft());
        List<Integer> list2 = this.calcCommonAncestors(node.getRight());
        Object object = list.iterator();
        while (object.hasNext()) {
            int n2 = object.next();
            for (int n3 : list2) {
                this.commonAncestor[n2][n3] = n;
                this.commonAncestor[n3][n2] = n;
            }
        }
        object = this.descendantsList[n];
        object.clear();
        object.addAll(list);
        object.addAll(list2);
        return object;
    }

    private void initExp() {
        int n;
        expy = new double[N];
        for (n = 0; n < N; ++n) {
            CoalCorrection.expy[n] = Math.exp((double)(-n) / ((double)N - 1.0));
        }
        intexp = new double[750];
        for (n = 0; n < 750; ++n) {
            CoalCorrection.intexp[n] = Math.exp(-n);
        }
    }

    private double exp(double d) {
        if (d < -746.0) {
            return 0.0;
        }
        int n = (int)d;
        double d2 = -d % 1.0;
        int n2 = (int)(d2 * (double)N);
        return expy[n2] * intexp[-n];
    }

    public boolean requiresRecalculation() {
        if (InputUtil.isDirty(this.popSizesInput)) {
            this.hasDirt = true;
            return true;
        }
        return ((Tree)this.treeInput.get()).somethingIsDirty();
    }

    protected void store() {
        int n;
        int n2 = this.tree.getLeafNodeCount();
        for (n = 0; n < n2; ++n) {
            for (int i = 0; i < n2; ++i) {
                this.storedCorrectedPhyloMat[n][i] = this.correctedPhyloTMat[n][i];
                this.storedCommonAncestor[n][i] = this.commonAncestor[n][i];
            }
        }
        for (n = 0; n < this.tree.getNodeCount(); ++n) {
            this.storedLeftX[n] = this.leftX[n];
            this.storedRightX[n] = this.rightX[n];
            this.storedParentIdxs[n] = this.parentIdxs[n];
            this.storedAllBrLengthsInCU[n] = this.allBrLengthsInCU[n];
        }
        for (n = 0; n < ((RealParameter)this.popSizesInput.get()).getDimension(); ++n) {
            this.storedPopSizesD[n] = this.popSizesd[n];
        }
        super.store();
    }

    protected void restore() {
        double[] dArray = this.popSizesd;
        this.popSizesd = this.storedPopSizesD;
        this.storedPopSizesD = dArray;
        double[][] dArray2 = this.correctedPhyloTMat;
        this.correctedPhyloTMat = this.storedCorrectedPhyloMat;
        this.storedCorrectedPhyloMat = dArray2;
        int[] nArray = this.leftX;
        this.leftX = this.storedLeftX;
        this.storedLeftX = nArray;
        nArray = this.rightX;
        this.rightX = this.storedRightX;
        this.storedRightX = nArray;
        nArray = this.parentIdxs;
        this.parentIdxs = this.storedParentIdxs;
        this.storedParentIdxs = nArray;
        dArray = this.allBrLengthsInCU;
        this.allBrLengthsInCU = this.storedAllBrLengthsInCU;
        this.storedAllBrLengthsInCU = dArray;
        int[][] nArray2 = this.commonAncestor;
        this.commonAncestor = this.storedCommonAncestor;
        this.storedCommonAncestor = nArray2;
        this.hasDirt = true;
        super.restore();
    }

    static {
        N = 0x100000;
    }

    class X {
        Integer from;
        Integer to;
        Double scale;

        X(int n, int n2, double d) {
            this.from = n;
            this.to = n2;
            this.scale = d;
        }
    }
}

