#!/usr/bin/env python2
# python >= 2.7
import argparse
import logging
import os
import sys
import re
import tempfile
import copy
import getpass
import subprocess
import shutil
from string import Template
from datetime import date
from multiprocessing import cpu_count, Pool

template_dir = os.path.join(os.path.expanduser('~'), '.easy_qsub')
template_default = os.path.join(template_dir, 'default.pbs')
template_default_text = '''#PBS -S /bin/bash
#PBS -N $name
#PBS -q $queue
#PBS -l ncpus=$ncpus
#PBS -l mem=$mem
#PBS -l walltime=$walltime
#PBS -V

cd $$PBS_O_WORKDIR
echo run on node: $$HOSTNAME >&2

$cmd

'''


def parse_args():
    parser = argparse.ArgumentParser(
        description='Easily submitting PBS jobs with script template. Multiple input files supported.',
        epilog='Note: if "{}" appears in a command, it will be replaced with the current filename. ' +
               'More format supported: "{%}" for basename, "{^suffix}" for clipping "suffix", ' +
               '"{%^suffix}" for clipping suffix from basename. See more: https://github.com/shenwei356/easy_qsub')

    parser.add_argument('command', type=str, help='command to submit')
    parser.add_argument('files', type=str, nargs='*', help='input files')

    group = parser.add_mutually_exclusive_group()
    group.add_argument('-lp', '--local_p', action='store_true',
                       help='run commands locally, parallelly')
    group.add_argument('-ls', '--local_s', action='store_true',
                       help='run commands locally, serially')

    parser.add_argument('-N', '--name', type=str, default="easy_qsub",
                        help='job name')
    parser.add_argument('-n', '--ncpus', type=int, default=cpu_count(),
                        help='cpu number [logical cpu number]')
    parser.add_argument('-m', '--mem', type=str, default='5gb',
                        help='memory [5gb]')
    parser.add_argument('-q', '--queue', type=str, default='batch',
                        help='queue [batch]')
    parser.add_argument('-w', '--walltime', type=str, default='30:00:00:00',
                        help='walltime [30:00:00:00]')
    parser.add_argument('-t', '--template', type=str, default=template_default,
                        help='script template')
    parser.add_argument('-o', '--outfile', type=str, help='output script')
    parser.add_argument('-v', '--verbose', help='verbosely print information. -vv for just printing command ' +
                                                'not creating scripts and submitting jobs',
                        action="count", default=0)

    args = parser.parse_args()

    if not args.template:
        args.template = template_default

    # logging level
    if args.verbose >= 2:
        level = logging.DEBUG
    elif args.verbose == 1:
        level = logging.INFO
    else:
        level = logging.WARNING
    logging.basicConfig(level=level, format="[%(levelname)s] %(message)s")

    return args


def check_default_template():
    if not os.path.exists(template_dir):
        os.mkdir(template_dir)
    if not os.path.exists(template_default):
        with open(template_default, 'wt') as fh:
            fh.write(template_default_text)


def generate_script(args):
    file_template = args.template
    if not os.path.exists(file_template):
        logging.warning(
            "Template file not found: {}. Use default template instead.".format(file_template))
        file_template = template_default
    template = Template(''.join(open(file_template, 'rt').readlines()))

    file_script = args.outfile
    if not file_script:
        tmpdir = os.path.join(tempfile.gettempdir(), 'easy_qsub-' + getpass.getuser())
        if not os.path.exists(tmpdir):
            os.mkdir(tmpdir)
        (_, file_script) = tempfile.mkstemp(prefix=str(date.today()) + '_', suffix='.qsub', dir=tmpdir)

    open(file_script, 'wt').write(template.substitute({'name': args.name, 'mem': args.mem, 'queue': args.queue,
                                                       'ncpus': args.ncpus, 'walltime': args.walltime,
                                                       'cmd': args.command}))

    return file_script


def submit_job(args):
    file_script = generate_script(args)
    logging.info("create script: {} with command: {}".format(file_script, args.command))

    cmd = 'qsub {}'.format(file_script)
    # os.system(cmd)

    output = ''
    try:
        if sys.version_info[0] == 3:
            output = subprocess.getoutput(cmd)  # bug: fail to catch the exception
        else:
            output = subprocess.check_output(cmd.split())
    except:
        logging.error("fail to run: {}".format(cmd))
        os.remove(file_script)
        sys.exit(1)

    print(output)

    jobid = output.split('.')[0]
    if not re.match('^\d+$', jobid):
        logging.error("fail to run: {}".format(cmd))
        os.remove(file_script)
        sys.exit(1)

    if os.access('./', os.W_OK):
        shutil.move(file_script, '{}.s{}'.format(args.name, jobid))


if __name__ == '__main__':
    args = parse_args()

    check_default_template()

    pattern = re.compile(r'{([^{}]*)}')

    matched = pattern.findall(args.command)
    if matched:
        if args.files:
            cmds = list()  # for local run

            for file in args.files:
                # if not os.path.exists(file):
                #     logging.error('file not found: {}'.format(file))
                #     sys.exit(1)

                cmd = args.command
                for c in matched:
                    repl = ''
                    if c == '':
                        repl = file
                    elif c.startswith('%^'):
                        file = os.path.basename(file)
                        i = file.rfind(c[2:])
                        if i > 0:
                            repl = file[:file.rfind(c[2:])]
                        else:
                            repl = file
                    elif c[0] == '^':
                        i = file.rfind(c[1:])
                        if i > 0:
                            repl = file[:file.rfind(c[1:])]
                        else:
                            repl = file
                    elif c[0] == '%':
                        repl = os.path.basename(file)
                    else:
                        logging.error("unsupported fromat: {}".format(c))
                        sys.exit(1)

                    cmd = pattern.sub(repl, cmd, 1)

                if args.verbose > 1:
                    logging.debug('command: {}'.format(cmd))
                else:
                    if args.local_p or args.local_s:
                        cmds.append(cmd)
                    else:
                        args_copy = copy.copy(args)
                        args_copy.command = cmd
                        submit_job(args_copy)

            # local run
            if args.local_p:
                pool = Pool(processes=args.ncpus)
                pool.map(os.system, cmds)
                pool.close()
                pool.join()
            elif args.local_s:
                for cmd in cmds:
                    os.system(cmd)

        else:
            logging.error("'{}' found in command, but no file glob expression given")
            sys.exit(1)

    else:  # single job
        if args.files:
            logging.error('input files are given but no "{}" found in command')
            sys.exit(1)
        else:
            if args.verbose > 1:
                logging.debug('command: {}'.format(args.command))
            else:
                if args.local_p or args.local_s:
                    os.system(args.command)
                else:
                    submit_job(args)