1 #include "caffe2/operators/batch_gather_ops.h" 5 REGISTER_CPU_OPERATOR(BatchGather, BatchGatherOp<CPUContext>);
6 REGISTER_CPU_OPERATOR(BatchGatherGradient, BatchGatherGradientOp<CPUContext>);
8 OPERATOR_SCHEMA(BatchGather)
11 .TensorInferenceFunction([](
const OperatorDef& def,
12 const vector<TensorShape>& in) {
13 vector<TensorShape> out(1);
14 ArgumentHelper helper(def);
16 vector<int> output_dims;
17 const auto& data_dims = GetDimsVector(in[0]);
18 const auto& indices_dims = GetDimsVector(in[1]);
19 output_dims.push_back(data_dims[0]);
21 output_dims.end(), indices_dims.begin(), indices_dims.end());
23 output_dims.end(), data_dims.begin() + 2, data_dims.end());
25 out[0] = CreateTensorShape(output_dims, TensorProto::FLOAT);
29 Batch gather operation, first dimension in DATA is the batch size. 30 Given DATA tensor of rank r >= 2, and INDICES tensor of rank q >= 1, gather 31 entries of the outer-most dimension of DATA indexed by INDICES, and concatenate 32 them in an output tensor of rank (q - 1) + (r - 1). 49 .Input(0, "DATA",
"Tensor of rank r >= 2.")
50 .Input(1,
"INDICES",
"Tensor of int32/int64 indices, of any rank q.")
51 .Output(0,
"OUTPUT",
"Tensor of rank (q - 1) + (r - 1).");
53 OPERATOR_SCHEMA(BatchGatherGradient).NumInputs(3).NumOutputs(1);
56 using GradientMakerBase::GradientMakerBase;
57 vector<OperatorDef> GetGradientDefs()
override {
60 "BatchGatherGradient",
62 vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)},
63 vector<string>{GI(0)});
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...