/*
 * Decompiled with CFR 0.152.
 */
package org.cbio.causality.binintanalysis;

import java.awt.Color;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.biopax.paxtools.pattern.miner.SIFEnum;
import org.cbio.causality.analysis.BranchDataProvider;
import org.cbio.causality.analysis.GeneBranch;
import org.cbio.causality.analysis.Graph;
import org.cbio.causality.analysis.RadialInfluenceTree;
import org.cbio.causality.analysis.UpstreamTree;
import org.cbio.causality.binintanalysis.Dataset1;
import org.cbio.causality.data.drug.DrugData;
import org.cbio.causality.data.go.GO;
import org.cbio.causality.data.portal.BroadAccessor;
import org.cbio.causality.data.portal.CBioPortalAccessor;
import org.cbio.causality.data.portal.ExpDataManager;
import org.cbio.causality.model.Alteration;
import org.cbio.causality.model.AlterationPack;
import org.cbio.causality.model.Change;
import org.cbio.causality.network.HPRD;
import org.cbio.causality.network.IntAct;
import org.cbio.causality.network.MSigDBTFT;
import org.cbio.causality.network.PathwayCommons;
import org.cbio.causality.network.SPIKE;
import org.cbio.causality.network.SignaLink;
import org.cbio.causality.util.CollectionUtil;
import org.cbio.causality.util.FDR;
import org.cbio.causality.util.StudentsT;
import org.cbio.causality.util.Summary;

public class ExpressionAffectedTargetFinder {
    private static final Set<String> focusExp = null;
    private static final Set<String> focusMut = null;
    private static final boolean PVAL_BY_PERMUTATION = false;
    private static final String dir = "binint/ExpressionAffectedTargetFinder/";
    private Graph travSt;
    private Graph travExp;
    private Dataset1 dataset;
    CBioPortalAccessor portalAcc;
    ExpDataManager expMan;
    Map<String, Set<String>> mutatedUpstr;
    Map<String, Boolean> targetChange;
    Set<String> mutsig;
    double mutsigThr;
    int depth;
    private static final int MIN_GROUP_SIZE = 3;

    public static void main(String[] args) throws IOException {
        double fdrThr = 0.05;
        Dataset1 dataset = Dataset1.BRCA;
        double mutsigThr = 0.05;
        int depth = 3;
        System.out.println("depth = " + depth);
        Graph graphSt = PathwayCommons.getGraph(SIFEnum.CONTROLS_STATE_CHANGE_OF);
        Graph graphExp = PathwayCommons.getGraph(SIFEnum.CONTROLS_EXPRESSION_OF);
        graphExp.merge(MSigDBTFT.getGraph());
        ExpressionAffectedTargetFinder finder = new ExpressionAffectedTargetFinder(dataset, graphSt, graphExp, mutsigThr, depth);
        List<String> resPC = finder.find(fdrThr);
        graphSt = SPIKE.getGraphPostTl().copy();
        graphExp = SPIKE.getGraphTR().copy();
        graphExp.merge(MSigDBTFT.getGraph());
        finder = new ExpressionAffectedTargetFinder(dataset, graphSt, graphExp, mutsigThr, depth);
        List<String> resSPIKE = finder.find(fdrThr);
        graphSt = SignaLink.getGraphPostTl();
        graphExp = SignaLink.getGraphTR();
        graphExp.merge(MSigDBTFT.getGraph());
        finder = new ExpressionAffectedTargetFinder(dataset, graphSt, graphExp, mutsigThr, depth);
        List<String> resSignaLink = finder.find(fdrThr);
        graphSt = HPRD.getGraph(true);
        graphSt.merge(IntAct.getGraph(true));
        graphExp = MSigDBTFT.getGraph();
        finder = new ExpressionAffectedTargetFinder(dataset, graphSt, graphExp, mutsigThr, depth);
        List<String> resHPRD = finder.find(fdrThr);
        System.out.println();
        CollectionUtil.printVennCounts(resPC, resSPIKE, resSignaLink, resHPRD);
    }

    public List<String> find(double fdrThr) throws IOException {
        Map<String, Double> pvals = this.calcInfluencePvals();
        System.out.println("pvals.size() = " + pvals.size());
        List<String> list = FDR.select(pvals, null, fdrThr);
        System.out.println("result size = " + list.size());
        return list;
    }

