Caffe2 - C++ API
A deep learning, cross platform ML framework
one_hot_ops.h
1 #ifndef CAFFE_OPERATORS_ONE_HOT_OPS_H_
2 #define CAFFE_OPERATORS_ONE_HOT_OPS_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <class Context>
12 class OneHotOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15 
16  OneHotOp(const OperatorDef& operator_def, Workspace* ws)
17  : Operator<Context>(operator_def, ws) {}
18 
19  bool RunOnDevice() override {
20  auto& indices = Input(0);
21  CAFFE_ENFORCE_EQ(
22  indices.ndim(),
23  1,
24  "indices input must be 1D tensor of data type TIndex");
25 
26  // Index size input must be in CPU context
27  auto& index_size_tensor = OperatorBase::Input<Tensor<CPUContext>>(1);
28  CAFFE_ENFORCE_EQ(
29  index_size_tensor.size(),
30  1,
31  "index_size_tensor input must be scalar of data type TIndex");
32 
33  auto batch_size = indices.size();
34  auto index_size = *index_size_tensor.template data<TIndex>();
35  auto one_hots = Output(0);
36  one_hots->Resize(batch_size, index_size);
37  auto output_size = one_hots->size();
38  if (output_size == 0) {
39  return true;
40  }
41 
42  DoOneHotOp(batch_size, index_size, indices, one_hots);
43  return true;
44  }
45 
46  protected:
47  void DoOneHotOp(
48  TIndex batch_size,
49  TIndex index_size,
50  const Tensor<Context>& indices,
51  Tensor<Context>* output);
52 };
53 
54 template <class Context>
55 class BatchOneHotOp final : public Operator<Context> {
56  public:
57  USE_OPERATOR_CONTEXT_FUNCTIONS;
58  BatchOneHotOp(const OperatorDef& operator_def, Workspace* ws)
59  : Operator<Context>(operator_def, ws) {}
60 
61  bool RunOnDevice() override {
62  return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(this, Input(X));
63  }
64 
65  template <typename T>
66  bool DoRunWithType();
67 
68  protected:
69  INPUT_TAGS(X, LENS, VALS);
70  OUTPUT_TAGS(ONE_HOT);
71 
72  private:
73  // allows for fast random access to a given dict and is re-used across runs
74  std::vector<TIndex> valsOffsets_;
75 };
76 
77 template <class Context>
78 class BatchBucketOneHotOp final : public Operator<Context> {
79  public:
80  USE_OPERATOR_CONTEXT_FUNCTIONS;
81  BatchBucketOneHotOp(const OperatorDef& operator_def, Workspace* ws)
82  : Operator<Context>(operator_def, ws) {}
83 
84  bool RunOnDevice() override;
85 
86  protected:
87  INPUT_TAGS(X, LENS, BOUNDARIES);
88  OUTPUT_TAGS(ONE_HOT);
89 };
90 
91 } // namespace caffe2
92 
93 #endif // CAFFE_OPERATORS_ONE_HOT_OPS_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...