import numpy as np
import tensorflow as tf
from tensorflow.python.keras import layers as keras_layers
from tensorflow.python.layers import base
from tensorflow.python.ops import init_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer, InputSpec
from tensorflow.keras import initializers, regularizers, constraints
from tensorflow.python.keras.utils import tf_utils

from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.ops import math_ops


def sqrt_init(shape, dtype=None):
    a = 2.0 * K.ones_like(2.0)
    value = (1 / K.sqrt(a)) * K.ones(shape)
    return value


def complex_standardization(input_centred_real, input_centred_imag, Vrr, Vii, Vri,
                            layernorm=False, axis=-1):
    ndim = K.ndim(input_centred_real)
    input_dim = tf.Tensor.get_shape(input_centred_real).as_list()[-1]
    variances_broadcast = [1] * ndim
    variances_broadcast[axis] = input_dim
    # variances_broadcast = variances_broadcast.to_list()
    if layernorm:
        variances_broadcast[0] = tf.Tensor.get_shape(input_centred_real).as_list()[0]

    # We require the covariance matrix's inverse square root. That first requires
    # square rooting, followed by inversion (I do this in that order because during
    # the computation of square root we compute the determinant we'll need for
    # inversion as well).

    # tau = Vrr + Vii = Trace. Guaranteed >= 0 because SPD
    tau = Vrr + Vii
    # delta = (Vrr * Vii) - (Vri ** 2) = Determinant. Guaranteed >= 0 because SPD
    delta = (Vrr * Vii) - (Vri ** 2)

    s = K.sqrt(delta)  # Determinant of square root matrix
    t = K.sqrt(tau + 2 * s)

    # The square root matrix could now be explicitly formed as
    #       [ Vrr+s Vri   ]
    # (1/t) [ Vir   Vii+s ]
    # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
    # but we don't need to do this immediately since we can also simultaneously
    # invert. We can do this because we've already computed the determinant of
    # the square root matrix, and can thus invert it using the analytical
    # solution for 2x2 matrices
    #      [ A B ]             [  D  -B ]
    # inv( [ C D ] ) = (1/det) [ -C   A ]
    # http://mathworld.wolfram.com/MatrixInverse.html
    # Thus giving us
    #           [  Vii+s  -Vri   ]
    # (1/s)(1/t)[ -Vir     Vrr+s ]
    # So we proceed as follows:

    inverse_st = 1.0 / ((s * t) + 1e-3)
    Wrr = (Vii + s) * inverse_st
    Wii = (Vrr + s) * inverse_st
    Wri = -Vri * inverse_st

    # And we have computed the inverse square root matrix W = sqrt(V)!
    # Normalization. We multiply, x_normalized = W.x.

    # The returned result will be a complex standardized input
    # where the real and imaginary parts are obtained as follows:
    # x_real_normed = Wrr * x_real_centred + Wri * x_imag_centred
    # x_imag_normed = Wri * x_real_centred + Wii * x_imag_centred

    broadcast_Wrr = K.reshape(Wrr, variances_broadcast)
    broadcast_Wri = K.reshape(Wri, variances_broadcast)
    broadcast_Wii = K.reshape(Wii, variances_broadcast)

    # cat_W_4_real = tf.cast(K.concatenate([broadcast_Wrr, broadcast_Wii], axis=axis), tf.float32)
    # cat_W_4_imag = tf.cast(K.concatenate([broadcast_Wri, broadcast_Wri], axis=axis), tf.float32)

    outputs_real = broadcast_Wrr * input_centred_real + broadcast_Wri * input_centred_imag
    outputs_imag = broadcast_Wri * input_centred_real + broadcast_Wii * input_centred_imag
    # rolled_input = tf.cast(K.concatenate([centred_imag, centred_real], axis=axis), tf.float32)

    # output = cat_W_4_real * input_centred + cat_W_4_imag * rolled_input

    #   Wrr * x_real_centered | Wii * x_imag_centered
    # + Wri * x_imag_centered | Wri * x_real_centered
    # -----------------------------------------------
    # = output

    return outputs_real, outputs_imag


