Caffe2 - Python API
A deep learning, cross platform ML framework
concat.py
1 ## @package concat
2 # Module caffe2.python.layers.concat
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 from future.utils import viewitems
13 import numpy as np
14 
15 import logging
16 logger = logging.getLogger(__name__)
17 
18 class Concat(ModelLayer):
19  """
20  Construct Concat layer
21  Assume that first dimension is batch,
22 
23  Example:
24 
25  embedding_dim = 64
26  input_record = self.new_record(schema.Struct(
27  ('input1', schema.Scalar((np.float32, (embedding_dim, )))),
28  ('input2', schema.Scalar((np.float32, (embedding_dim, )))),
29  ('input3', schema.Scalar((np.float32, (embedding_dim, )))),
30  ))
31 
32  output = self.model.Concat(input_record)
33  self.assertEqual(
34  schema.Scalar((np.float32, ((len(input_record.fields) * embedding_dim, )))),
35  output
36  )
37 
38  # Note that in Concat layer we assume first dimension is batch.
39  # so input is B * embedding_dim
40  # add_axis=1 make it B * 1 * embedding_dim
41  # Concat on axis=1 make it B * N * embedding_dim
42 
43  output = self.model.Concat(input_record, axis=1, add_axis=1)
44  self.assertEqual(
45  schema.Scalar((np.float32, ((len(input_record.fields), embedding_dim)))),
46  output
47  )
48  """
49 
50  def __init__(self, model, input_record, axis=1, add_axis=0,
51  name='concat', **kwargs):
52  super(Concat, self).__init__(model, name, input_record, **kwargs)
53  self.axis = axis
54  self.add_axis = add_axis
55  assert not (axis == 0 and add_axis == 1), \
56  "It's not allowed to add axis=0"
57  assert isinstance(input_record, schema.Struct),\
58  "Incorrect input type. Excpected Struct, but received: {0}".\
59  format(input_record)
60 
61  shapes = []
62  for field_name, field_type in viewitems(input_record.fields):
63  assert isinstance(field_type, schema.Scalar),\
64  "Incorrect input type for {}. Excpected Scalar, but got: {}".\
65  format(field_name, field_type)
66  # Assume that first dimension is batch, so actual axis in shape is
67  # axis - 1
68  shape = list(field_type.field_type().shape)
69  if add_axis:
70  shape.insert(axis - 1, 1)
71  assert len(shape) >= axis,\
72  "Concat expects that limited dimensions of the input tensor"
73  shapes.append(shape)
74  logger.info('Concat Layer input shapes: ' + str(shapes))
75 
76  if axis == 0:
77  self.output_schema = schema.from_blob_list(
78  input_record[0],
79  [self.get_next_blob_reference('output')]
80  )
81  return
82 
83  concat_dim = 0
84  for shape in shapes:
85  concat_dim += shape[axis - 1]
86  shape[axis - 1] = 0
87  assert shape == shapes[0],\
88  "Shapes {0} and {1} are not compatible for Concat".\
89  format(shape, shapes[0])
90  output_dims = shapes[0]
91  output_dims[axis - 1] = concat_dim
92 
93  logger.info('Concat Layer output_dims: ' + str(output_dims))
95  (np.float32, output_dims),
96  self.get_next_blob_reference('output'))
97 
98  def add_ops(self, net):
99  net.Concat(
100  self.input_record.field_blobs(),
101  [
102  self.output_schema.field_blobs()[0],
103  self.output_schema.field_blobs()[0] + "_concat_dims"
104  ],
105  axis=self.axis,
106  add_axis=self.add_axis,
107  )