import os
import numpy as np
import tensorflow as tf

from tensorflow.python.ops.rnn import static_rnn
from tensorflow.python.ops.rnn import dynamic_rnn
from tensorflow.python.ops import variable_scope
from network.custom_optimizers import RMSpropNatGrad
import network.cells_unit
import network.optimizer_unit
from network.cgRNN import arjovski_init, complex_matmul
from network.cgRNN import ComplexGatedRecurrentUnit_bn as cgRNNcell_bn
# from network.cgRNN1 import ComplexGatedRecurrentUnit as cgRNNcell

from network.utils import LossBuild,complex_dense


def log10(x):
    numerator = tf.log(x)
    denominator = tf.log(tf.constant(10, dtype=numerator.dtype))
    return numerator / denominator



def mod_relu(z, scope='', reuse=None):
    """
        Implementation of the modRelu from Arjovski et al.
        f(z) = relu(|z| + b)(z / |z|) or
        f(r,theta) = relu(r + b)e^(i*theta)
        b is initialized to zero, this leads to a network, which
        is linear during early optimization.
    Input:
        z: complex input.
        b: 'dead' zone radius.
    Returns:
        z_out: complex output.
    """
    with tf.variable_scope('mod_relu' + scope, reuse=reuse):
        b = tf.get_variable('b', [], dtype=tf.float32,
                                                       initializer=tf.initializers.glorot_uniform)
        modulus = tf.sqrt(tf.real(z) ** 2 + tf.imag(z) ** 2)
        rescale = tf.nn.relu(modulus + b) / (modulus + 1e-6)
        # return tf.complex(rescale * tf.real(z),
        #                   rescale * tf.imag(z))
        rescale = tf.complex(rescale, tf.zeros_like(rescale))
        return tf.multiply(rescale, z)

