#!/usr/bin/python3

import argparse
import random
import sys
from enum import Enum


class InterlaceOtherType(Enum):
    ZERO_ONE_PLUS = '01+'
    ZERO_PLUS_ONE = '0+1'
    ONE_ZERO_PLUS = '10+'
    ONE_PLUS_ZERO = '1+0'

    def __str__(self):
        return self.value


NUM_CASES_MIN = 1
NUM_CASES_MAX = 50
DEFAULT_CASES = NUM_CASES_MAX
ALPHABET = 'abcdefghijklmnopqrstuvwxyz'
WORD_LEN_MAX = 500
TAIL_LEN_MIN = 1
# We will reseed this with the provided seed or a random seed.
RNG = random.Random(0)


def bounded_int(string, val_min, val_max, name='Value'):
    value = int(string)
    if value < val_min or value > val_max:
        raise argparse.ArgumentTypeError(f'{name} must be in range [{val_min}, {val_max}]')
    return value


def bounded_float(string, val_min, val_max, name='Value'):
    value = float(string)
    if value < val_min or value > val_max:
        raise argparse.ArgumentTypeError(f'{name} must be in range [{val_min}, {val_max}]')
    return value


def bounded_cases(string):
    return bounded_int(string, NUM_CASES_MIN, NUM_CASES_MAX, 'num_cases')


def bounded_alphabet_len(string):
    return bounded_int(string, 1, 26, 'alphabet size')


def bounded_weight(string):
    return bounded_float(string, 0, float('inf'), 'weight')


def bounded_word_len(string):
    return bounded_int(string, 1, WORD_LEN_MAX, 'word length')


def bounded_tail_len(string):
    return bounded_int(string, 1, WORD_LEN_MAX, 'tail length')


def bounded_alphabet(string):
    if len(set([ch for ch in string])) != len(string):
        raise argparse.ArgumentTypeError(f'alphabet {string} must not have repeated characters')
    for ch in string:
        if ch not in ALPHABET:
            raise argparse.ArgumentTypeError(f'character {ch} in alphabet {string} is not allowed')
    return string


def _random_word(alphabet, word_len):
    word = ''.join(RNG.choice(alphabet) for _ in range(word_len))
    return word


def _equal_words(alphabet, word_len):
    word = _random_word(alphabet, word_len)
    return f'{alphabet} {word} {word}'


def _substring_case(alphabet, word_len):
    word = _random_word(alphabet, word_len)
    index_1 = RNG.randint(0, word_len - 1)
    index_2 = RNG.randint(0, word_len - 1)
    flipped = False
    if index_2 < index_1:
        flipped = True
        index_1, index_2 = index_2, index_1
    subword = word[index_1:index_2 + 1]
    if flipped:
        return f'{alphabet} {subword} {word}'
    return f'{alphabet} {word} {subword}'


def generate_substrings_lines(args):
    lines = []
    alphabet_size_cases = len(ALPHABET)
    if 2 * alphabet_size_cases + 1 > args.num_cases:
        alphabet_size_cases = (args.num_cases - 10) // 2
    alphabet_sizes = RNG.sample(list(range(1, len(ALPHABET) + 1)), k=alphabet_size_cases)
    for alphabet_size in alphabet_sizes:
        alphabet = ''.join(RNG.sample(ALPHABET, k=alphabet_size))
        word_len = RNG.randint(args.word_len_min, args.word_len_max)
        lines.append(_equal_words(alphabet, word_len))
        lines.append(_substring_case(alphabet, word_len))
    # Try to throw in some case where there is no machine.
    random_alphabet = ''.join(RNG.sample(ALPHABET, k=RNG.randint(2, len(ALPHABET))))
    random_word_1 = _random_word(random_alphabet, args.word_len_max)
    random_word_2 = _random_word(random_alphabet, args.word_len_max)
    if random_word_1 not in random_word_2 and random_word_2 not in random_word_1:
        lines.append(f'{random_alphabet} {random_word_1} {random_word_2}')
    # Fill in more random cases until we hit the default number of cases.
    while len(lines) < args.num_cases:
        alphabet_size = RNG.randint(1, len(ALPHABET))
        alphabet = ''.join(RNG.sample(ALPHABET, k=alphabet_size))
        word_len = RNG.randint(args.word_len_min, args.word_len_max)
        lines.append(_substring_case(alphabet, word_len))
    return lines


