/*
 * Decompiled with CFR 0.152.
 */
package contraband.prunelikelihood;

import beast.base.core.Input;
import beast.base.evolution.branchratemodel.BranchRateModel;
import beast.base.evolution.tree.Node;
import beast.base.evolution.tree.Tree;
import beast.base.inference.Distribution;
import beast.base.inference.State;
import beast.base.inference.parameter.RealParameter;
import contraband.math.MatrixUtilsContra;
import contraband.math.NodeMath;
import contraband.utils.PruneLikelihoodUtils;
import java.util.List;
import java.util.Random;

public abstract class PruneLikelihoodProcess
extends Distribution {
    public final Input<Tree> treeInput = new Input("tree", "Tree object containing tree.", Input.Validate.REQUIRED);
    public final Input<BranchRateModel.Base> branchRateModelInput = new Input("branchRateModel", "the rate or optimum on each branch");
    public final Input<NodeMath> nodeMathInput = new Input("nodeMath", "Node information that will be used in PCM likelihood calculation.", Input.Validate.REQUIRED);
    public final Input<RealParameter> traitsValuesInput = new Input("traits", "Trait values at tips.", Input.Validate.REQUIRED);
    private Tree tree;
    private int nTraits;
    private int nSpecies;
    private int nSpeciesWithData;
    private RealParameter traitsValues;
    private BranchRateModel.Base branchRateModel;
    private NodeMath nodeMath;
    private double[] traitValuesArr;
    private boolean popSE;

    public void initAndValidate() {
        super.initAndValidate();
        this.tree = (Tree)this.treeInput.get();
        this.branchRateModel = (BranchRateModel.Base)this.branchRateModelInput.get();
        this.traitsValues = (RealParameter)this.traitsValuesInput.get();
        this.nTraits = this.traitsValues.getMinorDimension1();
        this.nSpeciesWithData = this.traitsValues.getMinorDimension2();
        this.nSpecies = this.tree.getLeafNodeCount();
        this.traitValuesArr = new double[this.nSpecies * this.nTraits];
        if (this.nodeMathInput.get() == null) {
            throw new RuntimeException("PruneLikelihoodProcess::NodeMath is required for pmc likelihood.");
        }
        this.nodeMath = (NodeMath)((Object)this.nodeMathInput.get());
        PruneLikelihoodUtils.populateTraitValuesArr(this.traitsValues, this.tree, this.nodeMath, this.nTraits, this.traitValuesArr);
    }

    protected void populateLogP() {
        int n;
        if (this.tree.getRoot().getChild(0).isDirectAncestor() || this.tree.getRoot().getChild(1).isDirectAncestor()) {
            this.logP = Double.NEGATIVE_INFINITY;
            return;
        }
        this.nodeMath.setLikelihoodForSampledAncestors(0.0);
        this.nodeMath.initializeNodeStatArrays();
        for (n = 0; n < this.tree.getLeafNodeCount(); ++n) {
            if (!this.nodeMath.isSpeciesToIgnore(n)) continue;
            Node node = this.tree.getNode(n);
            Node node2 = node.getParent();
            int n2 = node2.getNr();
            this.nodeMath.setNodeHasMissingData(n2);
            Node node3 = node2.getChild(0);
            int n3 = node3.getNr();
            if (n3 == n) {
                this.nodeMath.setSpeciesToIgnoreIndex(n2, node2.getChild(1).getNr());
                continue;
            }
            this.nodeMath.setSpeciesToIgnoreIndex(n2, n3);
        }
        this.pruneNode(this.tree.getRoot(), this.nTraits, this.traitValuesArr, this.branchRateModel, this.nodeMath, this.popSE);
        if (this.nodeMath.isSingularMatrix()) {
            this.logP = Double.NEGATIVE_INFINITY;
            return;
        }
        n = this.tree.getRoot().getNr();
        double d = this.nodeMath.getLForNode(n);
        double[] dArray = this.nodeMath.getMVecForNode(n);
        double d2 = this.nodeMath.getRForNode(n);
        this.nodeMath.populateRootValuesVec(n);
        this.logP = this.calculateLikelihood(this.nodeMath, d, dArray, d2, n);
    }

    public int getNTraits() {
        return this.nTraits;
    }

    public int getNSpecies() {
        return this.nSpecies;
    }

    public double getLogP() {
        return this.logP;
    }

    public NodeMath getNodeMath() {
        return this.nodeMath;
    }

    public int getRootIndex() {
        return this.tree.getRoot().getNr();
    }

    public int getNumberOfSpeciesWithData() {
        return this.nSpeciesWithData;
    }

    public void setPopSE(boolean bl) {
        this.popSE = bl;
    }

    public void setTraitValuesArr(double[] dArray) {
        this.traitValuesArr = dArray;
    }

    public void pruneNode(Node node, int n, double[] dArray, BranchRateModel.Base base, NodeMath nodeMath, boolean bl) {
        int n2 = node.getNr();
        double d = 0.0;
        double[] dArray2 = (double[])nodeMath.getInitMVec().clone();
        double d2 = 0.0;
        List list = node.getChildren();
        for (Node node2 : list) {
            int n3 = node2.getNr();
            if (nodeMath.isSpeciesToIgnore(n3)) continue;
            double d3 = node2.getLength() * base.getRateForBranch(node2);
            nodeMath.setVarianceForTip(n3, d3);
            if (node2.isLeaf() && d3 != 0.0) {
                if (bl) {
                    nodeMath.setVarianceForTip(n3, nodeMath.getVarianceForNode(n3) + 1.0);
                }
                PruneLikelihoodUtils.populateACEf(nodeMath, nodeMath.getVarianceForNode(n3), n, n3);
                this.calculateLmrForTips(nodeMath, dArray, n, n3);
                d += nodeMath.getLForNode(n3);
                d2 += nodeMath.getRForNode(n3);
                MatrixUtilsContra.vectorAdd(dArray2, nodeMath.getTempVec(), dArray2);
                continue;
            }
            if (node2.isDirectAncestor()) continue;
            if (!node2.getChild(0).isDirectAncestor() && !node2.getChild(1).isDirectAncestor()) {
                PruneLikelihoodUtils.populateACEf(nodeMath, d3, n, n3);
                this.pruneNode(node2, n, dArray, base, nodeMath, bl);
                this.calculateLmrForInternalNodes(nodeMath, n, n3);
                d2 += nodeMath.getRForNode(n3);
                MatrixUtilsContra.vectorAdd(dArray2, nodeMath.getTempVec(), dArray2);
                d += nodeMath.getLForNode(n3);
                continue;
            }
            Node node3 = node2.getChild(0);
            if (node2.getChild(1).isDirectAncestor()) {
                node3 = node2.getChild(1);
            }
            int n4 = node3.getNr();
            PruneLikelihoodUtils.populateACEf(nodeMath, d3, n, n4);
            this.calculateLmrForTips(nodeMath, dArray, n, n4);
            d += nodeMath.getLForNode(n4);
            d2 += nodeMath.getRForNode(n4);
            MatrixUtilsContra.vectorAdd(dArray2, nodeMath.getTempVec(), dArray2);
            this.pruneNode(node2, n, dArray, base, nodeMath, bl);
            if (nodeMath.isSingularMatrix()) {
                nodeMath.setLikelihoodForSampledAncestors(Double.NEGATIVE_INFINITY);
                continue;
            }
            nodeMath.setTraitsVecForSampledAncestor(dArray, n4);
            double d4 = bl ? MatrixUtilsContra.vecTransScalarMultiply(nodeMath.getSampledAncestorTraitsVec(), nodeMath.getLForNode(n3), n) + MatrixUtilsContra.vectorDotMultiply(nodeMath.getSampledAncestorTraitsVec(), nodeMath.getMVecForNode(n3)) + nodeMath.getRForNode(n3) : nodeMath.getLForNode(n3) * MatrixUtilsContra.tVecDotMatrixDotVec(nodeMath.getSampledAncestorTraitsVec(), nodeMath.getTraitRateMatrixInverse(), n) + MatrixUtilsContra.vectorDotMultiply(nodeMath.getSampledAncestorTraitsVec(), nodeMath.getMVecForNode(n3)) + nodeMath.getRForNode(n3);
            nodeMath.setLikelihoodForSampledAncestors(nodeMath.getLikelihoodForSampledAncestors() + d4);
        }
        if (nodeMath.hasMissingDataSpecies(n2)) {
            int n5 = nodeMath.getSpeciesToIgnoreIndex(n2);
            double d5 = nodeMath.getVarianceForNode(n2) + nodeMath.getVarianceForNode(n5);
            nodeMath.setVarianceForTip(n2, d5);
            nodeMath.setExpectationForIntNode(n2, nodeMath.getExpectationForNode(n5));
        } else {
            nodeMath.setVarianceForParent(n2, node.getLength() * base.getRateForBranch(node), node.getChild(0).getNr(), node.getChild(1).getNr());
            nodeMath.setExpectationForParent(n2, node.getChild(0).getNr(), node.getChild(1).getNr());
        }
        nodeMath.setLForNode(n2, d);
        nodeMath.setMVecForNode(n2, dArray2);
        nodeMath.setRForNode(n2, d2);
    }

    protected void calculateLmrForTips(NodeMath nodeMath, double[] dArray, int n, int n2) {
    }

    protected void calculateLmrForInternalNodes(NodeMath nodeMath, int n, int n2) {
    }

    protected double calculateLikelihood(NodeMath nodeMath, double d, double[] dArray, double d2, int n) {
        return 1.0;
    }

    public List<String> getArguments() {
        return null;
    }

    public List<String> getConditions() {
        return null;
    }

    public void sample(State state, Random random) {
    }
}

