import numpy as np
import tensorflow as tf
import json
import matplotlib.pyplot as plt
from shutil import copyfile
import os
class LossBuild():
    def __init__(self, b_mode=True):
        self.b_mode = b_mode

    def reshape_input(self, x):
        return tf.transpose(x, [0, 2, 1])

    def b_mode_loss(self, x_src, y_true, y_pred):
        true_concat = tf.abs(tf.reduce_sum(tf.multiply(x_src, y_true), axis=-1))
        pred_concat = tf.abs(tf.reduce_sum(tf.multiply(x_src, y_pred), axis=-1))

        true_b_mode = np.divide(true_concat, tf.reduce_max(true_concat, axis=-1, keepdims=True))
        pred_b_mode = np.divide(pred_concat, tf.reduce_max(pred_concat, axis=-1, keepdims=True))

        loss = tf.reduce_mean(tf.math.square(tf.math.log1p(true_concat) - tf.math.log1p(pred_concat)))

        # loss = tf.reduce_mean(tf.abs(tf.log(tf.divide(true_b_mode, pred_b_mode))))
        # loss = tf.reduce_mean(tf.divide(pred_b_mode, true_b_mode))

        return loss

    def test_loss(self, y_true, y_pred):
        diff = y_true - y_pred
        real_ = tf.real(diff)
        imag_ = tf.imag(diff)
        # loss = tf.reduce_mean(real_ ** 2 + imag_ ** 2)
        loss = tf.reduce_mean(real_ ** 2 + imag_ ** 2 + 2 * real_ * imag_)
        return loss

    def complex_mse(self, y_true, y_pred):
        # true_concat = tf.abs(tf.reduce_sum(tf.multiply(x_src, y_true), axis=-1))
        # pred_concat = tf.abs(tf.reduce_sum(tf.multiply(x_src, y_pred), axis=-1))

        loss = tf.reduce_mean(tf.abs(tf.math.square(tf.subtract(y_true, y_pred))))
        return loss

    def complex_angle(self, y_true, y_pred):
        # loss = tf.reduce_mean(tf.square(tf.subtract(y_true, y_pred)))
        y_true_ = tf.layers.flatten(y_true)
        y_pred_ = tf.layers.flatten(y_pred)

        loss = tf.reduce_mean(tf.abs(tf.math.angle(y_true_ - y_pred_)))

        # loss = 0.5 * tf.abs(
        #    tf.subtract(1.0, tf.losses.cosine_distance(tf.real(y_true_), tf.real(y_pred_), axis=-1))) + 0.5 * tf.abs(
        #    tf.subtract(1.0, tf.losses.cosine_distance(tf.imag(y_true_), tf.imag(y_pred_), axis=-1)))
        return loss


def output_append(output_arr, output, iteration, batch_size, append_axis=0):
    # output_arr's shape: [n_data, num_seq, nz, nx]
    # output's shape: [batch_size, num_seq, n_data]

    n_data, num_seq, nz, nx = output_arr.shape

    if append_axis == 0:
        x_, y_ = (iteration * batch_size) // nx, (iteration * batch_size) % nx
    elif append_axis == 1:
        x_, y_ = (iteration * batch_size) % nz, (iteration * batch_size) // nz

    for x in range(batch_size):
        if append_axis == 0:
            x_iter, y_iter = x_, y_
            for seq_iter in range(num_seq):
                if y_iter >= nx:
                    x_iter = x_iter + 1
                    y_iter = 0

                if x_iter >= nz:
                    x_iter = 0
                # print(seq_iter, x_iter, y_iter,y_iter + 1,n_data,x)
                output_arr[:, seq_iter, x_iter, y_iter: y_iter + 1] = output[x, seq_iter, :].reshape(n_data, 1)
                y_iter = y_iter + 1
            y_ = y_ + 1
            if y_ >= nx:
                x_ = x_ + 1
                y_ = 0
            if x_ >= nz:
                x_ = 0

        elif append_axis == 1:
            x_iter, y_iter = x_, y_
            for seq_iter in range(num_seq):
                if x_iter >= nz:
                    y_iter = y_iter + 1
                    x_iter = 0
                if y_iter >= nx:
                    y_iter = 0
                output_arr[:, seq_iter, x_iter: x_iter + 1, y_iter] = output[x, seq_iter, :].reshape(n_data, 1)
                x_iter = x_iter + 1
            x_ = x_ + 1
            if x_ >= nz:
                y_ = y_ + 1
                x_ = 0
            if y_ >= nx:
                y_ = 0

    return output_arr


