"""
    Implementation of the complex memory cells used in the paper.
    Including:
        1.) The original URNN-cell.
        2.) Our cgRNN-cell.
"""
import collections
import numpy as np
import tensorflow as tf
from tensorflow import random_uniform_initializer as urnd_init
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import RNNCell
from network.custom_regularizers import complex_dropout
from IPython.core.debugger import Tracer
debug_here = Tracer()

_URNNStateTuple = collections.namedtuple("URNNStateTuple", ("o", "h"))


class URNNStateTuple(_URNNStateTuple):
    """Tuple used by URNN Cells for `state_size`, `zero_state`, and output state.
       Stores two elements: `(c, h)`, in that order.
       Only used when `state_is_tuple=True`.
    """
    slots__ = ()

    @property
    def dtype(self):
        (c, h) = self
        if c.dtype != h.dtype:
            raise TypeError("Inconsistent internal state: %s vs %s" %
                            (str(c.dtype), str(h.dtype)))
        return c.dtype


def hilbert(xr):
    '''
    Implements the hilbert transform, a mapping from C to R.
    Args:
        xr: The input sequence.
    Returns:
        xc: A complex sequence of the same length.
    '''
    with tf.variable_scope('hilbert_transform'):
        n = tf.Tensor.get_shape(xr).as_list()[0]
        # Run the fft on the columns no the rows.
        x = tf.transpose(tf.fft(tf.transpose(xr)))
        h = np.zeros([n])
        if n > 0 and 2*np.fix(n/2) == n:
            # even and nonempty
            h[0:int(n/2+1)] = 1
            h[1:int(n/2)] = 2
        elif n > 0:
            # odd and nonempty
            h[0] = 1
            h[1:int((n+1)/2)] = 2
        tf_h = tf.constant(h, name='h', dtype=tf.float32)
        if len(x.shape) == 2:
            hs = np.stack([h]*x.shape[-1], -1)
            reps = tf.Tensor.get_shape(x).as_list()[-1]
            hs = tf.stack([tf_h]*reps, -1)
        elif len(x.shape) == 1:
            hs = tf_h
        else:
            raise NotImplementedError
        tf_hc = tf.complex(hs, tf.zeros_like(hs))
        xc = x*tf_hc
        return tf.transpose(tf.ifft(tf.transpose(xc)))


def unitary_init(shape, dtype=tf.float32, partition_info=None):
    '''
    Initialize using an unitary matrix, generated by using an SVD and
    multiplying UV, while discarding the singular value matrix.
    '''
    limit = np.sqrt(6 / (shape[0] + shape[1]))
    rand_r = np.random.uniform(-limit, limit, shape[0:2])
    rand_i = np.random.uniform(-limit, limit, shape[0:2])
    crand = rand_r + 1j*rand_i
    u, s, vh = np.linalg.svd(crand)
    # use u and vg to create a unitary matrix:
    unitary = np.matmul(u, np.transpose(np.conj(vh)))

    test_eye = np.matmul(np.transpose(np.conj(unitary)), unitary)
    print('I - Wi.H Wi', np.linalg.norm(test_eye) - unitary)
    # test
    # plt.imshow(np.abs(np.matmul(unitary, np.transpose(np.conj(unitary))))); plt.show()
    stacked = np.stack([np.real(unitary), np.imag(unitary)], -1)
    assert stacked.shape == tuple(shape), "Unitary initialization shape mismatch."
    # debug_here()
    return tf.constant(stacked, dtype)


def arjovski_init(shape, dtype=tf.float32, partition_info=None):
    '''
    Use Arjovsky's unitary basis as initialization.
    Reference:
         Arjovsky et al. Unitary Evolution Recurrent Neural Networks
         https://arxiv.org/abs/1511.06464
    '''
    print("Arjosky basis initialization.")
    assert shape[0] == shape[1]
    omega1 = np.random.uniform(-np.pi, np.pi, shape[0])
    omega2 = np.random.uniform(-np.pi, np.pi, shape[0])
    omega3 = np.random.uniform(-np.pi, np.pi, shape[0])

    vr1 = np.random.uniform(-1, 1, [shape[0], 1])
    vi1 = np.random.uniform(-1, 1, [shape[0], 1])
    v1 = vr1 + 1j*vi1
    vr2 = np.random.uniform(-1, 1, [shape[0], 1])
    vi2 = np.random.uniform(-1, 1, [shape[0], 1])
    v2 = vr2 + 1j*vi2

    D1 = np.diag(np.exp(1j*omega1))
    D2 = np.diag(np.exp(1j*omega2))
    D3 = np.diag(np.exp(1j*omega3))

    vvh1 = np.matmul(v1, np.transpose(np.conj(v1)))
    beta1 = 2./np.matmul(np.transpose(np.conj(v1)), v1)
    R1 = np.eye(shape[0]) - beta1*vvh1

    vvh2 = np.matmul(v2, np.transpose(np.conj(v2)))
    beta2 = 2./np.matmul(np.transpose(np.conj(v2)), v2)
    R2 = np.eye(shape[0]) - beta2*vvh2

    perm = np.random.permutation(np.eye(shape[0], dtype=np.float32)) \
        + 1j*np.zeros(shape[0])

    fft = np.fft.fft
    ifft = np.fft.ifft

    step1 = fft(D1)
    step2 = np.matmul(R1, step1)
    step3 = np.matmul(perm, step2)
    step4 = np.matmul(D2, step3)
    step5 = ifft(step4)
    step6 = np.matmul(R2, step5)
    unitary = np.matmul(D3, step6)
    eye_test = np.matmul(np.transpose(np.conj(unitary)), unitary)
    unitary_test = np.linalg.norm(np.eye(shape[0]) - eye_test)
    print('I - Wi.H Wi', unitary_test, unitary.dtype)
    assert unitary_test < 1e-10, "Unitary initialization not unitary enough."
    stacked = np.stack([np.real(unitary), np.imag(unitary)], -1)
    assert stacked.shape == tuple(shape), "Unitary initialization shape mismatch."
    return tf.constant(stacked, dtype)


