Caffe2 - C++ API
A deep learning, cross platform ML framework
prepend_dim_op.h
1 
2 #ifndef CAFFE2_OPERATORS_PREPEND_DIM_OP_H_
3 #define CAFFE2_OPERATORS_PREPEND_DIM_OP_H_
4 
5 #include "caffe2/core/common_omp.h"
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/logging.h"
8 #include "caffe2/core/operator.h"
9 
10 namespace caffe2 {
11 
12 template <class Context>
13 class PrependDimOp : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  PrependDimOp(const OperatorDef& operator_def, Workspace* ws)
17  : Operator<Context>(operator_def, ws),
18  dim_size_(OperatorBase::GetSingleArgument<int64_t>("dim_size", 0)) {
19  CAFFE_ENFORCE_GT(
20  dim_size_, 0, "Argument dim_size must be greater than zero.");
21  }
22 
23  bool RunOnDevice() override {
24  auto& input = Input(0);
25  auto* output = Output(0);
26 
27  CAFFE_ENFORCE(input.ndim() > 0, "Input must be at least 1D.");
28  CAFFE_ENFORCE(
29  input.dim(0) % dim_size_ == 0,
30  "First dimension must be multiple of prepend_dim.");
31 
32  vector<int64_t> actual_new_shape(input.ndim() + 1);
33  actual_new_shape[0] = dim_size_;
34  actual_new_shape[1] = input.dim(0) / dim_size_;
35  for (int i = 1; i < input.dims().size(); ++i) {
36  actual_new_shape[i + 1] = input.dim(i);
37  }
38  output->Resize(actual_new_shape);
39 
40  if (output != &input) {
41  // If we are not doing in-place computation, a copy is needed.
42  context_.template CopyItems<Context, Context>(
43  input.meta(),
44  input.size(),
45  input.raw_data(),
46  output->raw_mutable_data(input.meta()));
47  }
48  return true;
49  }
50 
51  private:
52  int64_t dim_size_;
53 };
54 
55 template <class Context>
56 class MergeDimOp : public Operator<Context> {
57  public:
58  USE_OPERATOR_CONTEXT_FUNCTIONS;
59  MergeDimOp(const OperatorDef& operator_def, Workspace* ws)
60  : Operator<Context>(operator_def, ws) {}
61 
62  bool RunOnDevice() override {
63  auto& input = Input(0);
64  auto* output = Output(0);
65 
66  CAFFE_ENFORCE(input.ndim() > 1, "Input must be at least 2D.");
67 
68  vector<int64_t> actual_new_shape(input.ndim() - 1);
69  actual_new_shape[0] = input.dim(0) * input.dim(1);
70  for (int i = 1; i < input.dims().size() - 1; ++i) {
71  actual_new_shape[i] = input.dim(i + 1);
72  }
73  output->Resize(actual_new_shape);
74 
75  if (output != &input) {
76  // If we are not doing in-place computation, a copy is needed.
77  context_.template CopyItems<Context, Context>(
78  input.meta(),
79  input.size(),
80  input.raw_data(),
81  output->raw_mutable_data(input.meta()));
82  }
83  return true;
84  }
85 
86  private:
87  int64_t dim_size_;
88 };
89 
90 } // namespace caffe2
91 
92 #endif // CAFFE2_OPERATORS_PREPEND_DIM_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...