def output_appendfc(output_arr, output, iteration, batch_size, append_axis=0):
    # output_arr's shape: [n_data, num_seq, nz, nx]
    # output's shape: [batch_size, num_seq, n_data]

    n_data, num_seq, nz, nx = output_arr.shape

    if append_axis == 0:
        x_, y_ = (iteration * batch_size) // nx, (iteration * batch_size) % nx
    elif append_axis == 1:
        x_, y_ = (iteration * batch_size) % nz, (iteration * batch_size) // nz

    for x in range(batch_size):
        if append_axis == 0:
            x_iter, y_iter = x_, y_
            for seq_iter in range(num_seq):
                if y_iter >= nx:
                    x_iter = x_iter + 1
                    y_iter = 0

                if x_iter >= nz:
                    x_iter = 0
                # print(seq_iter, x_iter, y_iter,y_iter + 1,n_data,x)
                output_arr[:, seq_iter, x_iter, y_iter: y_iter + 1] = output[x, :].reshape(n_data, 1)
                y_iter = y_iter + 1
            y_ = y_ + 1
            if y_ >= nx:
                x_ = x_ + 1
                y_ = 0
            if x_ >= nz:
                x_ = 0

        elif append_axis == 1:
            x_iter, y_iter = x_, y_
            for seq_iter in range(num_seq):
                if x_iter >= nz:
                    y_iter = y_iter + 1
                    x_iter = 0
                if y_iter >= nx:
                    y_iter = 0
                output_arr[:, seq_iter, x_iter: x_iter + 1, y_iter] = output[x, seq_iter, :].reshape(n_data, 1)
                x_iter = x_iter + 1
            x_ = x_ + 1
            if x_ >= nz:
                y_ = y_ + 1
                x_ = 0
            if y_ >= nx:
                y_ = 0
    return output_arr


def recover_image(pre_outputs, true_outputs, nz, nx):
    input_size = pre_outputs.shape[1]
    recove_pre = np.zeros((nz, nx))
    recove_true = np.zeros((nz, nx))
    for i in range(nz):
        for j in range(nx):
            index = i * nx + j
            recove_pre[i, j] = pre_outputs[int(index / input_size), (index % input_size)]
            recove_true[i, j] = true_outputs[int(index / input_size), (index % input_size)]
    return recove_pre, recove_true


def record_json(data, file_path):
    with open(file_path, 'w') as save_f:
        json.dump(data, save_f, indent=4, separators=(',', ':'))



def save_convergence_figure(train_loss, vaild_loss, save_path):
        plt.figure()
        plt.plot(train_loss, c='#00BFFF', label='train')
        plt.plot(vaild_loss, c='r', label='valid')
        plt.legend()
        plt.savefig(save_path)
        plt.clf()
        plt.close()

def arjovski_init(shape, dtype=tf.float32, partition_info=None):
    print("Arjosky basis initialization.")
    assert shape[0] == shape[1]
    omega1 = np.random.uniform(-np.pi, np.pi, shape[0])
    omega2 = np.random.uniform(-np.pi, np.pi, shape[0])
    omega3 = np.random.uniform(-np.pi, np.pi, shape[0])

    vr1 = np.random.uniform(-1, 1, [shape[0], 1])
    vi1 = np.random.uniform(-1, 1, [shape[0], 1])
    v1 = vr1 + 1j * vi1
    vr2 = np.random.uniform(-1, 1, [shape[0], 1])
    vi2 = np.random.uniform(-1, 1, [shape[0], 1])
    v2 = vr2 + 1j * vi2

    D1 = np.diag(np.exp(1j * omega1))
    D2 = np.diag(np.exp(1j * omega2))
    D3 = np.diag(np.exp(1j * omega3))

    vvh1 = np.matmul(v1, np.transpose(np.conj(v1)))
    beta1 = 2. / np.matmul(np.transpose(np.conj(v1)), v1)
    R1 = np.eye(shape[0]) - beta1 * vvh1

    vvh2 = np.matmul(v2, np.transpose(np.conj(v2)))
    beta2 = 2. / np.matmul(np.transpose(np.conj(v2)), v2)
    R2 = np.eye(shape[0]) - beta2 * vvh2

    perm = np.random.permutation(np.eye(shape[0], dtype=np.float32)) \
           + 1j * np.zeros(shape[0])

    fft = np.fft.fft
    ifft = np.fft.ifft

    step1 = fft(D1)
    step2 = np.matmul(R1, step1)
    step3 = np.matmul(perm, step2)
    step4 = np.matmul(D2, step3)
    step5 = ifft(step4)
    step6 = np.matmul(R2, step5)
    unitary = np.matmul(D3, step6)
    eye_test = np.matmul(np.transpose(np.conj(unitary)), unitary)
    unitary_test = np.linalg.norm(np.eye(shape[0]) - eye_test)
    print('I - Wi.H Wi', unitary_test, unitary.dtype)
    assert unitary_test < 1e-10, "Unitary initialization not unitary enough."
    stacked = np.stack([np.real(unitary), np.imag(unitary)], -1)
    assert stacked.shape == tuple(shape), "Unitary initialization shape mismatch."
    return tf.constant(stacked, dtype)



