In [72]:
import theano.tensor as tt
from keras.layers.recurrent import GRU
from keras.layers.core import Dense, MaskedLayer, Layer, Merge
from keras.models import Graph
from keras.utils.theano_utils import shared_zeros

In [46]:
class SoftSequentialAttentionLayer(MaskedLayer):
    
    def __init__(self, memmory_dim, driver_dim, inner_dim=128, init='glorot_uniform', inner_activation='relu'):
        super(SoftSequentialAttentionLayer, self).__init__()
        self.init = initializations.get(init)
        self.W_m = self.init((memory_dim, inner_dim))
        self.W_d = self.init((driver_dim, inner_dim))
        self.W_a = self.init((inner_dim, 1))
        self.inner_activation = activations.get(inner_activation)
        self.b_inner = shared_zeros(inner_dim)
        self.b_out = shared_zeros(1)
    
    def set_previous(self, *previous_layers):
        type_name = self.__class__.__name__
        if len(previous_layers) != 2:
            raise ValueError("{}.set_previous expects 2 input layers, got {}".format(
                type_name, previous_layers))
        sequential_memory, attention_driver = previous_layers
        if not sequential_memory.return_sequences:
            raise ValueError("The first input of {} should be a recurrent layer with"
                             " return_sequences=True".format(type_name))
        self.sequential_memory = sequential_memory
        self.attention_driver = attention_driver
        
    def get_input(self, train=False):
        return [self.sequential_memory.get_output(train=train),
                self.attention_driver.get_output(train=train)]
        
    def get_output(self, train=False):
        sequential_memory, attention_driver = self.get_input(train=train)
        # sequential_memory shape: (nb_samples, time (padded with zeros), input_dim)
        # attentin_driver shape: (nb_samples, input_dim)
        # new shape: (time, nb_samples, input_dim) -> because theano.scan iterates over main dimension
        padded_mask = self.get_padded_shuffled_mask(train, sequential_memory, pad=1)
        sequential_memory = sequential_memory.dimshuffle((1, 0, 2))
        h = self.inner_activation(tt.dot(sequential_memory, self.W_m)
                                  + tt.dot(driver, self.W_d)
                                  + self.b_inner)
        a = tt.exp(tt.dot(h, self.W_a) + self.b_out)
        
        
        output = None  #XXX: TODO
        return output
    
    def _variable_length_softmax_step(self, a_t, sum_t):
        return )

In [47]:
class CustomGraph(Graph):

    def add_node(self, layer, name, input=None, inputs=[], merge_mode='concat', create_output=False):
        if hasattr(layer, 'set_name'):
            layer.set_name(name)
        if name in self.namespace:
            raise Exception('Duplicate node identifier: ' + name)
        if input:
            if input not in self.namespace:
                raise Exception('Unknown node/input identifier: ' + input)
            if input in self.nodes:
                layer.set_previous(self.nodes[input])
            elif input in self.inputs:
                layer.set_previous(self.inputs[input])
        if inputs:
            to_merge = []
            for n in inputs:
                if n in self.nodes:
                    to_merge.append(self.nodes[n])
                elif n in self.inputs:
                    to_merge.append(self.inputs[n])
                else:
                    raise Exception('Unknown identifier: ' + n)
            # XXX: here is the change
            if merge_mode == 'distinct':
                layer.set_previous(*to_merge)
            else:
                merge = Merge(to_merge, mode=merge_mode)
                layer.set_previous(merge)

        self.namespace.add(name)
        self.nodes[name] = layer
        self.node_config.append({'name': name,
                                 'input': input,
                                 'inputs': inputs,
                                 'merge_mode': merge_mode})
        layer.init_updates()
        params, regularizers, constraints, updates = layer.get_params()
        self.params += params
        self.regularizers += regularizers
        self.constraints += constraints
        self.updates += updates

        if create_output:
            self.add_output(name, input=name)

In [54]:
graph = CustomGraph()
graph.add_input(name='context_sequences', ndim=3)
graph.add_node(GRU(32, return_sequences=True), name='dense1', input='context_sequences')
graph.add_node(Dense(32, 4), name='dense2', input='context_sequences')
graph.add_node(SoftSequentialAttentionLayer(),
               name='attention', inputs=['dense1', 'dense2'],
               merge_mode='distinct')
graph.add_output(name='output1', input='dense2')
graph.add_output(name='output2', input='attention')

In [55]:
graph.nodes

{'attention': <__main__.SoftSequentialAttentionLayer at 0x10873d630>,
 'dense1': <keras.layers.recurrent.GRU at 0x1085caeb8>,
 'dense2': <keras.layers.core.Dense at 0x10873f438>}

In [56]:
graph.namespace

{'attention', 'context_sequences', 'dense1', 'dense2'}

In [62]:
import numpy as np

In [69]:
x = np.arange(3 * 4 * 5).reshape(5, 3, 4)
a = np.arange(4 * 2).reshape(4, 2)

In [71]:
np.dot(x, a).shape

(5, 3, 2)