Caffe2 - C++ API
A deep learning, cross platform ML framework
embedding_lookup.cc
1 #include "caffe2/perfkernels/embedding_lookup.h"
2 
3 #include "caffe2/core/types.h"
4 #include "caffe2/perfkernels/common.h"
5 #include "caffe2/perfkernels/typed_axpy.h"
6 #include "caffe2/utils/cpuid.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 // Base implementation does runtime dispatch for each segment of reduction
12 template <
13  typename IndexType,
14  typename InType,
15  typename OutType,
16  bool IS_WEIGHT_POSITIONAL = false>
17 static void EmbeddingLookupGenericSlow(
18  const TIndex block_size,
19  const TIndex output_size,
20  const TIndex index_size,
21  const TIndex data_size,
22  const InType* input,
23  const IndexType* indices,
24  const int* lengths,
25  const float* weights, // optional, can be null for sum reducer
26  const float* scale_bias, // optional scale & bias params for uint8 input
27  bool normalize_by_lengths,
28  OutType* out) {
29  TIndex current = 0;
30  for (int m = 0; m < output_size; ++m) {
31  memset(out, 0, sizeof(OutType) * block_size);
32  EigenVectorArrayMap<OutType> out_vector(out, block_size);
33  for (int i = 0; i < lengths[m]; ++i) {
34  CAFFE_ENFORCE_LT(current, index_size);
35  TIndex idx = indices[current];
36  CAFFE_ENFORCE(
37  0 <= idx && idx < data_size,
38  "Index ",
39  current,
40  " is out of bounds: ",
41  idx,
42  ", range 0 to ",
43  data_size);
44  CAFFE_ENFORCE_LT(idx, data_size);
45 #ifdef __GNUC__
46  if (current + 1 < index_size) {
47  __builtin_prefetch(input + block_size * indices[current + 1], 0, 1);
48  }
49 #endif // __GNUC__
50 
51  float w = 1.f, b = 0.f;
52  if (weights) {
53  w = weights[IS_WEIGHT_POSITIONAL ? i : current];
54  }
55  if (scale_bias) {
56  b = w * scale_bias[2 * indices[current] + 1];
57  w = w * scale_bias[2 * indices[current]];
58  }
59 
60  TypedAxpy<InType, OutType>(
61  block_size, w, input + block_size * indices[current], out);
62 
63  if (scale_bias) {
64  out_vector = out_vector + b;
65  }
66 
67  ++current;
68  }
69  if (normalize_by_lengths && lengths[m]) {
70  // hack: context is not really used
71  math::Scale<OutType, CPUContext>(
72  block_size, 1.f / lengths[m], out, out, nullptr);
73  }
74  out += block_size;
75  }
76  CAFFE_ENFORCE_EQ(
77  current,
78  index_size,
79  "Your input seems to be incorrect: the sum of lengths values should be "
80  "the size of the indices tensor, but it appears not.");
81 }
82 
83 // Proxy back to generic implementation
84 #define EMBEDDING_SPECIALIZATION( \
85  IndexType, InType, OutType, IS_WEIGHT_POSITIONAL) \
86  void \
87  EmbeddingLookup_##IndexType##_##InType##_##OutType##_##IS_WEIGHT_POSITIONAL##__base( \
88  const TIndex block_size, \
89  const TIndex output_size, \
90  const TIndex index_size, \
91  const TIndex data_size, \
92  const InType* input, \
93  const IndexType* indices, \
94  const int* lengths, \
95  const float* weights, \
96  const float* scale_bias, \
97  bool normalize_by_lengths, \
98  OutType* out) { \
99  EmbeddingLookupGenericSlow< \
100  IndexType, \
101  InType, \
102  OutType, \
103  IS_WEIGHT_POSITIONAL>( \
104  block_size, \
105  output_size, \
106  index_size, \
107  data_size, \
108  input, \
109  indices, \
110  lengths, \
111  weights, \
112  scale_bias, \
113  normalize_by_lengths, \
114  out); \
115  } \
116  template <> \
117  void EmbeddingLookup<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>( \
118  const TIndex block_size, \
119  const TIndex output_size, \
120  const TIndex index_size, \
121  const TIndex data_size, \
122  const InType* input, \
123  const IndexType* indices, \
124  const int* lengths, \
125  const float* weights, \
126  const float* scale_bias, \
127  bool normalize_by_lengths, \
128  OutType* out) { \
129  AVX2_FMA_DO( \
130  EmbeddingLookup_##IndexType##_##InType##_##OutType##_##IS_WEIGHT_POSITIONAL, \
131  block_size, \
132  output_size, \
133  index_size, \
134  data_size, \
135  input, \
136  indices, \
137  lengths, \
138  weights, \
139  scale_bias, \
140  normalize_by_lengths, \
141  out); \
142  BASE_DO( \
143  EmbeddingLookup_##IndexType##_##InType##_##OutType##_##IS_WEIGHT_POSITIONAL, \
144  block_size, \
145  output_size, \
146  index_size, \
147  data_size, \
148  input, \
149  indices, \
150  lengths, \
151  weights, \
152  scale_bias, \
153  normalize_by_lengths, \
154  out); \
155  }
156 
157 EMBEDDING_SPECIALIZATION(int32_t, float, float, false);
158 EMBEDDING_SPECIALIZATION(int64_t, float, float, false);
159 EMBEDDING_SPECIALIZATION(int32_t, float16, float, false);
160 EMBEDDING_SPECIALIZATION(int64_t, float16, float, false);
161 EMBEDDING_SPECIALIZATION(int32_t, uint8_t, float, false);
162 EMBEDDING_SPECIALIZATION(int64_t, uint8_t, float, false);
163 
164 EMBEDDING_SPECIALIZATION(int32_t, float, float, true);
165 EMBEDDING_SPECIALIZATION(int64_t, float, float, true);
166 EMBEDDING_SPECIALIZATION(int32_t, float16, float, true);
167 EMBEDDING_SPECIALIZATION(int64_t, float16, float, true);
168 EMBEDDING_SPECIALIZATION(int32_t, uint8_t, float, true);
169 EMBEDDING_SPECIALIZATION(int64_t, uint8_t, float, true);
170 
171 #undef EMBEDDING_SPECIALIZATION
172 
173 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...