Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_gather_ops.cc
1 #include "caffe2/operators/batch_gather_ops.h"
2 
3 namespace caffe2 {
4 
5 REGISTER_CPU_OPERATOR(BatchGather, BatchGatherOp<CPUContext>);
6 REGISTER_CPU_OPERATOR(BatchGatherGradient, BatchGatherGradientOp<CPUContext>);
7 
8 OPERATOR_SCHEMA(BatchGather)
9  .NumInputs(2)
10  .NumOutputs(1)
11  .TensorInferenceFunction([](const OperatorDef& def,
12  const vector<TensorShape>& in) {
13  vector<TensorShape> out(1);
14  ArgumentHelper helper(def);
15 
16  vector<int> output_dims;
17  const auto& data_dims = GetDimsVector(in[0]);
18  const auto& indices_dims = GetDimsVector(in[1]);
19  output_dims.push_back(data_dims[0]);
20  output_dims.insert(
21  output_dims.end(), indices_dims.begin(), indices_dims.end());
22  output_dims.insert(
23  output_dims.end(), data_dims.begin() + 2, data_dims.end());
24 
25  out[0] = CreateTensorShape(output_dims, TensorProto::FLOAT);
26  return out;
27  })
28  .SetDoc(R"DOC(
29 Batch gather operation, first dimension in DATA is the batch size.
30 Given DATA tensor of rank r >= 2, and INDICES tensor of rank q >= 1, gather
31 entries of the outer-most dimension of DATA indexed by INDICES, and concatenate
32 them in an output tensor of rank (q - 1) + (r - 1).
33 
34 Example:
35  DATA = [
36  [1.0, 1.2, 2.4, 4.5],
37  [2.3, 3.4, 3.6, 2.3],
38  [4.5, 5.7, 1.2, 4.5],
39  ]
40  INDICES = [
41  [0, 2],
42  ]
43  OUTPUT = [
44  [1.0, 2.4],
45  [2.3, 3.6],
46  [4.5, 1.2],
47  ]
48 )DOC")
49  .Input(0, "DATA", "Tensor of rank r >= 2.")
50  .Input(1, "INDICES", "Tensor of int32/int64 indices, of any rank q.")
51  .Output(0, "OUTPUT", "Tensor of rank (q - 1) + (r - 1).");
52 
53 OPERATOR_SCHEMA(BatchGatherGradient).NumInputs(3).NumOutputs(1);
54 
56  using GradientMakerBase::GradientMakerBase;
57  vector<OperatorDef> GetGradientDefs() override {
58  using Op = BatchGatherOp<CPUContext>;
59  return SingleGradientDef(
60  "BatchGatherGradient",
61  "",
62  vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)},
63  vector<string>{GI(0)});
64  }
65 };
66 
67 REGISTER_GRADIENT(BatchGather, GetBatchGatherGradient);
68 
69 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...