    private void generateResultDetails(Map<String, Double> pvals, List<String> list) {
        this.removeNonAffectors(list);
        System.out.println("non-effectors removed");
        this.generateInfluenceGraphs(list);
        System.out.println("influence graphs generated");
        HashSet<String> druggable = new HashSet<String>();
        HashSet<String> ups = new HashSet<String>();
        HashSet<String> dws = new HashSet<String>();
        for (String s : list) {
            if (this.targetChange.get(s).booleanValue()) {
                ups.add(s);
                continue;
            }
            dws.add(s);
        }
        System.out.println("\nUpregulated genes GO enrichment (" + ups.size() + " genes)");
        GO.printEnrichment(ups, pvals.keySet(), 0.05);
        System.out.println("\nDownregulated genes GO enrichment (" + dws.size() + " genes)");
        GO.printEnrichment(dws, pvals.keySet(), 0.05);
        System.out.println("\n\n---------------------Drugs");
        Map<String, Set<String>> drugs = DrugData.getDrugs(druggable);
        for (String drug : DrugData.sortDrugs(drugs)) {
            if (drugs.get(drug).size() <= 1) continue;
            System.out.println(drug + "\t" + drugs.get(drug) + "\tfdaAppr = " + DrugData.isFDAApproved(drug) + "\tcancer-drug = " + DrugData.isCancerDrug(drug));
        }
    }

    public ExpressionAffectedTargetFinder(Dataset1 dataset, Graph graphSt, Graph graphExp, double mutsigThr, int depth) throws IOException {
        this.dataset = dataset;
        this.mutsigThr = mutsigThr;
        this.depth = depth;
        this.travSt = graphSt;
        this.travExp = graphExp;
        this.loadData();
    }

    private void loadData() throws IOException {
        this.portalAcc = new CBioPortalAccessor(this.dataset.mutCnCallExpZ);
        this.expMan = new ExpDataManager(this.portalAcc.getGeneticProfileById(this.dataset.exp.getProfileID()[0]), this.portalAcc.getCaseListById(this.dataset.exp.getCaseListID()));
        this.expMan.setTakeLog(true);
        this.mutsig = BroadAccessor.getMutsigGenes(this.dataset.code(), this.mutsigThr, true);
        Set<String> symbols = this.travSt.getSymbols();
        symbols.addAll(this.travExp.getSymbols());
        this.mutsig.retainAll(symbols);
        this.removeMutsigWithMissingData();
        System.out.println("mutsig in network = " + this.mutsig.size());
    }

    private void removeMutsigWithMissingData() {
        Iterator<String> iter = this.mutsig.iterator();
        while (iter.hasNext()) {
            String ms = iter.next();
            AlterationPack alts = this.portalAcc.getAlterations(ms);
            if (alts == null) {
                iter.remove();
                continue;
            }
            if (alts.get(Alteration.MUTATION) != null && alts.get(Alteration.COPY_NUMBER) != null && alts.get(Alteration.EXPRESSION) != null) continue;
            iter.remove();
        }
    }

    private boolean[] selectSubset(Set<String> symbols, boolean mutated) {
        AlterationPack[] alts = new AlterationPack[symbols.size()];
        int i = 0;
        for (String symbol : symbols) {
            alts[i++] = this.portalAcc.getAlterations(symbol);
        }
        boolean[] x = new boolean[alts[0].getSize()];
        for (i = 0; i < x.length; ++i) {
            x[i] = (mutated ? this.atLeastOneIsMutated(alts, i) : this.nothingMutated(alts, i)) && this.copyNumberNotLost(alts, i) && this.expressionNotDown(alts, i);
        }
        return x;
    }

    private boolean atLeastOneIsMutated(AlterationPack[] alts, int index) {
        for (AlterationPack alt : alts) {
            if (!alt.getChange(Alteration.MUTATION, index).isAltered()) continue;
            return true;
        }
        return false;
    }

    private boolean nothingMutated(AlterationPack[] alts, int index) {
        for (AlterationPack alt : alts) {
            if (alt.getChange(Alteration.MUTATION, index) == Change.NO_CHANGE) continue;
            return false;
        }
        return true;
    }

    private boolean copyNumberNotLost(AlterationPack[] alts, int index) {
        for (AlterationPack alt : alts) {
            Change ch = alt.getChange(Alteration.COPY_NUMBER, index);
            if (ch == Change.NO_CHANGE) continue;
            return false;
        }
        return true;
    }