def mod_relu(z, scope='', reuse=None):
    """
        Implementation of the modRelu from Arjovski et al.
        f(z) = relu(|z| + b)(z / |z|) or
        f(r,theta) = relu(r + b)e^(i*theta)
        b is initialized to zero, this leads to a network, which
        is linear during early optimization.
    Input:
        z: complex input.
        b: 'dead' zone radius.
    Returns:
        z_out: complex output.

    Reference:
         Arjovsky et al. Unitary Evolution Recurrent Neural Networks
         https://arxiv.org/abs/1511.06464
    """
    with tf.variable_scope('mod_relu' + scope, reuse=reuse):
        b = tf.get_variable('b', [], dtype=tf.float32,
                            initializer=urnd_init(-0.01, 0.01))
        modulus = tf.sqrt(tf.real(z)**2 + tf.imag(z)**2)
        rescale = tf.nn.relu(modulus + b) / (modulus + 1e-6)
        # return tf.complex(rescale * tf.real(z),
        #                   rescale * tf.imag(z))
        rescale = tf.complex(rescale, tf.zeros_like(rescale))
        return tf.multiply(rescale, z)


def relu(x, scope='', reuse=None):
    """ Relu wrapper """
    return tf.nn.relu(x)


def tanh(x, scope='', reuse=None):
    """ tanh wrapper """
    return tf.nn.tanh(x)


def split_relu(z, scope='', reuse=None):
    """ A split relu applying relus on the real and
        imaginary parts separately following
        Trabelsi https://arxiv.org/abs/1705.09792"""
    with tf.variable_scope('split_relu' + scope):
        x = tf.real(z)
        y = tf.imag(z)
        return tf.complex(tf.nn.relu(x), tf.nn.relu(y))


def z_relu(z, scope='', reuse=None):
    """
    The z-relu, which is active only
    in the first quadrant.
    As proposed by Guberman:
    https://arxiv.org/abs/1602.09046
    """
    with tf.variable_scope('z_relu'):
        factor1 = tf.cast(tf.real(z) > 0, tf.float32)
        factor2 = tf.cast(tf.imag(z) > 0, tf.float32)
        combined = factor1*factor2
        rescale = tf.complex(combined, tf.zeros_like(combined))
        return tf.multiply(rescale, z)


def hirose(z, scope='', reuse=None):
    """
    Compute the non-linearity proposed by Hirose.
    See for example:
    Complex Valued nonlinear Adaptive Filters
    Mandic and Su Lee Goh
    Chapter 4.3.1 (Amplitude-Phase split complex approach)
    """
    with tf.variable_scope('hirose' + scope, reuse=reuse):
        m = tf.get_variable('m', [], tf.float32,
                            initializer=urnd_init(0.9, 1.1))
        modulus = tf.sqrt(tf.real(z)**2 + tf.imag(z)**2)
        # use m*m to enforce positive m.
        rescale = tf.complex(tf.nn.tanh(modulus/(m*m))/modulus,
                             tf.zeros_like(modulus))
        return tf.multiply(rescale, z)


def single_sigmoid_real(z, scope='', reuse=None):
    """
    What happens if we throw the imaginary part away?
    Problem: Half of the weights don't contribute.
    """
    with tf.variable_scope('sigmoid_real_' + scope, reuse=reuse):
        rz = tf.nn.sigmoid(tf.real(z))
        return tf.complex(rz, tf.zeros_like(rz))


def single_sigmoid_imag(z, scope='', reuse=None):
    """
    What happens if we throw the real part away?
    Problem: Half of the weights don't contribute.
    """
    with tf.variable_scope('sigmoid_imag_' + scope, reuse=reuse):
        iz = tf.nn.sigmoid(tf.imag(z))
        return tf.complex(iz, tf.zeros_like(iz))


def mod_sigmoid(z, scope='', reuse=None):
    """
    ModSigmoid implementation, using a coupled alpha and beta.
    """
    with tf.variable_scope('mod_sigmoid_' + scope, reuse=reuse):
        alpha = tf.get_variable('alpha', [], dtype=tf.float32,
                                initializer=tf.constant_initializer(0.0))
        alpha_norm = tf.nn.sigmoid(alpha)
        pre_act = alpha_norm * tf.real(z) + (1 - alpha_norm)*tf.imag(z)
        return tf.complex(tf.nn.sigmoid(pre_act), tf.zeros_like(pre_act))


def mod_sigmoid_beta(z, scope='', reuse=None):
    """
    ModSigmoid implementation. Alpha and beta and beta are uncoupled
    and constrained to (0, 1).
    """
    with tf.variable_scope('mod_sigmoid_beta_' + scope, reuse=reuse):
        alpha = tf.get_variable('alpha', [], dtype=tf.float32,
                                initializer=tf.constant_initializer(0.0))
        beta = tf.get_variable('beta', [], dtype=tf.float32,
                               initializer=tf.constant_initializer(1.0))
        alpha_norm = tf.nn.sigmoid(alpha)
        beta_norm = tf.nn.sigmoid(beta)
        pre_act = alpha_norm * tf.real(z) + beta_norm*tf.imag(z)
        return tf.complex(tf.nn.sigmoid(pre_act), tf.zeros_like(pre_act))


def real_mod_sigmoid_beta(z, scope='', reuse=None):
    """
    Real valued ModSigmoid implementation, alpha and beta and beta are uncoupled
    and constrained to (0, 1).
    """
    with tf.variable_scope('real_mod_sigmoid_beta_' + scope, reuse=reuse):
        alpha = tf.get_variable('alpha', [], dtype=tf.float32,
                                initializer=tf.constant_initializer(0.0))
        beta = tf.get_variable('beta', [], dtype=tf.float32,
                               initializer=tf.constant_initializer(1.0))
        alpha_norm = tf.nn.sigmoid(alpha)
        beta_norm = tf.nn.sigmoid(beta)
        pre_act = alpha_norm*z[0] + beta_norm*z[1]
        return tf.nn.sigmoid(pre_act)


