Caffe2 - C++ API
A deep learning, cross platform ML framework
concat_split_op.h
1 #ifndef CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
2 #define CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/core/types.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 namespace {
12 inline int GetDimFromOrderString(const string& str) {
13  auto order = StringToStorageOrder(str);
14  switch (order) {
15  case StorageOrder::NHWC:
16  return 3;
17  case StorageOrder::NCHW:
18  return 1;
19  default:
20  CAFFE_THROW("Unsupported storage order: ", str);
21  return -1;
22  }
23 }
24 } // namespace
25 
26 template <class Context>
27 class SplitOp final : public Operator<Context> {
28  public:
29  static const int kSplitOpInputSize = 2;
30 
31  USE_OPERATOR_CONTEXT_FUNCTIONS;
32  SplitOp(const OperatorDef& operator_def, Workspace* ws)
33  : Operator<Context>(operator_def, ws),
34  split_(OperatorBase::GetRepeatedArgument<int>("split")) {
35  CAFFE_ENFORCE(
36  !(OperatorBase::HasArgument("axis") &&
37  OperatorBase::HasArgument("order")),
38  "You shouldn't specify both the dim to split, and the order "
39  "in the case of 4-D images.");
40  if (OperatorBase::HasArgument("axis")) {
41  axis_ = OperatorBase::GetSingleArgument<int>("axis", -1);
42  // only exists for computing the gradient of a Concat with 'add_axis'
43  add_axis_ = OperatorBase::GetSingleArgument<int>("add_axis", 0);
44  } else {
45  axis_ = GetDimFromOrderString(
46  OperatorBase::GetSingleArgument<string>("order", "NCHW"));
47  add_axis_ = 0;
48  }
49  }
50 
51  bool RunOnDevice() override;
52 
53  protected:
54  int axis_;
55  int add_axis_;
56  vector<int> split_;
57  // Input: X, optionally split
58  // The split tensor is stored in CPU.
59 };
60 
61 template <class Context>
62 class ConcatOp final : public Operator<Context> {
63  public:
64  USE_OPERATOR_CONTEXT_FUNCTIONS;
65  ConcatOp(const OperatorDef& operator_def, Workspace* ws)
66  : Operator<Context>(operator_def, ws) {
67  CAFFE_ENFORCE(
68  !(OperatorBase::HasArgument("axis") &&
69  OperatorBase::HasArgument("order")),
70  "You shouldn't specify both the dim to concat, and the order "
71  "in the case of 4-D images.");
72  if (OperatorBase::HasArgument("axis")) {
73  axis_ = OperatorBase::GetSingleArgument<int>("axis", -1);
74  add_axis_ = OperatorBase::GetSingleArgument<int>("add_axis", 0);
75  } else {
76  axis_ = GetDimFromOrderString(
77  OperatorBase::GetSingleArgument<string>("order", "NCHW"));
78  add_axis_ = 0;
79  }
80  }
81 
82  bool RunOnDevice() override;
83 
84  protected:
85  int axis_;
86  int add_axis_;
87  // Input: a number of tensors. Output: Y, split
88  // The split are stored in CPU.
89 };
90 
91 // Implementations
92 template <class Context>
94  auto& input = Input(0);
95  int canonical_axis = input.canonical_axis_index(axis_);
96  CAFFE_ENFORCE_LT(
97  canonical_axis, input.ndim(), "Axis not in input ndim range.");
98  const int input_channels = input.dim32(canonical_axis);
99  const int* axis_data;
100  vector<int> equal_split;
101  if (InputSize() == kSplitOpInputSize) {
102  // We obtain split from the input tensor.
103  CAFFE_ENFORCE_EQ(
104  split_.size(),
105  0,
106  "If you set split with an input blob, do not pass in "
107  "split in the argument.");
108  auto& split_tensor = OperatorBase::Input<TensorCPU>(1);
109  CAFFE_ENFORCE_EQ(split_tensor.size(), OutputSize());
110  axis_data = split_tensor.template data<int>();
111  } else if (split_.size() == 0) {
112  CAFFE_ENFORCE_EQ(
113  input_channels % OutputSize(),
114  0,
115  "If you did not specify split explicitly, the number of "
116  "input channels should be divisible by the output size.");
117  equal_split.resize(OutputSize(), input_channels / OutputSize());
118  axis_data = equal_split.data();
119  } else {
120  // We obtain split from the parameters.
121  CAFFE_ENFORCE_EQ(
122  split_.size(),
123  OutputSize(),
124  "The number of splits specified should be equal to the "
125  "number of outputs.");
126  axis_data = split_.data();
127  }
128 
129  CAFFE_ENFORCE_EQ(
130  add_axis_ ? OutputSize()
131  : std::accumulate(axis_data, axis_data + OutputSize(), 0),
132  input_channels,
133  "Sum of split dimensions do not match: should be ",
134  input_channels);
135  vector<TIndex> output_dims(input.dims());
136  int before = 1, after = 1;
137  for (int i = 0; i < canonical_axis; ++i) {
138  before *= input.dim32(i);
139  }
140  for (int i = canonical_axis + 1; i < input.ndim(); ++i) {
141  after *= input.dim32(i);
142  }
143  if (add_axis_) {
144  output_dims.erase(output_dims.begin() + canonical_axis);
145  }
146  size_t input_offset = 0;
147  for (int i = 0; i < OutputSize(); ++i) {
148  auto* output = Output(i);
149  auto axis_dim = add_axis_ ? 1 : axis_data[i];
150  if (!add_axis_) {
151  output_dims[canonical_axis] = axis_data[i];
152  }
153  output->Resize(output_dims);
154  math::CopyMatrix<Context>(
155  input.itemsize(),
156  before,
157  axis_dim * after,
158  static_cast<const char*>(input.raw_data()) + input_offset,
159  input.dim32(canonical_axis) * after,
160  output->raw_mutable_data(input.meta()),
161  axis_dim * after,
162  &context_,
163  input.meta().copy());
164  input_offset += axis_dim * after * input.itemsize();
165  }
166  return true;
167 }
168 
169 template <class Context>
171  auto* output = Output(0);
172  TensorCPU* split = OperatorBase::Output<TensorCPU>(1);
173  split->Resize(vector<TIndex>(1, InputSize()));
174  int* axis_data = split->template mutable_data<int>();
175  auto& input_zero = Input(0);
176  int adj_size = input_zero.ndim() + (add_axis_ ? 1 : 0);
177  int canonical_axis = canonical_axis_index_(axis_, adj_size);
178  CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
179  for (int i = 1; i < InputSize(); ++i) {
180  CAFFE_ENFORCE(
181  Input(i).meta() == input_zero.meta(),
182  "All inputs must have the same type, expected: ",
183  input_zero.meta().name(),
184  " but got: ",
185  Input(i).meta().name(),
186  " for input: ",
187  i);
188  }
189 
190  int before = 1, after = 1;
191  vector<TIndex> output_dims(input_zero.dims());
192  for (int i = 0; i < input_zero.ndim(); ++i) {
193  if (i == canonical_axis && !add_axis_) {
194  continue;
195  }
196  int dim = input_zero.dim32(i);
197  if (i < canonical_axis) {
198  before *= dim;
199  } else { // i > canonical_axis || i == canonical_axis && add_axis_
200  after *= dim;
201  }
202  // check the input dims are compatible.
203  for (int j = 1; j < InputSize(); ++j) {
204  int dim_j = Input(j).dim32(i);
205  CAFFE_ENFORCE(
206  dim == dim_j,
207  "Expect dimension = ",
208  dim,
209  " got ",
210  dim_j,
211  " at axis = ",
212  i,
213  " for input: ",
214  j,
215  ". The input tensors can only have different dimensions "
216  "when arg 'add_axis' = 0 and along the axis = ",
217  canonical_axis,
218  " <",
219  Input(0).dims(),
220  "> vs <",
221  Input(j).dims(),
222  ">.");
223  }
224  }
225 
226  int output_channels = 0;
227  for (int i = 0; i < InputSize(); ++i) {
228  axis_data[i] = add_axis_ ? 1 : Input(i).dim32(canonical_axis);
229  output_channels += axis_data[i];
230  }
231  if (add_axis_) {
232  output_dims.insert(output_dims.begin() + canonical_axis, output_channels);
233  } else {
234  output_dims[canonical_axis] = output_channels;
235  }
236  output->Resize(output_dims);
237  size_t output_offset = 0;
238  for (int i = 0; i < InputSize(); ++i) {
239  auto& input = Input(i);
240  auto axis_dim = add_axis_ ? 1 : input.dim32(canonical_axis);
241  math::CopyMatrix<Context>(
242  input.itemsize(),
243  before,
244  axis_dim * after,
245  input.raw_data(),
246  axis_dim * after,
247  static_cast<char*>(output->raw_mutable_data(input_zero.meta())) +
248  output_offset,
249  output_channels * after,
250  &context_,
251  input_zero.meta().copy());
252  output_offset += axis_dim * after * input.itemsize();
253  }
254  return true;
255 }
256 
257 } // namespace caffe2
258 
259 #endif // CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
int dim32(const int i) const
Returns the i-th dimension of the tensor in int.
Definition: tensor.h:657
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:288
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37