def validate_substrings_args(args):
    if args.word_len_min > args.word_len_max:
        raise ValueError('Word len min must be less than or equal to word len max.')


def _setup_substrings_parser(parser):
    parser.add_argument(
        '--word-len-min',
        type=bounded_word_len,
        default=1,
        help='Minimum word length for the longer word.'
    )
    parser.add_argument(
        '--word-len-max',
        type=bounded_word_len,
        default=WORD_LEN_MAX,
        help='Maximum word length for the longer word.'
    )
    parser.set_defaults(
        generate_lines=generate_substrings_lines,
        validate=validate_substrings_args,
    )


def _interlace_word(word_len, tail_len, next_largest_len, tail_char, other_char):
    tail = ''.join(tail_char for _  in range(tail_len))
    if next_largest_len is None:
        next_largest = ''
        next_largest_len = tail_len - 1
    else:
        next_largest = ''.join(tail_char for _ in range(next_largest_len))
        next_largest += other_char
    pre_tail_words = [next_largest]
    pre_tail_lens = [i for i in range(min(next_largest_len, 30))]
    weights = [2 ** i for i in pre_tail_lens]
    weights.reverse()
    while tail_len + sum(len(word) for word in pre_tail_words) < word_len:
        remaining = word_len - (tail_len + sum(len(word) for word in pre_tail_words))
        run_len_max = remaining - 1
        run_len = RNG.choices(
            pre_tail_lens[:run_len_max+1], weights=weights[:run_len_max+1], k=1
        )[0]
        pre_tail_words.append(''.join(tail_char for _ in range(run_len)) + other_char)
    pre_tail_words = RNG.sample(pre_tail_words, k=len(pre_tail_words))
    return ''.join(pre_tail_words) + tail


def generate_interlace_lines(args):
    lines = []
    tail_char, other_char = random.choices(args.alphabet, k=2)
    word_1 = _interlace_word(args.word_len, args.tail_len, args.next_largest_len, tail_char, other_char)
    if args.reverse:
        word_1 = word_1[::-1]
    other_run_lens = []
    if args.tail_len < WORD_LEN_MAX:
        other_run_lens.append(args.tail_len)
    offset = 1
    while len(other_run_lens) < args.num_cases:
        if args.tail_len + offset < WORD_LEN_MAX:
            other_run_lens.append(args.tail_len + offset)
        if len(other_run_lens) >= args.num_cases:
            continue
        if args.tail_len - offset >= 0:
            other_run_lens.append(args.tail_len - offset)
        offset += 1
        if offset >= WORD_LEN_MAX:
            break
    other_run_lens.sort()
    for run_len in other_run_lens:
        if args.other_pattern is InterlaceOtherType.ONE_PLUS_ZERO:
            word_2 = ''.join(tail_char for _ in range(run_len)) + other_char
        elif args.other_pattern is InterlaceOtherType.ONE_ZERO_PLUS:
            word_2 = tail_char + ''.join(other_char for _ in range(run_len))
        elif args.other_pattern is InterlaceOtherType.ZERO_PLUS_ONE:
            word_2 = ''.join(other_char for _ in range(run_len)) + tail_char
        elif args.other_pattern is InterlaceOtherType.ZERO_ONE_PLUS:
            word_2 = ''.join(other_char for _ in range(run_len)) + tail_char
        else:
            raise ValueError(f'Unexpected InterlaceOtherType: {args.other_pattern}')
        if RNG.random() < 0.5:
            lines.append(f'{args.alphabet} {word_1} {word_2}')
        else:
            lines.append(f'{args.alphabet} {word_2} {word_1}')
    return lines


def validate_interlace_args(args):
    if args.word_len < args.tail_len:
        raise ValueError('Word len must be greater than or equal to tail length')
    if len(args.alphabet) < 2:
        raise ValueError('Unary alphabets are not allowed for interlace cases')
    if args.next_largest_len is not None:
        if args.tail_len + args.next_largest_len + 1 > args.word_len:
            raise ValueError('Tail length plus next largest length plus 1 must be at most the word length')
        if args.next_largest_len + 1 > args.tail_len:
            raise ValueError('next-largest-len must be less than the tail length')


