Caffe2 - C++ API
A deep learning, cross platform ML framework
expand_squeeze_dims_op.h
1 #ifndef CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
2 #define CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <class Context>
10 class ExpandDimsOp : public Operator<Context> {
11  public:
12  USE_OPERATOR_CONTEXT_FUNCTIONS;
13  ExpandDimsOp(const OperatorDef& operator_def, Workspace* ws)
14  : Operator<Context>(operator_def, ws),
15  dims_(OperatorBase::GetRepeatedArgument<int>("dims")) {
16  auto originalSize = dims_.size();
17  CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
18  std::sort(dims_.begin(), dims_.end());
19  dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
20  if (dims_.size() < originalSize) {
21  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
22  }
23  CAFFE_ENFORCE(dims_.front() >= 0, "Dimension ids must be non-negative.");
24  }
25 
26  bool RunOnDevice() override {
27  auto& input = Input(0);
28  auto* output = Output(0);
29  output->CopyFrom(input, &context_);
30  if (dims_.empty()) {
31  return true;
32  }
33 
34  auto newDims = input.dims();
35  CAFFE_ENFORCE_GE(
36  input.dims().size() + dims_.size(),
37  dims_.back() + 1,
38  "Input needs at least ",
39  (1 + dims_.back() - dims_.size()),
40  " dimensions given `dims`.");
41  for (const auto dim : dims_) {
42  newDims.insert(newDims.begin() + dim, 1);
43  }
44  output->Reshape(newDims);
45  return true;
46  }
47 
48  private:
49  vector<int> dims_;
50 };
51 
52 template <class Context>
53 class SqueezeOp : public Operator<Context> {
54  public:
55  USE_OPERATOR_CONTEXT_FUNCTIONS;
56  SqueezeOp(const OperatorDef& operator_def, Workspace* ws)
57  : Operator<Context>(operator_def, ws),
58  dims_(OperatorBase::GetRepeatedArgument<int>("dims")) {
59  auto originalSize = dims_.size();
60  CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
61 
62  std::sort(dims_.begin(), dims_.end());
63  dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
64  if (dims_.size() < originalSize) {
65  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
66  }
67  CAFFE_ENFORCE(dims_.front() >= 0, "Dimension ids must be non-negative.");
68  }
69 
70  bool RunOnDevice() override {
71  auto& input = Input(0);
72  auto* output = Output(0);
73  output->CopyFrom(input, &context_);
74 
75  CAFFE_ENFORCE_GT(
76  input.ndim(),
77  dims_.back(),
78  "Input needs at least ",
79  (dims_.back() + 1),
80  " dimensions.");
81 
82  std::vector<int> newDims = ComputeDims(input.dims(), dims_);
83  output->Reshape(newDims);
84  return true;
85  }
86 
87  static std::vector<int> ComputeDims(
88  std::vector<TIndex> inputDims,
89  std::vector<int> dims) {
90  int j = 0;
91  std::vector<int> newDims;
92  for (int i = 0; i < inputDims.size(); ++i) {
93  if (j < dims.size() && dims[j] == i) {
94  CAFFE_ENFORCE_EQ(
95  inputDims[i],
96  1,
97  "Dimension ",
98  i,
99  " of input must be 1",
100  " instead of ",
101  inputDims[i],
102  ".");
103  ++j;
104  continue;
105  }
106  newDims.push_back(inputDims.at(i));
107  }
108  return newDims;
109  }
110 
111  private:
112  vector<int> dims_;
113 
114  public:
115  DISABLE_COPY_AND_ASSIGN(SqueezeOp);
116 };
117 } // namespace caffe2
118 #endif // CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...