1 #ifndef CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_ 2 #define CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/core/types.h" 7 #include "caffe2/utils/math.h" 12 inline int GetDimFromOrderString(
const string& str) {
13 auto order = StringToStorageOrder(str);
15 case StorageOrder::NHWC:
17 case StorageOrder::NCHW:
20 CAFFE_THROW(
"Unsupported storage order: ", str);
26 template <
class Context>
29 static const int kSplitOpInputSize = 2;
31 USE_OPERATOR_CONTEXT_FUNCTIONS;
34 split_(OperatorBase::GetRepeatedArgument<int>(
"split")) {
38 "You shouldn't specify both the dim to split, and the order " 39 "in the case of 4-D images.");
41 axis_ = OperatorBase::GetSingleArgument<int>(
"axis", -1);
43 add_axis_ = OperatorBase::GetSingleArgument<int>(
"add_axis", 0);
45 axis_ = GetDimFromOrderString(
46 OperatorBase::GetSingleArgument<string>(
"order",
"NCHW"));
51 bool RunOnDevice()
override;
61 template <
class Context>
64 USE_OPERATOR_CONTEXT_FUNCTIONS;
70 "You shouldn't specify both the dim to concat, and the order " 71 "in the case of 4-D images.");
73 axis_ = OperatorBase::GetSingleArgument<int>(
"axis", -1);
74 add_axis_ = OperatorBase::GetSingleArgument<int>(
"add_axis", 0);
76 axis_ = GetDimFromOrderString(
77 OperatorBase::GetSingleArgument<string>(
"order",
"NCHW"));
82 bool RunOnDevice()
override;
92 template <
class Context>
94 auto& input = Input(0);
95 int canonical_axis = input.canonical_axis_index(axis_);
97 canonical_axis, input.ndim(),
"Axis not in input ndim range.");
98 const int input_channels = input.dim32(canonical_axis);
100 vector<int> equal_split;
101 if (InputSize() == kSplitOpInputSize) {
106 "If you set split with an input blob, do not pass in " 107 "split in the argument.");
108 auto& split_tensor = OperatorBase::Input<TensorCPU>(1);
109 CAFFE_ENFORCE_EQ(split_tensor.size(), OutputSize());
110 axis_data = split_tensor.template data<int>();
111 }
else if (split_.size() == 0) {
113 input_channels % OutputSize(),
115 "If you did not specify split explicitly, the number of " 116 "input channels should be divisible by the output size.");
117 equal_split.resize(OutputSize(), input_channels / OutputSize());
118 axis_data = equal_split.data();
124 "The number of splits specified should be equal to the " 125 "number of outputs.");
126 axis_data = split_.data();
130 add_axis_ ? OutputSize()
131 : std::accumulate(axis_data, axis_data + OutputSize(), 0),
133 "Sum of split dimensions do not match: should be ",
135 vector<TIndex> output_dims(input.dims());
136 int before = 1, after = 1;
137 for (
int i = 0; i < canonical_axis; ++i) {
138 before *= input.dim32(i);
140 for (
int i = canonical_axis + 1; i < input.ndim(); ++i) {
141 after *= input.dim32(i);
144 output_dims.erase(output_dims.begin() + canonical_axis);
146 size_t input_offset = 0;
147 for (
int i = 0; i < OutputSize(); ++i) {
148 auto* output = Output(i);
149 auto axis_dim = add_axis_ ? 1 : axis_data[i];
151 output_dims[canonical_axis] = axis_data[i];
153 output->Resize(output_dims);
154 math::CopyMatrix<Context>(
158 static_cast<const char*
>(input.raw_data()) + input_offset,
159 input.dim32(canonical_axis) * after,
160 output->raw_mutable_data(input.meta()),
163 input.meta().copy());
164 input_offset += axis_dim * after * input.itemsize();
169 template <
class Context>
171 auto* output = Output(0);
172 TensorCPU* split = OperatorBase::Output<TensorCPU>(1);
173 split->
Resize(vector<TIndex>(1, InputSize()));
174 int* axis_data = split->template mutable_data<int>();
175 auto& input_zero = Input(0);
176 int adj_size = input_zero.ndim() + (add_axis_ ? 1 : 0);
177 int canonical_axis = canonical_axis_index_(axis_, adj_size);
178 CAFFE_ENFORCE_LT(canonical_axis, adj_size,
"Axis not in input ndim range.");
179 for (
int i = 1; i < InputSize(); ++i) {
181 Input(i).meta() == input_zero.meta(),
182 "All inputs must have the same type, expected: ",
183 input_zero.meta().name(),
185 Input(i).meta().name(),
190 int before = 1, after = 1;
191 vector<TIndex> output_dims(input_zero.dims());
192 for (
int i = 0; i < input_zero.ndim(); ++i) {
193 if (i == canonical_axis && !add_axis_) {
196 int dim = input_zero.
dim32(i);
197 if (i < canonical_axis) {
203 for (
int j = 1; j < InputSize(); ++j) {
204 int dim_j = Input(j).dim32(i);
207 "Expect dimension = ",
215 ". The input tensors can only have different dimensions " 216 "when arg 'add_axis' = 0 and along the axis = ",
226 int output_channels = 0;
227 for (
int i = 0; i < InputSize(); ++i) {
228 axis_data[i] = add_axis_ ? 1 : Input(i).dim32(canonical_axis);
229 output_channels += axis_data[i];
232 output_dims.insert(output_dims.begin() + canonical_axis, output_channels);
234 output_dims[canonical_axis] = output_channels;
236 output->Resize(output_dims);
237 size_t output_offset = 0;
238 for (
int i = 0; i < InputSize(); ++i) {
239 auto& input = Input(i);
240 auto axis_dim = add_axis_ ? 1 : input.dim32(canonical_axis);
241 math::CopyMatrix<Context>(
247 static_cast<char*
>(output->raw_mutable_data(input_zero.meta())) +
249 output_channels * after,
251 input_zero.meta().copy());
252 output_offset += axis_dim * after * input.itemsize();
259 #endif // CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
int dim32(const int i) const
Returns the i-th dimension of the tensor in int.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
void Resize(Ts...dim_source)
Resizes a tensor.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.