#!/usr/bin/python
import argparse
import colorlog
from collections import defaultdict
import networkx as nx
import pyalpm
import random

# TODO: Automate this. The packages listed here depend on ghc/ghc-libs but don't have .so files.
HASKELL_DO_NOT_EXPAND = {
    'alex',
    'c2hs',
    'cabal-install',
    'cgrep',
    'git-annex',
    'git-repair',
    'happy',
    'hedgewars',
    'hopenpgp-tools',
    'pandoc-cli',
    'tamarin-prover',
    'uusi',
    'xmonad-utils',
}
# The packages listed here are always required on build time, not check()
HASKELL_REAL_MAKEDEPEND = {
    'alex',
    'happy',
    'pandoc-cli',
    'uusi',
}

parser = argparse.ArgumentParser(description='Rebuild generator and orderer, generates all haskell reverse deps too')
parser.add_argument('-H', '--haskell-check', action="store_true",
                                             help='Filter results for haskell-only rebuilds, also implies --expand, --expand-make, --expand-check')
parser.add_argument('-e', '--expand', action="store_true",
                                      help='Expand all reverse deps')
parser.add_argument('-m', '--expand-make', action="store_true",
                                           help='Expand all reverse make deps')
parser.add_argument('-c', '--expand-check', action="store_true",
                                            help='Expand all reverse check deps')
parser.add_argument('--ignore', nargs='?', default="", help='Ignore packages, separated by comma')
parser.add_argument('-d', '--db', nargs='*', default=["core", "extra", "community", "multilib"],
                                             help='Pacman sync db to consider. Defaulting to all stable dbs.')
parser.add_argument('--dbpath', nargs='?', default="/var/lib/pacman",
                                           help='Pacman sync db location. Default: /var/lib/pacman')
parser.add_argument('package', nargs='+', help='Packages to rebuild')
args = parser.parse_args()

args.ignore = set(args.ignore.split(","))

if args.haskell_check:
    args.expand = True
    args.expand_make = True
    args.expand_check = True

logger = colorlog.getLogger()
logger.setLevel(colorlog.DEBUG)

handler = colorlog.StreamHandler()
handler.setFormatter(colorlog.ColoredFormatter())
logger.addHandler(handler)

rebuild_list = set(filter(lambda p: not p.endswith(":nocheck"), args.package))

package_db = defaultdict(dict)
pkglist = rebuild_list.copy()
reverse_deps = defaultdict(set)
G = nx.DiGraph()
handle = pyalpm.Handle(".", args.dbpath)

for db in args.db:
    db_handle = handle.register_syncdb(db, 0)
    for package in db_handle.search(""):
        if package.name in args.ignore:
            continue
        for field in ("depends", "makedepends", "checkdepends", "provides", "arch"):
            package_db[package.name][field] = getattr(package, field)

for pkg in package_db:
    package_db[pkg]["depends"] = {dep.split("=")[0].split(">")[0].split("<")[0] for dep in package_db[pkg]["depends"]}
    package_db[pkg]["makedepends"] = {dep.split("=")[0].split(">")[0].split("<")[0] for dep in package_db[pkg]["makedepends"]}
    package_db[pkg]["checkdepends"] = {dep.split("=")[0].split(">")[0].split("<")[0] for dep in package_db[pkg]["checkdepends"]}
    package_db[pkg]["provides"] = {dep.split("=")[0].split(">")[0].split("<")[0] for dep in package_db[pkg]["provides"]}


def resolve_pkg(pkg):
    newpkgs = set()
    if pkg not in reverse_deps:
        if pkg not in package_db:
            logger.error(f"Package {pkg} not found in any db")
            exit(1)

        if args.haskell_check and (not ({"ghc", "ghc-libs"} & set(package_db[pkg]["depends"]) and package_db[pkg]["arch"] != "any")):
           return reverse_deps[pkg]

        for _pkg in package_db:
            if args.haskell_check and not ({"ghc", "ghc-libs"} & set(package_db[_pkg]["depends"]) and package_db[_pkg]["arch"] != "any"):
                continue
            if ({pkg} | package_db[pkg]["provides"]) & (package_db[_pkg]["depends"] | package_db[_pkg]["makedepends"] | \
                                                        package_db[_pkg]["checkdepends"]):
                reverse_deps[pkg].add(_pkg)
                if args.haskell_check and pkg in HASKELL_DO_NOT_EXPAND:
                    pass
                else:
                    if args.expand and ({pkg} | package_db[pkg]["provides"]) & package_db[_pkg]["depends"]:
                        newpkgs.add(_pkg)
                    if args.expand_make and ({pkg} | package_db[pkg]["provides"]) & package_db[_pkg]["makedepends"]:
                        newpkgs.add(_pkg)
                    if args.expand_check and ({pkg} | package_db[pkg]["provides"]) & package_db[_pkg]["checkdepends"]:
                        newpkgs.add(_pkg)

    # Handle empty gracefully
    reverse_deps[pkg]
    return newpkgs


# Expand reverse dependencies
while len(pkglist) != len(reverse_deps):
    for pkg in list(pkglist):
        if pkg not in reverse_deps:
            pkglist |= resolve_pkg(pkg)