def my_complex_standardization(input_centred_real, input_centred_imag, Vrr, Vii, Vri,
                            layernorm=False, axis=-1):
    ndim = K.ndim(input_centred_real)
    input_dim = tf.Tensor.get_shape(input_centred_real).as_list()[-1]
    variances_broadcast = [1] * ndim
    variances_broadcast[axis] = input_dim
    # variances_broadcast = variances_broadcast.to_list()
    if layernorm:
        variances_broadcast[0] = tf.Tensor.get_shape(input_centred_real).as_list()[0]

    # We require the covariance matrix's inverse square root. That first requires
    # square rooting, followed by inversion (I do this in that order because during
    # the computation of square root we compute the determinant we'll need for
    # inversion as well).

    # tau = Vrr + Vii = Trace. Guaranteed >= 0 because SPD
    tau = Vrr + Vii
    # delta = (Vrr * Vii) - (Vri ** 2) = Determinant. Guaranteed >= 0 because SPD
    delta = (Vrr * Vii) - (Vri ** 2)

    s = K.sqrt(delta)  # Determinant of square root matrix
    t = K.sqrt(tau + 2 * s)

    # The square root matrix could now be explicitly formed as
    #       [ Vrr+s Vri   ]
    # (1/t) [ Vir   Vii+s ]
    # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
    # but we don't need to do this immediately since we can also simultaneously
    # invert. We can do this because we've already computed the determinant of
    # the square root matrix, and can thus invert it using the analytical
    # solution for 2x2 matrices
    #      [ A B ]             [  D  -B ]
    # inv( [ C D ] ) = (1/det) [ -C   A ]
    # http://mathworld.wolfram.com/MatrixInverse.html
    # Thus giving us
    #           [  Vii+s  -Vri   ]
    # (1/s)(1/t)[ -Vir     Vrr+s ]
    # So we proceed as follows:

    inverse_st = 1.0 / ((s * t) + 1e-3)
    Wrr = (Vii + s) * inverse_st
    Wii = (Vrr + s) * inverse_st
    Wri = -Vri * inverse_st

    # And we have computed the inverse square root matrix W = sqrt(V)!
    # Normalization. We multiply, x_normalized = W.x.

    # The returned result will be a complex standardized input
    # where the real and imaginary parts are obtained as follows:
    # x_real_normed = Wrr * x_real_centred + Wri * x_imag_centred
    # x_imag_normed = Wri * x_real_centred + Wii * x_imag_centred

    broadcast_Wrr = K.reshape(Wrr, variances_broadcast)
    broadcast_Wri = K.reshape(Wri, variances_broadcast)
    broadcast_Wii = K.reshape(Wii, variances_broadcast)

    # cat_W_4_real = tf.cast(K.concatenate([broadcast_Wrr, broadcast_Wii], axis=axis), tf.float32)
    # cat_W_4_imag = tf.cast(K.concatenate([broadcast_Wri, broadcast_Wri], axis=axis), tf.float32)

    outputs_real = broadcast_Wrr * input_centred_real + broadcast_Wri * input_centred_imag
    outputs_imag = broadcast_Wri * input_centred_real + broadcast_Wii * input_centred_imag
    # rolled_input = tf.cast(K.concatenate([centred_imag, centred_real], axis=axis), tf.float32)

    # output = cat_W_4_real * input_centred + cat_W_4_imag * rolled_input

    #   Wrr * x_real_centered | Wii * x_imag_centered
    # + Wri * x_imag_centered | Wri * x_real_centered
    # -----------------------------------------------
    # = output

    return outputs_real, outputs_imag