def mod_sigmoid_gamma(z, scope='', reuse=None):
    """
    ModSigmoid implementation, with uncoupled and unbounded
    alpha and beta.
    """
    with tf.variable_scope('mod_sigmoid_beta_' + scope, reuse=reuse):
        alpha = tf.get_variable('alpha', [], dtype=tf.float32,
                                initializer=tf.constant_initializer(0.0))
        beta = tf.get_variable('beta', [], dtype=tf.float32,
                               initializer=tf.constant_initializer(1.0))
        pre_act = alpha * tf.real(z) + beta*tf.imag(z)
        return tf.complex(tf.nn.sigmoid(pre_act), tf.zeros_like(pre_act))


def mod_sigmoid_prod(z, scope='', reuse=None):
    """
    Product version of the mod sigmoid.
    """
    with tf.variable_scope('mod_sigmoid_prod_' + scope, reuse=reuse):
        prod = tf.nn.sigmoid(tf.real(z)) * tf.nn.sigmoid(tf.imag(z))
        return tf.complex(prod, tf.zeros_like(prod))


def mod_sigmoid_sum(z, scope='', reuse=None):
    """
    Use a weighted sum outside of the sigmoid, coupled weights.
    """
    with tf.variable_scope('mod_sigmoid_sum_' + scope, reuse=reuse):
        alpha = tf.get_variable('alpha', [], dtype=tf.float32,
                                initializer=tf.constant_initializer(0.0))
        sig_alpha = tf.nn.sigmoid(alpha)
        sig_sum = (sig_alpha*tf.nn.sigmoid(tf.real(z))
                   + (1.0 - sig_alpha) * tf.nn.sigmoid(tf.imag(z)))
        return tf.complex(sig_sum, tf.zeros_like(sig_sum))


def mod_sigmoid_sum_beta(z, scope='', reuse=None):
    """ Probably not a good idea, gate outputs are not constrained
    to (0, 1)."""
    with tf.variable_scope('mod_sigmoid_sum_beta_' + scope, reuse=reuse):
        alpha = tf.get_variable('alpha', [], dtype=tf.float32,
                                initializer=tf.constant_initializer(0.0))
        beta = tf.get_variable('beta', [], dtype=tf.float32,
                               initializer=tf.constant_initializer(0.0))
        sig_alpha = tf.nn.sigmoid(alpha)
        sig_beta = tf.nn.sigmoid(beta)
        sig_sum = (sig_alpha*tf.nn.sigmoid(tf.real(z))
                   + sig_beta * tf.nn.sigmoid(tf.imag(z)))
        return tf.complex(sig_sum, tf.zeros_like(sig_sum))


def mod_sigmoid_split(z, scope='', reuse=None):
    """
    ModSigmoid implementation applying a sigmoid on the imaginary
    and real parts separately.
    Problem: Gate outputs are not always real. This did not work very
    well.
    """
    with tf.variable_scope('mod_sigmoid_split_' + scope, reuse=reuse):
        return tf.complex(tf.nn.sigmoid(tf.real(z)), tf.nn.sigmoid(tf.imag(z)))


def gate_phase_hirose(z, scope='', reuse=None):
    '''
    Hirose inspired gate activation filtering according to
    phase angle.
    '''
    with tf.variable_scope('phase_hirose_' + scope, reuse=reuse):
        m = tf.get_variable('m', [], tf.float32,
                            initializer=urnd_init(0.9, 1.1))
        a = tf.get_variable('a', [], tf.float32,
                            initializer=urnd_init(1.9, 2.1))
        b = tf.get_variable('b', [], tf.float32, urnd_init(3.9, 4.1))
        modulus = tf.sqrt(tf.real(z)**2 + tf.imag(z)**2)
        phase = tf.atan2(tf.imag(z), tf.real(z))
        gate = tf.tanh(modulus/(m*m)) * tf.nn.sigmoid(a*phase + b)
        return tf.complex(gate, tf.zeros_like(gate))


def moebius(z, scope='', reuse=None):
    """
    Implements a learn-able Moebius transformation.
    It converges very fast sometimes, but is often unstable.
    For this to work we need some way to constrain this to make
    sure the optimizer stays away from the singularities.
    """
    with tf.variable_scope('moebius' + scope, reuse=reuse):
        ar = tf.get_variable('ar', [], tf.float32,
                             initializer=tf.constant_initializer(1))
        ai = tf.get_variable('ai', [], tf.float32,
                             initializer=tf.constant_initializer(0))
        b = tf.get_variable('b', [2], tf.float32,
                            initializer=tf.constant_initializer(0))
        c = tf.get_variable('c', [2], tf.float32,
                            initializer=tf.constant_initializer(0))
        dr = tf.get_variable('dr', [], tf.float32,
                             initializer=tf.constant_initializer(1))
        di = tf.get_variable('di', [], tf.float32,
                             initializer=tf.constant_initializer(0))

        a = tf.complex(ar, ai)
        b = tf.complex(b[0], b[1])
        c = tf.complex(c[0], c[1])
        d = tf.complex(dr, di)
        return tf.divide(tf.multiply(a, z) + b,
                         tf.multiply(c, z) + d)


def linear(z, scope='', reuse=None, coupled=False):
    """
    Use this function to create a linear cell "activation" for comparison.
    """
    return z


