/*
 * Decompiled with CFR 0.152.
 */
package contraband.utils;

import beast.base.evolution.tree.Tree;
import beast.base.inference.parameter.RealParameter;
import contraband.math.MatrixUtilsContra;
import contraband.math.NodeMath;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.stat.descriptive.moment.Variance;
import org.apache.commons.math3.stat.descriptive.rank.Median;
import org.apache.commons.math3.util.FastMath;

public class PruneLikelihoodUtils {
    private static double LOGTWOPI = FastMath.log((double)(Math.PI * 2));

    public static void populateTraitValuesArr(RealParameter realParameter, Tree tree, int n, double[] dArray) {
        for (int i = 0; i < tree.getLeafNodeCount(); ++i) {
            Double[] doubleArray = (Double[])realParameter.getRowValues(tree.getNode(i).getID());
            for (int j = 0; j < n; ++j) {
                dArray[i * n + j] = doubleArray[j];
            }
        }
    }

    public static void populateTraitValuesArr(RealParameter realParameter, Tree tree, NodeMath nodeMath, int n, double[] dArray) {
        for (int i = 0; i < tree.getLeafNodeCount(); ++i) {
            Double[] doubleArray = (Double[])realParameter.getRowValues(tree.getNode(i).getID());
            if (doubleArray == null) {
                nodeMath.setSpeciesToIgnore(i);
                continue;
            }
            for (int j = 0; j < n; ++j) {
                dArray[i * n + j] = doubleArray[j];
            }
        }
    }

    public static void populateTraitValuesMatrix(RealParameter realParameter, Tree tree, int n, RealMatrix realMatrix) {
        int n2 = 0;
        for (int i = 0; i < tree.getLeafNodeCount(); ++i) {
            Double[] doubleArray = (Double[])realParameter.getRowValues(tree.getNode(i).getID());
            if (doubleArray == null) continue;
            for (int j = 0; j < n; ++j) {
                realMatrix.setEntry(n2, j, doubleArray[j].doubleValue());
            }
            ++n2;
        }
    }

    public static void populateACEf(NodeMath nodeMath, double d, int n, int n2) {
        double d2 = 1.0 / d;
        double d3 = -0.5 * d2;
        nodeMath.setAForNode(n2, d3);
        nodeMath.setEForNode(n2, d2);
        nodeMath.setCForNode(n2, d3);
        nodeMath.setfForNode(n2, -0.5 * ((double)n * LOGTWOPI + (double)n * Math.log(d) + nodeMath.getTraitRateMatrixDeterminant()));
    }

    public static void populateLmrForTip(NodeMath nodeMath, double[] dArray, int n, int n2) {
        nodeMath.setTraitsVecForTip(dArray, n2);
        nodeMath.setExpectationForTip(n2);
        nodeMath.setLForNode(n2, nodeMath.getCForNode(n2));
        nodeMath.setRForNode(n2, nodeMath.getAForNode(n2) * MatrixUtilsContra.tVecDotMatrixDotVec(nodeMath.getTraitsVec(), nodeMath.getTraitRateMatrixInverse(), n) + nodeMath.getfForNode(n2));
        MatrixUtilsContra.matrixPreMultiply(nodeMath.getTraitsVec(), nodeMath.getTraitRateMatrixInverse(), n, n, nodeMath.getTempVec());
        MatrixUtilsContra.vectorMapMultiply(nodeMath.getTempVec(), nodeMath.getEForNode(n2), nodeMath.getTempVec());
        nodeMath.setMVecForNode(n2, nodeMath.getTempVec());
    }

