import os
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from scipy import io
import math


class DataLoader():
    def __init__(self, config):
        self.batch_size = config['batch_size']
        self.batch_normalized_flag = config['batch_normalized']


        self.reset()

        
    def norm_data(self):
        if self.dataset_partition== 'train':
            self.train_inputs = self.train_inputs / (self.input_norm + 1e-30)
            self.train_labels = self.train_labels / (self.MV_norm + 1e-30)
        elif self.dataset_partition== 'test':
            self.test_inputs = self.train_inputs / (self.input_norm + 1e-30)
            self.test_labels = self.train_labels / (self.MV_norm + 1e-30)


        
    def set_dataset_partition(self, mode='train'):
        self.dataset_partition = mode

    def reset(self):
        self.index = 0

    def train_shuffle(self):
        train_index = np.arange(self.train_len)
        np.random.shuffle(train_index)
        self.train_inputs = self.train_inputs[train_index, :, :]
        self.train_labels = self.train_labels[train_index, :]

    def get_iterations(self):
        if self.dataset_partition == 'train':
            return math.ceil(self.train_len / self.batch_size)
        elif self.dataset_partition == 'test':
            return math.ceil(self.test_len / self.batch_size)
        else:
            raise (ValueError, 'Error loading model.')

    def generate_train_test_seed(self, rate=0.7):
        self.train_test_rate = rate

        random_list = np.arange(self.train_len)
        np.random.shuffle(random_list)

        self.train_seed = random_list[int(np.ceil(self.train_len * (1 - rate))):]
        self.test_seed = random_list[:int(np.ceil(self.train_len * (1 - rate)))]
    def batch_normalized(self, data):
        data = data / np.std(data,ddof=1)
        return data



    def load_dataset(self, data_path):
        if self.dataset_partition == 'train':
            # [train_len, time_steps, input_size]
            # [train_len, input_size]
            data = io.loadmat(data_path)
            self.train_inputs = data['train_inputs']
            self.train_labels = data['train_labels']
            self.train_len = self.train_inputs.shape[0]
            self.seq_len = self.train_inputs.shape[1]  # The times steps
            self.input_size = self.train_inputs.shape[2]  # The size of inputs
            self.output_size = self.input_size  # The size of outputs
            
        elif self.dataset_partition == 'test':
            data = io.loadmat(data_path)
            # [n_figure, test_len, time_step, input_data]
            # [n_sample, test_len, input_data]
            self.test_inputs = data['test_inputs']
            self.test_label = data['test_labels']
            self.test_name = data['test_name']
            
            self.test_num = self.test_inputs.shape[0]  # The number of test images
            self.test_len = self.test_inputs.shape[1]
            self.seq_len = self.test_inputs.shape[2]  # The times steps
            self.input_size = self.test_inputs.shape[-1]
            self.output_size = self.input_size  # The size of outputs

        
    def load_ultrasonic_data_from_inputs_label(self, data_path):
        print('Load Path:', data_path)
        data = io.loadmat(data_path)
        self.train_inputs = data['train_inputs']
        self.train_labels = data['train_labels']

        self.test_inputs = data['test_inputs']
        self.test_label = data['test_labels']
        self.test_name = data['test_name']
        
        self.test_num = self.test_inputs.shape[0]  # The number of test images
        self.train_len = self.train_inputs.shape[0]  # The number of train inputs
        self.test_len = self.test_inputs.shape[1]  # The number of test inputs
        

        self.input_size = self.train_inputs.shape[2]  # The size of inputs
        self.seq_len = self.train_inputs.shape[1]  # The times steps
        self.num_seq = self.train_inputs.shape[0]  # The number of train inputs
        self.output_size = self.input_size  # The size of outputs

    def train_dataset_shuffle(self, random_seed=1000):
        np.random.seed(random_seed)
        index_ = np.arange(self.train_len)
        np.random.shuffle(index_)
        temp_input = self.train_inputs[index_, :, :]
        temp_output = self.train_labels[index_, :]

        self.train_inputs = temp_input
        self.train_labels = temp_output

    def get_train_batch(self):
        if ((self.index + 1) * self.batch_size) > self.train_len:
            batch_size = self.train_len - self.index * self.batch_size
        else:
            batch_size = self.batch_size
        start_index = self.index * self.batch_size
        end_index = start_index + batch_size
        batch_input = self.train_inputs[start_index: end_index, :, :]
        batch_label = self.train_labels[start_index: end_index, :]
        self.index += 1
        self.last_batch_input = batch_input
        self.last_batch_label = batch_label
        # if self.batch_normalized_flag:
        #     batch_input = self.batch_normalized(batch_input)
        #     batch_label = self.batch_normalized(batch_label)
        

        return batch_input, batch_label

    def get_test_batch(self, num):
        if ((self.index + 1) * self.batch_size) > self.test_len:
            batch_size = self.test_len - self.index * self.batch_size
        else:
            batch_size = self.batch_size
        start_index = self.index * self.batch_size
        end_index = start_index + batch_size
        batch_input = self.test_inputs[num, start_index: end_index, :, :].reshape(batch_size, self.seq_len,
                                                                                  self.output_size)
        batch_label = self.test_label[num, start_index: end_index, :].reshape(batch_size, self.output_size)
        
        self.index += 1
        # if self.batch_normalized_flag:
        #     batch_input = self.batch_normalized(batch_input)
        #     batch_label = self.batch_normalized(batch_label)
        return batch_input, batch_label
