import argparse
from exp_main import Exp_main

parse = argparse.ArgumentParser()

# training
parse.add_argument('--batch_size', type=int, default=1024, help='Setting the batch size, default value is 16.')
parse.add_argument('--learning_rate', default=1e-4, type=float,
                   help='Setting the learning rate, the default value is 0.0001.')
parse.add_argument('--epochs', default=50, type=int, help='Setting the epochs, the default value is 50.')
parse.add_argument('--valid_flag', default=1, type=int, help='Setting the epochs, the default value is 50.')
parse.add_argument('--keep_num', default=10, type=int, help='The number of keep model.')
parse.add_argument('--patience', default=5, type=int, help='The number of keep model.')
parse.add_argument('--momentum', default=0.99, type=float, help='The number of keeping model.')
parse.add_argument('--dropout', default=1.0, type=float, help='The number of keeping model.')
parse.add_argument('--runs', default=2, type=int, help='The number of runs times.')
parse.add_argument('--norm_type', default=0, type=int, help='The number of runs times.')
parse.add_argument('--dropout_type', default=0, type=int, help='The number of runs times.')
parse.add_argument('--optimizer', default='SGD', type=str, help='SGD Or Adam.')

# model structure
parse.add_argument('--batch_normalized', default=0, type=int, help='The batch normalized flag. 1: batch normalized.')
parse.add_argument('--input_size', default=1, type=int, help='The input size of model.')
parse.add_argument('--seq_len', default=25, type=int, help='The sequence length of model.')
parse.add_argument('--output_size', default=1, type=int, help='The output size of model.')
parse.add_argument('--outputs_type', default='mean', type=str, help='The type of outputs. mean or last.')
parse.add_argument('--shuffle', default=1, type=int, help='The input shuffle flag. 1: shuffle;')
parse.add_argument('--is_training', default=1, type=int, help='training flag.')

# CgRNN
parse.add_argument('--num_proj', default=[1], type=int,nargs='+', help='The reduction dimension of CgRNN.')
parse.add_argument('--num_units', default=[256], type=int, nargs='+',help='The size of CgRNN units.')
parse.add_argument('--abs_layer', default=1, type=int, help='The size of CgRNN units.')
parse.add_argument('--dense', default=[], type=int, nargs='+',
                   help='The conditate od CgRNNwhether use the batch normalization.')
parse.add_argument('--complex_dense', default=[], type=int, nargs='+',
                   help='The conditate od CgRNNwhether use the batch normalization.')
parse.add_argument('--complex_inout', default=1, type=int,
                   help='The conditate od CgRNNwhether use the batch normalization.')
parse.add_argument('--single_gate', default=0, type=int, help='The number of gates, SINGLE or DOUBLE.')
parse.add_argument('--dim_reduce', default=1, type=int, help='The CgRNN whether run dimension reduction.')
parse.add_argument('--memory_hh_bn', default=0, type=int,
                   help='The gates of CgRNN whether use the batch normalization.')
parse.add_argument('--memory_hx_bn', default=0, type=int,
                   help='The gates of CgRNN whether use the batch normalization.')
parse.add_argument('--memory_bn', default=0, type=int, help='The gates of CgRNN whether use the batch normalization.')
parse.add_argument('--canditate_hx_bn', default=0, type=int,
                   help='The conditate hx of CgRNN whether use the batch normalization.')
parse.add_argument('--canditate_hh_bn', default=0, type=int,
                   help='The conditate hh of CgRNN whether use the batch normalization.')

# Read Data path
parse.add_argument('--disk', type=int, default=0,
                   help='The list of train dataset from disk.')
parse.add_argument('--dataset_path', type=str,
                   default='/home/lzy/UltrasonicPrediction/Dataset/ultrasonic_data/partition_data/',
                   help='The path of dataset')
parse.add_argument('--train_dataset', type=str, default='complex_train_test_normalized_std',
                   help='The name of train dataset (the dataset included the part of validating dataset)')
parse.add_argument('--test_dataset', type=str, default='complex_test_in_vivo_normalized_std',
                   help='The name of test dataset')

parse.add_argument('--train_list', type=str, default='./list/train_index.txt',
                   help='The list of train dataset from disk.')
parse.add_argument('--valid_list', type=str, default='./list/valid_index.txt',
                   help='The list of validate dataset from disk.')
parse.add_argument('--test_list', type=str, default='./list/test_index.txt',
                   help='The list of test dataset from disk.')

# Save path
parse.add_argument('--results_path', default='./results/', type=str, help='The path of saved results.')
parse.add_argument('--checkpoint', default='./checkpoints/', type=str, help='The path of saved results.')

## Load Parameters of program
print('************** args ***********')
args = parse.parse_args()
print(type(args.num_units))
print((args.dense))
print('********************')

def list_to_string(list_):
    return '_'.join(str(s) for s in list_)


if args.is_training:
    for i in range(args.runs):
        setting = '{}_disk{}_out_{}_norm_type{}_abs{}_dr{}_dropout{}_lr{}_units{}_proj{}_den{}_cden{}_sg{}_mb{}_chb{}_me{}_sf{}_run{}'.format(
            args.optimizer,
            args.disk,
            args.outputs_type,
            args.norm_type,
            args.abs_layer,
            args.dim_reduce,
            args.dropout,
            args.learning_rate,
            list_to_string(args.num_units),
            list_to_string(args.num_proj),
            list_to_string(args.dense),
            list_to_string(args.complex_dense),
            args.single_gate,
            args.memory_bn,
            args.canditate_hh_bn,
            int(args.momentum * 100),
            args.shuffle,
            i
        )

        print('>>>>>>>>>>>>> ', setting, ' >>>>>>>>>>>>>')
        exp = Exp_main(args, setting)
        print('>>>>>>>>>>>>>> training >>>>>>>>>>>>>>')
        exp.train(setting)
        print('>>>>>>>>>>>>>> test >>>>>>>>>>>>>>')
        exp.test(setting)
else:
    for i in range(args.runs):
        setting = '{}_disk{}_out_{}_norm_type{}_abs{}_dr{}_dropout{}_lr{}_units{}_proj{}_den{}_cden{}_sg{}_mb{}_chb{}_me{}_sf{}_run{}'.format(
            args.optimizer,
            args.disk,
            args.outputs_type,
            args.norm_type,
            args.abs_layer,
            args.dim_reduce,
            args.dropout,
            args.learning_rate,
            list_to_string(args.num_units),
            list_to_string(args.num_proj),
            list_to_string(args.dense),
            list_to_string(args.complex_dense),
            args.single_gate,
            args.memory_bn,
            args.canditate_hh_bn,
            int(args.momentum * 100),
            args.shuffle,
            i
        )
        setting = 'act_SGD_disk0_out_mean_abs1_dr1_dropout0.9_lr0.001_units256_proj1_den64_1_cden_sg0_mb0_chb0_me100_sf1_run3'
        print('>>>>>>>>>>>>> ', setting, ' >>>>>>>>>>>>>')
        exp = Exp_main(args, setting)
        print('>>>>>>>>>>>>>> test')
        exp.test(setting)