    private boolean expressionNotDown(AlterationPack[] alts, int index) {
        for (AlterationPack alt : alts) {
            Change ch = alt.getChange(Alteration.EXPRESSION, index);
            if (ch == Change.NO_CHANGE) continue;
            return false;
        }
        return true;
    }

    private double[][] getValueSubsets(String symbol, boolean[] set1, boolean[] set2) {
        double[] exp = this.expMan.get(symbol);
        if (exp == null) {
            return null;
        }
        boolean[] cnUnchanged = this.getCopyNumberUnchanged(symbol);
        if (cnUnchanged == null) {
            return null;
        }
        double[] vals1 = this.getSubset(exp, set1, cnUnchanged);
        double[] vals2 = this.getSubset(exp, set2, cnUnchanged);
        return new double[][]{vals1, vals2};
    }

    private double calcDiffPval(String symbol, boolean[] set1, boolean[] set2) {
        double[][] vals = this.getValueSubsets(symbol, set1, set2);
        if (vals == null) {
            return Double.NaN;
        }
        if (vals[0].length < 3 || vals[1].length < 3) {
            return Double.NaN;
        }
        double pval = StudentsT.getPValOfMeanDifference(vals[0], vals[1]);
        return pval;
    }

    private double calcMeanChange(String symbol, boolean[] set1, boolean[] set2) {
        double[][] vals = this.getValueSubsets(symbol, set1, set2);
        if (vals == null) {
            return Double.NaN;
        }
        if (vals[0].length < 3 || vals[1].length < 3) {
            return Double.NaN;
        }
        return Summary.calcChangeOfMean(vals[0], vals[1]);
    }

    private double[] getSubset(double[] vals, boolean[] inds, boolean[] cnUnchanged) {
        ArrayList<Double> list = new ArrayList<Double>(vals.length);
        for (int i = 0; i < vals.length; ++i) {
            if (!inds[i] || !cnUnchanged[i] || Double.isNaN(vals[i])) continue;
            list.add(vals[i]);
        }
        double[] sub = new double[list.size()];
        for (int i = 0; i < sub.length; ++i) {
            sub[i] = (Double)list.get(i);
        }
        return sub;
    }

    private boolean[] getCopyNumberUnchanged(String symbol) {
        AlterationPack alts = this.portalAcc.getAlterations(symbol);
        if (alts == null) {
            return null;
        }
        Change[] cnc = alts.get(Alteration.COPY_NUMBER);
        if (cnc == null) {
            return null;
        }
        boolean[] b = new boolean[cnc.length];
        for (int i = 0; i < b.length; ++i) {
            b[i] = !cnc[i].isAltered() && !cnc[i].isAbsent();
        }
        return b;
    }

    private Set<String> getUpstream(String symbol, int depth) {
        assert (depth >= 0);
        HashSet<String> expUp = new HashSet<String>(this.travExp.getUpstream(symbol));
        if (depth == 0) {
            return expUp;
        }
        Set<String> sigUp = this.travSt.getUpstream(expUp, depth);
        expUp.addAll(sigUp);
        return expUp;
    }

    private Set<String> getUpstreamX(String symbol, int depth) {
        return new HashSet<String>(this.mutsig);
    }

    private String getPath(String from, String to, int depth) {
        HashSet<String> expUp = new HashSet<String>(this.travExp.getUpstream(to));
        if (expUp.contains(from)) {
            return from + " -.> " + to;
        }
        if (depth > 0) {
            for (String s : expUp) {
                String path = this.getStChPath(from, s, depth);
                if (path == null) continue;
                return path + " -.> " + to;
            }
        }
        return null;
    }

    private String getStChPath(String from, String to, int depth) {
        HashSet<String> stUp = new HashSet<String>(this.travSt.getUpstream(to));
        if (stUp.contains(from)) {
            return from + " --> " + to;
        }
        if (depth > 1) {
            for (String s : stUp) {
                String path = this.getStChPath(from, s, depth - 1);
                if (path == null) continue;
                return path + " --> " + to;
            }
        }
        return null;
    }

    public Map<String, Double> calcInfluencePvals() {
        this.mutatedUpstr = new HashMap<String, Set<String>>();
        this.targetChange = new HashMap<String, Boolean>();
        HashMap<String, Double> pvals = new HashMap<String, Double>();
        if (focusMut != null) {
            this.mutsig.retainAll(focusMut);
        }
        for (String sym : focusExp == null ? this.travExp.getSymbols() : focusExp) {
            boolean[] mutated;
            boolean[] normal;
            double pval;
            Set<String> up = this.getUpstream(sym, this.depth);
            up.retainAll(this.mutsig);
            if (up.isEmpty() || Double.isNaN(pval = this.calcDiffPval(sym, normal = this.selectSubset(up, false), mutated = this.selectSubset(up, true)))) continue;
            this.mutatedUpstr.put(sym, up);
            pvals.put(sym, pval);
            this.targetChange.put(sym, Math.signum(this.calcMeanChange(sym, normal, mutated)) > 0.0);
        }
        return pvals;
    }

