/*
 * Decompiled with CFR 0.152.
 */
package contraband.test;

import beast.base.evolution.tree.Tree;
import beast.base.evolution.tree.TreeParser;
import beast.base.inference.parameter.RealParameter;
import contraband.math.MatrixUtilsContra;
import contraband.math.NodeMath;
import contraband.utils.PruneLikelihoodUtils;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.junit.Assert;
import org.junit.Test;

public class PruneLikelihoodUtilsTest {
    private TreeParser tree;
    private String treeStr;
    private Integer nTraits;
    private String spNames;
    private List<Double> data;
    private final RealParameter traitValues = new RealParameter();
    static final double EPSILON = 1.0E-7;

    @Test
    public void testPopulateTraitValues() {
        this.treeStr = "((sp3:0.1201847336,sp4:0.1201847336):0.8798152664,(sp1:0.9098543416,sp2:0.9098543416):0.09014565845):0.0;";
        this.tree = new TreeParser(this.treeStr, false, false, true, 0);
        this.nTraits = 2;
        this.spNames = "sp3 sp4 sp1 sp2";
        int n = 4;
        this.data = Arrays.asList(0.983714690867666, -7.54729477473779, -7.86424514338822, -2.97908131550921, 7.23079460908758, -0.780647498381348, -1.39605330265115, 3.72028693114977);
        this.traitValues.initByName(new Object[]{"value", this.data, "keys", this.spNames, "minordimension", this.nTraits});
        double[] dArray = new double[n * this.nTraits];
        PruneLikelihoodUtils.populateTraitValuesArr(this.traitValues, (Tree)this.tree, this.nTraits, dArray);
        Assert.assertArrayEquals((Object[])new Double[]{7.23079460908758, -0.780647498381348, -1.39605330265115, 3.72028693114977, 0.983714690867666, -7.54729477473779, -7.86424514338822, -2.97908131550921}, (Object[])ArrayUtils.toObject((double[])dArray));
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(new double[n][this.nTraits.intValue()]);
        PruneLikelihoodUtils.populateTraitValuesMatrix(this.traitValues, (Tree)this.tree, this.nTraits, (RealMatrix)array2DRowRealMatrix);
        Assert.assertArrayEquals((Object[])new Double[]{7.23079460908758, -0.780647498381348}, (Object[])ArrayUtils.toObject((double[])array2DRowRealMatrix.getRow(0)));
        Assert.assertArrayEquals((Object[])new Double[]{-1.39605330265115, 3.72028693114977}, (Object[])ArrayUtils.toObject((double[])array2DRowRealMatrix.getRow(1)));
        Assert.assertArrayEquals((Object[])new Double[]{0.983714690867666, -7.54729477473779}, (Object[])ArrayUtils.toObject((double[])array2DRowRealMatrix.getRow(2)));
        Assert.assertArrayEquals((Object[])new Double[]{-7.86424514338822, -2.97908131550921}, (Object[])ArrayUtils.toObject((double[])array2DRowRealMatrix.getRow(3)));
    }