def rfl_mul(h, state_size, no, reuse):
    """
    Multiplication with a reflection.
    Implementing R = I - (vv*/|v|^2)
    Input:
        h: hidden state_vector.
        state_size: The RNN state size.
        reuse: True if graph variables should be reused.
    Returns:
        R*h
    """
    with tf.variable_scope("reflection_v_" + str(no), reuse=reuse):
        vr = tf.get_variable('vr', shape=[state_size, 1], dtype=tf.float32,
                             initializer=tf.glorot_uniform_initializer())
        vi = tf.get_variable('vi', shape=[state_size, 1], dtype=tf.float32,
                             initializer=tf.glorot_uniform_initializer())

    with tf.variable_scope("ref_mul_" + str(no), reuse=reuse):
        hr = tf.real(h)
        hi = tf.imag(h)
        vstarv = tf.reduce_sum(vr**2 + vi**2)
        hr_vr = tf.matmul(hr, vr)
        hr_vi = tf.matmul(hr, vi)
        hi_vr = tf.matmul(hi, vr)
        hi_vi = tf.matmul(hi, vi)

        # tf.matmul with transposition is the same as T.outer
        # we need something of the shape [batch_size, state_size] in the end
        a = tf.matmul(hr_vr - hi_vi, vr, transpose_b=True)
        b = tf.matmul(hr_vi + hi_vr, vi, transpose_b=True)
        c = tf.matmul(hr_vr - hi_vi, vi, transpose_b=True)
        d = tf.matmul(hr_vi + hi_vr, vr, transpose_b=True)

        # the thing we return is:
        # return_re = hr - (2/vstarv)(d - c)
        # return_im = hi - (2/vstarv)(a + b)
        new_hr = hr - (2.0 / vstarv) * (a + b)
        new_hi = hi - (2.0 / vstarv) * (d - c)
        new_state = tf.complex(new_hr, new_hi)

        # v = tf.complex(vr, vi)
        # vstarv = tf.complex(tf.reduce_sum(vr**2 + vi**2), 0.0)
        # # vstarv = tf.matmul(tf.transpose(tf.conj(v)), v)
        # vvstar = tf.matmul(v, tf.transpose(tf.conj(v)))
        # refsub = (2.0/vstarv)*vvstar
        # R = tf.identity(refsub) - refsub
        return new_state


def diag_mul(h, state_size, no, reuse):
    """
    Multiplication with a diagonal matrix.
    Input:
        h: hidden state_vector.
        state_size: The RNN state size.
        reuse: True if graph variables should be reused.
    Returns:
        R*h
    """
    with tf.variable_scope("diag_phis_" + str(no), reuse=reuse):
        omega = tf.get_variable('vr', shape=[state_size], dtype=tf.float32,
                                initializer=urnd_init(-np.pi, np.pi))
        dr = tf.cos(omega)
        di = tf.sin(omega)

    with tf.variable_scope("diag_mul_" + str(no)):
        D = tf.diag(tf.complex(dr, di))
        return tf.matmul(h, D)


def permutation(h, state_size, no, reuse):
    """
    Apply a random permutation to the RNN state.
    Input:
        h: the original state.
    Output:
        hp: the permuted state.
    """
    with tf.variable_scope("permutation_" + str(no), reuse):
        def _initializer(shape, dtype=np.float32, partition_info=None):
            return np.random.permutation(np.eye(state_size, dtype=np.float32))
        Pr = tf.get_variable("Permutation", dtype=tf.float32,
                             initializer=_initializer, shape=[state_size],
                             trainable=False)
        P = tf.complex(Pr, tf.constant(0.0, dtype=tf.float32))
    return tf.matmul(h, P)


def matmul_plus_bias(x, num_proj, scope, reuse, bias=True,
                     bias_init=0.0, orthogonal=False):
    """
    Compute Ax + b.
    Arguments:
        x: A real (!) input vector.
        num_proj: The desired dimension of the output.
        scope: This string under which the variables will be
               registered.
        reuse: If this bool is True, the variables will be reused.
        bias: If True a bias will be added.
        bias_init: How to initialize the bias, defaults to zero.
        orthogonal: If true A will be initialized orthogonally
                    and kept orthogonal (make sure to use the
                    Stiefel optimizer if orthogonality is desired).
    Returns:
        Ax + b: A vector of size [batch_size, num_proj]
    """
    in_shape = tf.Tensor.get_shape(x).as_list()
    with tf.variable_scope(scope, reuse=reuse):
        if orthogonal:
            with tf.variable_scope('orthogonal_stiefel', reuse=reuse):
                A = tf.get_variable('gate_O', [in_shape[-1], num_proj],
                                    dtype=tf.float32,
                                    initializer=tf.orthogonal_initializer())
        else:
            A = tf.get_variable('A', [in_shape[-1], num_proj], dtype=tf.float32,
                                initializer=tf.glorot_uniform_initializer())
        if bias:
            b = tf.get_variable('bias', [num_proj], dtype=tf.float32,
                                initializer=tf.constant_initializer(bias_init))
            print('Initializing', tf.contrib.framework.get_name_scope(), 'bias to',
                  bias_init)
            return tf.matmul(x, A) + b
        else:
            return tf.matmul(x, A)


def complex_matmul(x, num_proj, scope, reuse, bias=False, bias_init_r=0.0,
                   bias_init_i=0.0, unitary=False, split_orthogonal=False,
                   unitary_init=arjovski_init):
    """
    Compute Ax + b.
    Arguments:
        x: A complex input vector.
        num_proj: The desired dimension of the output.
        scope: This string under which the variables will be
               registered.
        reuse: If this bool is True, the variables will be reused.
        bias: If True a bias will be added.
        bias_init_r: How to initialize the real part of the bias, defaults to zero.
        bias_init_i: How to initialize the imaginary part of the bias, defaults to zero.
        split_orthogonal: If true A's real and imaginary parts will be
                    initialized orthogonally and kept orthogonal (make sure to use the
                    Stiefel optimizer if orthogonality is desired).
        unitary: If true A will be initialized and kept in a unitary state
                 (make sure to use the Stiefel optimizer)
        unitary_init: The initialization method for the unitary matrix.
    Returns:
        Ax + b: A vector of size [batch_size, num_proj]

    WARNING:
    Simply setting split_orthogonal or unitary to True is not enough.
    Use the Stiefel optimizer as well to enforce orthogonality/unitarity.
    """
    in_shape = tf.Tensor.get_shape(x).as_list()
    with tf.variable_scope(scope, reuse=reuse):
        if unitary:
            with tf.variable_scope('unitary_stiefel', reuse=reuse):
                varU = tf.get_variable('gate_U',
                                       shape=in_shape[-1:] + [num_proj] + [2],
                                       dtype=tf.float32,
                                       initializer=unitary_init)
                A = tf.complex(varU[:, :, 0], varU[:, :, 1])
        elif split_orthogonal:
            with tf.variable_scope('orthogonal_stiefel', reuse=reuse):
                Ar = tf.get_variable('gate_Ur', in_shape[-1:] + [num_proj],
                                     dtype=tf.float32,
                                     initializer=tf.orthogonal_initializer())
                Ai = tf.get_variable('gate_Ui', in_shape[-1:] + [num_proj],
                                     dtype=tf.float32,
                                     initializer=tf.orthogonal_initializer())
                A = tf.complex(Ar, Ai)
        else:
            varU = tf.get_variable('gate_A',
                                   shape=in_shape[-1:] + [num_proj] + [2],
                                   dtype=tf.float32,
                                   initializer=tf.glorot_uniform_initializer())
            A = tf.complex(varU[:, :, 0], varU[:, :, 1])
        if bias:
            varbr = tf.get_variable('bias_r', [num_proj], dtype=tf.float32,
                                    initializer=tf.constant_initializer(bias_init_r))
            varbc = tf.get_variable('bias_c', [num_proj], dtype=tf.float32,
                                    initializer=tf.constant_initializer(bias_init_i))
            b = tf.complex(varbr, varbc)
            return tf.matmul(x, A) + b
        else:
            return tf.matmul(x, A)


