Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_reducer_ops.h
1 #pragma once
2 #include "caffe2/core/context.h"
3 #include "caffe2/core/operator.h"
4 #include "caffe2/perfkernels/embedding_lookup.h"
5 
6 namespace caffe2 {
7 
8 // A templated class that implements SparseLengths[Sum,WeightedSum,Mean].
9 template <
10  typename T, // output type
11  class InputTypes, // supported input types, such as TensorTypes<float>
12  bool USE_WEIGHT = 0, // Whether it is SparseLengthsWeightedSum
13  bool USE_MEAN = 0, // Whether this is SparseLengthsMean
14  bool USE_POSITIONAL_WEIGHT = 0
15  // USE_WEIGHT = 1 and USE_POSITIONAL_WEIGHT = 1
16  // -> SparseLengthsPositionalWeightedSum
17  >
18 class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
19  public:
20  USE_OPERATOR_FUNCTIONS(CPUContext);
21  CPUSparseLengthsReductionOp(const OperatorDef& operator_def, Workspace* ws)
22  : Operator<CPUContext>(operator_def, ws) {
23  static_assert(
24  !(USE_WEIGHT & USE_MEAN), "Cannot both specify weight and mean.");
25  }
26 
28 
29  // Currently, we support float and float16 inputs for input data type, and
30  // int32_t and int64_t for the index type.
31 
32  bool RunOnDevice() override {
33  return DispatchHelper<InputTypes>::call(this, Input(DATA));
34  }
35 
36  template <typename InputType>
37  bool DoRunWithType() {
38  return DispatchHelper<TensorTypes2<int32_t, int64_t>, InputType>::call(
39  this, Input(INDICES));
40  }
41 
42  template <typename InputType, typename IndexType>
43  bool DoRunWithType2() {
44  auto& dataInput = Input(DATA);
45  auto& indicesInput = Input(INDICES);
46  auto& lengthsInput = Input(LENGTHS);
47 
48  CAFFE_ENFORCE_EQ(1, indicesInput.ndim(), "INDICES must be a vector");
49  CAFFE_ENFORCE_EQ(1, lengthsInput.ndim(), "LENGTHS must be a vector");
50  const TIndex N = dataInput.dim(0);
51  const int D = dataInput.size_from_dim(1);
52  const TIndex M = lengthsInput.dim(0);
53  const TIndex indices_size = indicesInput.size();
54 
55  auto* output = Output(0);
56  auto shape = dataInput.dims();
57  shape[0] = M;
58  output->Resize(shape);
59  T* out_data = output->template mutable_data<T>();
60 
61  const InputType* in_data = dataInput.template data<InputType>();
62  const IndexType* indices = indicesInput.template data<IndexType>();
63  const int* lengths = lengthsInput.template data<int>();
64  const T* in_weight = nullptr;
65 
66  if (USE_WEIGHT) {
67  // static if
68  auto& weightInput = Input(WEIGHT);
69  CAFFE_ENFORCE_EQ(1, weightInput.ndim(), "WEIGHT must be a vector");
70  if (!USE_POSITIONAL_WEIGHT) {
71  CAFFE_ENFORCE_EQ(
72  weightInput.size(),
73  indices_size,
74  "Weight should have the same length as indices.");
75  }
76  in_weight = weightInput.template data<T>();
77  }
78 
79  // delegate work to perfkernel that branches based on architecture
80  EmbeddingLookup<IndexType, InputType, T, USE_POSITIONAL_WEIGHT>(
81  D,
82  M,
83  indices_size,
84  N,
85  in_data,
86  indices,
87  lengths,
88  in_weight,
89  nullptr, // scale_bias field is only used in SparseLengths8BitsRowwiseOp
90  USE_MEAN,
91  out_data);
92  return true;
93  }
94 
95  private:
96  enum {
97  DATA = 0, // Data input.
98  WEIGHT = 1, // Weight input used in SparseLengthsWeightedSum
99  INDICES = 1 + USE_WEIGHT, // 1 in SparseLengths[Sum,Mean] and
100  // 2 in SparseLengthsWeightedSum
101  LENGTHS = 2 + USE_WEIGHT, // 2 in SparseLengths[Sum, Mean],
102  // 3 in SparseLengthsWeightedSum
103  };
104 };
105 
106 } // namespace caffe2
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:66
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...