1 #ifndef CAFFE2_OPERATORS_TRANSPOSE_H_ 2 #define CAFFE2_OPERATORS_TRANSPOSE_H_ 3 #define MAX_BLOB_NUM 1024 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
class Context>
14 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 axes_(OperatorBase::GetRepeatedArgument<int>(
"axes")) {
20 std::vector<int> axes_sorted(axes_);
21 std::sort(axes_sorted.begin(), axes_sorted.end());
22 for (
int i = 0; i < axes_sorted.size(); ++i) {
23 if (axes_sorted[i] != i) {
24 CAFFE_THROW(
"Axes should be a permutation of 0 to ndim.");
30 bool RunOnDevice()
override {
31 const auto& X = Input(0);
33 const int num_axes = X.ndim();
34 const std::vector<int> x_dims(X.dims().cbegin(), X.dims().cend());
35 std::vector<int> y_dims(num_axes);
37 axes_.resize(num_axes);
38 for (
int i = 0; i < num_axes; ++i) {
39 axes_[i] = num_axes - 1 - i;
41 y_dims.assign(X.dims().rbegin(), X.dims().rend());
43 CAFFE_ENFORCE_EQ(X.ndim(), axes_.size());
44 for (
int i = 0; i < num_axes; ++i) {
45 y_dims[i] = X.dim32(axes_[i]);
49 SetDeviceTensor(x_dims, &x_dims_device_);
50 SetDeviceTensor(y_dims, &y_dims_device_);
51 SetDeviceTensor(axes_, &axes_device_);
59 void SetDeviceTensor(
const std::vector<int>& data,
Tensor<Context>* tensor) {
60 tensor->
Resize(data.size());
61 context_.template Copy<int, CPUContext, Context>(
62 data.size(), data.data(), tensor->template mutable_data<int>());
66 bool DoRunWithType() {
67 const auto& X = Input(0);
69 math::Transpose<T, Context>(
71 x_dims_device_.template data<int>(),
72 y_dims_device_.template data<int>(),
73 axes_device_.template data<int>(),
76 Y->template mutable_data<T>(),
81 std::vector<int> axes_;
90 #endif // CAFFE2_OPERATORS_TRANSPOSE_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
void Resize(Ts...dim_source)
Resizes a tensor.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...