def C_to_R(h, num_proj, reuse, scope=None, bias_init=0.0):
    '''
    Linear mapping from the complex numbers to the reals.
    See Arjovski https://arxiv.org/pdf/1511.06464.pdf (eq. 9)
    for reference.

    Arguments:
        h: The hidden input representation.
        num_proj: The desired dimension of the real output.
        reuse: If True variables will be reused.
    Returns:
        A real output vector [batch_size, num_proj]
    '''
    with tf.variable_scope(scope or "C_to_R"):
        concat = tf.concat([tf.real(h), tf.imag(h)], axis=-1)
        return matmul_plus_bias(concat, num_proj, 'final', reuse, bias_init)


class UnitaryCell(tf.nn.rnn_cell.RNNCell):
    """
    Tensorflow implementation of unitary evolution RNN as proposed by Arjosky et al.
    https://arxiv.org/pdf/1511.06464.pdf
    and extended by Wisdom et al. https://arxiv.org/abs/1611.00035
    """
    def __init__(self, num_units, activation=mod_relu, num_proj=None, reuse=None,
                 real=False, arjovski_basis=False):
        """
        Create a unitary gate-free RNN cell.
        Arguments:
            num_units: Cell size
            activation: The non-linear activation function.
            num_proj: Desired dimension of the cell output.
            reuse: If true the cell's variables will be reused.
            real: If true the cell expects a real-valued input.
                  Please note that you will also have to use
                  a real activation function, i.e. the Relu.
            arjovski_basis: If True, work with the unitary
                            matrix parametrization
                            of Arjovski et al.
        """
        super().__init__(_reuse=reuse)
        self._num_units = num_units
        self._activation = activation
        self._output_size = num_proj
        self._arjovski_basis = arjovski_basis
        self._real = real

    def to_string(self):
        cell_str = 'UnitaryCell' + '_' \
            + '_' + 'activation' + '_' + str(self._activation.__name__) + '_' \
            + '_arjovski_basis' + '_' + str(self._arjovski_basis) + '_' \
            + '_real_cell_' + str(self._real)
        return cell_str

    @property
    def state_size(self):
        return URNNStateTuple(self._num_units, self._num_units)

    @property
    def output_size(self):
        if self._output_size is None:
            return self._num_units
        else:
            return self._output_size

    def zero_state(self, batch_size, dtype=tf.float32):
        out = tf.zeros([batch_size, self._output_size], dtype=tf.float32)
        if self._real:
            rnd = tf.random_uniform([batch_size, self._num_units],
                                    minval=0, maxval=2)
            first_state = rnd/tf.norm(rnd)
        else:
            omegas = tf.random_uniform([batch_size, self._num_units],
                                       minval=0, maxval=2*np.pi)
            sx = tf.cos(omegas)
            sy = tf.sin(omegas)
            r = (1.0)/np.sqrt(self._num_units)
            first_state = tf.complex(r*sx, r*sy)
        return URNNStateTuple(out, first_state)

    def call(self, inputs, state):
        """
            Evaluate the RNN cell. Using
            h_(t+1) = U_t*f(h_t) + V_t x_t
        """
        # with tf.variable_scope("UnitaryCell"):
        last_out, last_h = state
        if self._real:
            with tf.variable_scope("orthogonal_stiefel"):
                matO = tf.get_variable("recurrent_O",
                                       shape=[self._num_units, self._num_units],
                                       dtype=tf.float32,
                                       initializer=tf.orthogonal_initializer())
                Uh = tf.matmul(last_h, matO)
        elif self._arjovski_basis:
            with tf.variable_scope("arjovski_basis", reuse=self._reuse):
                step1 = diag_mul(last_h, self._num_units, 0, self._reuse)
                step2 = tf.spectral.fft(step1)
                step3 = rfl_mul(step2, self._num_units, 0, self._reuse)
                step4 = permutation(step3, self._num_units, 0, self._reuse)
                step5 = diag_mul(step4, self._num_units, 1, self._reuse)
                step6 = tf.spectral.ifft(step5)
                step7 = rfl_mul(step6, self._num_units, 1, self._reuse)
                Uh = diag_mul(step7, self._num_units, 2, self._reuse)
        else:
            with tf.variable_scope("unitary_stiefel", reuse=self._reuse):
                varU = tf.get_variable("recurrent_U",
                                       shape=[self._num_units, self._num_units, 2],
                                       dtype=tf.float32,
                                       initializer=arjovski_init)
                U = tf.complex(varU[:, :, 0], varU[:, :, 1])
                # U = tf.Print(U, [U])
                Uh = tf.matmul(last_h, U)

        # Deal with the inputs.
        if self._real:
            Vx = matmul_plus_bias(inputs, self._num_units, 'Vx', self._reuse)
        else:
            cin = tf.complex(inputs, tf.zeros_like(inputs))
            Vx = complex_matmul(cin, self._num_units, 'Vx', self._reuse, bias=True)

        zt = Uh + Vx
        ht = self._activation(zt, '', self._reuse)

        # Mapping the state back onto the real axis.
        # By mapping.
        if not self._real:
            output = C_to_R(ht, self._output_size, reuse=self._reuse)
        else:
            output = matmul_plus_bias(ht, self._output_size, 'final', self._reuse, 0.0)

        newstate = URNNStateTuple(output, ht)
        return output, newstate


