17 #include "batch_permutation_op.h" 21 REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
22 REGISTER_CPU_OPERATOR(
23 BatchPermutationGradient,
24 BatchPermutationGradientOp<float, CPUContext>);
26 OPERATOR_SCHEMA(BatchPermutation)
30 Permute the batch elements of the input tensor X according to the permutation 31 specified in the input indices. 33 Warning: this op does not verify that indices is a valid permutation; gradient 34 comptuation is only correct if indices is a permutation. 39 "Tensor of at least 1D shape (N, D0, D1, ...).")
43 "1D tensor of type int with shape (N, ) specifying a valid permutation " 44 "of the indices in [0, N - 1] (inclusive).")
48 "Tensor with the same shape as X where the (D0, D1, ...) dimensional " 49 "batch elements of X are permuted according to the input indices.");
51 OPERATOR_SCHEMA(BatchPermutationGradient)
57 "See BatchPermutation.")
61 "Gradient of forward output 0 (Y).")
65 "Gradient of forward input 0 (X).");
68 using GradientMakerBase::GradientMakerBase;
69 vector<OperatorDef> GetGradientDefs()
override {
71 "BatchPermutationGradient",
73 vector<string>{I(1), GO(0)},
74 vector<string>{GI(0)});
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...