def _setup_interlace_parser(parser):
    parser.add_argument(
        'word_len',
        type=bounded_word_len,
        default=WORD_LEN_MAX,
        help='Word length to use.'
    )
    parser.add_argument(
        'tail_len',
        type=bounded_word_len,
        help='Tail length in the base word.',
    )
    parser.add_argument(
        '--next-largest-len',
        type=bounded_word_len,
        help='The length of the next longest run of the tail char.'
    )
    parser.add_argument(
        '--alphabet',
        type=bounded_alphabet,
        default='ab',
        help='The alphabet to use'
    )
    parser.add_argument(
        '--other-pattern',
        type=InterlaceOtherType,
        choices=list(InterlaceOtherType),
        default=InterlaceOtherType.ONE_PLUS_ZERO,
        help='The pattern for the other word.'
    )
    parser.add_argument(
        '--reverse',
        action='store_true',
        default=False,
        help='Make the longest run happen at the start of the string.'
    )
    parser.set_defaults(
        generate_lines=generate_interlace_lines,
        validate=validate_interlace_args,
    )


def generate_random_lines(args):
    lines = []
    for _ in range(args.num_cases):
        alphabet = args.alphabet
        if alphabet is None:
            alphabet = ''.join(RNG.sample(ALPHABET, k=RNG.randint(1, len(ALPHABET))))
        word_len_1 = RNG.randint(args.word_len_min, args.word_len_max)
        word_len_2 = RNG.randint(args.word_len_min, args.word_len_max)
        word_1 = _random_word(alphabet, word_len_1)
        word_2 = _random_word(alphabet, word_len_2)
        lines.append(f'{alphabet} {word_1} {word_2}')
    return lines


def validate_random_args(args):
    if args.word_len_min > args.word_len_max:
        raise argparse.ArgumentTypeError('word_len_max must be greater than or equal to word_len_min')


def _setup_random_parser(parser):
    parser.add_argument(
        '--alphabet',
        type=bounded_alphabet,
        default=None,
        help='The alphabet to use.'
    )
    parser.add_argument(
        '--word-len-min',
        type=bounded_word_len,
        default=1,
        help='The shortest word length to generate.'
    )
    parser.add_argument(
        '--word-len-max', type=bounded_word_len, default=WORD_LEN_MAX,
        help='The largest word length to generate.'
    )
    parser.set_defaults(
        generate_lines=generate_random_lines,
        validate=validate_random_args,
    )


def _parse_args():
    parser = argparse.ArgumentParser('')
    subparsers = parser.add_subparsers(dest='subparser_name')

    substrings_parser = subparsers.add_parser('substrings')
    _setup_substrings_parser(substrings_parser)

    interlace_parser = subparsers.add_parser('interlace')
    _setup_interlace_parser(interlace_parser)

    random_parser = subparsers.add_parser('random')
    _setup_random_parser(random_parser)

    parser.add_argument(
        '--num-cases',
        type=bounded_cases,
        default=DEFAULT_CASES,
        help='The number of cases to generate.',
    )
    parser.add_argument(
        '--seed', type=int, default=random.randint(0, 10000),
        help='The random number to seed the random number generator with.'
    )
    parser.add_argument(
        '--test-name', type=str,
        help='The name for the test case. E.g., "025-small-cases" will produce files '
             '025-small-cases.in and 025-small-cases.desc. If no name is specified, '
             'output will be printed to stdout.'
    )
    return parser.parse_args()


def main():
    global RNG
    args = _parse_args()
    args.validate(args)
    RNG = random.Random(args.seed)

    case_lines = args.generate_lines(args)
    case_lines = [f'{len(case_lines)}'] + case_lines

    output = '\n'.join(case_lines) + '\n'
    argv_list = list(sys.argv)
    if '--seed' not in sys.argv:
        argv_list = [argv_list[0]] + ['--seed', str(args.seed)] + argv_list[1:]
    command = ' '.join(argv_list)
    if args.test_name is not None:
        test_input_file_name = args.test_name + '.in'
        test_desc_file_name = args.test_name + '.desc'
        with open(test_input_file_name, 'w') as test_input_file:
            test_input_file.write(output)
        with open(test_desc_file_name, 'w') as test_desc_file:
            test_desc_file.write(f'Produced by:\n\t{command}\n')
    else:
        sys.stdout.write(output)


if __name__ == '__main__':
    main()
