Caffe2 - C++ API
A deep learning, cross platform ML framework
fused_8bit_rowwise_embedding_lookup.cc
1 #include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h"
2 
3 #include "caffe2/core/types.h"
4 #include "caffe2/perfkernels/common.h"
5 #include "caffe2/perfkernels/typed_axpy.h"
6 #include "caffe2/utils/cpuid.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 // Base implementation does runtime dispatch for each segment of reduction
12 template <
13  typename IndexType,
14  typename InType,
15  typename OutType,
16  bool IS_WEIGHT_POSITIONAL = false>
17 static void Fused8BitRowwiseEmbeddingLookupGenericSlow(
18  const TIndex block_size,
19  const TIndex output_size,
20  const TIndex index_size,
21  const TIndex data_size,
22  const InType* input,
23  const IndexType* indices,
24  const int* lengths,
25  const float* weights, // optional, can be null for sum reducer
26  bool normalize_by_lengths,
27  OutType* out) {
28  // block_size is the number of elements and fused_block_size is the size of
29  // an entire row, including scale and bias.
30  const auto scale_bias_offset = 8 / sizeof(InType);
31  const TIndex fused_block_size = block_size + scale_bias_offset;
32  TIndex current = 0;
33  for (int m = 0; m < output_size; ++m) {
34  memset(out, 0, sizeof(OutType) * block_size);
35  EigenVectorArrayMap<OutType> out_vector(out, block_size);
36  for (int i = 0; i < lengths[m]; ++i) {
37  CAFFE_ENFORCE_LT(current, index_size);
38  TIndex idx = indices[current];
39  CAFFE_ENFORCE(
40  0 <= idx && idx < data_size,
41  "Index ",
42  current,
43  " is out of bounds: ",
44  idx,
45  ", range 0 to ",
46  data_size);
47  CAFFE_ENFORCE_LT(idx, data_size);
48 #ifdef __GNUC__
49  if (current + 1 < index_size) {
50  __builtin_prefetch(
51  input + fused_block_size * indices[current + 1], 0, 1);
52  }
53 #endif // __GNUC__
54 
55  const float* scale_bias = reinterpret_cast<const float*>(
56  input + fused_block_size * indices[current] + block_size);
57 
58  float weight = 1.0f;
59  if (weights) {
60  weight = weights[IS_WEIGHT_POSITIONAL ? i : current];
61  }
62  const float scale = weight * scale_bias[0];
63  const float bias = weight * scale_bias[1];
64 
65  TypedAxpy<InType, OutType>(
66  block_size, scale, input + fused_block_size * indices[current], out);
67 
68  out_vector += bias;
69 
70  ++current;
71  }
72  if (normalize_by_lengths && lengths[m]) {
73  // hack: context is not really used
74  math::Scale<OutType, CPUContext>(
75  block_size, 1.f / lengths[m], out, out, nullptr);
76  }
77  out += block_size;
78  }
79  CAFFE_ENFORCE_EQ(
80  current,
81  index_size,
82  "Your input seems to be incorrect: the sum of lengths values should be "
83  "the size of the indices tensor, but it appears not.");
84 }
85 
86 // Proxy back to generic implementation
87 #define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION( \
88  IndexType, InType, OutType) \
89  void \
90  Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__base( \
91  const TIndex block_size, \
92  const TIndex output_size, \
93  const TIndex index_size, \
94  const TIndex data_size, \
95  const InType* input, \
96  const IndexType* indices, \
97  const int* lengths, \
98  const float* weights, \
99  bool normalize_by_lengths, \
100  OutType* out) { \
101  Fused8BitRowwiseEmbeddingLookupGenericSlow< \
102  IndexType, \
103  InType, \
104  OutType, \
105  false>( \
106  block_size, \
107  output_size, \
108  index_size, \
109  data_size, \
110  input, \
111  indices, \
112  lengths, \
113  weights, \
114  normalize_by_lengths, \
115  out); \
116  } \
117  template <> \
118  void Fused8BitRowwiseEmbeddingLookup<IndexType, InType, OutType, false>( \
119  const TIndex block_size, \
120  const TIndex output_size, \
121  const TIndex index_size, \
122  const TIndex data_size, \
123  const InType* input, \
124  const IndexType* indices, \
125  const int* lengths, \
126  const float* weights, \
127  bool normalize_by_lengths, \
128  OutType* out) { \
129  const int32_t one = 1; \
130  CAFFE_ENFORCE_EQ( \
131  reinterpret_cast<const uint8_t*>(&one)[0], \
132  1, \
133  "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \
134  AVX2_FMA_DO( \
135  Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false, \
136  block_size, \
137  output_size, \
138  index_size, \
139  data_size, \
140  input, \
141  indices, \
142  lengths, \
143  weights, \
144  normalize_by_lengths, \
145  out); \
146  BASE_DO( \
147  Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false, \
148  block_size, \
149  output_size, \
150  index_size, \
151  data_size, \
152  input, \
153  indices, \
154  lengths, \
155  weights, \
156  normalize_by_lengths, \
157  out); \
158  }
159 
160 FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, uint8_t, float);
161 FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t, uint8_t, float);
162 
163 #undef FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION
164 
165 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...