3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
12 from future.utils
import viewitems
16 logger = logging.getLogger(__name__)
20 Construct Concat layer 21 Assume that first dimension is batch, 26 input_record = self.new_record(schema.Struct( 27 ('input1', schema.Scalar((np.float32, (embedding_dim, )))), 28 ('input2', schema.Scalar((np.float32, (embedding_dim, )))), 29 ('input3', schema.Scalar((np.float32, (embedding_dim, )))), 32 output = self.model.Concat(input_record) 34 schema.Scalar((np.float32, ((len(input_record.fields) * embedding_dim, )))), 38 # Note that in Concat layer we assume first dimension is batch. 39 # so input is B * embedding_dim 40 # add_axis=1 make it B * 1 * embedding_dim 41 # Concat on axis=1 make it B * N * embedding_dim 43 output = self.model.Concat(input_record, axis=1, add_axis=1) 45 schema.Scalar((np.float32, ((len(input_record.fields), embedding_dim)))), 50 def __init__(self, model, input_record, axis=1, add_axis=0,
51 name=
'concat', **kwargs):
52 super(Concat, self).__init__(model, name, input_record, **kwargs)
55 assert not (axis == 0
and add_axis == 1), \
56 "It's not allowed to add axis=0" 58 "Incorrect input type. Excpected Struct, but received: {0}".\
62 for field_name, field_type
in viewitems(input_record.fields):
64 "Incorrect input type for {}. Excpected Scalar, but got: {}".\
65 format(field_name, field_type)
68 shape = list(field_type.field_type().shape)
70 shape.insert(axis - 1, 1)
71 assert len(shape) >= axis,\
72 "Concat expects that limited dimensions of the input tensor" 74 logger.info(
'Concat Layer input shapes: ' + str(shapes))
79 [self.get_next_blob_reference(
'output')]
85 concat_dim += shape[axis - 1]
87 assert shape == shapes[0],\
88 "Shapes {0} and {1} are not compatible for Concat".\
89 format(shape, shapes[0])
90 output_dims = shapes[0]
91 output_dims[axis - 1] = concat_dim
93 logger.info(
'Concat Layer output_dims: ' + str(output_dims))
95 (np.float32, output_dims),
96 self.get_next_blob_reference(
'output'))
98 def add_ops(self, net):
100 self.input_record.field_blobs(),
102 self.output_schema.field_blobs()[0],
103 self.output_schema.field_blobs()[0] +
"_concat_dims"