def mod_relu(z, scope=''):
    """
        Implementation of the modRelu from Arjovski et al.
        f(z) = relu(|z| + b)(z / |z|) or
        f(r,theta) = relu(r + b)e^(i*theta)
        b is initialized to zero, this leads to a network, which
        is linear during early optimization.
    Input:
        z: complex input.
        b: 'dead' zone radius.
    Returns:
        z_out: complex output.
    """
    with tf.variable_scope('mod_relu' + scope):
        b = tf.get_variable('b', [], dtype=tf.float32,
                                                       initializer=tf.random_uniform_initializer(-0.01, 0.01))
        modulus = tf.sqrt(tf.real(z) ** 2 + tf.imag(z) ** 2)
        rescale = tf.nn.relu(modulus + b) / (modulus + 1e-6)
        # return tf.complex(rescale * tf.real(z),
        #                   rescale * tf.imag(z))
        rescale = tf.complex(rescale, tf.zeros_like(rescale))
        return tf.multiply(rescale, z)

def complex_dense(x, num_proj, scope, bias=False, bias_init_r=0.0,
                   bias_init_c=0.0, unitary=False, orthogonal=False,
                   unitary_init=arjovski_init):
    """
    Compute Ax + b.
    Input: x
    Returns: Ax + b
    """
    in_shape = tf.Tensor.get_shape(x).as_list()
    # debug_here()
    with tf.variable_scope(scope):
        # initializer
        if unitary:
            with tf.variable_scope('unitary_stiefel'):
                varU = tf.get_variable('denseU',
                                       shape=in_shape[-1:] + [num_proj] + [2],
                                       dtype=tf.float32,
                                       initializer=unitary_init)
                A = tf.complex(varU[:, :, 0], varU[:, :, 1])
        elif orthogonal:
            with tf.variable_scope('orthogonal_stiefel'):
                Ar = tf.get_variable('dense_r', in_shape[-1:] + [num_proj],
                                     dtype=tf.float32,
                                     initializer=tf.orthogonal_initializer())
                Ai = tf.get_variable('dense_i', in_shape[-1:] + [num_proj],
                                     dtype=tf.float32,
                                     initializer=tf.orthogonal_initializer())
                A = tf.complex(Ar, Ai)
        else:
            varU = tf.get_variable('denseU',
                                   shape=in_shape[-1:] + [num_proj] + [2],
                                   dtype=tf.float32,
                                   initializer=tf.glorot_uniform_initializer())
            A = tf.complex(varU[:, :, 0], varU[:, :, 1])
        if bias:
            varbr = tf.get_variable('bias_r', [num_proj], dtype=tf.float32,
                                    initializer=tf.constant_initializer(bias_init_r))
            varbc = tf.get_variable('bias_i', [num_proj], dtype=tf.float32,
                                    initializer=tf.constant_initializer(bias_init_c))
            b = tf.complex(varbr, varbc)

            # print(x.dtype,A.dtype)
            return mod_relu(tf.matmul(x, A) + b, scope)
        else:
            return mod_relu(tf.matmul(x, A), scope)
        
class EarlyStopping:
    def __init__(self, patience=5, delta=0, keep_num=10):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.loss_list = np.zeros(keep_num, np.float32) + np.inf

    def __call__(self,val_loss, sess, model, setting, checkpoint, epoch):
        score = val_loss
        if self.best_score is None:
            self.best_score = score
            self._save_model(val_loss, sess, model,setting, epoch, checkpoint)
        elif np.isnan(score):
            self.early_stop = True
            print('EarlyStopping loss nan')
        elif score >= self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self._save_model(val_loss, sess, model,setting, epoch, checkpoint)
            self.counter = 0

    def _save_model(self, val_loss, sess, model,setting, epoch, checkpoint):
            save_model_folder = os.path.join(checkpoint, setting)
            if not os.path.exists(save_model_folder):
                os.makedirs(save_model_folder)

            saved_path_model = os.path.join(save_model_folder, 'Model')

            if val_loss < self.loss_list[-1]:
                print('*********** Save Model ***********')
                index_ = np.where(val_loss < self.loss_list)[0]
                first_index = index_[0]
                temp_loss = self.loss_list[first_index:-1].copy()
                self.loss_list[first_index] = val_loss
                self.loss_list[first_index + 1:] = temp_loss
                model.saver.save(sess, saved_path_model, global_step=epoch)

    
        
        