class ComplexRNN(object):
    def __init__(self, args):

        self.args = args
        self.train_writer = tf.summary.FileWriter(self.args.train_path)
        self.test_writer = tf.summary.FileWriter(self.args.test_path)

    def set_train_test(self, mode):
        self.train_test = mode


    def build(self):
        """
        The structure
            C-value inputs --CgRNN -- Mean/Last -- C-R --ABS -- outputs
        """
        graph = tf.Graph()
        with graph.as_default():

            # with tf.name_scope('global_step'):
                # self.global_step = tf.Variable(0, trainable=False, name='global_step')

            with tf.name_scope('input'):
                self.input_data = tf.placeholder(tf.complex64, shape=[None, self.args.seq_len, self.args.input_size],
                                                 name="input_data")
            with tf.name_scope('target'):
                self.true_data = tf.placeholder(tf.float32, shape=[None, self.args.output_size],
                                                name="true_data")
            with tf.name_scope('dropout'):
                self.dropout = tf.placeholder(tf.float32, shape=[], name="dropout")

            with tf.name_scope('training'):
                self.training = tf.placeholder(tf.bool, shape=[], name="training")

            with tf.name_scope('momentum'):
                self.momentum = tf.placeholder(tf.float32, shape=[], name="momentum")

                ' num_units, training=None, dropout=None, activation=mod_relu, num_proj=None, '
                'reuse=None, single_gate=False, complex_inout=False, complex_out=False'
            # CgRNN layer

            # rnn_layers = [
            #     cgRNNcell(size, training=self.training, dropout=self.dropout, num_proj=n,
            #                  complex_inout=self.args.complex_inout) for size, n in
            #     zip(self.args.num_units, self.args.num_proj)]
            rnn_layers = [cgRNNcell_bn(size, training=self.training, momentum=self.momentum, dropout=self.dropout,num_proj=n,
                             complex_inout=self.args.complex_inout, dim_reduce=self.args.dim_reduce,
                             memory_bn=self.args.memory_bn, canditate_hh_bn=self.args.canditate_hh_bn,canditate_hx_bn=self.args.canditate_hx_bn,
                             dropout_type=self.args.dropout_type) for size, n in
                zip(self.args.num_units, self.args.num_proj)]
                
            multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
            input_data = tf.unstack(self.input_data, axis=1)
            outputs, state = static_rnn(multi_rnn_cell, input_data, dtype=tf.complex64)
            print('**** output', outputs)

            # Mean / Last
            if self.args.outputs_type == 'mean':
                rnn_outputs = (tf.reduce_mean(outputs, axis=0))
            elif self.args.outputs_type == 'last':
                rnn_outputs = (outputs[-1])
            else:
                raise ValueError('The type of outputs is not exist. Please select from \'mean\' and \'last\'')

            # complex-valued dimension reduction
            dr_outputs = self.dimension_reduction(rnn_outputs, self.args.complex_dense)

            # C to R
            C_R_outputs = self.C_R_layer(dr_outputs, self.args.abs_layer)

            # real-valued dimension reduction
            for dense_layer in range(len(self.args.dense)):
                with tf.name_scope('dense' + str(dense_layer)):
                    C_R_outputs = tf.layers.dense(C_R_outputs, self.args.dense[dense_layer], activation=tf.nn.relu,
                                                   use_bias=True, name='dense' + str(dense_layer))

            self.outputs = C_R_outputs

            # loss is the real-valued MSE
            with tf.name_scope('loss'):
                self.loss = tf.losses.mean_squared_error(self.true_data, self.outputs)
            self.loss_summary = tf.summary.scalar('loss/loss', self.loss)


            with tf.name_scope('opt'):
                if self.args.optimizer == 'SGD':
                    self.opt = tf.train.GradientDescentOptimizer(self.args.learning_rate).minimize(self.loss)
                elif self.args.optimizer == 'Adam':
                    self.opt = tf.train.AdamOptimizer(self.args.learning_rate).minimize(self.loss,
                                                                                   global_step=self.global_step)
                if self.args.memory_bn or self.args.canditate_hh_bn or self.args.canditate_hx_bn:
                    # update the moving mean and variance
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    self.opt = tf.group([self.opt], update_ops)

            '''
                with tf.name_scope('gradient'):
                    # Gradients and SGD update operation for training the model.
                    self.params = params = tf.trainable_variables()
                    # Update all the trainable parameters
                    self.gradients = tf.gradients(self.loss, params)
                    opt = tf.train.GradientDescentOptimizer(self.learning_rate)
                    # opt = RMSpropNatGrad(self.learning_rate, global_step=self.global_step)

                    # Update all the trainable parameters
                    gradients = tf.gradients(self.loss, params)

                    clipped_gradients, norm = tf.clip_by_global_norm(gradients, 5)
                    self.gradient_norms = norm
                    self.updates = opt.apply_gradients(
                        zip(clipped_gradients, params), global_step=self.global_step)
            '''


            with tf.name_scope('global_variable'):
                self.gv = tf.global_variables()

            self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=self.args.keep_num)

        self.graph = graph

    def build_(self):
        """
        The structure
            C-value inputs --CgRNN -- Mean/Last -- C-R --ABS -- outputs
        """
        graph = tf.Graph()
        with graph.as_default():

            with tf.name_scope('global_step'):
                self.global_step = tf.Variable(0, trainable=False, name='global_step')

            with tf.name_scope('input'):
                self.input_data = tf.placeholder(tf.complex64, shape=[None, self.args.seq_len, self.args.input_size],
                                                 name="input_data")
            with tf.name_scope('target'):
                self.true_data = tf.placeholder(tf.float32, shape=[None, self.args.output_size],
                                                name="true_data")
            with tf.name_scope('dropout'):
                self.dropout = tf.placeholder(tf.float32, shape=[], name="dropout")

            with tf.name_scope('training'):
                self.training = tf.placeholder(tf.bool, shape=[], name="training")

            with tf.name_scope('momentum'):
                self.momentum = tf.placeholder(tf.float32, shape=[], name="momentum")

            # CgRNN layer
            rnn_layers = [
                cgRNNcell_bn(size, training=self.training, momentum=self.momentum, dropout=self.dropout,num_proj=n,
                             complex_inout=self.args.complex_inout, dim_reduce=self.args.dim_reduce,
                             memory_bn=self.args.memory_bn, canditate_hh_bn=self.args.canditate_hh_bn,canditate_hx_bn=self.args.canditate_hx_bn) for size, n in
                zip(self.args.num_units, self.args.num_proj)]
            multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
            input_data = tf.unstack(self.input_data, axis=1)
            outputs, state = static_rnn(multi_rnn_cell, input_data, dtype=tf.complex64)
            print('**** output', outputs)

            # Mean / Last
            if self.args.outputs_type == 'mean':
                rnn_outputs = (tf.reduce_mean(outputs, axis=0))
            elif self.args.outputs_type == 'last':
                rnn_outputs = (outputs[-1])
            else:
                raise ValueError('The type of outputs is not exist. Please select from \'mean\' and \'last\'')

            # reduce the dimension
            if not self.args.dim_reduce:
                for dense_layer in range(len(self.args.dense)):
                        dr_outputs = complex_dense(rnn_outputs, self.args.complex_dense[dense_layer], scope='complex_dense'+str(dense_layer))
            else:
                dr_outputs = rnn_outputs

            # ABS layer C-R
            if self.args.abs_layer:
                C_R_outputs = tf.abs(dr_outputs)
            else:
                C_R_outputs = tf.concat([tf.real(dr_outputs), tf.imag(dr_outputs)], axis=-1)

            # dense layer
            for dense_layer in range(len(self.args.dense)):
                with tf.name_scope('dense' + str(dense_layer)):
                    C_R_outputs = tf.layers.dense(C_R_outputs, self.args.dense[dense_layer], activation=tf.nn.relu,
                                                   use_bias=True, name='dense' + str(dense_layer))
            self.outputs = C_R_outputs

            # loss is the real-valued MSE
            with tf.name_scope('loss'):
                self.loss = tf.losses.mean_squared_error(self.true_data, self.outputs)
            self.loss_summary = tf.summary.scalar('loss/loss', self.loss)


            with tf.name_scope('opt'):
                if self.args.optimizer == 'SGD':
                    self.opt = tf.train.GradientDescentOptimizer(self.args.learning_rate).minimize(self.loss)
                elif self.args.optimizer == 'Adam':
                    self.opt = tf.train.AdamOptimizer(self.args.learning_rate).minimize(self.loss,
                                                                                   global_step=self.global_step)
                if self.args.memory_bn or self.args.canditate_hh_bn or self.args.canditate_hx_bn:
                    # update the moving mean and variance
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    self.opt = tf.group([self.opt], update_ops)

            '''
                with tf.name_scope('gradient'):
                    # Gradients and SGD update operation for training the model.
                    self.params = params = tf.trainable_variables()
                    # Update all the trainable parameters
                    self.gradients = tf.gradients(self.loss, params)
                    opt = tf.train.GradientDescentOptimizer(self.learning_rate)
                    # opt = RMSpropNatGrad(self.learning_rate, global_step=self.global_step)

                    # Update all the trainable parameters
                    gradients = tf.gradients(self.loss, params)

                    clipped_gradients, norm = tf.clip_by_global_norm(gradients, 5)
                    self.gradient_norms = norm
                    self.updates = opt.apply_gradients(
                        zip(clipped_gradients, params), global_step=self.global_step)
            '''


            with tf.name_scope('global_variable'):
                self.gv = tf.global_variables()

            self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=self.args.keep_num)

        self.graph = graph


    def C_R_layer(self,inputs,abs_layer):
        if abs_layer:
            outputs = tf.abs(inputs)
        else:
            outputs = tf.concat([tf.real(inputs), tf.imag(inputs)], axis=-1)
        return outputs

    def dimension_reduction(self, inputs, complex_dense_layer):
        for dense_layer in range(len(complex_dense_layer)):
            inputs = complex_dense(inputs, self.args.complex_dense[dense_layer],
                                    scope='complex_dense' + str(dense_layer))
        outputs = inputs
        return outputs
