Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_reducer_fused_8bit_rowwise_ops.h
1 #ifndef CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_
2 #define CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/operators/fused_rowwise_8bit_conversion_ops.h"
8 #include "caffe2/operators/reducer_functors.h"
9 #include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h"
10 #include "caffe2/utils/math.h"
11 
12 namespace caffe2 {
13 
14 template <class Context, bool with_weights = 0, bool is_mean = 0>
15 class SparseLengthsFused8BitRowwiseOp : public Operator<Context> {
16  public:
17  static_assert(
18  !(with_weights && is_mean),
19  "Cannot have with_weights and is_mean a the same time");
20 
21  USE_OPERATOR_CONTEXT_FUNCTIONS;
22  USE_SIMPLE_CTOR_DTOR(SparseLengthsFused8BitRowwiseOp)
23 
24  bool RunOnDevice() override {
26  this, Input(INDICES));
27  }
28 
29  template <typename IndexType>
30  bool DoRunWithType() {
31  const auto& data = Input(DATA);
32  const auto& indices = Input(INDICES);
33  const auto& lengths = Input(LENGTHS);
34  auto* output = Output(0);
35 
36  CAFFE_ENFORCE_EQ(indices.ndim(), 1, "INDICES must be a vector");
37  CAFFE_ENFORCE_EQ(lengths.ndim(), 1, "LENGTHS must be a vector");
38 
39  const float* weights = nullptr;
40  if (with_weights) {
41  const auto& weights_input = Input(WEIGHTS);
42  CAFFE_ENFORCE_EQ(weights_input.ndim(), 1, "WEIGHTS must be a vector");
43  CAFFE_ENFORCE_EQ(
44  weights_input.size(),
45  indices.size(),
46  "WEIGHTS should have the same length as INDICES.");
47  weights = weights_input.template data<float>();
48  }
49 
50  CAFFE_ENFORCE_GT(data.dim(1), 8, "DATA must have more than 8 columns");
51  // Subtract 8 from the #columns of data for the 4 bytes for scale and 4
52  // bytes for bias that we use in the fused representation (per row).
53  const std::vector<TIndex> shape = {lengths.dim(0), data.dim(1) - 8};
54  output->Resize(shape);
55 
57  /*block_size=*/output->dim(1),
58  /*output_size=*/output->dim(0),
59  /*index_size=*/indices.size(),
60  /*data_size=*/data.dim(0),
61  /*input=*/data.template data<uint8_t>(),
62  /*indices=*/indices.template data<IndexType>(),
63  /*lengths=*/lengths.template data<int>(),
64  /*weights=*/weights,
65  /*normalize_by_lengths=*/is_mean,
66  /*out=*/output->template mutable_data<float>());
67 
68  return true;
69  }
70 
71  private:
72  enum {
73  DATA = 0,
74  WEIGHTS = 1,
75  INDICES = 1 + with_weights,
76  LENGTHS = 2 + with_weights,
77  };
78 };
79 
80 } // namespace caffe2
81 
82 #endif // CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void Fused8BitRowwiseEmbeddingLookup(const TIndex block_size, const TIndex output_size, const TIndex index_size, const TIndex data_size, const InType *input, const IndexType *indices, const int *lengths, const float *weights, bool normalize_by_lengths, OutType *out)
Embedding lookup with reduction.