class StiefelGatedRecurrentUnit(tf.nn.rnn_cell.RNNCell):
    '''
    Implementation of a Stiefel Gated Recurrent unit.
    '''

    def __init__(self, num_units, activation=mod_relu,
                 gate_activation=mod_sigmoid,
                 num_proj=None, reuse=None, stiefel=True,
                 real=False, real_double=False,
                 complex_input=False, dropout=False,
                 single_gate=False, arjovski_basis=False):
        """
        Arguments:
            num_units: The size of the hidden state.
            activation: State to state non-linearity.
            gate_activation: The gating non-linearity.
            num_proj: Output dimension.
            reuse: Reuse graph weights in existing scope.
            stiefel: If True the cell will be used using the Stiefel
                     optimization scheme from Wisdom et al.
            real: If true a real valued cell will be created.
            real_double: Use a double real gate similar to to
                         the complex version (for comparison only).
            complex_input: If true the cell expects a complex input.
            arjovski_basis: If true Arjovski et al.'s parameterization
                            is used for the state transition matrix.
        """
        super().__init__(_reuse=reuse)
        self._num_units = num_units
        self._activation = activation
        # self._state_to_state_act = linear
        self._output_size = num_proj
        self._arjovski_basis = arjovski_basis
        self._input_hilbert = False  # We did not end up using this in the paper.
        self._input_split_matmul = False  # We did not end up using this in the paper.
        self._stiefel = stiefel
        self._gate_activation = gate_activation
        self._single_gate = single_gate
        self._real = real
        self._real_double = False  # Real "bilinear" double gating for comparison only.
        self._complex_input = complex_input
        self._dropout = dropout

    def to_string(self):
        cell_str = 'StiefelGatedRecurrentUnit' + '_' \
            + '_' + 'activation' + '_' + str(self._activation.__name__) + '_'
        if self._input_hilbert:
            cell_str += '_input_hilbert_'
        elif self._input_split_matmul:
            cell_str += '__input_split_matmul_'
        cell_str += '_stiefel_' + str(self._stiefel)
        if self._real is False and self._single_gate is False:
            cell_str += '_gate_activation_' + self._gate_activation.__name__
        if self._single_gate:
            cell_str += '_single_gate_'
        if self._real:
            cell_str += '_real_'
            if self._real_double:
                cell_str += '_realDouble_'
                cell_str += '_gate_activation_' + self._gate_activation.__name__
        return cell_str

    @property
    def state_size(self):
        return URNNStateTuple(self._num_units, self._num_units)

    @property
    def output_size(self):
        if self._output_size is None:
            if self._real:
                return self._num_units
            else:
                return self._num_units*2
        else:
            return self._output_size

    def zero_state(self, batch_size, dtype=tf.float32):
        if self._real:
            out = tf.zeros([batch_size, self._output_size], dtype=tf.float32)
            first_state = tf.zeros([batch_size, self._num_units])
        else:
            first_state = tf.complex(tf.zeros([batch_size, self._num_units]),
                                     tf.zeros([batch_size, self._num_units]))
            if self._output_size:
                out = tf.zeros([batch_size, self._output_size], dtype=tf.float32)
            else:
                out = tf.zeros([batch_size, self._num_units*2])
        return URNNStateTuple(out, first_state)

    def double_memory_gate(self, h, x, scope, bias_init=4.0):
        """
        Complex GRU gates, the idea is that gates should make use of phase information.
        """

        with tf.variable_scope(scope, self._reuse):
            if self._real:
                ghr = matmul_plus_bias(h, self._num_units, scope='ghr', reuse=self._reuse,
                                       bias=False)
                gxr = matmul_plus_bias(x, self._num_units, scope='gxr', reuse=self._reuse,
                                       bias=True, bias_init=bias_init)
                gr = ghr + gxr
                r = tf.nn.sigmoid(gr)
                ghz = matmul_plus_bias(h, self._num_units, scope='ghz', reuse=self._reuse,
                                       bias=False)
                gxz = matmul_plus_bias(x, self._num_units, scope='gxz', reuse=self._reuse,
                                       bias=True, bias_init=bias_init)
                gz = ghz + gxz
                z = tf.nn.sigmoid(gz)

                if self._real_double:
                    ghr2 = matmul_plus_bias(h, self._num_units, scope='ghr2',
                                            reuse=self._reuse, bias=False)
                    gxr2 = matmul_plus_bias(x, self._num_units, scope='gxr2',
                                            reuse=self._reuse, bias=True,
                                            bias_init=bias_init)
                    gr2 = ghr2 + gxr2
                    r = self._gate_activation([gr, gr2], 'r', self._reuse)
                    ghz2 = matmul_plus_bias(h, self._num_units, scope='ghz2',
                                            reuse=self._reuse, bias=False)
                    gxz2 = matmul_plus_bias(x, self._num_units, scope='gxz2',
                                            reuse=self._reuse, bias=True,
                                            bias_init=bias_init)
                    gz2 = ghz2 + gxz2
                    z = self._gate_activation([gz, gz2], 'z', self._reuse)
            else:
                ghr = complex_matmul(h, self._num_units, scope='ghr', reuse=self._reuse)
                gxr = complex_matmul(x, self._num_units, scope='gxr', reuse=self._reuse,
                                     bias=True, bias_init_i=bias_init,
                                     bias_init_r=bias_init)
                gr = ghr + gxr
                r = self._gate_activation(gr, 'r', self._reuse)
                ghz = complex_matmul(h, self._num_units, scope='ghz', reuse=self._reuse)
                gxz = complex_matmul(x, self._num_units, scope='gxz', reuse=self._reuse,
                                     bias=True, bias_init_i=bias_init,
                                     bias_init_r=bias_init)
                gz = ghz + gxz
                z = self._gate_activation(gz, 'z', self._reuse)
            return r, z

    def single_memory_gate(self, h, x, scope, bias_init):
        """
        Use the real and imaginary parts of the gate equation to do the gating.
        """
        with tf.variable_scope(scope, self._reuse):
            if self._real:
                raise ValueError('Real cells cannot be single gated.')
            else:
                ghs = complex_matmul(h, self._num_units, scope='ghs', reuse=self._reuse)
                gxs = complex_matmul(x, self._num_units, scope='gxs', reuse=self._reuse,
                                     bias=True, bias_init_c=bias_init,
                                     bias_init_r=bias_init)
                gs = ghs + gxs
                return (tf.complex(tf.nn.sigmoid(tf.real(gs)),
                                   tf.zeros_like(tf.real(gs))),
                        tf.complex(tf.nn.sigmoid(tf.imag(gs)),
                                   tf.zeros_like(tf.imag(gs))))

    def __call__(self, inputs, state):
        """
        Evaluate the cell equations.
        Params:
            inputs: The input values.
            state: the past cell state.
        Returns:
            output and new cell state touple.
        """
        with tf.variable_scope("ComplexGatedRecurrentUnit", reuse=self._reuse):
            _, last_h = state

            if not self._real:
                if not self._complex_input:
                    if self._input_hilbert:
                        cinputs = tf.complex(inputs, tf.zeros_like(inputs))
                        inputs = hilbert(cinputs)
                    elif self._input_split_matmul:
                        # Map the inputs from R to C.
                        cinr = matmul_plus_bias(inputs, self._num_units,
                                                'real', self._reuse)
                        cini = matmul_plus_bias(inputs, self._num_units,
                                                'imag', self._reuse)
                        inputs = tf.complex(cinr, cini)
                    else:
                        inputs = tf.complex(inputs, tf.zeros_like(inputs))

            if self._dropout:
                print('adding dropout!')
                inputs = complex_dropout(inputs, 0.2)

            # use open gates initially when working with stiefel optimization.
            if self._stiefel:
                bias_init = 4.0
            else:
                bias_init = 0.0

            if self._single_gate:
                r, z = self.single_memory_gate(last_h, inputs, 'single_memory_gate',
                                               bias_init=bias_init)
            else:
                r, z = self.double_memory_gate(last_h, inputs, 'double_memory_gate',
                                               bias_init=bias_init)

            with tf.variable_scope("canditate_h"):
                if self._real:
                    cinWx = matmul_plus_bias(inputs, self._num_units, 'wx', bias=False,
                                             reuse=self._reuse)
                    rhU = matmul_plus_bias(tf.multiply(r, last_h), self._num_units, 'rhu',
                                           bias=True, orthogonal=self._stiefel,
                                           reuse=self._reuse)
                    tmp = cinWx + rhU
                else:
                    cinWx = complex_matmul(inputs, self._num_units, 'wx', bias=False,
                                           reuse=self._reuse)
                    rhU = complex_matmul(tf.multiply(r, last_h), self._num_units, 'rhu',
                                         bias=True, unitary=self._stiefel,
                                         reuse=self._reuse)
                    tmp = cinWx + rhU

                h_bar = self._activation(tmp)

                if self._dropout:
                    print('adding dropout!')
                    h_bar = complex_dropout(h_bar, 0.25)

            new_h = (1 - z)*last_h + z*h_bar
            

            if self._output_size:
                print('c to r cell output mapping.')
                if self._real:
                    output = matmul_plus_bias(new_h, self._output_size, 'out_map',
                                              reuse=self._reuse)
                    print(self._output_size, 'real', output.shape)
                else:
                    output = C_to_R(new_h, self._output_size, reuse=self._reuse)
                    print(self._output_size, 'not real', output.shape)
            else:
                print('real concatinated cell output')
                output = tf.concat([tf.real(r), tf.imag(new_h)], axis=-1)
                print(self._output_size, output.shape)
            newstate = URNNStateTuple(output, new_h)
            return output, newstate


