import os
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from scipy import io
import math


class DataLoader():
    def __init__(self, config):
        self.batch_size = config['batch_size']
        self.shuffle_flag = config['shuffle']
        self.norm_type = config['norm_type']

        self.train_index = 0
        self.test_index = 0
        
        self.list_index = 0

    def set_dataset_partition(self, mode='train'):
        self.dataset_partition = mode

    def reset(self):
        self.index = 0 # 

    def train_shuffle(self):
        train_index = np.arange(self.len)
        np.random.shuffle(train_index)
        self.inputs = self.inputs[train_index, :, :]
        self.labels = self.labels[train_index, :]

    def get_iterations(self):
        if self.dataset_partition=='train':
            return math.ceil(self.len  * self.list_len / self.batch_size)
        elif self.dataset_partition=='test':
            return math.ceil(self.len   / self.batch_size)
        elif self.dataset_partition=='valid':
            return math.ceil(self.len   / self.batch_size)
        else:
            raise (ValueError, 'Error loading model.')
        
        

    def get_list(self, path):
        f = open(path)
        self.lines = f.readlines() # the content of data path 
        self.list_len = len(self.lines) # the number of samples
        f.close()
        
    def list_shuffle(self):
        self.lines = np.random.shuffle(self.lines)
        
    # def get_data(self):
    #     data = io.loadmat(self.lines[self.list_index].replace('\n',''))
    #     self.inputs = data['inputs']
    #     self.labels = data['labels']
    #
    #     self.name = data['name']
    #     self.input_size = self.inputs.shape[2]  # The size of inputs
    #     self.output_size = self.labels.shape[1]
    #     self.seq_len = self.inputs.shape[1]  # The times steps
    #     self.len = self.inputs.shape[0]  # The number of train inputs
    #     self.index = 0
    #     self.list_index += 1
    #
    #
    #     if self.dataset_partition == 'train':
    #         if self.shuffle_flag:
    #             self.train_shuffle()

    def get_data(self):
        data = io.loadmat(self.lines[self.list_index].replace('\n', ''))
        self.input_data = data['input_data']
        self.teacher_data = data['teacher_data']
        self.MV = data['MV']
        self.name = data['name']
        self.my_normalization(self.norm_type)

        self.input_size = self.inputs.shape[2]  # The size of inputs
        self.output_size = self.labels.shape[1]
        self.seq_len = self.inputs.shape[1]  # The times steps
        self.len = self.inputs.shape[0]  # The number of train inputs
        self.index = 0
        self.list_index += 1

        if self.dataset_partition == 'train':
                self.train_shuffle()

    def get_batch(self):
        if ((self.index + 1) * self.batch_size) >= self.len:
            batch_size = self.len - self.index * self.batch_size
            start_index = self.index * self.batch_size
            end_index = start_index + batch_size
            batch_input = self.inputs[start_index: end_index, :, :]
            batch_label = self.labels[start_index: end_index, :]
            self.index += 1
            if self.dataset_partition == 'train':
                if self.list_index < self.list_len:
                    self.get_data()
        else:
            batch_size = self.batch_size
            start_index = self.index * self.batch_size
            end_index = start_index + batch_size
            batch_input = self.inputs[start_index: end_index, :, :]
            batch_label = self.labels[start_index: end_index, :]
            self.index += 1

        return batch_input, batch_label

    def my_normalization(self, type):
        if type == 0:

            self.inputs = self.std_normalize(self.input_data)[:,:,np.newaxis]
            self.labels = self.get_mv(self.inputs, self.teacher_data)
        elif type == 1:
            self.inputs = self.std_normalize(self.input_data)[:,:,np.newaxis]
            self.labels = self.std_normalize(self.MV)
        elif type == 2:
            self.inputs = self.max_normalized(self.input_data)[:,:,np.newaxis]
            self.labels = self.get_mv(self.inputs, self.teacher_data)
        elif type == 3:
            self.inputs = self.max_normalized(self.input_data)[:,:,np.newaxis]
            self.labels = self.max_normalized(self.MV)
        elif type == 5:
            self.inputs = self.complex_standardization(self.input_data)[:,:,np.newaxis]
            self.labels = self.std_normalize(self.MV)
        elif type == 6:
            self.inputs = self.complex_standardization(self.input_data)[:,:,np.newaxis]
            self.labels = self.max_normalized(self.MV)
        elif type == 8:
            self.inputs = self.normalized(self.input_data)[:,:,np.newaxis]
            self.labels = self.std_normalize(self.MV)
        elif type == 9:
            self.inputs = self.normalized(self.input_data)[:,:,np.newaxis]
            self.labels = self.max_normalized(self.MV)
        else:
            raise ('type error')


    def complex_standardization(self, inputs):
        ndim = len(inputs.shape)

        input_dim = inputs.shape[-1]  # tf.Tensor.get_shape(input_centred_real).as_list()[-1]
        variances_broadcast = [1] * ndim
        variances_broadcast[-1] = input_dim

        input_centred_real = np.real(inputs)
        input_centred_imag = np.imag(inputs)
        Vrr = np.mean(input_centred_real ** 2)
        Vii = np.mean(input_centred_imag ** 2)
        Vri = np.mean(input_centred_real * input_centred_imag)

        # We require the covariance matrix's inverse square root. That first requires
        # square rooting, followed by inversion (I do this in that order because during
        # the computation of square root we compute the determinant we'll need for
        # inversion as well).

        # tau = Vrr + Vii = Trace. Guaranteed >= 0 because SPD
        tau = Vrr + Vii
        # delta = (Vrr * Vii) - (Vri ** 2) = Determinant. Guaranteed >= 0 because SPD
        delta = (Vrr * Vii) - (Vri ** 2)

        s = np.sqrt(delta)  # Determinant of square root matrix
        t = np.sqrt(tau + 2 * s)

        # The square root matrix could now be explicitly formed as
        #       [ Vrr+s Vri   ]
        # (1/t) [ Vir   Vii+s ]
        # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
        # but we don't need to do this immediately since we can also simultaneously
        # invert. We can do this because we've already computed the determinant of
        # the square root matrix, and can thus invert it using the analytical
        # solution for 2x2 matrices
        #      [ A B ]             [  D  -B ]
        # inv( [ C D ] ) = (1/det) [ -C   A ]
        # http://mathworld.wolfram.com/MatrixInverse.html
        # Thus giving us
        #           [  Vii+s  -Vri   ]
        # (1/s)(1/t)[ -Vir     Vrr+s ]
        # So we proceed as follows:

        inverse_st = 1.0 / ((s * t) + 1e-3)
        # print(inverse_st)
        Wrr = (Vii + s) * inverse_st
        Wii = (Vrr + s) * inverse_st
        Wri = -Vri * inverse_st

        # And we have computed the inverse square root matrix W = sqrt(V)!
        # Normalization. We multiply, x_normalized = W.x.

        # The returned result will be a complex standardized input
        # where the real and imaginary parts are obtained as follows:
        # x_real_normed = Wrr * x_real_centred + Wri * x_imag_centred
        # x_imag_normed = Wri * x_real_centred + Wii * x_imag_centred

        # broadcast_Wrr = np.reshape(Wrr, variances_broadcast)
        # broadcast_Wri = np.reshape(Wri, variances_broadcast)
        # broadcast_Wii = np.reshape(Wii, variances_broadcast)

        # cat_W_4_real = tf.cast(K.concatenate([broadcast_Wrr, broadcast_Wii], axis=axis), tf.float32)
        # cat_W_4_imag = tf.cast(K.concatenate([broadcast_Wri, broadcast_Wri], axis=axis), tf.float32)

        outputs_real = Wrr * input_centred_real + Wri * input_centred_imag
        outputs_imag = Wri * input_centred_real + Wii * input_centred_imag
        # rolled_input = tf.cast(K.concatenate([centred_imag, centred_real], axis=axis), tf.float32)

        # output = cat_W_4_real * input_centred + cat_W_4_imag * rolled_input

        #   Wrr * x_real_centered | Wii * x_imag_centered
        # + Wri * x_imag_centered | Wri * x_real_centered
        # -----------------------------------------------
        # = output

        return outputs_real + outputs_imag * 1j

    def std_normalize(self, inputs):
        outputs = inputs/ (np.std(np.abs(inputs), ddof=1))
        return outputs

    def max_normalized(self, inputs):
        outputs = inputs / (np.max(np.abs(inputs)))
        return outputs

    def normalized(self,data):
        # normalized

        real_ = data.real
        imag_ = data.imag
        # print(np.mean(real_), np.std(real_, ddof=1), np.mean(imag_), np.std(imag_, ddof=1))

        real_ = (real_ - np.mean(real_)) / np.std(real_, ddof=1)
        imag_ = (imag_ - np.mean(imag_)) / np.std(imag_, ddof=1)
        data = real_ + 1j * imag_
        return data

    def get_mv(self,input,tearch):
        MV = np.zeros((input.shape[0], 1))

        for x in range(input.shape[0]):
            MV[x, 0] = np.abs(np.vdot(input[x, :], tearch[x, :]))
        return MV

            
        
