1 #include "caffe2/operators/transpose_op.h" 4 #include "caffe2/mkl/operators/operator_fallback_mkl.h" 5 #endif // CAFFE2_USE_MKL 9 REGISTER_CPU_OPERATOR(Transpose, TransposeOp<CPUContext>);
11 #ifdef CAFFE2_HAS_MKL_DNN 14 REGISTER_MKL_OPERATOR(Transpose, mkl::MKLFallbackOp<TransposeOp<CPUContext>>);
15 #endif // CAFFE2_HAS_MKL_DNN 17 OPERATOR_SCHEMA(Transpose)
20 .TensorInferenceFunction([](
21 const OperatorDef& def,
22 const vector<TensorShape>& in) {
23 ArgumentHelper helper(def);
24 vector<int> axes = helper.GetRepeatedArgument<
int>(
"axes");
25 vector<TensorShape> out(1);
26 out[0].set_data_type(in[0].data_type());
29 for (
auto axis = in [0].dims().rbegin(); axis != in[0].dims().rend();
31 out[0].add_dims(*axis);
34 auto tensor_size = in[0].dims().size();
36 std::all_of(axes.begin(), axes.end(), [&tensor_size](
int& axis) {
37 return axis >= 0 && axis < tensor_size;
40 CAFFE_ENFORCE(valid_axes,
"Axes argument passed in had invalid values");
42 axes.size() == tensor_size,
43 "Axes argument passed in had the incorrect size");
45 for (
auto axis = axes.begin(); axis != axes.end(); ++axis) {
46 out[0].add_dims(in[0].dims().Get(*axis));
53 Transpose the input tensor similar to numpy.transpose. For example, when 54 axes=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape 59 "A list of integers. By default, reverse the dimensions, " 60 "otherwise permute the axes according to the values given.")
61 .Input(0,
"data",
"An input tensor.")
62 .Output(0,
"transposed",
"Transposed output.")
63 .InheritOnnxSchema(
"Transpose");
66 using GradientMakerBase::GradientMakerBase;
68 bool CopyArguments()
const override {
71 vector<OperatorDef> GetGradientDefs()
override {
73 "Transpose",
"", vector<string>{GO(0)}, vector<string>{GI(0)});
74 ops[0].mutable_arg()->CopyFrom(Def().arg());
75 if (ArgumentHelper::HasArgument(Def(),
"axes")) {
77 const Argument& old_axes = GetArgument(Def(),
"axes");
78 const int axes_size = old_axes.ints_size();
79 Argument* new_arg = GetMutableArgument(
"axes",
false, &ops[0]);
80 for (
int i = 0; i < axes_size; ++i) {
81 new_arg->set_ints(old_axes.ints(i), i);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...