class ComplexGatedRecurrentUnit(RNNCell):
    '''
    Can we implement a complex GRU?
    '''

    def __init__(self, num_units, activation=mod_relu,
                 num_proj=None, reuse=None, single_gate=False,
                 complex_out=False):
        super().__init__(_reuse=reuse)
        self._num_units = num_units
        self._activation = activation
        # self._state_to_state_act = linear
        self._num_proj = num_proj
        self._arjovski_basis = False
        self._input_fourier = False
        self._input_hilbert = False
        self._input_split_matmul = False
        self._stateU = True
        self._gateO = False
        self._single_gate = single_gate
        self._gate_activation = mod_sigmoid
        self._single_gate_avg = False
        self._complex_inout = complex_out

    def to_string(self):
        cell_str = 'ComplexGatedRecurrentUnit' + '_' \
                   + '_' + 'activation' + '_' + str(self._activation.__name__) + '_'
        if self._input_fourier:
            cell_str += '_input_fourier_'
        elif self._input_hilbert:
            cell_str += '_input_hilbert_'
        elif self._input_split_matmul:
            cell_str += '__input_split_matmul_'
        cell_str += '_stateU' + '_' + str(self._stateU) \
                    + '_gateO_' + str(self._gateO) \
                    + '_singleGate_' + str(self._single_gate)
        if self._single_gate is False:
            cell_str += '_gate_activation_' + self._gate_activation.__name__
        else:
            cell_str += '_single_gate_avg_' + str(self._single_gate_avg)
        return cell_str

    @property
    def state_size(self):
        return URNNStateTuple(self._num_units, self._num_units)

    @property
    def output_size(self):
        if self._num_proj is None:
            return self._num_units
        else:
            if self._complex_inout:
                return self._num_proj
            else:
                return self._num_proj

    def zero_state(self, batch_size, dtype=tf.float32):
        out = tf.zeros([batch_size, self.output_size], dtype=tf.float32)
        first_state = tf.zeros([batch_size, self._num_units])
        return URNNStateTuple(out, first_state)

    def single_memory_gate(self, h, x, scope, bias_init=0.0,
                           unitary=False, orthogonal=False):
        """
        New unified gate, idea use real and imaginary outputs as gating scalars.
        """
        with tf.variable_scope(scope, self._reuse):
            gh = complex_matmul(h, int(self._num_units / 2.0), scope='gh', reuse=self._reuse,
                                unitary=unitary, orthogonal=orthogonal)
            gx = complex_matmul(x, int(self._num_units / 2.0), scope='gx', reuse=self._reuse,
                                bias=True, bias_init_r=bias_init,
                                bias_init_c=bias_init)
            g = gh + gx
            if self._single_gate_avg:
                r = mod_sigmoid_beta(g, scope='r')
                z = mod_sigmoid_beta(g, scope='z')
                return r, z
            else:
                r = tf.nn.sigmoid(tf.real(g))
                z = tf.nn.sigmoid(tf.imag(g))
                return (tf.complex(r, tf.zeros_like(r), name='r'),
                        tf.complex(z, tf.zeros_like(z), name='z'))

    def double_memory_gate(self, h, x, scope, bias_init=4.0):
        """
        Complex GRU gates, the idea is that gates should make use of phase information.
        """
        with tf.variable_scope(scope, self._reuse):
            ghr = complex_matmul(h, int(self._num_units / 2.0), scope='ghr', reuse=self._reuse)
            gxr = complex_matmul(x, int(self._num_units / 2.0), scope='gxr', reuse=self._reuse,
                                 bias=True, bias_init_c=bias_init, bias_init_r=bias_init)
            gr = ghr + gxr
            r = self._gate_activation(gr, 'r', self._reuse)
            ghz = complex_matmul(h, int(self._num_units / 2.0), scope='ghz', reuse=self._reuse)
            gxz = complex_matmul(x, int(self._num_units / 2.0), scope='gxz', reuse=self._reuse,
                                 bias=True, bias_init_c=bias_init, bias_init_r=bias_init)
            gz = ghz + gxz
            z = self._gate_activation(gz, 'z', self._reuse)
            return r, z

    def __call__(self, inputs, state, scope=None):
        with tf.variable_scope("ComplexGatedRecurrentUnit", reuse=self._reuse):
            _, last_h_real = state

            # assemble complex state
            last_h = tf.complex(last_h_real[:, :int(self._num_units / 2)],
                                last_h_real[:, int(self._num_units / 2):])

            if self._input_fourier:
                cinputs = tf.complex(inputs, tf.zeros_like(inputs))
                cin = tf.fft(cinputs)
            elif self._input_hilbert:
                cinputs = tf.complex(inputs, tf.zeros_like(inputs))
                cin = hilbert(cinputs)
            elif self._input_split_matmul:
                # Map the inputs from R to C.
                cinr = matmul_plus_bias(inputs, int(self._num_units / 2.0), 'real', self._reuse)
                cini = matmul_plus_bias(inputs, int(self._num_units / 2.0), 'imag', self._reuse)
                cin = tf.complex(cinr, cini)
            elif self._complex_inout:
                cin = inputs
            else:
                cin = tf.complex(inputs, tf.zeros_like(inputs))

            if self._single_gate:
                r, z = self.single_memory_gate(last_h, cin, 'memory_gate', bias_init=4.0,
                                               orthogonal=self._gateO)
            else:
                r, z = self.double_memory_gate(last_h, cin, 'double_memory_gate',
                                               bias_init=4.0)

            with tf.variable_scope("canditate_h"):
                in_shape = tf.Tensor.get_shape(cin).as_list()[-1]
                var_Wx = tf.get_variable("Wx", [in_shape, int(self._num_units / 2.0), 2],
                                         dtype=tf.float32,
                                         initializer=tf.glorot_uniform_initializer())
                if self._stateU:
                    with tf.variable_scope("unitary_stiefel", reuse=self._reuse):
                        varU = tf.get_variable("recurrent_U",
                                               shape=[int(self._num_units / 2.0),
                                                      int(self._num_units / 2.0), 2],
                                               dtype=tf.float32,
                                               initializer=arjovski_init)
                        U = tf.complex(varU[:, :, 0], varU[:, :, 1])
                else:
                    varU = tf.get_variable("recurrent_U",
                                           shape=[int(self._num_units / 2.0),
                                                  int(self._num_units / 2.0), 2],
                                           dtype=tf.float32,
                                           initializer=arjovski_init)
                    U = tf.complex(varU[:, :, 0], varU[:, :, 1])

                var_bias = tf.get_variable("b", [int(self._num_units / 2.0), 2], dtype=tf.float32,
                                           initializer=tf.zeros_initializer())
                Wx = tf.complex(var_Wx[:, :, 0], var_Wx[:, :, 1])
                bias = tf.complex(var_bias[:, 0], var_bias[:, 1])
                tmp = tf.matmul(cin, Wx) + tf.matmul(tf.multiply(r, last_h), U) + bias
                h_bar = self._activation(tmp)
            new_h = (1 - z) * last_h + z * h_bar
            new_h_real = tf.concat([tf.real(new_h), tf.imag(new_h)], -1)

            if self._num_proj is None:
                output = new_h_real
            else:
                if self._complex_inout:
                    output = complex_matmul(new_h, self._num_proj, scope='C_to_C_out',
                                            reuse=self._reuse)
                    # disassemble complex state.
                    # output = tf.concat([tf.real(output), tf.imag(output)], -1)
                else:
                    output = C_to_R(new_h, self._num_proj, reuse=self._reuse)

            newstate = URNNStateTuple(output, new_h_real)
            return output, newstate