/*
 * Decompiled with CFR 0.152.
 */
package contraband.prunelikelihood;

import beast.base.core.Input;
import beast.base.evolution.tree.Tree;
import beast.base.inference.State;
import beast.base.inference.parameter.RealParameter;
import contraband.math.MatrixUtilsContra;
import contraband.math.NodeMath;
import contraband.prunelikelihood.PruneLikelihoodProcess;
import contraband.utils.PruneLikelihoodUtils;
import java.util.List;
import java.util.Random;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;

public class BMPruneShrinkageLikelihood
extends PruneLikelihoodProcess {
    public final Input<Double> deltaInput = new Input("delta", "Shrinkage parameter for correlations, either sampled or given.");
    public final Input<Boolean> includePopVarInput = new Input("includePopVar", "if including population variance or not.", (Object)false);
    public final Input<RealParameter> popVarInput = new Input("popVar", "population variance.");
    public final Input<Double> deltaVarInput = new Input("deltaVar", "Shrinkage parameter for population variance, either sampled or given.");
    public final Input<RealParameter> populationTraitsInput = new Input("populationTraits", "Trait values for calculating the population noise.");
    private RealMatrix traitRM;
    private double delta;
    private double lambda;
    private double popVar;

    @Override
    public void initAndValidate() {
        super.initAndValidate();
        if (this.deltaInput.get() == null) {
            throw new RuntimeException("BMPruneShrinkageLikelihood::NodeMath is required for pmc likelihood.");
        }
        this.delta = (Double)this.deltaInput.get();
        this.traitRM = new Array2DRowRealMatrix(new double[this.getNumberOfSpeciesWithData()][this.getNTraits()]);
        PruneLikelihoodUtils.populateTraitValuesMatrix((RealParameter)this.traitsValuesInput.get(), (Tree)this.treeInput.get(), this.getNTraits(), this.traitRM);
        if (((Boolean)this.includePopVarInput.get()).booleanValue()) {
            this.setPopSE(true);
            if (this.populationTraitsInput.get() != null) {
                this.lambda = (Double)this.deltaVarInput.get();
                RealMatrix realMatrix = this.populationTraitMatrix((RealParameter)this.populationTraitsInput.get());
                this.traitRM = PruneLikelihoodUtils.populateTraitValueMatrixEstimatedPopulationVariance(realMatrix, this.traitRM, this.getNTraits(), this.lambda);
                this.getNodeMath().estimateCorrelations(realMatrix);
            } else {
                this.getNodeMath().estimateCorrelations(this.traitRM);
                if (this.popVarInput.get() != null) {
                    this.popVar = ((RealParameter)this.popVarInput.get()).getValue();
                    this.traitRM = PruneLikelihoodUtils.populateTraitValueMatrixGivenPopulationVariance(this.traitRM, this.popVar, this.getNTraits());
                } else {
                    this.lambda = (Double)this.deltaVarInput.get();
                    this.traitRM = PruneLikelihoodUtils.populateTraitValueMatrixEstimatedPopulationVariance(this.traitRM, this.traitRM, this.getNTraits(), this.lambda);
                }
            }
        } else {
            this.setPopSE(false);
            this.getNodeMath().estimateCorrelations(this.traitRM);
        }
        this.getNodeMath().populateShrinkageEstimation(this.delta);
        this.getNodeMath().populateTraitRateMatrix();
        this.getNodeMath().populateInverseTraitRateMatrix();
        this.getNodeMath().populateTransformedTraitValues(this.traitRM);
        this.setTraitValuesArr(this.getNodeMath().getTransformedTraitValues());
    }

    public double calculateLogP() {
        if (this.getNodeMath().updateParameters()) {
            this.getNodeMath().populateTraitRateMatrix();
            this.getNodeMath().populateInverseTraitRateMatrix();
            this.getNodeMath().populateTransformedTraitValues(this.traitRM);
            this.setTraitValuesArr(this.getNodeMath().getTransformedTraitValues());
        }
        super.populateLogP();
        return this.getLogP();
    }

    private RealMatrix populationTraitMatrix(RealParameter realParameter) {
        int n = realParameter.getMinorDimension1();
        int n2 = realParameter.getMinorDimension2();
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(new double[n2][n]);
        String[] stringArray = realParameter.getKeys();
        for (int i = 0; i < n2; ++i) {
            Double[] doubleArray = (Double[])realParameter.getRowValues(stringArray[i]);
            for (int j = 0; j < n; ++j) {
                array2DRowRealMatrix.setEntry(i, j, doubleArray[j].doubleValue());
            }
        }
        return array2DRowRealMatrix;
    }

    @Override
    protected void calculateLmrForTips(NodeMath nodeMath, double[] dArray, int n, int n2) {
        PruneLikelihoodUtils.populateLmrForTipWithShrinkage(nodeMath, dArray, n, n2);
    }

    @Override
    protected void calculateLmrForInternalNodes(NodeMath nodeMath, int n, int n2) {
        PruneLikelihoodUtils.populateLmrForInternalNodeWithShrinkage(nodeMath, n, n2);
    }

    @Override
    protected double calculateLikelihood(NodeMath nodeMath, double d, double[] dArray, double d2, int n) {
        double d3 = nodeMath.getVarianceForNode(n);
        double d4 = -0.5 * nodeMath.getTraitRateMatrixDeterminant() - (double)this.getNTraits() / 2.0 * Math.log(Math.PI * 2 * d3);
        return MatrixUtilsContra.vecTransScalarMultiply(nodeMath.getRootValuesArr(), d, this.getNTraits()) + MatrixUtilsContra.vectorDotMultiply(nodeMath.getRootValuesArr(), dArray) + d2 + nodeMath.getLikelihoodForSampledAncestors() - d4;
    }

    @Override
    public List<String> getArguments() {
        return null;
    }

    @Override
    public List<String> getConditions() {
        return null;
    }

    @Override
    public void sample(State state, Random random) {
    }
}