    public static void populateLmrForTipWithShrinkage(NodeMath nodeMath, double[] dArray, int n, int n2) {
        nodeMath.setTraitsVecForTip(dArray, n2);
        nodeMath.setExpectationForTip(n2);
        nodeMath.setLForNode(n2, nodeMath.getCForNode(n2));
        nodeMath.setRForNode(n2, MatrixUtilsContra.vecTransScalarMultiply(nodeMath.getTraitsVec(), nodeMath.getAForNode(n2), n) + nodeMath.getfForNode(n2));
        MatrixUtilsContra.vectorMapMultiply(nodeMath.getTraitsVec(), nodeMath.getEForNode(n2), nodeMath.getTempVec());
        nodeMath.setMVecForNode(n2, nodeMath.getTempVec());
    }

    public static void populateLmrForInternalNode(NodeMath nodeMath, int n, int n2) {
        double d = nodeMath.getAForNode(n2) + nodeMath.getLForNode(n2);
        double d2 = 1.0 / d;
        double d3 = 0.5 * (double)n * FastMath.log((double)(-2.0 * d)) + 0.5 * FastMath.log((double)nodeMath.getTraitRateMatrixInverseDeterminant());
        double[] dArray = nodeMath.getMVecForNode(n2);
        nodeMath.setRForNode(n2, nodeMath.getRForNode(n2) + nodeMath.getfForNode(n2) + 0.5 * (double)n * LOGTWOPI - d3 - 0.25 * d2 * MatrixUtilsContra.tVecDotMatrixDotVec(dArray, nodeMath.getTraitRateMatrix(), n));
        double d4 = nodeMath.getEForNode(n2) * d2;
        MatrixUtilsContra.vectorMapMultiply(dArray, -0.5 * d4, nodeMath.getTempVec());
        nodeMath.setMVecForNode(n2, nodeMath.getTempVec());
        nodeMath.setLForNode(n2, nodeMath.getCForNode(n2) - 0.25 * d4 * nodeMath.getEForNode(n2));
    }

    public static void populateLmrForInternalNodeWithShrinkage(NodeMath nodeMath, int n, int n2) {
        double d = nodeMath.getAForNode(n2) + nodeMath.getLForNode(n2);
        double d2 = 1.0 / d;
        double d3 = (double)n * FastMath.log((double)(-2.0 * d)) + nodeMath.getTraitRateMatrixInverseDeterminant();
        double[] dArray = nodeMath.getMVecForNode(n2);
        nodeMath.setRForNode(n2, nodeMath.getRForNode(n2) + nodeMath.getfForNode(n2) + 0.5 * (double)n * LOGTWOPI - 0.5 * d3 - 0.25 * MatrixUtilsContra.vecTransScalarMultiply(dArray, d2, n));
        double d4 = nodeMath.getEForNode(n2) * d2;
        MatrixUtilsContra.vectorMapMultiply(dArray, -0.5 * d4, nodeMath.getTempVec());
        nodeMath.setMVecForNode(n2, nodeMath.getTempVec());
        nodeMath.setLForNode(n2, nodeMath.getCForNode(n2) - 0.25 * d4 * nodeMath.getEForNode(n2));
    }

    public static RealMatrix populateTraitValueMatrixEstimatedPopulationVariance(RealMatrix realMatrix, RealMatrix realMatrix2, int n, double d) {
        Variance variance = new Variance();
        Median median = new Median();
        double[] dArray = new double[n];
        double[] dArray2 = new double[n];
        for (int i = 0; i < n; ++i) {
            dArray[i] = variance.evaluate(realMatrix.getColumn(i));
        }
        double d2 = median.evaluate(dArray);
        for (int i = 0; i < n; ++i) {
            dArray2[i] = 1.0 / Math.sqrt(d * d2 + (1.0 - d) * dArray[i]);
        }
        return realMatrix2.multiply(MatrixUtils.createRealDiagonalMatrix((double[])dArray2));
    }

    public static RealMatrix populateTraitValueMatrixGivenPopulationVariance(RealMatrix realMatrix, double d, int n) {
        double[] dArray = new double[n];
        for (int i = 0; i < n; ++i) {
            dArray[i] = 1.0 / Math.sqrt(d);
        }
        return realMatrix.multiply(MatrixUtils.createRealDiagonalMatrix((double[])dArray));
    }
}