G.add_nodes_from(reverse_deps.keys())
for dep, pkgs in reverse_deps.items():
    for pkg in pkgs:
        if pkg in pkglist:
            G.add_edge(dep, pkg)


def flatten_frozenset(nodes):
    new_set = set()
    for node in nodes:
        if isinstance(node, (set, frozenset)):
            new_set |= node
        else:
            new_set.add(node)
    return frozenset(new_set)


for cycle in list(nx.algorithms.components.strongly_connected_components(G)):
    if len(cycle) == 1:
        continue
    logger.info("Found circular dependency: " + str(cycle))
    new_pkg = flatten_frozenset(cycle)
    package_db[new_pkg] = dict(depends=set())
    for pkg in cycle:
        package_db[new_pkg]["depends"] |= package_db[pkg]["depends"]
        for _pkg in package_db:
            if pkg in package_db[_pkg]["depends"]:
                package_db[_pkg]["depends"].add(new_pkg)

    for edge in list(G.edges):
        if edge[0] in cycle:
            G.add_edge(new_pkg, edge[1])
        if edge[1] in cycle:
            G.add_edge(edge[0], new_pkg)
    G.remove_nodes_from(cycle)


try:
    while cycles := nx.find_cycle(G):
        for cycle in cycles:
            if len(cycle) == 2 and cycle[0] == cycle[1]:
                logger.warning("Removing edge to resolve self-loop: " + str(cycle))
                G.remove_edge(cycle[0], cycle[1])
            else:
                logger.error("Found circular dependency after resolving: " + str(cycle))
                exit(1)
except nx.NetworkXNoCycle:
    pass


def pass1(pkg):
    NG = nx.DiGraph()
    NG.add_nodes_from(pkg)
    for dep, pkgs in reverse_deps.items():
        for _pkg in pkgs:
            if _pkg in pkg and dep in pkg:
                NG.add_edge(dep, _pkg)
    nocheck_pkg = set()

    while nx.algorithms.components.number_strongly_connected_components(NG) != len(NG):
        found = False
        while True:
            node = random.choice(list(NG))
            if node in HASKELL_REAL_MAKEDEPEND:
                continue
            for rdep in list(NG.successors(node)):
                if node not in package_db[rdep]["depends"]:
                    logger.warning(f"Removing soft dep to solve circular dependency: {node} <- {rdep}")
                    NG.remove_edge(node, rdep)
                    nocheck_pkg.add(rdep)
                    found = True
                    break

            if found:
                break

        if not found:
            logger.error(f"Removing HARD dep to solve circular dependency:  {node} <- {rdep}")
            NG.remove_edge(node, rdep)
            nocheck_pkg.add(rdep)

    pass1 = []
    for _pkg in nx.topological_sort(NG):
        if _pkg in nocheck_pkg:
            pass1.append(_pkg + ":nocheck")
        else:
            pass1.append(_pkg)
    logger.debug("Cycle Solver pass1: " + " ".join(pass1))
    return pass1, NG, nocheck_pkg

def pass2(NG, nocheck_pkg):
    best_pass2 = False
    nocheck_pkg_bak = nocheck_pkg.copy()
    # Try 3 times, use the most optimal result
    pass2 = []
    broken_set = set()
    nocheck_pkg = nocheck_pkg_bak.copy()
    for _ in range(3):
        counter = 0
        while nocheck_pkg | broken_set:
            for pkg in nocheck_pkg | broken_set:
                if not (package_db[pkg]["depends"] | package_db[pkg]["makedepends"] | package_db[pkg]["checkdepends"]) & broken_set:
                    to_break = set([rdep for rdep in reverse_deps[pkg] if pkg in package_db[rdep]["depends"] and rdep in NG])
                    last_to_break = False
                    while to_break != last_to_break:
                        last_to_break = to_break
                        for _pkg in set(to_break):
                            to_break = to_break | set([rdep for rdep in reverse_deps[_pkg] if _pkg in package_db[rdep]["depends"] and rdep in NG])
                    # if not to_break <= (nocheck_pkg | broken_set):
                    #     continue
                    pass2.append(pkg)
                    broken_set |= to_break
                    if pkg in broken_set:
                        broken_set.remove(pkg)
                    if pkg in nocheck_pkg:
                        nocheck_pkg.remove(pkg)
                    break
            else:
                # Deadend. Start over.
                counter += 1
                if counter > 20:
                    logger.error("Cycle Solver pass2 unresolvable loop.")
                    raise RuntimeError
                logger.warning("Cycle Solver pass2 loop detected. Starting over...")
                pass2 = []
                broken_set = set()
                nocheck_pkg = nocheck_pkg_bak.copy()

        if best_pass2 is False or len(pass2) < len(best_pass2):
            best_pass2 = pass2.copy()

    logger.debug("Cycle Solver pass2: " + " ".join(best_pass2))
    return best_pass2

result = []
for pkg in nx.topological_sort(G):
    if isinstance(pkg, str):
        result.append(pkg)
    else:
        while True:
            _pass1, NG, nocheck_pkg = pass1(pkg)
            try:
                _pass2 = pass2(NG, nocheck_pkg)
            except RuntimeError:
                logger.warning("Cycle Solver retrying from pass1...")
                continue
            result.extend(_pass1 + _pass2)
            break

print(" ".join(result))