import tensorflow as tf
from tensorflow.python.ops.nn_ops import _get_noise_shape


def complex_dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
    '''
    Implementation of complex dropout based on tf.nn.dropout.
    The idea is straightforward, just like its done in the real
    case if a complex number is dropped out it is set to zero.
    The remaining numbers are scaled according to the keep probability.
    '''
    with tf.name_scope(name, "complex_dropout", [x]) as name:
        # Early return if nothing needs to be dropped.
        if keep_prob == 1.0:
            return x

        noise_shape = _get_noise_shape(x, noise_shape)
        # uniform [keep_prob, 1.0 + keep_prob)
        random_tensor = keep_prob
        random_tensor += tf.random_uniform(
            noise_shape, seed=seed, dtype=tf.float32)
        # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
        binary_tensor = tf.floor(random_tensor)
        ret = tf.complex(tf.div(tf.real(x), keep_prob) * binary_tensor,
                         tf.div(tf.imag(x), keep_prob) * binary_tensor)

    return ret

# def complex_dropout_RNN(x, ,keep_prob, noise_shape=None, seed=None, name=None, reuse=None):
#     '''
#     Implementation of complex dropout based on tf.nn.dropout.
#     The idea is straightforward, just like its done in the real
#     case if a complex number is dropped out it is set to zero.
#     The remaining numbers are scaled according to the keep probability.
#     '''
#
#     with tf.variable_scope(name, reuse=reuse):
#         # Early return if nothing needs to be dropped.
#         if keep_prob == 1.0:
#             return x,last_h
#
#         noise_shape = _get_noise_shape(x, noise_shape)
#
#         in_shape = tf.Tensor.get_shape(x).as_list()[-1]
#         # uniform [keep_prob, 1.0 + keep_prob)
#         random_tensor = keep_prob
#         # random_tensor += tf.random_uniform(
#         #     noise_shape, seed=seed, dtype=tf.float32)
#         random_tensor = tf.get_variable("dropout",
#                                shape=[in_shape],
#                                dtype=tf.float32,
#                                initializer=tf.random_uniform_initializer,
#                                trainable=False)
#         # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
#
#         binary_tensor = tf.floor(random_tensor + keep_prob)
#         print('****drop', random_tensor)
#         ret_x = tf.complex(tf.div(tf.real(x), keep_prob) * binary_tensor,
#                          tf.div(tf.imag(x), keep_prob) * binary_tensor)
#
#
#
#     return ret_x


def complex_dropout_RNN(x, last_h, keep_prob, dropout_type, name=None, reuse=None):
    '''
    Implementation of complex dropout based on tf.nn.dropout.
    The idea is straightforward, just like its done in the real
    case if a complex number is dropped out it is set to zero.
    The remaining numbers are scaled according to the keep probability.
    '''

    with tf.variable_scope(name, reuse=reuse):
        # Early return if nothing needs to be dropped.
        if keep_prob == 1.0:
            return x, last_h


        in_shape = tf.Tensor.get_shape(x).as_list()[-1]

        random_tensor_x = tf.get_variable("dropout_x",
                                        shape=[in_shape],
                                        dtype=tf.float32,
                                        initializer=tf.random_uniform_initializer,
                                        trainable=False)

        random_tensor_h = tf.get_variable("dropout_h",
                                        shape=[in_shape],
                                        dtype=tf.float32,
                                        initializer=tf.random_uniform_initializer,
                                        trainable=False)

        # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)

        binary_tensor_x = tf.floor(random_tensor_x + keep_prob)
        binary_tensor_h = tf.floor(random_tensor_h + keep_prob)

        if dropout_type == 0:
            # using different masks for x and state.
            ret_x = tf.complex(tf.div(tf.real(x), keep_prob) * binary_tensor_x,
                               tf.div(tf.imag(x), keep_prob) * binary_tensor_x)

            ret_h = tf.complex(tf.div(tf.real(last_h), keep_prob) * binary_tensor_h,
                               tf.div(tf.imag(last_h), keep_prob) * binary_tensor_h)
            return ret_x, ret_h

        elif dropout_type == 1:
            # using a same mask
            ret_x = tf.complex(tf.div(tf.real(x), keep_prob) * binary_tensor_x,
                               tf.div(tf.imag(x), keep_prob) * binary_tensor_x)

            ret_h = tf.complex(tf.div(tf.real(last_h), keep_prob) * binary_tensor_x,
                               tf.div(tf.imag(last_h), keep_prob) * binary_tensor_x)

            return ret_x, ret_h
        else:
            raise ('Error dropout type')