def ComplexBN(input_centred_real, input_centred_imag, Vrr, Vii, Vri, beta_real, beta_imag,
              gamma_rr, gamma_ri, gamma_ii, scale=True,
              center=True, layernorm=False, axis=-1):
    ndim = K.ndim(input_centred_real)
    input_dim = tf.Tensor.get_shape(input_centred_real).as_list()[-1]

    # param_shape = [100,100]

    if scale:
        gamma_broadcast_shape = [1] * ndim
        gamma_broadcast_shape[axis] = input_dim
    if center:
        broadcast_beta_shape = [1] * ndim
        broadcast_beta_shape[axis] = input_dim

    if scale:
        standardized_output_real, standardized_output_imag = complex_standardization(
            input_centred_real, input_centred_imag, Vrr, Vii, Vri,
            layernorm,
            axis=axis
        )

        # Now we perform th scaling and Shifting of the normalized x using
        # the scaling parameter
        #           [  gamma_rr gamma_ri  ]
        #   Gamma = [  gamma_ri gamma_ii  ]
        # and the shifting parameter
        #    Beta = [beta_real beta_imag].T
        # where:
        # x_real_BN = gamma_rr * x_real_normed + gamma_ri * x_imag_normed + beta_real
        # x_imag_BN = gamma_ri * x_real_normed + gamma_ii * x_imag_normed + beta_imag

        broadcast_gamma_rr = K.reshape(gamma_rr, gamma_broadcast_shape)
        broadcast_gamma_ri = K.reshape(gamma_ri, gamma_broadcast_shape)
        broadcast_gamma_ii = K.reshape(gamma_ii, gamma_broadcast_shape)

        cat_gamma_4_real = tf.cast(K.concatenate([broadcast_gamma_rr, broadcast_gamma_ii], axis=axis), tf.float32)
        cat_gamma_4_imag = tf.cast(K.concatenate([broadcast_gamma_ri, broadcast_gamma_ri], axis=axis), tf.float32)
        # centred_imag, centred_real = C_R(standardized_output, axis, ndim, input_dim)
        # rolled_standardized_output = tf.cast(K.concatenate([centred_imag, centred_real], axis=axis), tf.float32)

        if center:
            broadcast_beta_real = tf.cast(K.reshape(beta_real, broadcast_beta_shape), tf.float32)
            broadcast_beta_imag = tf.cast(K.reshape(beta_imag, broadcast_beta_shape), tf.float32)
            # outputs = cat_gamma_4_real * standardized_output + cat_gamma_4_imag * rolled_standardized_output + broadcast_beta
            outputs_real = broadcast_gamma_rr * standardized_output_real + broadcast_gamma_ri * standardized_output_imag + broadcast_beta_real
            outputs_imag = broadcast_gamma_ri * standardized_output_real + broadcast_gamma_ii * standardized_output_imag + broadcast_beta_imag
            return tf.complex(outputs_real, outputs_imag)
        else:
            outputs_real = broadcast_gamma_rr * standardized_output_real + broadcast_gamma_ri * standardized_output_imag
            outputs_imag = broadcast_gamma_ri * standardized_output_real + broadcast_gamma_ii * standardized_output_imag
            return tf.complex(outputs_real, outputs_imag)
    else:
        if center:
            broadcast_beta_real = K.reshape(beta_real, broadcast_beta_shape)
            broadcast_beta_imag = K.reshape(beta_imag, broadcast_beta_shape)
            return tf.complex(input_centred_real + broadcast_beta_real, input_centred_imag + broadcast_beta_imag)
        else:
            return tf.complex(input_centred_real, input_centred_imag)