    @Test
    public void testPopulatePCMParams() {
        this.treeStr = "((A:23.0058179,B:23.0058179):14.350951,C:37.3567689);";
        this.spNames = "A B C";
        this.tree = new TreeParser(this.treeStr, false, false, true, 0);
        this.nTraits = 2;
        this.data = Arrays.asList(-2.62762948691895, -1.56292164859448, -1.50846427625826, -1.59482814741543, -0.226074849617958, -2.11000367246907);
        this.traitValues.initByName(new Object[]{"value", this.data, "keys", this.spNames, "minordimension", this.nTraits});
        double[] dArray = new double[3 * this.nTraits];
        PruneLikelihoodUtils.populateTraitValuesArr(this.traitValues, (Tree)this.tree, this.nTraits, dArray);
        RealParameter realParameter = new RealParameter(new Double[]{0.314574});
        RealParameter realParameter2 = new RealParameter(new Double[]{-0.632620487603683});
        RealParameter realParameter3 = new RealParameter(new Double[]{-1.31465955080609, -1.79611255566212});
        NodeMath nodeMath = new NodeMath();
        nodeMath.initByName(new Object[]{"traits", this.traitValues, "sigmasq", realParameter, "correlation", realParameter2, "rootValues", realParameter3, "oneRateOnly", true});
        nodeMath.performMatrixOperations();
        double[] dArray2 = new double[2];
        int n = 0;
        PruneLikelihoodUtils.populateACEf(nodeMath, this.tree.getNode(n).getLength(), this.nTraits, n);
        Assert.assertEquals((double)-0.021733632865102357, (double)nodeMath.getAForNode(n), (double)0.0);
        Assert.assertEquals((double)-0.021733632865102357, (double)nodeMath.getCForNode(n), (double)0.0);
        Assert.assertEquals((double)0.043467265730204714, (double)nodeMath.getEForNode(n), (double)0.0);
        Assert.assertEquals((double)-3.561501522879213, (double)nodeMath.getfForNode(n), (double)1.0E-7);
        PruneLikelihoodUtils.populateLmrForTip(nodeMath, dArray, this.nTraits, n);
        Assert.assertEquals((double)-0.021733632865102357, (double)nodeMath.getLForNode(n), (double)0.0);
        Assert.assertArrayEquals((Object[])new Double[]{-0.8331278807826273, -0.7430154496439219}, (Object[])ArrayUtils.toObject((double[])nodeMath.getMVecForNode(n)));
        Assert.assertEquals((double)-5.2367146815829, (double)nodeMath.getRForNode(n), (double)1.0E-7);
        MatrixUtilsContra.vectorAdd(dArray2, nodeMath.getMVecForNode(n), dArray2);
        n = 1;
        PruneLikelihoodUtils.populateACEf(nodeMath, this.tree.getNode(n).getLength(), this.nTraits, n);
        Assert.assertEquals((double)-0.02173363286510235, (double)nodeMath.getAForNode(n), (double)1.0E-7);
        Assert.assertEquals((double)-0.02173363286510235, (double)nodeMath.getCForNode(n), (double)1.0E-7);
        Assert.assertEquals((double)0.04346726573020471, (double)nodeMath.getEForNode(n), (double)1.0E-7);
        Assert.assertEquals((double)-3.561501522879213, (double)nodeMath.getfForNode(n), (double)1.0E-7);
        PruneLikelihoodUtils.populateLmrForTip(nodeMath, dArray, this.nTraits, n);
        Assert.assertEquals((double)-0.021733632865102357, (double)nodeMath.getLForNode(n), (double)1.0E-7);
        Assert.assertArrayEquals((Object[])new Double[]{-0.5799479302276102, -0.5872574081072599}, (Object[])ArrayUtils.toObject((double[])nodeMath.getMVecForNode(n)));
        Assert.assertEquals((double)-4.467204212412191, (double)nodeMath.getRForNode(n), (double)1.0E-7);
        MatrixUtilsContra.vectorAdd(dArray2, nodeMath.getMVecForNode(n), dArray2);
        int n2 = 3;
        nodeMath.setLForNode(n2, nodeMath.getLForNode(0) + nodeMath.getLForNode(1));
        nodeMath.setMVecForNode(n2, dArray2);
        nodeMath.setRForNode(n2, nodeMath.getRForNode(0) + nodeMath.getRForNode(1));
        Assert.assertEquals((double)-0.04346726573020471, (double)nodeMath.getLForNode(n2), (double)1.0E-7);
        Assert.assertArrayEquals((Object[])new Double[]{-1.4130758110102375, -1.3302728577511818}, (Object[])ArrayUtils.toObject((double[])nodeMath.getMVecForNode(n2)));
        Assert.assertEquals((double)-9.703918893995091, (double)nodeMath.getRForNode(n2), (double)0.0);
        PruneLikelihoodUtils.populateACEf(nodeMath, this.tree.getNode(n2).getLength(), this.nTraits, n2);
        Assert.assertEquals((double)-0.03484089660678236, (double)nodeMath.getAForNode(n2), (double)1.0E-7);
        Assert.assertEquals((double)-0.03484089660678236, (double)nodeMath.getCForNode(n2), (double)1.0E-7);
        Assert.assertEquals((double)0.06968179321356473, (double)nodeMath.getEForNode(n2), (double)1.0E-7);
        Assert.assertEquals((double)-3.089570598549913, (double)nodeMath.getfForNode(n2), (double)1.0E-7);
        PruneLikelihoodUtils.populateLmrForInternalNode(nodeMath, this.nTraits, n2);
        Assert.assertEquals((double)-0.0193394719769881, (double)nodeMath.getLForNode(n2), (double)1.0E-7);
        Assert.assertArrayEquals((Object[])new Double[]{-0.6287062134990086, -0.5918654928494784}, (Object[])ArrayUtils.toObject((double[])nodeMath.getMVecForNode(n2)));
        Assert.assertEquals((double)-9.119795872783662, (double)nodeMath.getRForNode(n2), (double)0.0);
    }