    private void removeNonAffectors(List<String> list) {
        for (String target : list) {
            Set<String> up = this.mutatedUpstr.get(target);
            HashSet<String> remove = new HashSet<String>();
            for (String u : up) {
                double mc = this.calcMeanChange(target, Collections.singleton(u));
                if (this.targetChange.get(target) != false && mc <= 0.0 || !this.targetChange.get(target).booleanValue() && mc >= 0.0) {
                    remove.add(u);
                }
                if (!(this.calcPVal(target, Collections.singleton(u)) > 0.5)) continue;
                remove.add(u);
            }
            up.removeAll(remove);
            boolean loop = up.size() > 1;
            int i = 0;
            while (loop) {
                boolean[] normal = this.selectSubset(up, false);
                boolean[] mutated = this.selectSubset(up, true);
                double pval = this.calcDiffPval(target, normal, mutated);
                double dif = Math.abs(this.calcMeanChange(target, normal, mutated));
                double minPval = 1.0;
                double dOfMinPval = 0.0;
                String rem = null;
                for (String s : up) {
                    HashSet<String> reduced = new HashSet<String>(up);
                    reduced.remove(s);
                    boolean[] nor = this.selectSubset(reduced, false);
                    boolean[] mut = this.selectSubset(reduced, true);
                    double p = this.calcDiffPval(target, nor, mut);
                    double d = Math.abs(this.calcMeanChange(target, nor, mut));
                    if (!(p < minPval)) continue;
                    minPval = p;
                    rem = s;
                    dOfMinPval = d;
                }
                if (minPval < pval || minPval == pval && dOfMinPval > dif) {
                    up.remove(rem);
                    if (up.size() < 2) {
                        loop = false;
                    }
                } else {
                    loop = false;
                }
                if (i++ <= 100) continue;
                System.out.println("loop of 100");
            }
        }
    }

    private double calcPVal(String target, Set<String> ups) {
        boolean[] normal = this.selectSubset(ups, false);
        boolean[] mutated = this.selectSubset(ups, true);
        return this.calcDiffPval(target, normal, mutated);
    }

    private double calcMeanChange(String target, Set<String> ups) {
        boolean[] normal = this.selectSubset(ups, false);
        boolean[] mutated = this.selectSubset(ups, true);
        return this.calcMeanChange(target, normal, mutated);
    }

    private double[][] getValueSubsets(String target, Set<String> ups) {
        boolean[] normal = this.selectSubset(ups, false);
        boolean[] mutated = this.selectSubset(ups, true);
        return this.getValueSubsets(target, normal, mutated);
    }

    public void generateInfluenceGraphs(List<String> result) {
        UpstreamTree tree = new UpstreamTree(this.travSt, this.travExp, new BranchDataProvider(){
            final Color TARG_UP_COLOR = new Color(220, 255, 220);
            final Color TARG_DW_COLOR = new Color(255, 255, 200);

            @Override
            public Color getColor(String gene, String root) {
                if (gene.equals(root)) {
                    return ExpressionAffectedTargetFinder.this.targetChange.get(gene) != false ? this.TARG_UP_COLOR : this.TARG_DW_COLOR;
                }
                if (ExpressionAffectedTargetFinder.this.mutatedUpstr.get(root).contains(gene)) {
                    double pval = ExpressionAffectedTargetFinder.this.calcPVal(root, Collections.singleton(gene));
                    String color = ExpressionAffectedTargetFinder.this.val2Color(pval);
                    String[] c = color.split(" ");
                    return new Color(Integer.parseInt(c[0]), Integer.parseInt(c[1]), Integer.parseInt(c[2]));
                }
                return Color.WHITE;
            }

            @Override
            public double getThickness(GeneBranch branch, String root) {
                Set<String> genes = branch.getAllGenes();
                genes.retainAll((Collection)ExpressionAffectedTargetFinder.this.mutatedUpstr.get(root));
                if (genes.isEmpty()) {
                    System.out.println();
                }
                assert (!genes.isEmpty());
                double pval = ExpressionAffectedTargetFinder.this.calcPVal(root, genes);
                return Math.sqrt(-Math.log(pval));
            }
        });
        String dir = dir + this.dataset.name() + "/";
        File d = new File(dir);
        if (!d.exists()) {
            d.mkdirs();
        }
        assert (d.isDirectory());
        for (String target : result) {
            if (this.mutatedUpstr.get(target).isEmpty()) continue;
            GeneBranch g = tree.getTree(target, this.mutatedUpstr.get(target), this.depth + 1);
            g.trimToMajorPaths((Collection<String>)this.mutatedUpstr.get(target));
            if (g.branches.size() == 1) continue;
            RadialInfluenceTree.write(g, true, dir + target + ".svg");
            this.writeTree(g, dir);
            this.writeValues(target, dir);
        }
    }

