/*
 * Decompiled with CFR 0.152.
 */
package contraband.test;

import beast.base.evolution.tree.Tree;
import beast.base.evolution.tree.TreeParser;
import beast.base.inference.parameter.IntegerParameter;
import beast.base.inference.parameter.RealParameter;
import contraband.clock.RateCategoryClockModel;
import contraband.math.NodeMath;
import contraband.prunelikelihood.BMPruneLikelihood;
import contraband.prunelikelihood.BMPruneShrinkageLikelihood;
import contraband.utils.PruneLikelihoodUtils;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.junit.Assert;
import org.junit.Test;

public class PruneLikelihoodProcessTest {
    private TreeParser tree;
    private String treeStr;
    private Integer nTraits;
    private String spNames;
    private List<Double> data;
    private final RealParameter traitValues = new RealParameter();

    @Test
    public void testOneRateOnlyWithoutShrinkage() {
        this.treeStr = "((A:23.0058179,B:23.0058179):14.350951,C:37.3567689);";
        this.spNames = "A B C";
        this.tree = new TreeParser(this.treeStr, false, false, true, 0);
        this.nTraits = 2;
        this.data = Arrays.asList(-2.62762948691895, -1.56292164859448, -1.50846427625826, -1.59482814741543, -0.226074849617958, -2.11000367246907);
        this.traitValues.initByName(new Object[]{"value", this.data, "keys", this.spNames, "minordimension", this.nTraits});
        double[] dArray = new double[3 * this.nTraits];
        PruneLikelihoodUtils.populateTraitValuesArr(this.traitValues, (Tree)this.tree, this.nTraits, dArray);
        RealParameter realParameter = new RealParameter(new Double[]{0.314574});
        RealParameter realParameter2 = new RealParameter(new Double[]{-0.632620487603683});
        NodeMath nodeMath = new NodeMath();
        nodeMath.initByName(new Object[]{"traits", this.traitValues, "sigmasq", realParameter, "correlation", realParameter2, "oneRateOnly", true});
        nodeMath.performMatrixOperations();
        RateCategoryClockModel rateCategoryClockModel = new RateCategoryClockModel();
        IntegerParameter integerParameter = new IntegerParameter(new Integer[]{0});
        RealParameter realParameter3 = new RealParameter(new Double[]{1.0});
        rateCategoryClockModel.initByName(new Object[]{"nCat", 1, "rateCatAssign", integerParameter, "rates", realParameter3, "tree", this.tree});
        BMPruneLikelihood bMPruneLikelihood = new BMPruneLikelihood();
        bMPruneLikelihood.pruneNode(this.tree.getRoot(), this.nTraits, dArray, rateCategoryClockModel, nodeMath, false);
        int n = this.tree.getRoot().getNr();
        Assert.assertEquals((double)-0.032723927183444676, (double)nodeMath.getLForNode(n), (double)0.0);
        Assert.assertArrayEquals((Object[])new Double[]{-0.8501607370417381, -0.9115145062074426}, (Object[])ArrayUtils.toObject((double[])nodeMath.getMVecForNode(n)));
        Assert.assertEquals((double)-13.528327328718111, (double)nodeMath.getRForNode(n), (double)0.0);
        nodeMath.populateRootValuesVec(n);
        double[] dArray2 = nodeMath.getRootValuesArr();
        Assert.assertArrayEquals((Object[])new Double[]{-1.3146595508060854, -1.7961125556621191}, (Object[])ArrayUtils.toObject((double[])dArray2));
    }

