Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_gather_ops.h
1 #ifndef CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_
2 #define CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <class Context>
11 class BatchGatherOp final : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  USE_SIMPLE_CTOR_DTOR(BatchGatherOp)
15 
16  bool RunOnDevice() override {
18  this, OperatorBase::Input<TensorCPU>(INDICES));
19  }
20 
21  template <typename TInd>
22  bool DoRunWithType() {
23  auto& data = Input(DATA);
24  auto& indices = Input(INDICES);
25  auto* output = Output(0);
26 
27  CAFFE_ENFORCE_GE(data.ndim(), 2, "DATA should be at least 2-D");
28 
29  vector<TIndex> shape;
30  shape.push_back(data.dim(0));
31  shape.insert(shape.end(), indices.dims().begin(), indices.dims().end());
32  shape.insert(shape.end(), data.dims().begin() + 2, data.dims().end());
33  output->Resize(shape);
34 
35  auto block_size = data.size_from_dim(2);
36  auto block_bytesize = block_size * data.meta().itemsize();
37  auto N = indices.size();
38  auto data_batch_bytesize = data.size_from_dim(1) * data.meta().itemsize();
39  auto gathered_batch_bytesize =
40  N * data.size_from_dim(2) * data.meta().itemsize();
41  const TInd* idxs = indices.template data<TInd>();
42  auto src_base = static_cast<const char*>(data.raw_data());
43  auto out = static_cast<char*>(output->raw_mutable_data(data.meta()));
44 
45  for (auto batch = 0; batch < data.dim(0); ++batch) {
46  for (auto i = 0; i < N; ++i) {
47  auto idx = idxs[i];
48  CAFFE_ENFORCE(
49  0 <= idx && idx < data.dim(1),
50  "INDICES element is out of DATA bounds, id=",
51  idx,
52  " data_dim=",
53  data.dim(1));
54  auto src =
55  src_base + idx * block_bytesize + batch * data_batch_bytesize;
56  auto dst = out + i * block_bytesize + batch * gathered_batch_bytesize;
57  context_.template CopyItems<Context, Context>(
58  data.meta(), block_size, src, dst);
59  }
60  }
61  return true;
62  }
63 
64  INPUT_TAGS(DATA, INDICES);
65 };
66 
67 template <class Context>
68 class BatchGatherGradientOp final : public Operator<Context> {
69  public:
70  USE_OPERATOR_CONTEXT_FUNCTIONS;
71  USE_SIMPLE_CTOR_DTOR(BatchGatherGradientOp);
72 
73  bool RunOnDevice() override {
75  this, OperatorBase::Input<TensorCPU>(INDICES));
76  }
77 
78  template <typename TInd>
79  bool DoRunWithType() {
80  return DispatchHelper<
82  TInd>::call(this, Input(DATA));
83  }
84 
85  template <typename TInd, typename TData>
86  bool DoRunWithType2() {
87  auto& data = Input(DATA);
88  auto& indices = Input(INDICES);
89  auto& grad = Input(GRAD);
90  auto* output = Output(0);
91 
92  CAFFE_ENFORCE_GE(data.ndim(), 2, "DATA should be at least 2-D");
93  CAFFE_ENFORCE_EQ(
94  data.dim(0), grad.dim(0), "batch sizes should be the same");
95 
96  output->ResizeLike(data);
97  TData* out_data = output->template mutable_data<TData>();
98  if (data.size() <= 0) {
99  return true;
100  }
101 
102  memset(out_data, 0, output->nbytes());
103 
104  const TData* grad_data = grad.template data<TData>();
105 
106  auto block_size = data.size_from_dim(2);
107  auto N = indices.size();
108  auto data_batch_size = data.size_from_dim(1);
109  auto gathered_batch_size = N * data.size_from_dim(2);
110  const TInd* idxs = indices.template data<TInd>();
111 
112  for (auto batch = 0; batch < grad.dim(0); ++batch) {
113  for (auto i = 0; i < N; ++i) {
114  auto idx = idxs[i];
115  CAFFE_ENFORCE(
116  0 <= idx && idx < data.dim(1),
117  "INDICES element is out of DATA bounds, id=",
118  idx,
119  " data_dim=",
120  data.dim(1));
121  math::Add(
122  block_size,
123  out_data + idx * block_size + batch * data_batch_size,
124  grad_data + i * block_size + batch * gathered_batch_size,
125  out_data + idx * block_size + batch * data_batch_size,
126  &context_);
127  }
128  }
129  return true;
130  }
131 
132  template <typename TInd>
133  bool DoRunWithOtherType2() {
134  CAFFE_THROW(
135  "BatchGatherGradient is not implemented on tensor of type ",
136  Input(DATA).meta().name(),
137  "Consider adding it a type in the list DispatchHelper or implementing "
138  "a generic version (which won't work for duplicated indices though)");
139  }
140 
141  INPUT_TAGS(DATA, INDICES, GRAD);
142 };
143 
144 } // namespace caffe2
145 
146 #endif // CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...