Caffe2 - C++ API
A deep learning, cross platform ML framework
concat_split_op.cc
1 #include "caffe2/operators/concat_split_op.h"
2 
3 namespace caffe2 {
4 namespace {
5 std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> splitOpDevInfer(
6  const OperatorDef& def) {
7  auto op_device =
8  def.has_device_option() ? def.device_option() : DeviceOption();
9  vector<DeviceOption> in_dev(def.input_size(), op_device);
10  vector<DeviceOption> out_dev(def.output_size(), op_device);
11 
12  // If we obtain split from input tensor, then 2nd input's type is always CPU.
13  if (def.input_size() == SplitOp<CPUContext>::kSplitOpInputSize) {
14  CAFFE_ENFORCE_GT(in_dev.size(), 1);
15  in_dev[1] = DeviceOption();
16  }
17  return std::make_pair(in_dev, out_dev);
18 }
19 } // namespace.
20 
21 REGISTER_CPU_OPERATOR(Split, SplitOp<CPUContext>);
22 OPERATOR_SCHEMA(Split)
23  .NumInputs(1, 2)
24  .NumOutputs(1, INT_MAX)
25  .Input(0, "input", "The tensor to split")
26  .Input(1, "split", "Optional list of output lengths (see also arg 'split')")
27  .Arg("axis", "Which axis to split on")
28  .Arg("split", "length of each output")
29  .Arg("order", "Either NHWC or NCWH, will split on C axis, defaults to NCHW")
30  .DeviceInferenceFunction(splitOpDevInfer)
31  .SetDoc(R"DOC(
32 Split a tensor into a list of tensors, along the specified
33 'axis'. The lengths of the split can be specified using argument 'split' or
34 optional second input blob to the operator. Otherwise, the tensor is split
35 to equal sized parts.
36 )DOC")
37  .InheritOnnxSchema("Split");
38 
39 namespace {
40 OpSchema::Cost CostInferenceForConcat(
41  const OperatorDef& def,
42  const vector<TensorShape>& in) {
43  ArgumentHelper helper(def);
44  const int axis = helper.HasArgument("axis")
45  ? helper.GetSingleArgument<int>("axis", -1)
46  : GetDimFromOrderString(
47  helper.GetSingleArgument<string>("order", "NCHW"));
48  bool add_axis = helper.GetSingleArgument<int>("add_axis", 0) != 0;
49  const int canonical_axis = canonical_axis_index_(axis, in[0].dims_size());
50  CAFFE_ENFORCE_GT(in.size(), 0);
51  vector<int> out_shape(in[0].dims().begin(), in[0].dims().end());
52  if (add_axis) {
53  out_shape.insert(out_shape.begin() + canonical_axis, in.size());
54  } else {
55  for (int i = 1; i < in.size(); ++i) {
56  out_shape[canonical_axis] += in[i].dims(canonical_axis);
57  }
58  }
59  int size = 1;
60  for (auto& s : out_shape) {
61  size *= s;
62  }
63 
64  struct OpSchema::Cost cost;
65  cost.flops = 0;
66  cost.bytes_moved = size * sizeof(float);
67  cost.params_bytes = 0;
68  return cost;
69 }
70 
71 std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>>
72 concatOpDevInfer(const OperatorDef& def) {
73  auto op_device =
74  def.has_device_option() ? def.device_option() : DeviceOption();
75  vector<DeviceOption> in_dev(def.input_size(), op_device);
76  vector<DeviceOption> out_dev(def.output_size(), op_device);
77 
78  // 2nd output's type is always CPU irrespective of op's device option.
79  CAFFE_ENFORCE_GT(out_dev.size(), 1);
80  out_dev[1] = DeviceOption();
81  return std::make_pair(in_dev, out_dev);
82 }
83 } // namespace
84 
85 REGISTER_CPU_OPERATOR(Concat, ConcatOp<CPUContext>);
86 OPERATOR_SCHEMA(Concat)
87  .NumInputs(1, INT_MAX)
88  .NumOutputs(2)
89  .Arg("axis", "Which axis to concat on")
90  .Arg(
91  "order",
92  "Either NHWC or NCHW, will concat on C axis, defaults to NCHW")
93  .Arg(
94  "add_axis",
95  "Pass 1 to add the axis specified in arg 'axis' to all "
96  "input tensors")
97  .TensorInferenceFunction([](const OperatorDef& def,
98  const vector<TensorShape>& in) {
99  ArgumentHelper helper(def);
100  const int axis = helper.HasArgument("axis")
101  ? helper.GetSingleArgument<int>("axis", -1)
102  : GetDimFromOrderString(
103  helper.GetSingleArgument<string>("order", "NCHW"));
104  bool add_axis = helper.GetSingleArgument<int>("add_axis", 0) != 0;
105  const int canonical_axis = canonical_axis_index_(axis, in[0].dims_size());
106  CAFFE_ENFORCE_GT(in.size(), 0);
107  vector<int> split_shape(1, in.size());
108  vector<int> out_shape(in[0].dims().begin(), in[0].dims().end());
109  if (add_axis) {
110  for (int i = 1; i < in.size(); ++i) {
111  CAFFE_ENFORCE_EQ(
112  in[0].dims().size(),
113  in[i].dims().size(),
114  "All inputs of Concat should have same dims when add_axis = 1. "
115  "Got different sizes for inputs 0 and ",
116  i);
117  for (int j = 0; j < in[0].dims().size(); ++j) {
118  CAFFE_ENFORCE_EQ(
119  in[0].dims(j),
120  in[i].dims(j),
121  "All inputs of Concat should have same dims when add_axis = 1. "
122  "Got different dims for inputs 0 and ",
123  i,
124  ". At dim: ",
125  j);
126  }
127  }
128  out_shape.insert(out_shape.begin() + canonical_axis, in.size());
129  } else {
130  for (int i = 1; i < in.size(); ++i) {
131  CAFFE_ENFORCE_EQ(
132  in[0].dims().size(),
133  in[i].dims().size(),
134  "All inputs of Concat should have same dims except "
135  "canonical_axis dim that is equal to ",
136  canonical_axis,
137  "Got different sizes for inputs 0 and ",
138  i);
139  for (int j = 0; j < in[0].dims().size(); ++j) {
140  if (j == canonical_axis) {
141  continue;
142  }
143  CAFFE_ENFORCE_EQ(
144  in[0].dims(j),
145  in[i].dims(j),
146  "All inputs of Concat should have same dims except "
147  "canonical_axis dim that is equal to ",
148  canonical_axis,
149  "Got different dims for inputs 0 and ",
150  i,
151  ". At dim: ",
152  j);
153  }
154  }
155 
156  for (int i = 1; i < in.size(); ++i) {
157  out_shape[canonical_axis] += in[i].dims(canonical_axis);
158  }
159  }
160  if (def.output_size() == 1) {
161  return vector<TensorShape>{
162  CreateTensorShape(out_shape, in[0].data_type())};
163  }
164  return vector<TensorShape>{
165  CreateTensorShape(out_shape, in[0].data_type()),
166  CreateTensorShape(split_shape, TensorProto::INT32)};
167  })
168  .CostInferenceFunction(CostInferenceForConcat)
169  .DeviceInferenceFunction(concatOpDevInfer)
170  .SetDoc("Concatenate a list of tensors into a single tensor")
171  .Output(0, "concat_result", "Concatenated tensor")
172  .Output(1, "split_info", "The dimensions of the inputs.")
173  .InheritOnnxSchema("Concat");
174 
175 // Backward compatibility names.
176 REGISTER_CPU_OPERATOR(DepthSplit, SplitOp<CPUContext>);
177 REGISTER_CPU_OPERATOR(DepthConcat, ConcatOp<CPUContext>);
178 OPERATOR_SCHEMA(DepthSplit)
179  .NumInputs(1, 2)
180  .NumOutputs(1, INT_MAX)
181  .SetDoc("Backward compatible operator name for Split.");
182 OPERATOR_SCHEMA(DepthConcat)
183  .NumInputs(1, INT_MAX)
184  .NumOutputs(2)
185  .SetDoc("Backward compatible operator name for Concat.");
186 
188  using GradientMakerBase::GradientMakerBase;
189  vector<OperatorDef> GetGradientDefs() override {
190  vector<string> output_grads;
191  for (int i = 0; i < def_.output_size(); ++i) {
192  if (!GradOut(i).IsEmpty()) {
193  output_grads.push_back(GO(i));
194  }
195  }
196  if (output_grads.empty()) {
197  return {};
198  }
199  return SingleGradientDef(
200  "Concat",
201  "",
202  output_grads,
203  vector<string>{GI(0), "_" + GI(0) + "_dims"});
204  }
205 };
206 REGISTER_GRADIENT(Split, GetSplitGradient);
207 REGISTER_GRADIENT(DepthSplit, GetSplitGradient);
208 
210  using GradientMakerBase::GradientMakerBase;
211  vector<OperatorDef> GetGradientDefs() override {
212  if (GradOut(0).IsEmpty()) {
213  return {};
214  }
215  vector<string> grads;
216  for (int i = 0; i < def_.input_size(); ++i) {
217  grads.push_back(GI(i));
218  }
219  return SingleGradientDef("Split", "", vector<string>{GO(0), O(1)}, grads);
220  }
221 };
222 REGISTER_GRADIENT(Concat, GetConcatGradient);
223 REGISTER_GRADIENT(DepthConcat, GetConcatGradient);
224 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...