def my_ComplexBN(input_centred_real, input_centred_imag, Vrr, Vii, Vri, beta_real, beta_imag,
              gamma_rr, gamma_ri, gamma_ii, scale=True,
              center=True, layernorm=False, axis=-1):
    ndim = K.ndim(input_centred_real)
    input_dim = tf.Tensor.get_shape(input_centred_real).as_list()[-1]

    # param_shape = [100,100]

    if scale:
        gamma_broadcast_shape = [1] * ndim
        gamma_broadcast_shape[axis] = input_dim
    if center:
        broadcast_beta_shape = [1] * ndim
        broadcast_beta_shape[axis] = input_dim

    if scale:
        standardized_output_real, standardized_output_imag = complex_standardization(
            input_centred_real, input_centred_imag, Vrr, Vii, Vri,
            layernorm,
            axis=axis
        )

        # Now we perform th scaling and Shifting of the normalized x using
        # the scaling parameter
        #           [  gamma_rr gamma_ri  ]
        #   Gamma = [  gamma_ri gamma_ii  ]
        # and the shifting parameter
        #    Beta = [beta_real beta_imag].T
        # where:
        # x_real_BN = gamma_rr * x_real_normed + gamma_ri * x_imag_normed + beta_real
        # x_imag_BN = gamma_ri * x_real_normed + gamma_ii * x_imag_normed + beta_imag

        broadcast_gamma_rr = K.reshape(gamma_rr, gamma_broadcast_shape)
        broadcast_gamma_ri = K.reshape(gamma_ri, gamma_broadcast_shape)
        broadcast_gamma_ii = K.reshape(gamma_ii, gamma_broadcast_shape)

        cat_gamma_4_real = tf.cast(K.concatenate([broadcast_gamma_rr, broadcast_gamma_ii], axis=axis), tf.float32)
        cat_gamma_4_imag = tf.cast(K.concatenate([broadcast_gamma_ri, broadcast_gamma_ri], axis=axis), tf.float32)
        # centred_imag, centred_real = C_R(standardized_output, axis, ndim, input_dim)
        # rolled_standardized_output = tf.cast(K.concatenate([centred_imag, centred_real], axis=axis), tf.float32)

        if center:
            broadcast_beta_real = tf.cast(K.reshape(beta_real, broadcast_beta_shape), tf.float32)
            broadcast_beta_imag = tf.cast(K.reshape(beta_imag, broadcast_beta_shape), tf.float32)
            # outputs = cat_gamma_4_real * standardized_output + cat_gamma_4_imag * rolled_standardized_output + broadcast_beta
            outputs_real = broadcast_gamma_rr * standardized_output_real + broadcast_gamma_ri * standardized_output_imag + broadcast_beta_real
            outputs_imag = broadcast_gamma_ri * standardized_output_real + broadcast_gamma_ii * standardized_output_imag + broadcast_beta_imag
            return tf.complex(outputs_real, outputs_imag)
        else:
            outputs_real = broadcast_gamma_rr * standardized_output_real + broadcast_gamma_ri * standardized_output_imag
            outputs_imag = broadcast_gamma_ri * standardized_output_real + broadcast_gamma_ii * standardized_output_imag
            return tf.complex(outputs_real, outputs_imag)
    else:
        if center:
            broadcast_beta_real = K.reshape(beta_real, broadcast_beta_shape)
            broadcast_beta_imag = K.reshape(beta_imag, broadcast_beta_shape)
            return tf.complex(input_centred_real + broadcast_beta_real, input_centred_imag + broadcast_beta_imag)
        else:
            return tf.complex(input_centred_real, input_centred_imag)


