import numpy as np


def complex_standardization(inputs):
    ndim = len(inputs.shape)

    input_dim = inputs.shape[-1] # tf.Tensor.get_shape(input_centred_real).as_list()[-1]
    variances_broadcast = [1] * ndim
    variances_broadcast[-1] = input_dim
    

    input_centred_real = np.real(inputs)
    input_centred_imag = np.imag(inputs)
    Vrr = np.mean(input_centred_real**2,axis=(0, 1))
    Vii = np.mean(input_centred_imag**2,axis=(0, 1))
    Vri = np.mean(input_centred_real * input_centred_imag,axis=(0,1))

    # We require the covariance matrix's inverse square root. That first requires
    # square rooting, followed by inversion (I do this in that order because during
    # the computation of square root we compute the determinant we'll need for
    # inversion as well).

    # tau = Vrr + Vii = Trace. Guaranteed >= 0 because SPD
    tau = Vrr + Vii
    # delta = (Vrr * Vii) - (Vri ** 2) = Determinant. Guaranteed >= 0 because SPD
    delta = (Vrr * Vii) - (Vri ** 2)

    s = np.sqrt(delta)  # Determinant of square root matrix
    t = np.sqrt(tau + 2 * s)

    # The square root matrix could now be explicitly formed as
    #       [ Vrr+s Vri   ]
    # (1/t) [ Vir   Vii+s ]
    # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
    # but we don't need to do this immediately since we can also simultaneously
    # invert. We can do this because we've already computed the determinant of
    # the square root matrix, and can thus invert it using the analytical
    # solution for 2x2 matrices
    #      [ A B ]             [  D  -B ]
    # inv( [ C D ] ) = (1/det) [ -C   A ]
    # http://mathworld.wolfram.com/MatrixInverse.html
    # Thus giving us
    #           [  Vii+s  -Vri   ]
    # (1/s)(1/t)[ -Vir     Vrr+s ]
    # So we proceed as follows:

    inverse_st = 1.0 / ((s * t) + 1e-6)
    print(inverse_st)
    Wrr = (Vii + s) * inverse_st
    Wii = (Vrr + s) * inverse_st
    Wri = -Vri * inverse_st

    # And we have computed the inverse square root matrix W = sqrt(V)!
    # Normalization. We multiply, x_normalized = W.x.

    # The returned result will be a complex standardized input
    # where the real and imaginary parts are obtained as follows:
    # x_real_normed = Wrr * x_real_centred + Wri * x_imag_centred
    # x_imag_normed = Wri * x_real_centred + Wii * x_imag_centred

    broadcast_Wrr = np.reshape(Wrr, variances_broadcast)
    broadcast_Wri = np.reshape(Wri, variances_broadcast)
    broadcast_Wii = np.reshape(Wii, variances_broadcast)

    # cat_W_4_real = tf.cast(K.concatenate([broadcast_Wrr, broadcast_Wii], axis=axis), tf.float32)
    # cat_W_4_imag = tf.cast(K.concatenate([broadcast_Wri, broadcast_Wri], axis=axis), tf.float32)

    outputs_real = broadcast_Wrr * input_centred_real + broadcast_Wri * input_centred_imag
    outputs_imag = broadcast_Wri * input_centred_real + broadcast_Wii * input_centred_imag
    # rolled_input = tf.cast(K.concatenate([centred_imag, centred_real], axis=axis), tf.float32)

    # output = cat_W_4_real * input_centred + cat_W_4_imag * rolled_input

    #   Wrr * x_real_centered | Wii * x_imag_centered
    # + Wri * x_imag_centered | Wri * x_real_centered
    # -----------------------------------------------
    # = output

    return outputs_real+ outputs_imag *1j


a = np.reshape(np.arange(200),(2,10,10))
b = np.reshape(np.arange(200),(2,10,10))
np.random.shuffle(b)
inputs = a + b *1j
stand_inputs = complex_standardization(inputs)
print(stand_inputs)