    @Test
    public void testMultipleRatesWithoutShrinkage() {
        this.treeStr = "((A:23.0058179,B:23.0058179):14.350951,C:37.3567689);";
        this.tree = new TreeParser(this.treeStr, false, false, true, 0);
        this.nTraits = 2;
        this.data = Arrays.asList(-2.62762948691895, -0.764018322006132, -1.50846427625826, -1.02686498716963, -0.226074849617958, -1.73165056392106);
        this.spNames = "A B C";
        this.traitValues.initByName(new Object[]{"value", this.data, "keys", this.spNames, "minordimension", this.nTraits});
        double[] dArray = new double[3 * this.nTraits];
        PruneLikelihoodUtils.populateTraitValuesArr(this.traitValues, (Tree)this.tree, this.nTraits, dArray);
        RateCategoryClockModel rateCategoryClockModel = new RateCategoryClockModel();
        IntegerParameter integerParameter = new IntegerParameter(new Integer[]{0});
        RealParameter realParameter = new RealParameter(new Double[]{1.0});
        rateCategoryClockModel.initByName(new Object[]{"nCat", 1, "rateCatAssign", integerParameter, "rates", realParameter, "tree", this.tree});
        NodeMath nodeMath = new NodeMath();
        RealParameter realParameter2 = new RealParameter(new Double[]{0.3, 0.2});
        RealParameter realParameter3 = new RealParameter(new Double[]{-0.720107524122507});
        RealParameter realParameter4 = new RealParameter(new Double[]{-1.31465955080609, -1.2374605274288});
        nodeMath.initByName(new Object[]{"traits", this.traitValues, "sigmasq", realParameter2, "correlation", realParameter3, "rootValues", realParameter4, "oneRateOnly", false});
        nodeMath.performMatrixOperations();
        BMPruneLikelihood bMPruneLikelihood = new BMPruneLikelihood();
        bMPruneLikelihood.pruneNode(this.tree.getRoot(), this.nTraits, dArray, rateCategoryClockModel, nodeMath, false);
        int n = this.tree.getRoot().getNr();
        Assert.assertEquals((double)-0.032723927183444676, (double)nodeMath.getLForNode(n), (double)0.0);
        Assert.assertArrayEquals((Object[])new Double[]{-1.0902581683947343, -1.3664966897695918}, (Object[])ArrayUtils.toObject((double[])nodeMath.getMVecForNode(n)));
        Assert.assertEquals((double)-12.618549029514906, (double)nodeMath.getRForNode(n), (double)0.0);
    }

    @Test
    public void testOneRateOnlyWithShrinkage() {
        this.treeStr = "((A:12.4420263,B:12.4420263):42.9258211,C:43.5702874);";
        this.spNames = "A B C";
        this.tree = new TreeParser(this.treeStr, false, false, true, 0);
        this.nTraits = 2;
        this.data = Arrays.asList(1.0, 2.0, 3.0, 5.0, 2.0, 4.0);
        this.traitValues.initByName(new Object[]{"value", this.data, "keys", this.spNames, "minordimension", this.nTraits});
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(new double[3][this.nTraits.intValue()]);
        PruneLikelihoodUtils.populateTraitValuesMatrix(this.traitValues, (Tree)this.tree, this.nTraits, (RealMatrix)array2DRowRealMatrix);
        RealParameter realParameter = new RealParameter(new Double[]{0.1543038});
        NodeMath nodeMath = new NodeMath();
        nodeMath.initByName(new Object[]{"traits", this.traitValues, "sigmasq", realParameter, "shrinkage", true, "oneRateOnly", true});
        nodeMath.estimateCorrelations((RealMatrix)array2DRowRealMatrix);
        nodeMath.populateShrinkageEstimation(0.25925925925926);
        nodeMath.populateTraitRateMatrix();
        nodeMath.populateInverseTraitRateMatrix();
        nodeMath.populateTransformedTraitValues((RealMatrix)array2DRowRealMatrix);
        RateCategoryClockModel rateCategoryClockModel = new RateCategoryClockModel();
        IntegerParameter integerParameter = new IntegerParameter(new Integer[]{0});
        RealParameter realParameter2 = new RealParameter(new Double[]{1.0});
        rateCategoryClockModel.initByName(new Object[]{"nCat", 1, "rateCatAssign", integerParameter, "rates", realParameter2, "tree", this.tree});
        BMPruneShrinkageLikelihood bMPruneShrinkageLikelihood = new BMPruneShrinkageLikelihood();
        bMPruneShrinkageLikelihood.setPopSE(false);
        bMPruneShrinkageLikelihood.pruneNode(this.tree.getRoot(), this.nTraits, nodeMath.getTransformedTraitValues(), rateCategoryClockModel, nodeMath, false);
        int n = this.tree.getRoot().getNr();
        Assert.assertEquals((double)-0.02164930565494149, (double)nodeMath.getLForNode(n), (double)0.0);
        Assert.assertArrayEquals((Object[])new Double[]{0.22045281696520652, 0.37109117994225865}, (Object[])ArrayUtils.toObject((double[])nodeMath.getMVecForNode(n)));
        Assert.assertEquals((double)-13.012014942429309, (double)nodeMath.getRForNode(n), (double)0.0);
        nodeMath.populateRootValuesVec(n);
        double[] dArray = nodeMath.getRootValuesArr();
        Assert.assertArrayEquals((Object[])new Double[]{5.0914523652375845, 8.570509970548557}, (Object[])ArrayUtils.toObject((double[])dArray));
    }
}