class ComplexBatchNormalization(base.Layer):
    """Complex version of the real domain
    Batch normalization layer (Ioffe and Szegedy, 2014).
    Normalize the activations of the previous complex layer at each batch,
    i.e. applies a transformation that maintains the mean of a complex unit
    close to the null vector, the 2 by 2 covariance matrix of a complex unit close to identity
    and the 2 by 2 relation matrix, also called pseudo-covariance, close to the
    null matrix.
    # Arguments
        axis: Integer, the axis that should be normalized
            (typically the features axis).
            For instance, after a `Conv2D` layer with
            `data_format="channels_first"`,
            set `axis=2` in `ComplexBatchNormalization`.
        momentum: Momentum for the moving statistics related to the real and
            imaginary parts.
        epsilon: Small float added to each of the variances related to the
            real and imaginary parts in order to avoid dividing by zero.
        center: If True, add offset of `beta` to complex normalized tensor.
            If False, `beta` is ignored.
            (beta is formed by real_beta and imag_beta)
        scale: If True, multiply by the `gamma` matrix.
            If False, `gamma` is not used.
        beta_initializer: Initializer for the real_beta and the imag_beta weight.
        gamma_diag_initializer: Initializer for the diagonal elements of the gamma matrix.
            which are the variances of the real part and the imaginary part.
        gamma_off_initializer: Initializer for the off-diagonal elements of the gamma matrix.
        moving_mean_initializer: Initializer for the moving means.
        moving_variance_initializer: Initializer for the moving variances.
        moving_covariance_initializer: Initializer for the moving covariance of
            the real and imaginary parts.
        beta_regularizer: Optional regularizer for the beta weights.
        gamma_regularizer: Optional regularizer for the gamma weights.
        beta_constraint: Optional constraint for the beta weights.
        gamma_constraint: Optional constraint for the gamma weights.
    # Input shape
        Arbitrary. Use the keyword argument `input_shape`
        (tuple of integers, does not include the samples axis)
        when using this layer as the first layer in a model.
    # Output shape
        Same shape as input.
    # References
        - [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
    """
    _USE_V2_BEHAVIOR = False
    def __init__(self,
                 axis=-1,
                 momentum=0.9,
                 epsilon=1e-4,
                 center=True,
                 scale=True,
                 name=None,
                 reuse=None,
                 beta_initializer=init_ops.zeros_initializer(),
                 gamma_diag_initializer=init_ops.constant_initializer(1.0 / np.sqrt(2.0)),
                 gamma_off_initializer=init_ops.zeros_initializer(),
                 moving_mean_initializer=init_ops.zeros_initializer(),
                 moving_variance_initializer=init_ops.constant_initializer(1.0 / np.sqrt(2.0)),
                 moving_covariance_initializer=init_ops.zeros_initializer(),
                 beta_regularizer=None,
                 gamma_diag_regularizer=None,
                 gamma_off_regularizer=None,
                 beta_constraint=None,
                 gamma_diag_constraint=None,
                 gamma_off_constraint=None,
                 **kwargs):
        super(ComplexBatchNormalization, self).__init__(name=name, **kwargs)

        self.supports_masking = True
        self.axis = axis
        self.momentum = momentum
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.reuse = reuse
        self.beta_initializer = beta_initializer
        self.gamma_diag_initializer = gamma_diag_initializer
        self.gamma_off_initializer = gamma_off_initializer
        self.moving_mean_initializer = moving_mean_initializer
        self.moving_variance_initializer = moving_variance_initializer
        self.moving_covariance_initializer = moving_covariance_initializer
        self.beta_regularizer = beta_regularizer
        self.gamma_diag_regularizer = gamma_diag_regularizer
        self.gamma_off_regularizer = gamma_off_regularizer
        self.beta_constraint = beta_constraint
        self.gamma_diag_constraint = gamma_diag_constraint
        self.gamma_off_constraint = gamma_off_constraint
        # print('Complex batch normalizaiton v1\n')

    def build(self, input_shape):

        ndim = len(input_shape)

        dim = input_shape[self.axis]
        if dim is None:
            raise ValueError('Axis ' + str(self.axis) + ' of '
                                                        'input tensor should have a defined dimension '
                                                        'but the layer received an input with shape ' +
                             str(input_shape) + '.')
        self.input_spec = InputSpec(ndim=len(input_shape),
                                    axes={self.axis: dim})

        #         param_shape = (input_shape[self.axis],)

        param_shape = (input_shape[-1])

        if self.scale:
            with tf.variable_scope(self.name, self.reuse):
                self.gamma_rr = self.add_weight(shape=param_shape,
                                                name='gamma_rr',
                                                dtype=tf.float32,
                                                initializer=self.gamma_diag_initializer,
                                                regularizer=self.gamma_diag_regularizer,
                                                constraint=self.gamma_diag_constraint)
                self.gamma_ii = self.add_weight(shape=param_shape,
                                                name='gamma_ii',
                                                dtype=tf.float32,
                                                initializer=self.gamma_diag_initializer,
                                                regularizer=self.gamma_diag_regularizer,
                                                constraint=self.gamma_diag_constraint)
                self.gamma_ri = self.add_weight(shape=param_shape,
                                                name='gamma_ri',
                                                dtype=tf.float32,
                                                initializer=self.gamma_off_initializer,
                                                regularizer=self.gamma_off_regularizer,
                                                constraint=self.gamma_off_constraint)
                self.moving_Vrr = self.add_weight(shape=param_shape,
                                                  initializer=self.moving_variance_initializer,
                                                  name='moving_Vrr',
                                                  dtype=tf.float32,
                                                  trainable=False)
                self.moving_Vii = self.add_weight(shape=param_shape,
                                                  initializer=self.moving_variance_initializer,
                                                  name='moving_Vii',
                                                  dtype=tf.float32,
                                                  trainable=False)
                self.moving_Vri = self.add_weight(shape=param_shape,
                                                  initializer=self.moving_covariance_initializer,
                                                  name='moving_Vri',
                                                  dtype=tf.float32,
                                                  trainable=False)


        else:
            self.gamma_rr = None
            self.gamma_ii = None
            self.gamma_ri = None
            self.moving_Vrr = None
            self.moving_Vii = None
            self.moving_Vri = None

        if self.center:
            with tf.variable_scope(self.name, self.reuse):
                self.beta_real = self.add_weight(shape=param_shape,
                                                 name='beta_real',
                                                 dtype=tf.float32,
                                                 initializer=self.beta_initializer,
                                                 regularizer=self.beta_regularizer,
                                                 constraint=self.beta_constraint)
                self.beta_imag = self.add_weight(shape=param_shape,
                                                 name='beta_imag',
                                                 dtype=tf.float32,
                                                 initializer=self.beta_initializer,
                                                 regularizer=self.beta_regularizer,
                                                 constraint=self.beta_constraint)

                self.moving_mean_real = self.add_weight(shape=param_shape,
                                                   dtype=tf.float32,
                                                   initializer=self.moving_mean_initializer,
                                                   name='moving_mean_real',
                                                   trainable=False)
                                                   
                self.moving_mean_imag = self.add_weight(shape=param_shape,
                                                   dtype=tf.float32,
                                                   initializer=self.moving_mean_initializer,
                                                   name='moving_mean_imag',
                                                   trainable=False)

        else:
            self.beta = None
            self.moving_mean = None

        self.built = True

    def _get_training_value(self, training=None):
        if training is None:
            training = K.learning_phase()
        if self._USE_V2_BEHAVIOR:
            if isinstance(training, int):
                training = bool(training)
        if base_layer_utils.is_in_keras_graph():
            training = math_ops.logical_and(training, self._get_trainable_var())
        else:
            training = math_ops.logical_and(training, self.trainable)

        return training

    def complex_moments(self, inputs):
        input_shape = inputs.shape.as_list()
        ndim=len(input_shape)
        reduction_axes = list(range(ndim))
        del reduction_axes[self.axis]
        inputs_real = tf.math.real(inputs)
        inputs_imag = tf.math.imag(inputs)
        mu_real = tf.reduce_mean(inputs_real, axis=reduction_axes) #  + self.epsilon
        mu_imag = tf.reduce_mean(inputs_imag, axis=reduction_axes) #  + self.epsilon
        mu = (mu_imag + mu_real) / 2.0
        broadcast_mu_shape = [1] * len(input_shape)
        broadcast_mu_shape[self.axis] = input_shape[self.axis]
        
        # broadcast_mu = K.reshape(mu, broadcast_mu_shape)
        broadcast_mu_imag = K.reshape(mu_imag, broadcast_mu_shape)
        broadcast_mu_real = K.reshape(mu_real, broadcast_mu_shape)
        
        
        
        if self.center:
            inputs_real = inputs_real - broadcast_mu_real
            inputs_imag = inputs_imag - broadcast_mu_imag
        else:
            inputs_real = inputs_real
            inputs_imag = inputs_imag
        
            
        if self.scale:
            Vrr = tf.reduce_mean(
                tf.math.square(inputs_real),
                axis=reduction_axes
            ) +self.epsilon
            Vii = tf.reduce_mean(
                tf.math.square(inputs_imag),
                axis=reduction_axes
            ) + self.epsilon
            # Vri contains the real and imaginary covariance for each feature map.
            Vri = tf.reduce_mean(
                inputs_imag * inputs_real,
                axis=reduction_axes,
            )
        elif self.center:
            Vrr = None
            Vii = None
            Vri = None
        else:
            raise ValueError('Error. Both scale and center in batchnorm are set to False.')
        return mu_real,mu_imag, Vrr, Vri, Vii

    def call(self, inputs, training=None):
        training = self._get_training_value(training)
        input_shape = inputs.shape.as_list()
        ndim = len(input_shape)
        reduction_axes = list(range(ndim))
        del reduction_axes[self.axis]
        inputs_real = tf.math.real(inputs)
        inputs_imag = tf.math.imag(inputs)
        input_dim = input_shape[self.axis]
        
        
        broadcast_mu_shape = [1] * len(input_shape)
        broadcast_mu_shape[self.axis] = input_shape[self.axis]
        
        training_value = tf_utils.constant_value(training)
        if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
            moving_mean_real,moving_mean_imag, Vrr, Vri, Vii = self.moving_mean_real, self.moving_mean_imag, self.moving_Vrr, self.moving_Vri, self.moving_Vii
        else:
            moving_mean_real,moving_mean_imag, Vrr, Vri, Vii = self.complex_moments(inputs)
            inference_moving_real,inference_moving_imag, inference_Vrr, inference_Vri, inference_Vii = self.moving_mean_real,self.moving_mean_imag, self.moving_Vrr, self.moving_Vri, self.moving_Vii
            
            
            moving_mean_real = tf_utils.smart_cond(training, lambda: moving_mean_real, lambda: inference_moving_real)
            moving_mean_imag = tf_utils.smart_cond(training, lambda: moving_mean_imag, lambda: inference_moving_imag)
            Vrr = tf_utils.smart_cond(training, lambda: Vrr, lambda: inference_Vrr)
            Vri = tf_utils.smart_cond(training, lambda: Vri, lambda: inference_Vri)
            Vii = tf_utils.smart_cond(training, lambda: Vii, lambda: inference_Vii)
            
            update_list = []
            if self.center:
                update_list.append(K.moving_average_update(self.moving_mean_real, moving_mean_real, self.momentum))
                update_list.append(K.moving_average_update(self.moving_mean_imag, moving_mean_imag, self.momentum))
            if self.scale:
                update_list.append(K.moving_average_update(self.moving_Vrr, Vrr, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vii, Vii, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vri, Vri, self.momentum))
            self.add_update(update_list)

        input_centred_real = inputs_real - K.reshape(moving_mean_real, broadcast_mu_shape)
        input_centred_imag = inputs_imag - K.reshape(moving_mean_imag, broadcast_mu_shape)
        input_bn = ComplexBN(
            input_centred_real, input_centred_imag, Vrr, Vii, Vri,
            self.beta_real, self.beta_imag, self.gamma_rr, self.gamma_ri,
            self.gamma_ii, self.scale, self.center,
            axis=self.axis)

        return input_bn

    def get_config(self):
        config = {
            'axis': self.axis,
            'momentum': self.momentum,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': self.beta_initializer,
            'gamma_diag_initializer': self.gamma_diag_initializer,
            'gamma_off_initializer': self.gamma_off_initializer,
            'moving_mean_initializer': self.moving_mean_initializer,
            'moving_variance_initializer': self.moving_variance_initializer,
            'moving_covariance_initializer': self.moving_covariance_initializer,
            'beta_regularizer': self.beta_regularizer,
            'gamma_diag_regularizer': self.gamma_diag_regularizer,
            'gamma_off_regularizer': self.gamma_off_regularizer,
            'beta_constraint': self.beta_constraint,
            'gamma_diag_constraint': self.gamma_diag_constraint,
            'gamma_off_constraint': self.gamma_off_constraint,
        }
        base_config = super(ComplexBatchNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))




