Caffe2 - C++ API
A deep learning, cross platform ML framework
filler_op.h
1 #ifndef CAFFE2_OPERATORS_FILLER_OP_H_
2 #define CAFFE2_OPERATORS_FILLER_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 // FillerOp takes in either zero or one input.
12 //
13 // If the number of input is 1, the shape will be identical to that of the input
14 // at run time with optional additional dimensions appended at the end as
15 // specified by "extra_shape" argument. In that case the "shape" parameter
16 // should not be set.
17 //
18 // If the number of inputs is 0, the full shape must be provided via "shape"
19 // argument
20 template <class Context>
21 class FillerOp : public Operator<Context> {
22  public:
23  FillerOp(const OperatorDef& operator_def, Workspace* ws)
24  : Operator<Context>(operator_def, ws),
25  shape_(ToVectorTIndex(OperatorBase::GetRepeatedArgument<int>("shape"))),
26  extra_shape_(ToVectorTIndex(
27  OperatorBase::GetRepeatedArgument<int>("extra_shape"))),
28  input_as_shape_(
29  OperatorBase::GetSingleArgument<bool>("input_as_shape", false)) {
30  if (InputSize()) {
31  if (shape_.size() != 0) {
32  CAFFE_THROW(
33  "Cannot set the shape argument and pass in an input at "
34  "the same time");
35  }
36  } else {
37  if (!extra_shape_.empty()) {
38  CAFFE_THROW("Cannot set extra_shape when there is no input");
39  }
40  if (input_as_shape_) {
41  CAFFE_THROW("An input must be given if input_as_shape is true");
42  }
43  if (shape_.size() == 0 &&
44  OperatorBase::HasSingleArgumentOfType<int>("shape")) {
45  CAFFE_THROW("Fill 'shape' argument was a scalar, list expected");
46  }
47  }
48  }
49 
50  virtual ~FillerOp() {}
51  USE_OPERATOR_CONTEXT_FUNCTIONS;
52 
53  bool RunOnDevice() override {
54  auto* output = Operator<Context>::Output(0);
55  if (InputSize()) {
56  auto shape = vector<TIndex>{};
57  if (input_as_shape_) {
58  // Shape input must be in CPU context
59  auto& input = OperatorBase::Input<Tensor<CPUContext>>(0);
60  CAFFE_ENFORCE_EQ(
61  input.ndim(),
62  1,
63  "When input_as_shape is true, the input must be a 1D tensor of "
64  "data type TIndex");
65  auto* shape_data = input.template data<TIndex>();
66  shape.insert(shape.end(), shape_data, shape_data + input.dim32(0));
67  } else {
68  auto& input = Input(0);
69  shape.insert(shape.end(), input.dims().begin(), input.dims().end());
70  }
71  shape.insert(shape.end(), extra_shape_.begin(), extra_shape_.end());
72  output->Resize(shape);
73  } else {
74  output->Resize(shape_);
75  }
76  return Fill(output);
77  }
78 
79  virtual bool Fill(Tensor<Context>* output) = 0;
80 
81  protected:
82  vector<TIndex> shape_;
83  vector<TIndex> extra_shape_;
84  bool input_as_shape_;
85 };
86 
87 template <typename T, class Context>
88 class UniformFillOp final : public FillerOp<Context> {
89  public:
90  USE_OPERATOR_CONTEXT_FUNCTIONS;
91  UniformFillOp(const OperatorDef& operator_def, Workspace* ws)
92  : FillerOp<Context>(operator_def, ws),
93  min_(OperatorBase::template GetSingleArgument<T>("min", 0)),
94  max_(OperatorBase::template GetSingleArgument<T>("max", 1)) {
95  if (InputSize() == 3) {
96  CAFFE_ENFORCE(
97  !OperatorBase::HasSingleArgumentOfType<T>("min"),
98  "Cannot set both min arg and min input blob");
99  CAFFE_ENFORCE(
100  !OperatorBase::HasSingleArgumentOfType<T>("max"),
101  "Cannot set both max arg and max input blob");
102  } else {
103  CAFFE_ENFORCE_LT(
104  min_, max_, "Max value should be bigger than min value.");
105  }
106  }
107 
108  bool Fill(Tensor<Context>* output) override {
109  T min = min_;
110  T max = max_;
111  if (InputSize() == 3) {
112  CAFFE_ENFORCE_EQ(1, Input(1).size(), "min blob must be scalar");
113  CAFFE_ENFORCE_EQ(1, Input(2).size(), "max blob must be scalar");
114  min = *Input(1).template data<T>();
115  max = *Input(2).template data<T>();
116  if (min > max) {
117  auto shape = output->dims();
118  shape[0] = 0;
119  output->Resize(shape);
120  output->template mutable_data<T>();
121  return true;
122  }
123  }
124  math::RandUniform<T, Context>(
125  output->size(),
126  min,
127  max,
128  output->template mutable_data<T>(),
129  &context_);
130  return true;
131  }
132 
133  private:
134  T min_;
135  T max_;
136 };
137 
138 template <class Context>
139 class UniqueUniformFillOp final : public FillerOp<Context> {
140  public:
141  USE_OPERATOR_CONTEXT_FUNCTIONS;
142  UniqueUniformFillOp(const OperatorDef& operator_def, Workspace* ws)
143  : FillerOp<Context>(operator_def, ws) {
144  TensorProto_DataType dtype =
145  static_cast<TensorProto_DataType>(OperatorBase::GetSingleArgument<int>(
146  "dtype", TensorProto_DataType_INT32));
147 
148  switch (dtype) {
149  case TensorProto_DataType_INT32:
150  CheckRange<int>();
151  body_ = &UniqueUniformFillOp::FillWithType<int>;
152  break;
153  case TensorProto_DataType_INT64:
154  CheckRange<int64_t>();
155  body_ = &UniqueUniformFillOp::FillWithType<int64_t>;
156  break;
157  case TensorProto_DataType_UNDEFINED:
158  CAFFE_THROW(
159  "UniqueUniformFill op cannot have undefined 'dtype' argument");
160  // break;
161  default:
162  CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
163  }
164  }
165 
166  bool Fill(Tensor<Context>* output) override {
167  return (this->*body_)(output);
168  }
169 
170  private:
171  template <typename T>
172  void CheckRange() {
173  CAFFE_ENFORCE(OperatorBase::HasSingleArgumentOfType<T>("min"));
174  CAFFE_ENFORCE(OperatorBase::HasSingleArgumentOfType<T>("max"));
175  CAFFE_ENFORCE_LT(
176  OperatorBase::GetSingleArgument<T>("min", 0),
177  OperatorBase::GetSingleArgument<T>("max", 0),
178  "Max value should be bigger than min value.");
179  }
180 
181  template <typename T>
182  bool FillWithType(Tensor<Context>* output) {
183  T min = OperatorBase::GetSingleArgument<T>("min", 0);
184  T max = OperatorBase::GetSingleArgument<T>("max", 0);
185 
186  const T* avoid_data = nullptr;
187  size_t avoid_size = 0;
188  if (InputSize() >= 2) {
189  auto& avoid = Input(1);
190  avoid_data = avoid.template data<T>();
191  avoid_size = avoid.size();
192  }
193  math::RandUniformUnique<T, Context>(
194  output->size(),
195  min,
196  max,
197  output->template mutable_data<T>(),
198  avoid_size,
199  avoid_data,
200  &context_);
201  return true;
202  }
203 
204  bool (UniqueUniformFillOp::*body_)(Tensor<Context>* output);
205 };
206 
207 template <class Context>
208 class ConstantFillOp final : public FillerOp<Context> {
209  public:
210  USE_OPERATOR_CONTEXT_FUNCTIONS;
211  ConstantFillOp(const OperatorDef& operator_def, Workspace* ws)
212  : FillerOp<Context>(operator_def, ws) {
213  TensorProto_DataType dtype =
214  static_cast<TensorProto_DataType>(OperatorBase::GetSingleArgument<int>(
215  "dtype", TensorProto_DataType_FLOAT));
216 
217  if (!OperatorBase::HasArgument("dtype") &&
218  OperatorBase::HasArgument("value")) {
219  // If 'dtype' is not provided, infer type based on the type of 'value'
220  // Currently, single argument contains either float, int64 or bytes
221  if (OperatorBase::HasSingleArgumentOfType<float>("value")) {
222  dtype = TensorProto_DataType_FLOAT;
223  } else if (OperatorBase::HasSingleArgumentOfType<int64_t>("value")) {
224  dtype = TensorProto_DataType_INT64;
225  } else {
226  CAFFE_THROW("Argument 'value' is of unexpected type");
227  }
228  VLOG(1) << "Argument 'dtype' is not provided. Assume the data type is "
229  << "the same as that of argument 'value': " << dtype;
230  }
231 
232  switch (dtype) {
233  case TensorProto_DataType_FLOAT:
234  body_ = &ConstantFillOp::FillWithType<float>;
235  break;
236  case TensorProto_DataType_DOUBLE:
237  body_ = &ConstantFillOp::FillWithType<double>;
238  break;
239  case TensorProto_DataType_BOOL:
240  body_ = &ConstantFillOp::FillWithType<bool>;
241  break;
242  case TensorProto_DataType_INT8:
243  body_ = &ConstantFillOp::FillWithType<int8_t>;
244  break;
245  case TensorProto_DataType_INT16:
246  body_ = &ConstantFillOp::FillWithType<int16_t>;
247  break;
248  case TensorProto_DataType_INT32:
249  body_ = &ConstantFillOp::FillWithType<int>;
250  break;
251  case TensorProto_DataType_INT64:
252  body_ = &ConstantFillOp::FillWithType<int64_t>;
253  break;
254  case TensorProto_DataType_UINT8:
255  body_ = &ConstantFillOp::FillWithType<uint8_t>;
256  break;
257  case TensorProto_DataType_UINT16:
258  body_ = &ConstantFillOp::FillWithType<uint16_t>;
259  break;
260  case TensorProto_DataType_STRING:
261  body_ = &ConstantFillOp::FillWithString;
262  break;
263  case TensorProto_DataType_UNDEFINED:
264  CAFFE_THROW("ConstantFill op cannot have undefined 'dtype' argument");
265  // break;
266  default:
267  CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
268  }
269  }
270 
271  bool Fill(Tensor<Context>* output) override {
272  return (this->*body_)(output);
273  }
274 
275  template <typename T>
276  bool FillWithType(Tensor<Context>* output) {
277  T value = OperatorBase::GetSingleArgument<T>("value", 0);
278  auto* data = output->template mutable_data<T>();
279  if (output->size()) {
280  math::Set<T, Context>(output->size(), value, data, &context_);
281  }
282  return true;
283  }
284 
285  bool FillWithString(Tensor<Context>* output) {
286  auto value = OperatorBase::GetSingleArgument<std::string>("value", "");
287  auto* data = output->template mutable_data<std::string>();
288  for (int i = 0; i < output->size(); ++i) {
289  data[i] = value;
290  }
291  return true;
292  }
293 
294  private:
295  bool (ConstantFillOp::*body_)(Tensor<Context>* output);
296 };
297 
298 template <class Context>
299 class DiagonalFillOp final : public FillerOp<Context> {
300  public:
301  USE_OPERATOR_CONTEXT_FUNCTIONS;
302  DiagonalFillOp(const OperatorDef& operator_def, Workspace* ws)
303  : FillerOp<Context>(operator_def, ws) {
304  TensorProto_DataType dtype =
305  static_cast<TensorProto_DataType>(OperatorBase::GetSingleArgument<int>(
306  "dtype", TensorProto_DataType_FLOAT));
307 
308  if (!OperatorBase::HasArgument("dtype") &&
309  OperatorBase::HasArgument("value")) {
310  // If 'dtype' is not provided, infer type based on the type of 'value'
311  // Currently, single argument contains either float, int64 or bytes
312  if (OperatorBase::HasSingleArgumentOfType<float>("value")) {
313  dtype = TensorProto_DataType_FLOAT;
314  } else if (OperatorBase::HasSingleArgumentOfType<int64_t>("value")) {
315  dtype = TensorProto_DataType_INT64;
316  } else {
317  CAFFE_THROW("Argument 'value' is of unexpected type");
318  }
319  VLOG(1) << "Argument 'dtype' is not provided. Assume the data type is "
320  << "the same as that of argument 'value': " << dtype;
321  }
322 
323  switch (dtype) {
324  case TensorProto_DataType_FLOAT:
325  body_ = &DiagonalFillOp::FillWithType<float>;
326  break;
327  case TensorProto_DataType_DOUBLE:
328  body_ = &DiagonalFillOp::FillWithType<double>;
329  break;
330  case TensorProto_DataType_BOOL:
331  body_ = &DiagonalFillOp::FillWithType<bool>;
332  break;
333  case TensorProto_DataType_INT8:
334  body_ = &DiagonalFillOp::FillWithType<int8_t>;
335  break;
336  case TensorProto_DataType_INT16:
337  body_ = &DiagonalFillOp::FillWithType<int16_t>;
338  break;
339  case TensorProto_DataType_INT32:
340  body_ = &DiagonalFillOp::FillWithType<int>;
341  break;
342  case TensorProto_DataType_INT64:
343  body_ = &DiagonalFillOp::FillWithType<int64_t>;
344  break;
345  case TensorProto_DataType_UINT8:
346  body_ = &DiagonalFillOp::FillWithType<uint8_t>;
347  break;
348  case TensorProto_DataType_UINT16:
349  body_ = &DiagonalFillOp::FillWithType<uint16_t>;
350  break;
351  case TensorProto_DataType_UNDEFINED:
352  CAFFE_THROW("Cannot have undefined 'dtype' argument");
353  default:
354  CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
355  }
356  }
357 
358  bool Fill(Tensor<Context>* output) override {
359  return (this->*body_)(output);
360  }
361 
362  template <typename T>
363  bool FillWithType(Tensor<Context>* output);
364 
365  private:
366  void VerifyOutputShape(Tensor<Context>* output) {
367  CAFFE_ENFORCE(output->ndim() >= 2, "Input shape must be >= 2D");
368  }
369 
370  TIndex GetStepSize(Tensor<Context>* output) {
371  TIndex step;
372  if (output->ndim() == 2) {
373  step = output->dim(1) + 1;
374  } else {
375  TIndex prev_i = output->dim(0);
376  for (auto i : output->dims()) {
377  if (i != prev_i) {
378  CAFFE_THROW("All dimensions of input must be of equal length");
379  }
380  }
381  vector<TIndex> cumprod(output->ndim());
382  auto dims = output->dims();
383  std::partial_sum(
384  dims.begin(),
385  dims.end() - 1,
386  cumprod.begin(),
387  std::multiplies<TIndex>());
388  step = 1 +
389  std::accumulate(
390  cumprod.begin(), cumprod.end(), static_cast<TIndex>(0));
391  VLOG(0) << step;
392  }
393  return step;
394  }
395 
396  bool (DiagonalFillOp::*body_)(Tensor<Context>* output);
397 };
398 
399 template <typename T, class Context>
400 class GaussianFillOp final : public FillerOp<Context> {
401  public:
402  USE_OPERATOR_CONTEXT_FUNCTIONS;
403  GaussianFillOp(const OperatorDef& operator_def, Workspace* ws)
404  : FillerOp<Context>(operator_def, ws),
405  mean_(OperatorBase::template GetSingleArgument<float>("mean", 0)),
406  std_(OperatorBase::template GetSingleArgument<float>("std", 1)) {
407  DCHECK_GT(std_, 0) << "Standard deviation should be nonnegative.";
408  }
409 
410  bool Fill(Tensor<Context>* output) override {
411  math::RandGaussian<T, Context>(
412  output->size(),
413  mean_,
414  std_,
415  output->template mutable_data<T>(),
416  &context_);
417  return true;
418  }
419 
420  private:
421  T mean_;
422  T std_;
423 };
424 
425 template <typename T, class Context>
426 class XavierFillOp final : public FillerOp<Context> {
427  public:
428  USE_OPERATOR_CONTEXT_FUNCTIONS;
429  XavierFillOp(const OperatorDef& operator_def, Workspace* ws)
430  : FillerOp<Context>(operator_def, ws) {}
431 
432  bool Fill(Tensor<Context>* output) override {
433  const int fan_in = output->size() / output->dim32(0);
434  T scale = std::sqrt(T(3) / fan_in);
435  math::RandUniform<T, Context>(
436  output->size(),
437  -scale,
438  scale,
439  output->template mutable_data<T>(),
440  &context_);
441  return true;
442  }
443 };
444 
445 template <typename T, class Context>
446 class MSRAFillOp final : public FillerOp<Context> {
447  public:
448  USE_OPERATOR_CONTEXT_FUNCTIONS;
449  MSRAFillOp(const OperatorDef& operator_def, Workspace* ws)
450  : FillerOp<Context>(operator_def, ws) {}
451 
452  bool Fill(Tensor<Context>* output) override {
453  const int fan_out = output->size() / output->dim32(1);
454  T scale = std::sqrt(T(2) / fan_out);
455  math::RandGaussian<T, Context>(
456  output->size(),
457  0.0,
458  scale,
459  output->template mutable_data<T>(),
460  &context_);
461  return true;
462  }
463 };
464 
465 // This is mostly used just as a debugging purpose stuff: it fills a tensor
466 // sequentially with values 0, 1, 2..., which can then be used to check e.g.
467 // reshape operations by allowing one to read the indices more easily.
468 template <typename T, class Context>
469 class RangeFillOp final : public FillerOp<Context> {
470  public:
471  USE_OPERATOR_CONTEXT_FUNCTIONS;
472  RangeFillOp(const OperatorDef& operator_def, Workspace* ws)
473  : FillerOp<Context>(operator_def, ws) {}
474 
475  bool Fill(Tensor<Context>* output) override;
476 };
477 
478 template <class Context>
479 class LengthsRangeFillOp : public Operator<Context> {
480  public:
481  USE_OPERATOR_CONTEXT_FUNCTIONS;
482  USE_SIMPLE_CTOR_DTOR(LengthsRangeFillOp);
483 
484  bool RunOnDevice() override {
485  auto& input = Input(0);
486  auto* output = Output(0);
487  auto* input_data = input.template data<int32_t>();
488 
489  CAFFE_ENFORCE_EQ(input.ndim(), 1, "Input must be a vector.");
490 
491  auto len_sum = std::accumulate(input_data, input_data + input.size(), 0);
492 
493  output->Resize(len_sum);
494  auto* output_data = output->template mutable_data<int32_t>();
495 
496  int32_t offset = 0;
497  for (int i = 0; i < input.size(); ++i) {
498  auto len = input_data[i];
499  auto start = output_data + offset;
500  std::iota(
501  start,
502  start + len,
503  0); // make the third argument the arg of this operator
504  offset += len;
505  }
506  return true;
507  }
508 };
509 
510 template <int VALUE_TYPE = TensorProto_DataType_FLOAT>
511 inline std::vector<TensorShape> FillerTensorInference(
512  const OperatorDef& def,
513  const vector<TensorShape>& in) {
514  vector<TensorShape> out(1);
515  ArgumentHelper helper(def);
516  out[0].set_data_type(static_cast<TensorProto_DataType>(
517  helper.GetSingleArgument<int>("dtype", VALUE_TYPE)));
518 
519  if (in.size()) {
520  // TODO
521  bool input_as_shape =
522  helper.GetSingleArgument<bool>("input_as_shape", false);
523  if (input_as_shape) {
524  out[0].set_unknown_shape(true);
525  return out;
526  }
527  for (int d : in[0].dims()) {
528  out[0].add_dims(d);
529  }
530  } else {
531  auto shape = helper.GetRepeatedArgument<int>("shape");
532  for (int d : shape) {
533  out[0].add_dims(d);
534  }
535  }
536  return out;
537 }
538 
539 } // namespace caffe2
540 
541 #endif // CAFFE2_OPERATORS_FILLER_OP_H_
TIndex dim(const int i) const
Returns the i-th dimension of the tensor.
Definition: tensor.h:671
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
int dim32(const int i) const
Returns the i-th dimension of the tensor in int.
Definition: tensor.h:657
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:593
A helper class to index into arguments.
Definition: proto_utils.h:198
const vector< TIndex > & dims() const
Returns the dimensions of the tensor as a vector.
Definition: tensor.h:611
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:288
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37
vector< TIndex > ToVectorTIndex(const std::vector< int > &src)
A utility function to convert vector<int> to vector<TIndex>.
Definition: tensor.h:33
int ndim() const
Returns the number of dimensions of the data.
Definition: tensor.h:589