import os
import time
import json
import shutil
import numpy as np
import tensorflow as tf
from scipy import io as sci_io
import matplotlib.pyplot as plt
from network.build_model_vector import ComplexRNN
from data_loader.data_loader_vector import DataLoader as DataLoader
from data_loader.data_loader_vector_disk import DataLoader as DataLoader_disk
from network.utils import recover_image, record_json, save_convergence_figure,EarlyStopping


class Exp_main(object):
    def __init__(self, args, setting):
        self.args = args
        self.args.train_path = os.path.join(self.args.results_path, setting, 'train')
        self.args.test_path = os.path.join(self.args.results_path, setting, 'test')
        self._init_parameter_json(setting)
        self.early_stopping_flag = 0
        self._build_modle()
        self.early_stopping = EarlyStopping(patience=self.args.patience, keep_num=self.args.keep_num)
        self.loss_list = np.zeros(self.args.keep_num, np.float32) + np.inf
        if self.args.disk:
            self.train = self.train_disk
            self.test = self.test_disk
        else:
            self.train = self.train_memory
            self.test = self.test_memory

    def _init_parameter_json(self, setting):

        self.model_params = {
            'n_units': self.args.num_units,
            'num_proj': self.args.num_proj,
            'input_seq': self.args.input_size,  # input size
            'seq_len': self.args.seq_len,  # the number of plane wave
            'output_size': self.args.output_size,
            'train_test': 'train',
            'dense': [],
            'keep_num': self.args.keep_num,
            'bn': self.args.batch_normalized
        }
        self.train_params = {
            'epochs': self.args.epochs,
            'batch_size': self.args.batch_size,
            'learning_rate': self.args.learning_rate,
            'optimizer': self.args.optimizer,  # tf.train.GradientDescentOptimizer OR AdamOptimizer
            'outputs_type': self.args.outputs_type,  # mean Or last
            'dropout': self.args.dropout,
            'train_path': os.path.join(self.args.results_path, setting, 'train'),
            'test_path': os.path.join(self.args.results_path, setting, 'test'),
        }
        self.data_loader_config = {
            'batch_size': self.args.batch_size,
            'shuffle': self.args.shuffle,
            'batch_normalized': self.args.batch_normalized,
            'input_size': self.args.batch_size,
            'seq_len': self.args.batch_size,
            'output_size': self.args.batch_size,
            'norm_type':self.args.norm_type,
        }
    def early_stopping(self,loss):

        self.early_stopping_flag = 1

    def _save_json(self, setting):
        save_model_folder = os.path.join(self.args.checkpoint, setting)
        if not os.path.exists(save_model_folder):
            os.makedirs(save_model_folder)

        # Save the model parameters
        model_params_path = os.path.join(self.args.checkpoint, setting, 'model_params.json')
        with open(model_params_path, 'w') as save_f:
            json.dump(self.model_params, save_f, indent=4, separators=(',', ':'))

        # Save the train parameters
        train_params = os.path.join(self.args.checkpoint, setting, 'train_params.json')
        with open(train_params, 'w') as save_f:
            json.dump(self.train_params, save_f, indent=4, separators=(',', ':'))

        # Save the args
        argsDict = self.args.__dict__
        setting_path = os.path.join(self.args.checkpoint, setting, 'args.txt')
        with open(setting_path, 'w') as f:
            f.writelines('------------------ start ------------------' + '\n')
            for eachArg, value in argsDict.items():
                f.writelines(eachArg + ' : ' + str(value) + '\n')
            f.writelines('------------------- end -------------------')

        # Save the network structures
        try:
            net_save_path = os.path.join(self.args.checkpoint, setting, 'network')
            shutil.copytree('./network', net_save_path)
        except FileExistsError:
            net_save_path = os.path.join(self.args.checkpoint, setting, 'network')
            shutil.rmtree(net_save_path)
            shutil.copytree('./network', net_save_path)

    def save_convergence_figure(self, train_loss, valid_loss, save_path):
        plt.figure()
        plt.plot(train_loss, c='#00BFFF', label='train')
        plt.plot(valid_loss, c='r', label='valid')
        plt.legend()
        plt.savefig(save_path)
        plt.clf()
        plt.close()

    def _build_modle(self):
        self.CgRNN_model = ComplexRNN(self.args)
        self.CgRNN_model.build()
        self.graph = self.CgRNN_model.graph

    def recover_image(self, pre_outputs, true_outputs, nz, nx):
        input_size = pre_outputs.shape[1]
        recove_pre = np.zeros((nz, nx))
        recove_true = np.zeros((nz, nx))
        for i in range(nz):
            for j in range(nx):
                index = i * nx + j
                recove_pre[i, j] = pre_outputs[int(index / input_size), (index % input_size)]
                recove_true[i, j] = true_outputs[int(index / input_size), (index % input_size)]
        return recove_pre, recove_true

    def recover_image_disk(self, pre_outputs, true_outputs, nz, nx):
        input_size = pre_outputs.shape[1]
        recove_pre = np.zeros((nz, nx))
        recove_true = np.zeros((nz, nx))
        count  =0
        # print(pre_outputs.shape)
        for j in range(nx):  # 96
            for i in range(nz): # 1024
                # print(count)
                recove_pre[i, j] = pre_outputs[count,0]
                recove_true[i, j] = true_outputs[count,0]
                count = count +1
                # recove_pre[i, j] = pre_outputs[int(index / input_size), (index % input_size)]
                # recove_true[i, j] = true_outputs[int(index / input_size), (index % input_size)]
        return recove_pre, recove_true

    def _save_model(self, valid_loss, sess, setting, epoch):
        save_model_folder = os.path.join(self.args.checkpoint, setting)
        if not os.path.exists(save_model_folder):
            os.makedirs(save_model_folder)

        saved_path_model = os.path.join(save_model_folder, 'Model')

        if valid_loss < self.loss_list[-1]:
            print('*********** Save Model ***********')
            index_ = np.where(valid_loss < self.loss_list)[0]
            first_index = index_[0]
            temp_loss = self.loss_list[first_index:-1].copy()
            self.loss_list[first_index] = valid_loss
            self.loss_list[first_index + 1:] = temp_loss
            self.CgRNN_model.saver.save(sess, saved_path_model, global_step=epoch)

    def train_memory(self, setting):

        gpu_options = tf.GPUOptions(allow_growth=True,
                                    visible_device_list=str(0),
                                    per_process_gpu_memory_fraction=1.0)
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False,
                                gpu_options=gpu_options)
        with tf.Session(graph=self.graph, config=config) as sess:

            # Save the parameters and args
            self._save_json(setting)

            # Initialization global variables
            sess.run(tf.compat.v1.global_variables_initializer())

            # Load data

            data_loader_train = DataLoader(self.data_loader_config)
            data_loader_train.set_dataset_partition('train')
            data_loader_train.load_dataset(os.path.join(self.args.dataset_path, self.args.train_dataset))

            # ************** Train ***************
            train_time = time.time()
            print('****** Train *******')
            self.CgRNN_model.set_train_test('train')

            epoch_valid_loss = []  # convergence valid data
            epoch_train_loss = []  # convergence train data
            train_print_f = os.path.join(self.args.results_path, setting, 'train_print.txt')
            self.CgRNN_model.train_writer.add_graph(sess.graph)
            for epoch in range(self.args.epochs):

                ## Load train dataset
                data_loader_train.set_dataset_partition('train')
                data_loader_train.train_shuffle()
                data_loader_train.reset()
                iterations = data_loader_train.get_iterations()

                train_f = open(train_print_f, 'a')

                train_loss_iteration = []
                time_iteration = time.time()
                time_epoch = time.time()
                for i in range(iterations):
                    x, y = data_loader_train.get_train_batch()

                    feed_dict = {self.CgRNN_model.input_data: x, self.CgRNN_model.true_data: y,
                                 self.CgRNN_model.momentum: self.args.momentum,
                                 self.CgRNN_model.dropout: self.args.dropout,
                                 self.CgRNN_model.training: self.args.is_training}

                    output_feed = [self.CgRNN_model.opt,
                                   self.CgRNN_model.loss_summary,
                                   self.CgRNN_model.loss,
                                   self.CgRNN_model.outputs,
                                   ]

                    _, loss_summary, loss, output = sess.run(output_feed, feed_dict=feed_dict)
                    train_loss_iteration.append(loss)
                    self.CgRNN_model.train_writer.add_summary(loss_summary, epoch * iterations + i)

                    if ((i % 100) == 0) and (i != 0):
                        left_time = (time.time() - time_iteration) * ((self.args.epochs - epoch) * iterations - i) / 100
                        train_p_line = 'Iteration: {} ---- Loss: {:.4} ---- Time: {:.6} ---- Left time:{:.6}'.format(
                            i, np.mean(train_loss_iteration), time.time() - time_iteration, left_time)
                        time_iteration = time.time()
                        print(train_p_line)

                if self.args.valid_flag:
                    valid_loss = self.vaild_memory(sess)
                    epoch_valid_loss.append(valid_loss)
                else:
                    valid_loss = 0
                epoch_train_loss.append(np.mean(train_loss_iteration))

                left_time = (time.time() - time_epoch) * (self.args.epochs - epoch - 1)
                train_p_line = 'Epoch: {} ---- Train Loss: {:.4} ---- Valid Loss: {:.4} ---- Time: {:.6} ---- Left time:{:.6}'.format(
                    epoch, np.mean(train_loss_iteration), valid_loss, time.time() - time_epoch, left_time)
                print(train_p_line)
                print(train_p_line, file=train_f)
                print(type(valid_loss))

                # Save the trained model
                self.early_stopping(valid_loss, sess, self.CgRNN_model, setting, self.args.checkpoint, epoch)

                if self.early_stopping.early_stop:
                    print("Early stopping")
                    print("Early stopping", file=train_f)
                    train_f.close()
                    break

                train_f.close()

            # Save the convergence figure
            self.save_convergence_figure(epoch_train_loss, epoch_valid_loss,
                                         os.path.join(self.args.checkpoint, setting, 'convegence.png'))
            cost = time.time() - train_time
            print('Train cost-time: %.4f' % cost)
        sess.close()

    def vaild_memory(self, sess):
        valid_time = time.time()
        ## Load valid data
        data_loader_config = self.data_loader_config
        data_loader_valid = DataLoader(data_loader_config)
        data_loader_valid.set_dataset_partition('test')
        data_loader_valid.load_dataset(os.path.join(self.args.dataset_path, self.args.train_dataset))

        iterations = data_loader_valid.get_iterations()
        test_num = data_loader_valid.test_num
        total_valid_loss = []

        for num in range(test_num):
            data_loader_valid.reset()
            valid_loss = []
            for i in range(iterations):
                x, y = data_loader_valid.get_test_batch(num)
                feed_dict = {self.CgRNN_model.input_data: x, self.CgRNN_model.true_data: y,
                             self.CgRNN_model.momentum: self.args.momentum,
                             self.CgRNN_model.dropout: 1.0, self.CgRNN_model.training: False}
                output_feed = [self.CgRNN_model.loss, self.CgRNN_model.outputs]
                loss, output = sess.run(output_feed, feed_dict=feed_dict)
                valid_loss.append(loss)
            total_valid_loss.append(np.mean(valid_loss))
        print('Valid time: ', time.time() - valid_time)

        return np.mean(total_valid_loss)

    def test_memory(self, setting):
        test_time = time.time()
        gpu_options = tf.GPUOptions(allow_growth=True,
                                    visible_device_list=str(0),
                                    per_process_gpu_memory_fraction=1.0)
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False,
                                gpu_options=gpu_options)
        with tf.Session(graph=self.graph, config=config) as sess:

            ##  Set test data loader
            data_loader_config = self.data_loader_config
            data_loader_test = DataLoader(data_loader_config)
            data_loader_test.set_dataset_partition('test')
            dataset_path = os.path.join(self.args.dataset_path, self.args.test_dataset)
            data_loader_test.load_dataset(dataset_path)
            iterations = data_loader_test.get_iterations()
            test_num = data_loader_test.test_num
            folder_path = os.path.join(self.args.results_path, setting)
            if not os.path.exists(folder_path):
                os.makedirs(folder_path)

            # Load the model
            print(self.args.checkpoint, setting)
            cpkt = tf.compat.v1.train.get_checkpoint_state(os.path.join(self.args.checkpoint, setting))

            for model_num, model_path in enumerate(cpkt.all_model_checkpoint_paths):
                print('******* Prediction Model %d *******' % model_num)
                self.CgRNN_model.saver.restore(sess, model_path)

                # Record test logs
                test_print_f = os.path.join(folder_path, 'test_print.txt')
                test_f = open(test_print_f, 'a')
                now_time = '\nlog: ' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) + '\n'
                test_f.write(now_time)
                test_f.write(dataset_path + '\n')
                test_f.write('model: ' + model_path + '\n')

                for num in range(test_num):
                    test_name = data_loader_test.test_name[num]
                    # Save the test data
                    predict_output = np.zeros((data_loader_test.test_len, data_loader_test.output_size))
                    true_y = np.zeros((data_loader_test.test_len, data_loader_test.output_size))

                    data_loader_test.reset()
                    predict_losses = []
                    predict_start_time = time.time()
                    for i in range(iterations):
                        x, y = data_loader_test.get_test_batch(num)
                        feed_dict = {self.CgRNN_model.input_data: x, self.CgRNN_model.true_data: y,
                                     self.CgRNN_model.momentum: self.args.momentum,
                                     self.CgRNN_model.dropout: 1.0, self.CgRNN_model.training: False}
                        output_feed = [self.CgRNN_model.loss, self.CgRNN_model.outputs]
                        loss, output = sess.run(output_feed, feed_dict=feed_dict)
                        if ((i + 1) * self.args.batch_size) > data_loader_test.test_len:
                            predict_output[i * self.args.batch_size:, :] = output
                        else:
                            predict_output[i * self.args.batch_size:(i + 1) * self.args.batch_size, :] = output
                        true_y[i * self.args.batch_size:(i + 1) * self.args.batch_size, :] = y
                        predict_losses.append(loss)

                    test_p_line = 'Test image: {} ---- Loss: {:.4} ---- Predict time: {:.6}'.format(
                        test_name, np.mean(predict_losses), time.time() - predict_start_time, )
                    print(test_p_line)
                    print(test_p_line, file=test_f)

                    # Save the test data
                    predict_output, true_y = self.recover_image(predict_output, true_y, 1024, 96)
                    model_save_path = os.path.join(folder_path, 'model%d' % model_num)
                    if not os.path.exists(model_save_path):
                        os.makedirs(model_save_path)
                    test_save_path = os.path.join(model_save_path, test_name.replace(' ', '') + '.mat')
                    sci_io.savemat(test_save_path, {'pre_outputs': predict_output, 'true_outputs': true_y, })

                test_f.close()
            sess.close()
            print('Total Test time: ', time.time() - test_time)

    def train_disk(self, setting):
        # Initialization global variables
        gpu_options = tf.GPUOptions(allow_growth=True,
                                    visible_device_list=str(0),
                                    per_process_gpu_memory_fraction=1.0)
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False,
                                gpu_options=gpu_options)
        with tf.Session(graph=self.graph, config=config) as sess:

            # ************** Train ***************
            train_time = time.time()

            print('****** Train Disk*******')
            # Save the parameters and args
            self._save_json(setting)

            sess.run(tf.compat.v1.global_variables_initializer())
            self.CgRNN_model.train_writer.add_graph(sess.graph)
            # Load data

            data_loader_train = DataLoader_disk(self.data_loader_config)
            data_loader_train.set_dataset_partition('train')
            data_loader_train.get_list(self.args.train_list)

            self.CgRNN_model.set_train_test('train')

            epoch_valid_loss = []  # convergence valid data
            epoch_train_loss = []  # convergence train data
            train_print_f = os.path.join(self.args.results_path, setting, 'train_print.txt')
            # data_loader_train.list_shuffle()
            for epoch in range(self.args.epochs):

                data_loader_train.list_index = 0
                data_loader_train.get_data()
                iterations = data_loader_train.get_iterations()
                data_loader_train.reset()

                train_f = open(train_print_f, 'a')
                train_loss_iteration = []
                time_iteration = time.time()
                time_epoch = time.time()
                for i in range(iterations):
                    x, y = data_loader_train.get_batch()

                    feed_dict = {self.CgRNN_model.input_data: x, self.CgRNN_model.true_data: y,
                                 self.CgRNN_model.dropout: self.args.dropout,
                                 self.CgRNN_model.training: self.args.is_training}

                    output_feed = [self.CgRNN_model.opt,
                                   self.CgRNN_model.loss_summary,
                                   self.CgRNN_model.loss,
                                   self.CgRNN_model.outputs,
                                   ]

                    _, loss_summary, loss, output = sess.run(output_feed, feed_dict=feed_dict)
                    train_loss_iteration.append(loss)
                    self.CgRNN_model.train_writer.add_summary(loss_summary, epoch * iterations + i)

                    if ((i % 100) == 0) and (i != 0):
                        left_time = (time.time() - time_iteration) * ((self.args.epochs - epoch) * iterations - i) / 100
                        train_p_line = 'Iteration: {} ---- Loss: {:.4} ---- Time: {:.6} ---- Left time:{:.6}'.format(
                            i, np.mean(train_loss_iteration), time.time() - time_iteration, left_time)
                        time_iteration = time.time()
                        print(train_p_line)

                if self.args.valid_flag:
                    valid_loss = self.vaild_disk(sess)
                    epoch_valid_loss.append(valid_loss)
                else:
                    valid_loss = 0
                epoch_train_loss.append(np.mean(train_loss_iteration))

                left_time = (time.time() - time_epoch) * (self.args.epochs - epoch - 1)
                train_p_line = 'Epoch: {} ---- Train Loss: {:.4} ---- Valid Loss: {:.4} ---- Time: {:.6} ---- Left time:{:.6}'.format(
                    epoch, np.mean(train_loss_iteration), valid_loss, time.time() - time_epoch, left_time)
                print(train_p_line)
                print(train_p_line, file=train_f)
                print(type(valid_loss))

                # Save the trained model
                self.early_stopping(valid_loss, sess, self.CgRNN_model, setting, self.args.checkpoint, epoch)

                if self.early_stopping.early_stop:
                    print("Early stopping")
                    print("Early stopping", file=train_f)
                    train_f.close()
                    break

                # Save the trained model
                # saved_path_model = os.path.join(self.args.checkpoint, setting, 'Model')
                # self.CgRNN_model.saver.save(sess, saved_path_model, global_step=epoch)

                # Save the convergence figure
                self.save_convergence_figure(epoch_train_loss, epoch_valid_loss,
                                             os.path.join(self.args.checkpoint, setting, 'convegence.png'))

                cost = time.time() - train_time
                train_f.close()
                print('Train cost-time: %.4f' % cost)
        sess.close()

    def vaild_disk(self, sess):
        valid_time = time.time()
        ## Load valid data
        data_loader_valid = DataLoader_disk(self.data_loader_config)
        data_loader_valid.set_dataset_partition('valid')
        data_loader_valid.get_list(self.args.valid_list)

        test_num = data_loader_valid.list_len
        total_valid_loss = []
        data_loader_valid.list_index = 0

        for num in range(test_num):
            data_loader_valid.get_data()
            data_loader_valid.reset()
            iterations = data_loader_valid.get_iterations()
            predict_losses = []
            for i in range(iterations):
                x, y = data_loader_valid.get_batch()
                feed_dict = {self.CgRNN_model.input_data: x, self.CgRNN_model.true_data: y,
                             self.CgRNN_model.dropout: 1.0,
                             self.CgRNN_model.training: False}
                output_feed = [self.CgRNN_model.loss, self.CgRNN_model.outputs, self.CgRNN_model.gv]
                loss, output, gv = sess.run(output_feed, feed_dict=feed_dict)

                # print(i,loss,x.shape,gv[47][11],gv[48][11],gv[49][11],gv[52][11],output[0])
                predict_losses.append(loss)
            total_valid_loss.append(np.mean(predict_losses))
        return np.mean(total_valid_loss)

    def test_disk(self, setting):
        ##  Set test data loader

        test_time = time.time()
        gpu_options = tf.GPUOptions(allow_growth=True,
                                    visible_device_list=str(0),
                                    per_process_gpu_memory_fraction=1.0)
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False,
                                gpu_options=gpu_options)
        cpkt = tf.compat.v1.train.get_checkpoint_state(os.path.join(self.args.checkpoint, setting))

        with tf.Session(graph=self.graph, config=config) as sess:
            data_loader_test = DataLoader_disk(self.data_loader_config)
            data_loader_test.set_dataset_partition('test')
            data_loader_test.get_list(self.args.test_list)

            folder_path = os.path.join(self.args.results_path, setting)
            if not os.path.exists(folder_path):
                os.makedirs(folder_path)

            for model_num, model_path in enumerate(cpkt.all_model_checkpoint_paths):
                print('*******Prediction Model %d' % model_num)
                self.CgRNN_model.saver.restore(sess, model_path)

                test_num = data_loader_test.list_len
                data_loader_test.index = 0
                data_loader_test.list_index = 0

                # Record test logs
                test_print_f = os.path.join(folder_path, 'test_print.txt')
                test_f = open(test_print_f, 'a')
                now_time = '\nlog: ' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) + '\n'
                test_f.write(now_time)
                test_f.write(self.args.test_list + '\n')
                test_f.write('model: ' + model_path + '\n')

                for num in range(test_num):
                    data_loader_test.get_data()
                    test_name = data_loader_test.name[0]
                    test_iterations = data_loader_test.get_iterations()

                    # Save the test data
                    predict_output = np.zeros((data_loader_test.len, data_loader_test.output_size))
                    true_y = np.zeros((data_loader_test.len, data_loader_test.output_size))
                    input_x = np.zeros((data_loader_test.len, data_loader_test.seq_len, data_loader_test.input_size),
                                       dtype=complex)

                    data_loader_test.reset()
                    predict_losses = []
                    predict_start_time = time.time()
                    for i in range(test_iterations):
                        x, y = data_loader_test.get_batch()
                        feed_dict = {self.CgRNN_model.input_data: x, self.CgRNN_model.true_data: y,
                                     self.CgRNN_model.dropout: 1.0, self.CgRNN_model.training: False}
                        output_feed = [self.CgRNN_model.loss, self.CgRNN_model.outputs]
                        loss, output = sess.run(output_feed, feed_dict=feed_dict)
                        if ((i + 1) * self.args.batch_size) > data_loader_test.len:
                            predict_output[i * self.args.batch_size:, :] = output
                        else:
                            predict_output[i * self.args.batch_size:(i + 1) * self.args.batch_size, :] = output
                        true_y[i * self.args.batch_size:(i + 1) * self.args.batch_size, :] = y
                        predict_losses.append(loss)

                    test_p_line = 'Test image: %s ---- Loss: %.4f ---- Predict time: %.6f' %(
                        test_name.ljust(30,' '), np.mean(predict_losses), time.time() - predict_start_time )
                    print(test_p_line)
                    print(test_p_line, file=test_f)

                    # Save the test data
                    predict_output, true_y = self.recover_image_disk(predict_output, true_y, 1024, 96)
                    model_save_path = os.path.join(folder_path, 'model%d' % model_num)
                    if not os.path.exists(model_save_path):
                        os.makedirs(model_save_path)
                    test_save_path = os.path.join(model_save_path, test_name.replace(' ', '') + '.mat')



                    # test_save_path = os.path.join(folder_path, test_name.replace(' ', '') + '.mat')
                    sci_io.savemat(test_save_path, {'pre_outputs': predict_output, 'true_outputs': true_y, })
            sess.close()
            test_f.close()