    @Test
    public void testPopulatePCMParamsWithShrinkage() {
        this.treeStr = "((A:12.4420263,B:12.4420263):42.9258211,C:43.5702874);";
        this.spNames = "A B C";
        this.tree = new TreeParser(this.treeStr, false, false, true, 0);
        this.nTraits = 2;
        this.data = Arrays.asList(1.0, 2.0, 3.0, 5.0, 2.0, 4.0);
        this.traitValues.initByName(new Object[]{"value", this.data, "keys", this.spNames, "minordimension", this.nTraits});
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(new double[3][this.nTraits.intValue()]);
        PruneLikelihoodUtils.populateTraitValuesMatrix(this.traitValues, (Tree)this.tree, this.nTraits, (RealMatrix)array2DRowRealMatrix);
        RealParameter realParameter = new RealParameter(new Double[]{0.1543038});
        NodeMath nodeMath = new NodeMath();
        nodeMath.initByName(new Object[]{"traits", this.traitValues, "sigmasq", realParameter, "shrinkage", true, "oneRateOnly", true});
        nodeMath.estimateCorrelations((RealMatrix)array2DRowRealMatrix);
        nodeMath.populateShrinkageEstimation(0.25925925925926);
        nodeMath.populateTraitRateMatrix();
        nodeMath.populateInverseTraitRateMatrix();
        nodeMath.populateTransformedTraitValues((RealMatrix)array2DRowRealMatrix);
        double[] dArray = nodeMath.getTransformedTraitValues();
        double[] dArray2 = new double[2];
        int n = 0;
        PruneLikelihoodUtils.populateACEf(nodeMath, this.tree.getNode(n).getLength(), this.nTraits, n);
        Assert.assertEquals((double)-0.040186380252226275, (double)nodeMath.getAForNode(n), (double)0.0);
        Assert.assertEquals((double)-0.040186380252226275, (double)nodeMath.getCForNode(n), (double)0.0);
        Assert.assertEquals((double)0.08037276050445255, (double)nodeMath.getEForNode(n), (double)0.0);
        Assert.assertEquals((double)-2.11356981107731, (double)nodeMath.getfForNode(n), (double)1.0E-7);
        PruneLikelihoodUtils.populateLmrForTipWithShrinkage(nodeMath, dArray, this.nTraits, n);
        Assert.assertEquals((double)-0.0401863802522263, (double)nodeMath.getLForNode(n), (double)1.0E-7);
        Assert.assertArrayEquals((Object[])new Double[]{0.20460704078553443, 0.37944671022339144}, (Object[])ArrayUtils.toObject((double[])nodeMath.getMVecForNode(n)));
        Assert.assertEquals((double)-3.26970682734958, (double)nodeMath.getRForNode(n), (double)1.0E-7);
        MatrixUtilsContra.vectorAdd(dArray2, nodeMath.getMVecForNode(n), dArray2);
        n = 1;
        PruneLikelihoodUtils.populateACEf(nodeMath, this.tree.getNode(n).getLength(), this.nTraits, n);
        Assert.assertEquals((double)-0.0401863802522263, (double)nodeMath.getAForNode(n), (double)1.0E-7);
        Assert.assertEquals((double)-0.0401863802522263, (double)nodeMath.getCForNode(n), (double)1.0E-7);
        Assert.assertEquals((double)0.0803727605044526, (double)nodeMath.getEForNode(n), (double)1.0E-7);
        Assert.assertEquals((double)-2.11356981107731, (double)nodeMath.getfForNode(n), (double)1.0E-7);
        PruneLikelihoodUtils.populateLmrForTipWithShrinkage(nodeMath, dArray, this.nTraits, n);
        Assert.assertEquals((double)-0.0401863802522263, (double)nodeMath.getLForNode(n), (double)1.0E-7);
        Assert.assertArrayEquals((Object[])new Double[]{0.6138211223566032, 0.8401752608249581}, (Object[])ArrayUtils.toObject((double[])nodeMath.getMVecForNode(n)));
        Assert.assertEquals((double)-8.84887933857218, (double)nodeMath.getRForNode(n), (double)1.0E-7);
        MatrixUtilsContra.vectorAdd(dArray2, nodeMath.getMVecForNode(n), dArray2);
        int n2 = 3;
        nodeMath.setLForNode(n2, nodeMath.getLForNode(0) + nodeMath.getLForNode(1));
        nodeMath.setMVecForNode(n2, dArray2);
        nodeMath.setRForNode(n2, nodeMath.getRForNode(0) + nodeMath.getRForNode(1));
        Assert.assertEquals((double)-0.0803727605044526, (double)nodeMath.getLForNode(n2), (double)1.0E-7);
        Assert.assertArrayEquals((Object[])new Double[]{0.8184281631421377, 1.2196219710483496}, (Object[])ArrayUtils.toObject((double[])nodeMath.getMVecForNode(n2)));
        Assert.assertEquals((double)-12.1185861659218, (double)nodeMath.getRForNode(n2), (double)1.0E-7);
        PruneLikelihoodUtils.populateACEf(nodeMath, this.tree.getNode(n2).getLength(), this.nTraits, n2);
        Assert.assertEquals((double)-0.0116480008346305, (double)nodeMath.getAForNode(n2), (double)1.0E-7);
        Assert.assertEquals((double)-0.0116480008346305, (double)nodeMath.getCForNode(n2), (double)1.0E-7);
        Assert.assertEquals((double)0.023296001669261, (double)nodeMath.getEForNode(n2), (double)1.0E-7);
        Assert.assertEquals((double)-3.3519633864921, (double)nodeMath.getfForNode(n2), (double)1.0E-7);
        PruneLikelihoodUtils.populateLmrForInternalNodeWithShrinkage(nodeMath, this.nTraits, n2);
        Assert.assertEquals((double)-0.0101735952606144, (double)nodeMath.getLForNode(n2), (double)1.0E-7);
        Assert.assertArrayEquals((Object[])new Double[]{0.10359675130524983, 0.15437991959615796}, (Object[])ArrayUtils.toObject((double[])nodeMath.getMVecForNode(n2)));
        Assert.assertEquals((double)-8.32455362290912, (double)nodeMath.getRForNode(n2), (double)1.0E-7);
    }