    private void writeValues(String gene, String dir) {
        try {
            BufferedWriter writer = new BufferedWriter(new FileWriter(dir + gene + "-vals.txt"));
            writer.write("normal\tmutated");
            double[][] vals = this.getValueSubsets(gene, this.mutatedUpstr.get(gene));
            for (int i = 0; i < Math.max(vals[0].length, vals[1].length); ++i) {
                writer.write("\n");
                if (i < vals[0].length) {
                    writer.write("" + vals[0][i]);
                }
                writer.write("\t");
                if (i >= vals[1].length) continue;
                writer.write("" + vals[1][i]);
            }
            writer.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void writeTree(GeneBranch gwu, String dir) {
        try {
            BufferedWriter writer1 = new BufferedWriter(new FileWriter(dir + gwu.gene + ".sif"));
            BufferedWriter writer2 = new BufferedWriter(new FileWriter(dir + gwu.gene + ".format"));
            writer2.write("graph\tgrouping\ton\n");
            writer2.write("node\t" + gwu.gene + "\tcolor\t" + (this.targetChange.get(gwu.gene) != false ? "220 255 220" : "255 255 200") + "\n");
            for (GeneBranch up : gwu.branches) {
                String edgeTag = SIFEnum.CONTROLS_EXPRESSION_OF.getTag();
                writer1.write(up.gene + "\t" + edgeTag + "\t" + gwu.gene + "\n");
                this.writeWeights(gwu.gene, gwu.gene, up, edgeTag, writer2);
                this.writePart(gwu.gene, up, writer1, writer2);
            }
            writer1.close();
            writer2.close();
        }
        catch (IOException e) {
            e.printStackTrace();
            System.exit(1);
        }
    }

    private void writePart(String origTarget, GeneBranch gwu, Writer writer1, Writer writer2) throws IOException {
        for (GeneBranch up : gwu.branches) {
            String edgeTag = SIFEnum.CONTROLS_STATE_CHANGE_OF.getTag();
            writer1.write(up.gene + "\t" + edgeTag + "\t" + gwu.gene + "\n");
            this.writeWeights(origTarget, gwu.gene, up, edgeTag, writer2);
            this.writePart(origTarget, up, writer1, writer2);
        }
    }

    private void writeWeights(String orig, String to, GeneBranch gwu, String edgeType, Writer writer) throws IOException {
        Set<String> upstr = gwu.getAllGenes();
        upstr.retainAll((Collection)this.mutatedUpstr.get(orig));
        assert (!upstr.isEmpty());
        double cumPval = this.calcPVal(orig, upstr);
        String key = gwu.gene + " " + edgeType + " " + to;
        writer.write("edge\t" + key + "\tcolor\t" + this.val2Color(cumPval) + "\n");
        writer.write("edge\t" + key + "\twidth\t3\n");
        if (this.mutatedUpstr.get(orig).contains(gwu.gene)) {
            double pval = this.calcPVal(orig, Collections.singleton(gwu.gene));
            writer.write("node\t" + gwu.gene + "\tbordercolor\t" + this.val2Color(pval) + "\n");
        } else {
            writer.write("node\t" + gwu.gene + "\tbordercolor\t255 255 255\n");
        }
        writer.write("node\t" + gwu.gene + "\tborderwidth\t3\n");
        writer.write("node\t" + gwu.gene + "\tcolor\t255 255 255\n");
    }

    private String val2Color(double pval) {
        double score = Math.min(5.0, -Math.log10(pval));
        int v = 230 - (int)Math.round(46.0 * score);
        return v + " " + v + " " + v;
    }
}

