Caffe2 - C++ API
A deep learning, cross platform ML framework
transpose_op.h
1 #ifndef CAFFE2_OPERATORS_TRANSPOSE_H_
2 #define CAFFE2_OPERATORS_TRANSPOSE_H_
3 #define MAX_BLOB_NUM 1024
4 
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <class Context>
12 class TransposeOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  USE_DISPATCH_HELPER;
16  TransposeOp(const OperatorDef& operator_def, Workspace* ws)
17  : Operator<Context>(operator_def, ws),
18  axes_(OperatorBase::GetRepeatedArgument<int>("axes")) {
19  // We will check the legality of axes_: it should be from 0 to axes_.size().
20  std::vector<int> axes_sorted(axes_);
21  std::sort(axes_sorted.begin(), axes_sorted.end());
22  for (int i = 0; i < axes_sorted.size(); ++i) {
23  if (axes_sorted[i] != i) {
24  CAFFE_THROW("Axes should be a permutation of 0 to ndim.");
25  }
26  }
27  }
28  ~TransposeOp() {}
29 
30  bool RunOnDevice() override {
31  const auto& X = Input(0);
32  auto* Y = Output(0);
33  const int num_axes = X.ndim();
34  const std::vector<int> x_dims(X.dims().cbegin(), X.dims().cend());
35  std::vector<int> y_dims(num_axes);
36  if (axes_.empty()) {
37  axes_.resize(num_axes);
38  for (int i = 0; i < num_axes; ++i) {
39  axes_[i] = num_axes - 1 - i;
40  }
41  y_dims.assign(X.dims().rbegin(), X.dims().rend());
42  } else {
43  CAFFE_ENFORCE_EQ(X.ndim(), axes_.size());
44  for (int i = 0; i < num_axes; ++i) {
45  y_dims[i] = X.dim32(axes_[i]);
46  }
47  }
48  Y->Resize(y_dims);
49  SetDeviceTensor(x_dims, &x_dims_device_);
50  SetDeviceTensor(y_dims, &y_dims_device_);
51  SetDeviceTensor(axes_, &axes_device_);
52 
53  // Do the actual transpose, which is implemented in DoRunWithType().
55  this, Input(0));
56  }
57 
58  protected:
59  void SetDeviceTensor(const std::vector<int>& data, Tensor<Context>* tensor) {
60  tensor->Resize(data.size());
61  context_.template Copy<int, CPUContext, Context>(
62  data.size(), data.data(), tensor->template mutable_data<int>());
63  }
64 
65  template <typename T>
66  bool DoRunWithType() {
67  const auto& X = Input(0);
68  auto* Y = Output(0);
69  math::Transpose<T, Context>(
70  axes_.size(),
71  x_dims_device_.template data<int>(),
72  y_dims_device_.template data<int>(),
73  axes_device_.template data<int>(),
74  X.size(),
75  X.template data<T>(),
76  Y->template mutable_data<T>(),
77  &context_);
78  return true;
79  }
80 
81  std::vector<int> axes_;
82 
83  Tensor<Context> x_dims_device_;
84  Tensor<Context> y_dims_device_;
85  Tensor<Context> axes_device_;
86 };
87 
88 } // namespace caffe2
89 
90 #endif // CAFFE2_OPERATORS_TRANSPOSE_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:288
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...