    @Test
    public void testPopulateEstimatedVarianceMatrix() {
        this.nTraits = 2;
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix((double[][])new double[][]{{-2.62762948691895, -1.56292164859448}, {-1.50846427625826, -1.59482814741543}, {-0.226074849617958, -2.11000367246907}});
        double d = 0.574732079259225;
        array2DRowRealMatrix = PruneLikelihoodUtils.populateTraitValueMatrixEstimatedPopulationVariance((RealMatrix)array2DRowRealMatrix, (RealMatrix)array2DRowRealMatrix, this.nTraits, d);
        Assert.assertArrayEquals((Object[])new Double[]{-2.5567665115854, -2.2507927602031303}, (Object[])ArrayUtils.toObject((double[])array2DRowRealMatrix.getRow(0)));
        Assert.assertArrayEquals((Object[])new Double[]{-1.4677834012215858, -2.296741907183215}, (Object[])ArrayUtils.toObject((double[])array2DRowRealMatrix.getRow(1)));
        Assert.assertArrayEquals((Object[])new Double[]{-0.21997797158710666, -3.0386558368209258}, (Object[])ArrayUtils.toObject((double[])array2DRowRealMatrix.getRow(2)));
    }

    @Test
    public void testPopulateGivenPopulationVariance() {
        this.nTraits = 2;
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix((double[][])new double[][]{{-2.62762948691895, -1.56292164859448}, {-1.50846427625826, -1.59482814741543}, {-0.226074849617958, -2.11000367246907}});
        double d = 0.3;
        array2DRowRealMatrix = PruneLikelihoodUtils.populateTraitValueMatrixGivenPopulationVariance((RealMatrix)array2DRowRealMatrix, d, this.nTraits);
        Assert.assertArrayEquals((Object[])new Double[]{-4.7973731425041155, -2.853491475161197}, (Object[])ArrayUtils.toObject((double[])array2DRowRealMatrix.getRow(0)));
        Assert.assertArrayEquals((Object[])new Double[]{-2.754066370991179, -2.911744505612018}, (Object[])ArrayUtils.toObject((double[])array2DRowRealMatrix.getRow(1)));
        Assert.assertArrayEquals((Object[])new Double[]{-0.41275431606781265, -3.852322026100173}, (Object[])ArrayUtils.toObject((double[])array2DRowRealMatrix.getRow(2)));
    }
}

