/*
 * Decompiled with CFR 0.152.
 */
package contraband.prunelikelihood;

import beast.base.core.Description;
import beast.base.core.Input;
import beast.base.inference.util.InputUtil;
import contraband.math.MatrixUtilsContra;
import contraband.math.NodeMath;
import contraband.prunelikelihood.PruneLikelihoodProcess;
import contraband.utils.PruneLikelihoodUtils;
import java.util.Arrays;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;

@Description(value="This class implements likelihood for continuous traits under Brownian model.\nThe calculation uses Venelin's PCM likelihood.")
public class BMPruneLikelihood
extends PruneLikelihoodProcess {
    public final Input<Boolean> includeRootInput = new Input("includeRoot", "TRUE, if the likelihood at will be subtracted.", (Object)false);
    private RealMatrix rateRealMatrix;
    private boolean nearlySingularRateMatrix;
    private boolean includeRoot;

    @Override
    public void initAndValidate() {
        super.initAndValidate();
        this.rateRealMatrix = new Array2DRowRealMatrix(new double[this.getNTraits()][this.getNTraits()]);
        this.getNodeMath().populateTraitRateMatrix();
        this.getNodeMath().performMatrixOperations();
        this.includeRoot = (Boolean)this.includeRootInput.get();
        this.setPopSE(false);
    }

    public double calculateLogP() {
        boolean bl = false;
        if (InputUtil.isDirty((Input)this.nodeMathInput)) {
            bl = this.getNodeMath().updateParameters();
        }
        if (bl) {
            this.getNodeMath().populateTraitRateMatrix();
            this.checkNearlySingularForRateMatrix();
            if (this.nearlySingularRateMatrix) {
                return Double.NEGATIVE_INFINITY;
            }
            this.getNodeMath().performMatrixOperations();
            if (this.getNodeMath().isSingularMatrix()) {
                return Double.NEGATIVE_INFINITY;
            }
        }
        super.populateLogP();
        if (this.includeRoot) {
            double d = this.getNodeMath().getVarianceForNode(this.getRootIndex());
            double d2 = -0.5 * this.getNodeMath().getTraitRateMatrixDeterminant() - (double)this.getNTraits() / 2.0 * Math.log(Math.PI * 2 * d);
            return this.getLogP() - d2;
        }
        return this.getLogP();
    }

    private void checkNearlySingularForRateMatrix() {
        this.nearlySingularRateMatrix = false;
        if (this.getNTraits() == 1) {
            if (this.getNodeMath().getTraitRateMatrix()[0] < 1.0E-5) {
                this.nearlySingularRateMatrix = true;
            }
        } else {
            double[] dArray;
            for (int i = 0; i < this.getNTraits(); ++i) {
                MatrixUtilsContra.getMatrixRow(this.getNodeMath().getTraitRateMatrix(), i, this.getNTraits(), this.getNodeMath().getRateMatrixRow());
                this.rateRealMatrix.setRow(i, this.getNodeMath().getRateMatrixRow());
            }
            double[] dArray2 = new SingularValueDecomposition(this.rateRealMatrix).getSingularValues();
            double d = Arrays.stream(dArray2).min().getAsDouble();
            double d2 = Arrays.stream(dArray2).max().getAsDouble();
            for (double d3 : dArray = new EigenDecomposition(this.rateRealMatrix).getRealEigenvalues()) {
                if (!(d3 < 1.0E-5)) continue;
                this.nearlySingularRateMatrix = true;
            }
            if (d / d2 < 1.0E-6) {
                this.nearlySingularRateMatrix = true;
            }
        }
    }

    @Override
    protected void calculateLmrForTips(NodeMath nodeMath, double[] dArray, int n, int n2) {
        PruneLikelihoodUtils.populateLmrForTip(nodeMath, dArray, n, n2);
    }

    @Override
    protected void calculateLmrForInternalNodes(NodeMath nodeMath, int n, int n2) {
        PruneLikelihoodUtils.populateLmrForInternalNode(nodeMath, n, n2);
    }

    @Override
    protected double calculateLikelihood(NodeMath nodeMath, double d, double[] dArray, double d2, int n) {
        return d * MatrixUtilsContra.tVecDotMatrixDotVec(nodeMath.getRootValuesArr(), nodeMath.getTraitRateMatrixInverse(), this.getNTraits()) + MatrixUtilsContra.vectorDotMultiply(nodeMath.getRootValuesArr(), dArray) + d2 + nodeMath.getLikelihoodForSampledAncestors();
    }
}

