Caffe2 - C++ API
A deep learning, cross platform ML framework
embedding_lookup_fused_8bit_rowwise_avx2.cc
1 
8 #include <caffe2/core/common.h>
9 #include <caffe2/core/types.h>
10 #include <immintrin.h>
11 
12 namespace caffe2 {
13 
14 template <bool IS_WEIGHT_POSITIONAL>
15 static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
16  const TIndex block_size,
17  const TIndex output_size,
18  const TIndex index_size,
19  const TIndex data_size,
20  const float* input,
21  const int32_t* indices,
22  const int* lengths,
23  const float* weights,
24  bool normalize_by_lengths,
25  float* out) {
26  const int32_t prefdist_T0 = 16;
27  const int32_t fused_block_size = block_size + 2;
28  if (block_size == 128) {
29  // unrolling 16 times
30  int32_t dataInd = 0;
31  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
32  float* op = &out[rangeIndex * block_size];
33  __m256 vop0 = _mm256_setzero_ps();
34  __m256 vop8 = _mm256_setzero_ps();
35  __m256 vop16 = _mm256_setzero_ps();
36  __m256 vop24 = _mm256_setzero_ps();
37  __m256 vop32 = _mm256_setzero_ps();
38  __m256 vop40 = _mm256_setzero_ps();
39  __m256 vop48 = _mm256_setzero_ps();
40  __m256 vop56 = _mm256_setzero_ps();
41  __m256 vop64 = _mm256_setzero_ps();
42  __m256 vop72 = _mm256_setzero_ps();
43  __m256 vop80 = _mm256_setzero_ps();
44  __m256 vop88 = _mm256_setzero_ps();
45  __m256 vop96 = _mm256_setzero_ps();
46  __m256 vop104 = _mm256_setzero_ps();
47  __m256 vop112 = _mm256_setzero_ps();
48  __m256 vop120 = _mm256_setzero_ps();
49  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
50  ++dataInd) {
51  const int32_t idx = indices[dataInd];
52  CAFFE_ENFORCE(
53  idx >= 0 && idx < data_size,
54  "Index ",
55  dataInd,
56  " is out of bounds: ",
57  idx,
58  ", range 0 to ",
59  data_size);
60  float wgt = 1.f;
61  if (weights) {
62  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
63  }
64  __m256 vwgt = _mm256_set1_ps(wgt);
65  const float* ip = &input[idx * fused_block_size];
66  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
67  ? (dataInd + prefdist_T0)
68  : dataInd;
69  const int32_t idx_pref_T0 = indices[next_T0];
70  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
71  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
72  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
73  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
74  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
75  // skip unnecessary prefetch of (&ip_next_T0[8])
76  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
77  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
78  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
79  // skip unnecessary prefetch of (&ip_next_T0[24])
80  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
81  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
82  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
83  // skip unnecessary prefetch of (&ip_next_T0[40])
84  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
85  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
86  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
87  // skip unnecessary prefetch of (&ip_next_T0[56])
88  vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
89  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
90  vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
91  // skip unnecessary prefetch of (&ip_next_T0[72])
92  vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
93  _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
94  vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
95  // skip unnecessary prefetch of (&ip_next_T0[88])
96  vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
97  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
98  vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
99  // skip unnecessary prefetch of (&ip_next_T0[104])
100  vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
101  _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
102  vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
103  // skip unnecessary prefetch of (&ip_next_T0[120])
104  }
105  if (normalize_by_lengths == false) {
106  _mm256_storeu_ps(&op[0], vop0);
107  _mm256_storeu_ps(&op[8], vop8);
108  _mm256_storeu_ps(&op[16], vop16);
109  _mm256_storeu_ps(&op[24], vop24);
110  _mm256_storeu_ps(&op[32], vop32);
111  _mm256_storeu_ps(&op[40], vop40);
112  _mm256_storeu_ps(&op[48], vop48);
113  _mm256_storeu_ps(&op[56], vop56);
114  _mm256_storeu_ps(&op[64], vop64);
115  _mm256_storeu_ps(&op[72], vop72);
116  _mm256_storeu_ps(&op[80], vop80);
117  _mm256_storeu_ps(&op[88], vop88);
118  _mm256_storeu_ps(&op[96], vop96);
119  _mm256_storeu_ps(&op[104], vop104);
120  _mm256_storeu_ps(&op[112], vop112);
121  _mm256_storeu_ps(&op[120], vop120);
122  } else if (lengths[rangeIndex]) {
123  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
124  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
125  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
126  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
127  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
128  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
129  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
130  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
131  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
132  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
133  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
134  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
135  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
136  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
137  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
138  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
139  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
140  }
141  }
142  } else if (block_size == 64) {
143  // unrolling 8 times
144  int32_t dataInd = 0;
145  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
146  float* op = &out[rangeIndex * block_size];
147  __m256 vop0 = _mm256_setzero_ps();
148  __m256 vop8 = _mm256_setzero_ps();
149  __m256 vop16 = _mm256_setzero_ps();
150  __m256 vop24 = _mm256_setzero_ps();
151  __m256 vop32 = _mm256_setzero_ps();
152  __m256 vop40 = _mm256_setzero_ps();
153  __m256 vop48 = _mm256_setzero_ps();
154  __m256 vop56 = _mm256_setzero_ps();
155  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
156  ++dataInd) {
157  const int32_t idx = indices[dataInd];
158  CAFFE_ENFORCE(
159  idx >= 0 && idx < data_size,
160  "Index ",
161  dataInd,
162  " is out of bounds: ",
163  idx,
164  ", range 0 to ",
165  data_size);
166  float wgt = 1.f;
167  if (weights) {
168  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
169  }
170  __m256 vwgt = _mm256_set1_ps(wgt);
171  const float* ip = &input[idx * fused_block_size];
172  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
173  ? (dataInd + prefdist_T0)
174  : dataInd;
175  const int32_t idx_pref_T0 = indices[next_T0];
176  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
177  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
178  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
179  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
180  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
181  // skip unnecessary prefetch of (&ip_next_T0[8])
182  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
183  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
184  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
185  // skip unnecessary prefetch of (&ip_next_T0[24])
186  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
187  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
188  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
189  // skip unnecessary prefetch of (&ip_next_T0[40])
190  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
191  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
192  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
193  // skip unnecessary prefetch of (&ip_next_T0[56])
194  }
195  if (normalize_by_lengths == false) {
196  _mm256_storeu_ps(&op[0], vop0);
197  _mm256_storeu_ps(&op[8], vop8);
198  _mm256_storeu_ps(&op[16], vop16);
199  _mm256_storeu_ps(&op[24], vop24);
200  _mm256_storeu_ps(&op[32], vop32);
201  _mm256_storeu_ps(&op[40], vop40);
202  _mm256_storeu_ps(&op[48], vop48);
203  _mm256_storeu_ps(&op[56], vop56);
204  } else if (lengths[rangeIndex]) {
205  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
206  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
207  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
208  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
209  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
210  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
211  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
212  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
213  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
214  }
215  }
216  } else if (block_size == 32) {
217  // unrolling 4 times
218  int32_t dataInd = 0;
219  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
220  float* op = &out[rangeIndex * block_size];
221  __m256 vop0 = _mm256_setzero_ps();
222  __m256 vop8 = _mm256_setzero_ps();
223  __m256 vop16 = _mm256_setzero_ps();
224  __m256 vop24 = _mm256_setzero_ps();
225  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
226  ++dataInd) {
227  const int32_t idx = indices[dataInd];
228  CAFFE_ENFORCE(
229  idx >= 0 && idx < data_size,
230  "Index ",
231  dataInd,
232  " is out of bounds: ",
233  idx,
234  ", range 0 to ",
235  data_size);
236  float wgt = 1.f;
237  if (weights) {
238  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
239  }
240  __m256 vwgt = _mm256_set1_ps(wgt);
241  const float* ip = &input[idx * fused_block_size];
242  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
243  ? (dataInd + prefdist_T0)
244  : dataInd;
245  const int32_t idx_pref_T0 = indices[next_T0];
246  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
247  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
248  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
249  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
250  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
251  // skip unnecessary prefetch of (&ip_next_T0[8])
252  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
253  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
254  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
255  // skip unnecessary prefetch of (&ip_next_T0[24])
256  }
257  if (normalize_by_lengths == false) {
258  _mm256_storeu_ps(&op[0], vop0);
259  _mm256_storeu_ps(&op[8], vop8);
260  _mm256_storeu_ps(&op[16], vop16);
261  _mm256_storeu_ps(&op[24], vop24);
262  } else if (lengths[rangeIndex]) {
263  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
264  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
265  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
266  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
267  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
268  }
269  }
270  } else if (block_size == 16) {
271  // unrolling 2 times
272  int32_t dataInd = 0;
273  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
274  float* op = &out[rangeIndex * block_size];
275  __m256 vop0 = _mm256_setzero_ps();
276  __m256 vop8 = _mm256_setzero_ps();
277  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
278  ++dataInd) {
279  const int32_t idx = indices[dataInd];
280  CAFFE_ENFORCE(
281  idx >= 0 && idx < data_size,
282  "Index ",
283  dataInd,
284  " is out of bounds: ",
285  idx,
286  ", range 0 to ",
287  data_size);
288  float wgt = 1.f;
289  if (weights) {
290  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
291  }
292  __m256 vwgt = _mm256_set1_ps(wgt);
293  const float* ip = &input[idx * fused_block_size];
294  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
295  ? (dataInd + prefdist_T0)
296  : dataInd;
297  const int32_t idx_pref_T0 = indices[next_T0];
298  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
299  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
300  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
301  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
302  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
303  // skip unnecessary prefetch of (&ip_next_T0[8])
304  }
305  if (normalize_by_lengths == false) {
306  _mm256_storeu_ps(&op[0], vop0);
307  _mm256_storeu_ps(&op[8], vop8);
308  } else if (lengths[rangeIndex]) {
309  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
310  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
311  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
312  }
313  }
314  } else {
315  // generic code
316  int32_t dataInd = 0;
317  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
318  float* op = &out[rangeIndex * block_size];
319  TIndex j = 0;
320  for (; j + 8 <= block_size; j += 8) {
321  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
322  }
323  for (; j < block_size; j++) {
324  op[j] = 0.0f;
325  }
326  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
327  ++dataInd) {
328  const int32_t idx = indices[dataInd];
329  CAFFE_ENFORCE(
330  idx >= 0 && idx < data_size,
331  "Index ",
332  dataInd,
333  " is out of bounds: ",
334  idx,
335  ", range 0 to ",
336  data_size);
337  float wgt = 1.f;
338  if (weights) {
339  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
340  }
341  __m256 vwgt = _mm256_set1_ps(wgt);
342  const float* ip = &input[idx * fused_block_size];
343  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
344  ? (dataInd + prefdist_T0)
345  : dataInd;
346  const int32_t idx_pref_T0 = indices[next_T0];
347  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
348  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
349  j = 0;
350  for (; j + 8 <= block_size; j += 8) {
351  _mm256_storeu_ps(
352  &op[j],
353  _mm256_fmadd_ps(
354  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
355  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
356  }
357  for (; j < block_size; j++) {
358  op[j] += wgt * ip[j];
359  }
360  }
361  if (normalize_by_lengths && lengths[rangeIndex]) {
362  float len_inv = 1.0f / lengths[rangeIndex];
363  __m256 vlen_inv = _mm256_set1_ps(len_inv);
364  j = 0;
365  for (; j + 8 <= block_size; j += 8) {
366  _mm256_storeu_ps(
367  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
368  }
369  for (; j < block_size; j++) {
370  op[j] = len_inv * op[j];
371  }
372  }
373  }
374  }
375 }
376 void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_false__avx2_fma(
377  const TIndex block_size,
378  const TIndex output_size,
379  const TIndex index_size,
380  const TIndex data_size,
381  const float* input,
382  const int32_t* indices,
383  const int* lengths,
384  const float* weights,
385  bool normalize_by_lengths,
386  float* out) {
387  Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma<false>(
388  block_size,
389  output_size,
390  index_size,
391  data_size,
392  input,
393  indices,
394  lengths,
395  weights,
396  normalize_by_lengths,
397  out);
398 }
399 void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_true__avx2_fma(
400  const TIndex block_size,
401  const TIndex output_size,
402  const TIndex index_size,
403  const TIndex data_size,
404  const float* input,
405  const int32_t* indices,
406  const int* lengths,
407  const float* weights,
408  bool normalize_by_lengths,
409  float* out) {
410  Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma<true>(
411  block_size,
412  output_size,
413  index_size,
414  data_size,
415  input,
416  indices,
417  lengths,
418  weights,
419  normalize_by_lengths,
420  out);
421 }
422 
423 template <bool IS_WEIGHT_POSITIONAL>
424 static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
425  const TIndex block_size,
426  const TIndex output_size,
427  const TIndex index_size,
428  const TIndex data_size,
429  const float* input,
430  const int64_t* indices,
431  const int* lengths,
432  const float* weights,
433  bool normalize_by_lengths,
434  float* out) {
435  const int64_t prefdist_T0 = 16;
436  const int64_t fused_block_size = block_size + 2;
437  if (block_size == 128) {
438  // unrolling 16 times
439  int64_t dataInd = 0;
440  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
441  float* op = &out[rangeIndex * block_size];
442  __m256 vop0 = _mm256_setzero_ps();
443  __m256 vop8 = _mm256_setzero_ps();
444  __m256 vop16 = _mm256_setzero_ps();
445  __m256 vop24 = _mm256_setzero_ps();
446  __m256 vop32 = _mm256_setzero_ps();
447  __m256 vop40 = _mm256_setzero_ps();
448  __m256 vop48 = _mm256_setzero_ps();
449  __m256 vop56 = _mm256_setzero_ps();
450  __m256 vop64 = _mm256_setzero_ps();
451  __m256 vop72 = _mm256_setzero_ps();
452  __m256 vop80 = _mm256_setzero_ps();
453  __m256 vop88 = _mm256_setzero_ps();
454  __m256 vop96 = _mm256_setzero_ps();
455  __m256 vop104 = _mm256_setzero_ps();
456  __m256 vop112 = _mm256_setzero_ps();
457  __m256 vop120 = _mm256_setzero_ps();
458  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
459  ++dataInd) {
460  const int64_t idx = indices[dataInd];
461  CAFFE_ENFORCE(
462  idx >= 0 && idx < data_size,
463  "Index ",
464  dataInd,
465  " is out of bounds: ",
466  idx,
467  ", range 0 to ",
468  data_size);
469  float wgt = 1.f;
470  if (weights) {
471  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
472  }
473  __m256 vwgt = _mm256_set1_ps(wgt);
474  const float* ip = &input[idx * fused_block_size];
475  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
476  ? (dataInd + prefdist_T0)
477  : dataInd;
478  const int64_t idx_pref_T0 = indices[next_T0];
479  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
480  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
481  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
482  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
483  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
484  // skip unnecessary prefetch of (&ip_next_T0[8])
485  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
486  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
487  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
488  // skip unnecessary prefetch of (&ip_next_T0[24])
489  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
490  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
491  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
492  // skip unnecessary prefetch of (&ip_next_T0[40])
493  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
494  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
495  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
496  // skip unnecessary prefetch of (&ip_next_T0[56])
497  vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
498  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
499  vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
500  // skip unnecessary prefetch of (&ip_next_T0[72])
501  vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
502  _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
503  vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
504  // skip unnecessary prefetch of (&ip_next_T0[88])
505  vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
506  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
507  vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
508  // skip unnecessary prefetch of (&ip_next_T0[104])
509  vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
510  _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
511  vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
512  // skip unnecessary prefetch of (&ip_next_T0[120])
513  }
514  if (normalize_by_lengths == false) {
515  _mm256_storeu_ps(&op[0], vop0);
516  _mm256_storeu_ps(&op[8], vop8);
517  _mm256_storeu_ps(&op[16], vop16);
518  _mm256_storeu_ps(&op[24], vop24);
519  _mm256_storeu_ps(&op[32], vop32);
520  _mm256_storeu_ps(&op[40], vop40);
521  _mm256_storeu_ps(&op[48], vop48);
522  _mm256_storeu_ps(&op[56], vop56);
523  _mm256_storeu_ps(&op[64], vop64);
524  _mm256_storeu_ps(&op[72], vop72);
525  _mm256_storeu_ps(&op[80], vop80);
526  _mm256_storeu_ps(&op[88], vop88);
527  _mm256_storeu_ps(&op[96], vop96);
528  _mm256_storeu_ps(&op[104], vop104);
529  _mm256_storeu_ps(&op[112], vop112);
530  _mm256_storeu_ps(&op[120], vop120);
531  } else if (lengths[rangeIndex]) {
532  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
533  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
534  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
535  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
536  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
537  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
538  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
539  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
540  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
541  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
542  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
543  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
544  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
545  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
546  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
547  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
548  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
549  }
550  }
551  } else if (block_size == 64) {
552  // unrolling 8 times
553  int64_t dataInd = 0;
554  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
555  float* op = &out[rangeIndex * block_size];
556  __m256 vop0 = _mm256_setzero_ps();
557  __m256 vop8 = _mm256_setzero_ps();
558  __m256 vop16 = _mm256_setzero_ps();
559  __m256 vop24 = _mm256_setzero_ps();
560  __m256 vop32 = _mm256_setzero_ps();
561  __m256 vop40 = _mm256_setzero_ps();
562  __m256 vop48 = _mm256_setzero_ps();
563  __m256 vop56 = _mm256_setzero_ps();
564  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
565  ++dataInd) {
566  const int64_t idx = indices[dataInd];
567  CAFFE_ENFORCE(
568  idx >= 0 && idx < data_size,
569  "Index ",
570  dataInd,
571  " is out of bounds: ",
572  idx,
573  ", range 0 to ",
574  data_size);
575  float wgt = 1.f;
576  if (weights) {
577  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
578  }
579  __m256 vwgt = _mm256_set1_ps(wgt);
580  const float* ip = &input[idx * fused_block_size];
581  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
582  ? (dataInd + prefdist_T0)
583  : dataInd;
584  const int64_t idx_pref_T0 = indices[next_T0];
585  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
586  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
587  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
588  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
589  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
590  // skip unnecessary prefetch of (&ip_next_T0[8])
591  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
592  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
593  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
594  // skip unnecessary prefetch of (&ip_next_T0[24])
595  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
596  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
597  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
598  // skip unnecessary prefetch of (&ip_next_T0[40])
599  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
600  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
601  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
602  // skip unnecessary prefetch of (&ip_next_T0[56])
603  }
604  if (normalize_by_lengths == false) {
605  _mm256_storeu_ps(&op[0], vop0);
606  _mm256_storeu_ps(&op[8], vop8);
607  _mm256_storeu_ps(&op[16], vop16);
608  _mm256_storeu_ps(&op[24], vop24);
609  _mm256_storeu_ps(&op[32], vop32);
610  _mm256_storeu_ps(&op[40], vop40);
611  _mm256_storeu_ps(&op[48], vop48);
612  _mm256_storeu_ps(&op[56], vop56);
613  } else if (lengths[rangeIndex]) {
614  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
615  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
616  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
617  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
618  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
619  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
620  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
621  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
622  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
623  }
624  }
625  } else if (block_size == 32) {
626  // unrolling 4 times
627  int64_t dataInd = 0;
628  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
629  float* op = &out[rangeIndex * block_size];
630  __m256 vop0 = _mm256_setzero_ps();
631  __m256 vop8 = _mm256_setzero_ps();
632  __m256 vop16 = _mm256_setzero_ps();
633  __m256 vop24 = _mm256_setzero_ps();
634  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
635  ++dataInd) {
636  const int64_t idx = indices[dataInd];
637  CAFFE_ENFORCE(
638  idx >= 0 && idx < data_size,
639  "Index ",
640  dataInd,
641  " is out of bounds: ",
642  idx,
643  ", range 0 to ",
644  data_size);
645  float wgt = 1.f;
646  if (weights) {
647  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
648  }
649  __m256 vwgt = _mm256_set1_ps(wgt);
650  const float* ip = &input[idx * fused_block_size];
651  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
652  ? (dataInd + prefdist_T0)
653  : dataInd;
654  const int64_t idx_pref_T0 = indices[next_T0];
655  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
656  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
657  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
658  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
659  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
660  // skip unnecessary prefetch of (&ip_next_T0[8])
661  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
662  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
663  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
664  // skip unnecessary prefetch of (&ip_next_T0[24])
665  }
666  if (normalize_by_lengths == false) {
667  _mm256_storeu_ps(&op[0], vop0);
668  _mm256_storeu_ps(&op[8], vop8);
669  _mm256_storeu_ps(&op[16], vop16);
670  _mm256_storeu_ps(&op[24], vop24);
671  } else if (lengths[rangeIndex]) {
672  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
673  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
674  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
675  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
676  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
677  }
678  }
679  } else if (block_size == 16) {
680  // unrolling 2 times
681  int64_t dataInd = 0;
682  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
683  float* op = &out[rangeIndex * block_size];
684  __m256 vop0 = _mm256_setzero_ps();
685  __m256 vop8 = _mm256_setzero_ps();
686  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
687  ++dataInd) {
688  const int64_t idx = indices[dataInd];
689  CAFFE_ENFORCE(
690  idx >= 0 && idx < data_size,
691  "Index ",
692  dataInd,
693  " is out of bounds: ",
694  idx,
695  ", range 0 to ",
696  data_size);
697  float wgt = 1.f;
698  if (weights) {
699  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
700  }
701  __m256 vwgt = _mm256_set1_ps(wgt);
702  const float* ip = &input[idx * fused_block_size];
703  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
704  ? (dataInd + prefdist_T0)
705  : dataInd;
706  const int64_t idx_pref_T0 = indices[next_T0];
707  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
708  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
709  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
710  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
711  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
712  // skip unnecessary prefetch of (&ip_next_T0[8])
713  }
714  if (normalize_by_lengths == false) {
715  _mm256_storeu_ps(&op[0], vop0);
716  _mm256_storeu_ps(&op[8], vop8);
717  } else if (lengths[rangeIndex]) {
718  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
719  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
720  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
721  }
722  }
723  } else {
724  // generic code
725  int64_t dataInd = 0;
726  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
727  float* op = &out[rangeIndex * block_size];
728  TIndex j = 0;
729  for (; j + 8 <= block_size; j += 8) {
730  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
731  }
732  for (; j < block_size; j++) {
733  op[j] = 0.0f;
734  }
735  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
736  ++dataInd) {
737  const int64_t idx = indices[dataInd];
738  CAFFE_ENFORCE(
739  idx >= 0 && idx < data_size,
740  "Index ",
741  dataInd,
742  " is out of bounds: ",
743  idx,
744  ", range 0 to ",
745  data_size);
746  float wgt = 1.f;
747  if (weights) {
748  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
749  }
750  __m256 vwgt = _mm256_set1_ps(wgt);
751  const float* ip = &input[idx * fused_block_size];
752  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
753  ? (dataInd + prefdist_T0)
754  : dataInd;
755  const int64_t idx_pref_T0 = indices[next_T0];
756  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
757  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
758  j = 0;
759  for (; j + 8 <= block_size; j += 8) {
760  _mm256_storeu_ps(
761  &op[j],
762  _mm256_fmadd_ps(
763  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
764  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
765  }
766  for (; j < block_size; j++) {
767  op[j] += wgt * ip[j];
768  }
769  }
770  if (normalize_by_lengths && lengths[rangeIndex]) {
771  float len_inv = 1.0f / lengths[rangeIndex];
772  __m256 vlen_inv = _mm256_set1_ps(len_inv);
773  j = 0;
774  for (; j + 8 <= block_size; j += 8) {
775  _mm256_storeu_ps(
776  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
777  }
778  for (; j < block_size; j++) {
779  op[j] = len_inv * op[j];
780  }
781  }
782  }
783  }
784 }
785 void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_false__avx2_fma(
786  const TIndex block_size,
787  const TIndex output_size,
788  const TIndex index_size,
789  const TIndex data_size,
790  const float* input,
791  const int64_t* indices,
792  const int* lengths,
793  const float* weights,
794  bool normalize_by_lengths,
795  float* out) {
796  Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma<false>(
797  block_size,
798  output_size,
799  index_size,
800  data_size,
801  input,
802  indices,
803  lengths,
804  weights,
805  normalize_by_lengths,
806  out);
807 }
808 void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_true__avx2_fma(
809  const TIndex block_size,
810  const TIndex output_size,
811  const TIndex index_size,
812  const TIndex data_size,
813  const float* input,
814  const int64_t* indices,
815  const int* lengths,
816  const float* weights,
817  bool normalize_by_lengths,
818  float* out) {
819  Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma<true>(
820  block_size,
821  output_size,
822  index_size,
823  data_size,
824  input,
825  indices,
826  lengths,
827  weights,
828  normalize_by_lengths,
829  out);
830 }
831 
832 template <bool IS_WEIGHT_POSITIONAL>
833 static void Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float__avx2_fma(
834  const TIndex block_size,
835  const TIndex output_size,
836  const TIndex index_size,
837  const TIndex data_size,
838  const float16* input,
839  const int32_t* indices,
840  const int* lengths,
841  const float* weights,
842  bool normalize_by_lengths,
843  float* out) {
844  const int32_t prefdist_T0 = 16;
845  const int32_t fused_block_size = block_size + 4;
846  if (block_size == 128) {
847  // unrolling 16 times
848  int32_t dataInd = 0;
849  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
850  float* op = &out[rangeIndex * block_size];
851  __m256 vop0 = _mm256_setzero_ps();
852  __m256 vop8 = _mm256_setzero_ps();
853  __m256 vop16 = _mm256_setzero_ps();
854  __m256 vop24 = _mm256_setzero_ps();
855  __m256 vop32 = _mm256_setzero_ps();
856  __m256 vop40 = _mm256_setzero_ps();
857  __m256 vop48 = _mm256_setzero_ps();
858  __m256 vop56 = _mm256_setzero_ps();
859  __m256 vop64 = _mm256_setzero_ps();
860  __m256 vop72 = _mm256_setzero_ps();
861  __m256 vop80 = _mm256_setzero_ps();
862  __m256 vop88 = _mm256_setzero_ps();
863  __m256 vop96 = _mm256_setzero_ps();
864  __m256 vop104 = _mm256_setzero_ps();
865  __m256 vop112 = _mm256_setzero_ps();
866  __m256 vop120 = _mm256_setzero_ps();
867  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
868  ++dataInd) {
869  const int32_t idx = indices[dataInd];
870  CAFFE_ENFORCE(
871  idx >= 0 && idx < data_size,
872  "Index ",
873  dataInd,
874  " is out of bounds: ",
875  idx,
876  ", range 0 to ",
877  data_size);
878  float wgt = 1.f;
879  if (weights) {
880  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
881  }
882  __m256 vwgt = _mm256_set1_ps(wgt);
883  const float16* ip = &input[idx * fused_block_size];
884  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
885  ? (dataInd + prefdist_T0)
886  : dataInd;
887  const int32_t idx_pref_T0 = indices[next_T0];
888  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
889  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
890  vop0 = _mm256_fmadd_ps(
891  vwgt,
892  _mm256_cvtph_ps(
893  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
894  vop0);
895  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
896  vop8 = _mm256_fmadd_ps(
897  vwgt,
898  _mm256_cvtph_ps(
899  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
900  vop8);
901  // skip unnecessary prefetch of (&ip_next_T0[8])
902  vop16 = _mm256_fmadd_ps(
903  vwgt,
904  _mm256_cvtph_ps(
905  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
906  vop16);
907  // skip unnecessary prefetch of (&ip_next_T0[16])
908  vop24 = _mm256_fmadd_ps(
909  vwgt,
910  _mm256_cvtph_ps(
911  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
912  vop24);
913  // skip unnecessary prefetch of (&ip_next_T0[24])
914  vop32 = _mm256_fmadd_ps(
915  vwgt,
916  _mm256_cvtph_ps(
917  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
918  vop32);
919  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
920  vop40 = _mm256_fmadd_ps(
921  vwgt,
922  _mm256_cvtph_ps(
923  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
924  vop40);
925  // skip unnecessary prefetch of (&ip_next_T0[40])
926  vop48 = _mm256_fmadd_ps(
927  vwgt,
928  _mm256_cvtph_ps(
929  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
930  vop48);
931  // skip unnecessary prefetch of (&ip_next_T0[48])
932  vop56 = _mm256_fmadd_ps(
933  vwgt,
934  _mm256_cvtph_ps(
935  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
936  vop56);
937  // skip unnecessary prefetch of (&ip_next_T0[56])
938  vop64 = _mm256_fmadd_ps(
939  vwgt,
940  _mm256_cvtph_ps(
941  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
942  vop64);
943  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
944  vop72 = _mm256_fmadd_ps(
945  vwgt,
946  _mm256_cvtph_ps(
947  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
948  vop72);
949  // skip unnecessary prefetch of (&ip_next_T0[72])
950  vop80 = _mm256_fmadd_ps(
951  vwgt,
952  _mm256_cvtph_ps(
953  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
954  vop80);
955  // skip unnecessary prefetch of (&ip_next_T0[80])
956  vop88 = _mm256_fmadd_ps(
957  vwgt,
958  _mm256_cvtph_ps(
959  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
960  vop88);
961  // skip unnecessary prefetch of (&ip_next_T0[88])
962  vop96 = _mm256_fmadd_ps(
963  vwgt,
964  _mm256_cvtph_ps(
965  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
966  vop96);
967  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
968  vop104 = _mm256_fmadd_ps(
969  vwgt,
970  _mm256_cvtph_ps(
971  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
972  vop104);
973  // skip unnecessary prefetch of (&ip_next_T0[104])
974  vop112 = _mm256_fmadd_ps(
975  vwgt,
976  _mm256_cvtph_ps(
977  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
978  vop112);
979  // skip unnecessary prefetch of (&ip_next_T0[112])
980  vop120 = _mm256_fmadd_ps(
981  vwgt,
982  _mm256_cvtph_ps(
983  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
984  vop120);
985  // skip unnecessary prefetch of (&ip_next_T0[120])
986  }
987  if (normalize_by_lengths == false) {
988  _mm256_storeu_ps(&op[0], vop0);
989  _mm256_storeu_ps(&op[8], vop8);
990  _mm256_storeu_ps(&op[16], vop16);
991  _mm256_storeu_ps(&op[24], vop24);
992  _mm256_storeu_ps(&op[32], vop32);
993  _mm256_storeu_ps(&op[40], vop40);
994  _mm256_storeu_ps(&op[48], vop48);
995  _mm256_storeu_ps(&op[56], vop56);
996  _mm256_storeu_ps(&op[64], vop64);
997  _mm256_storeu_ps(&op[72], vop72);
998  _mm256_storeu_ps(&op[80], vop80);
999  _mm256_storeu_ps(&op[88], vop88);
1000  _mm256_storeu_ps(&op[96], vop96);
1001  _mm256_storeu_ps(&op[104], vop104);
1002  _mm256_storeu_ps(&op[112], vop112);
1003  _mm256_storeu_ps(&op[120], vop120);
1004  } else if (lengths[rangeIndex]) {
1005  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1006  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1007  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1008  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1009  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1010  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1011  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1012  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1013  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1014  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1015  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1016  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1017  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1018  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1019  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1020  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1021  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1022  }
1023  }
1024  } else if (block_size == 64) {
1025  // unrolling 8 times
1026  int32_t dataInd = 0;
1027  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1028  float* op = &out[rangeIndex * block_size];
1029  __m256 vop0 = _mm256_setzero_ps();
1030  __m256 vop8 = _mm256_setzero_ps();
1031  __m256 vop16 = _mm256_setzero_ps();
1032  __m256 vop24 = _mm256_setzero_ps();
1033  __m256 vop32 = _mm256_setzero_ps();
1034  __m256 vop40 = _mm256_setzero_ps();
1035  __m256 vop48 = _mm256_setzero_ps();
1036  __m256 vop56 = _mm256_setzero_ps();
1037  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1038  ++dataInd) {
1039  const int32_t idx = indices[dataInd];
1040  CAFFE_ENFORCE(
1041  idx >= 0 && idx < data_size,
1042  "Index ",
1043  dataInd,
1044  " is out of bounds: ",
1045  idx,
1046  ", range 0 to ",
1047  data_size);
1048  float wgt = 1.f;
1049  if (weights) {
1050  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1051  }
1052  __m256 vwgt = _mm256_set1_ps(wgt);
1053  const float16* ip = &input[idx * fused_block_size];
1054  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1055  ? (dataInd + prefdist_T0)
1056  : dataInd;
1057  const int32_t idx_pref_T0 = indices[next_T0];
1058  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1059  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1060  vop0 = _mm256_fmadd_ps(
1061  vwgt,
1062  _mm256_cvtph_ps(
1063  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1064  vop0);
1065  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1066  vop8 = _mm256_fmadd_ps(
1067  vwgt,
1068  _mm256_cvtph_ps(
1069  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1070  vop8);
1071  // skip unnecessary prefetch of (&ip_next_T0[8])
1072  vop16 = _mm256_fmadd_ps(
1073  vwgt,
1074  _mm256_cvtph_ps(
1075  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1076  vop16);
1077  // skip unnecessary prefetch of (&ip_next_T0[16])
1078  vop24 = _mm256_fmadd_ps(
1079  vwgt,
1080  _mm256_cvtph_ps(
1081  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1082  vop24);
1083  // skip unnecessary prefetch of (&ip_next_T0[24])
1084  vop32 = _mm256_fmadd_ps(
1085  vwgt,
1086  _mm256_cvtph_ps(
1087  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1088  vop32);
1089  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1090  vop40 = _mm256_fmadd_ps(
1091  vwgt,
1092  _mm256_cvtph_ps(
1093  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1094  vop40);
1095  // skip unnecessary prefetch of (&ip_next_T0[40])
1096  vop48 = _mm256_fmadd_ps(
1097  vwgt,
1098  _mm256_cvtph_ps(
1099  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1100  vop48);
1101  // skip unnecessary prefetch of (&ip_next_T0[48])
1102  vop56 = _mm256_fmadd_ps(
1103  vwgt,
1104  _mm256_cvtph_ps(
1105  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1106  vop56);
1107  // skip unnecessary prefetch of (&ip_next_T0[56])
1108  }
1109  if (normalize_by_lengths == false) {
1110  _mm256_storeu_ps(&op[0], vop0);
1111  _mm256_storeu_ps(&op[8], vop8);
1112  _mm256_storeu_ps(&op[16], vop16);
1113  _mm256_storeu_ps(&op[24], vop24);
1114  _mm256_storeu_ps(&op[32], vop32);
1115  _mm256_storeu_ps(&op[40], vop40);
1116  _mm256_storeu_ps(&op[48], vop48);
1117  _mm256_storeu_ps(&op[56], vop56);
1118  } else if (lengths[rangeIndex]) {
1119  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1120  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1121  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1122  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1123  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1124  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1125  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1126  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1127  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1128  }
1129  }
1130  } else if (block_size == 32) {
1131  // unrolling 4 times
1132  int32_t dataInd = 0;
1133  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1134  float* op = &out[rangeIndex * block_size];
1135  __m256 vop0 = _mm256_setzero_ps();
1136  __m256 vop8 = _mm256_setzero_ps();
1137  __m256 vop16 = _mm256_setzero_ps();
1138  __m256 vop24 = _mm256_setzero_ps();
1139  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1140  ++dataInd) {
1141  const int32_t idx = indices[dataInd];
1142  CAFFE_ENFORCE(
1143  idx >= 0 && idx < data_size,
1144  "Index ",
1145  dataInd,
1146  " is out of bounds: ",
1147  idx,
1148  ", range 0 to ",
1149  data_size);
1150  float wgt = 1.f;
1151  if (weights) {
1152  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1153  }
1154  __m256 vwgt = _mm256_set1_ps(wgt);
1155  const float16* ip = &input[idx * fused_block_size];
1156  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1157  ? (dataInd + prefdist_T0)
1158  : dataInd;
1159  const int32_t idx_pref_T0 = indices[next_T0];
1160  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1161  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1162  vop0 = _mm256_fmadd_ps(
1163  vwgt,
1164  _mm256_cvtph_ps(
1165  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1166  vop0);
1167  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1168  vop8 = _mm256_fmadd_ps(
1169  vwgt,
1170  _mm256_cvtph_ps(
1171  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1172  vop8);
1173  // skip unnecessary prefetch of (&ip_next_T0[8])
1174  vop16 = _mm256_fmadd_ps(
1175  vwgt,
1176  _mm256_cvtph_ps(
1177  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1178  vop16);
1179  // skip unnecessary prefetch of (&ip_next_T0[16])
1180  vop24 = _mm256_fmadd_ps(
1181  vwgt,
1182  _mm256_cvtph_ps(
1183  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1184  vop24);
1185  // skip unnecessary prefetch of (&ip_next_T0[24])
1186  }
1187  if (normalize_by_lengths == false) {
1188  _mm256_storeu_ps(&op[0], vop0);
1189  _mm256_storeu_ps(&op[8], vop8);
1190  _mm256_storeu_ps(&op[16], vop16);
1191  _mm256_storeu_ps(&op[24], vop24);
1192  } else if (lengths[rangeIndex]) {
1193  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1194  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1195  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1196  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1197  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1198  }
1199  }
1200  } else if (block_size == 16) {
1201  // unrolling 2 times
1202  int32_t dataInd = 0;
1203  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1204  float* op = &out[rangeIndex * block_size];
1205  __m256 vop0 = _mm256_setzero_ps();
1206  __m256 vop8 = _mm256_setzero_ps();
1207  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1208  ++dataInd) {
1209  const int32_t idx = indices[dataInd];
1210  CAFFE_ENFORCE(
1211  idx >= 0 && idx < data_size,
1212  "Index ",
1213  dataInd,
1214  " is out of bounds: ",
1215  idx,
1216  ", range 0 to ",
1217  data_size);
1218  float wgt = 1.f;
1219  if (weights) {
1220  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1221  }
1222  __m256 vwgt = _mm256_set1_ps(wgt);
1223  const float16* ip = &input[idx * fused_block_size];
1224  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1225  ? (dataInd + prefdist_T0)
1226  : dataInd;
1227  const int32_t idx_pref_T0 = indices[next_T0];
1228  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1229  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1230  vop0 = _mm256_fmadd_ps(
1231  vwgt,
1232  _mm256_cvtph_ps(
1233  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1234  vop0);
1235  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1236  vop8 = _mm256_fmadd_ps(
1237  vwgt,
1238  _mm256_cvtph_ps(
1239  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1240  vop8);
1241  // skip unnecessary prefetch of (&ip_next_T0[8])
1242  }
1243  if (normalize_by_lengths == false) {
1244  _mm256_storeu_ps(&op[0], vop0);
1245  _mm256_storeu_ps(&op[8], vop8);
1246  } else if (lengths[rangeIndex]) {
1247  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1248  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1249  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1250  }
1251  }
1252  } else {
1253  // generic code
1254  int32_t dataInd = 0;
1255  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1256  float* op = &out[rangeIndex * block_size];
1257  TIndex j = 0;
1258  for (; j + 8 <= block_size; j += 8) {
1259  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1260  }
1261  for (; j < block_size; j++) {
1262  op[j] = 0.0f;
1263  }
1264  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1265  ++dataInd) {
1266  const int32_t idx = indices[dataInd];
1267  CAFFE_ENFORCE(
1268  idx >= 0 && idx < data_size,
1269  "Index ",
1270  dataInd,
1271  " is out of bounds: ",
1272  idx,
1273  ", range 0 to ",
1274  data_size);
1275  float wgt = 1.f;
1276  if (weights) {
1277  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1278  }
1279  __m256 vwgt = _mm256_set1_ps(wgt);
1280  const float16* ip = &input[idx * fused_block_size];
1281  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1282  ? (dataInd + prefdist_T0)
1283  : dataInd;
1284  const int32_t idx_pref_T0 = indices[next_T0];
1285  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1286  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1287  j = 0;
1288  for (; j + 8 <= block_size; j += 8) {
1289  _mm256_storeu_ps(
1290  &op[j],
1291  _mm256_fmadd_ps(
1292  vwgt,
1293  _mm256_cvtph_ps(_mm_loadu_si128(
1294  reinterpret_cast<const __m128i*>(&ip[j]))),
1295  _mm256_loadu_ps(&op[j])));
1296  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
1297  }
1298  float16 vtmp1[8] CAFFE2_ALIGNED(64);
1299  for (; j < block_size; j++) {
1300  vtmp1[0] = ip[j];
1301  __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1302  op[j] += wgt * ((float*)(&vtmp2))[0];
1303  }
1304  }
1305  if (normalize_by_lengths && lengths[rangeIndex]) {
1306  float len_inv = 1.0f / lengths[rangeIndex];
1307  __m256 vlen_inv = _mm256_set1_ps(len_inv);
1308  j = 0;
1309  for (; j + 8 <= block_size; j += 8) {
1310  _mm256_storeu_ps(
1311  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1312  }
1313  for (; j < block_size; j++) {
1314  op[j] = len_inv * op[j];
1315  }
1316  }
1317  }
1318  }
1319 }
1320 void Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float_false__avx2_fma(
1321  const TIndex block_size,
1322  const TIndex output_size,
1323  const TIndex index_size,
1324  const TIndex data_size,
1325  const float16* input,
1326  const int32_t* indices,
1327  const int* lengths,
1328  const float* weights,
1329  bool normalize_by_lengths,
1330  float* out) {
1331  Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float__avx2_fma<false>(
1332  block_size,
1333  output_size,
1334  index_size,
1335  data_size,
1336  input,
1337  indices,
1338  lengths,
1339  weights,
1340  normalize_by_lengths,
1341  out);
1342 }
1343 void Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float_true__avx2_fma(
1344  const TIndex block_size,
1345  const TIndex output_size,
1346  const TIndex index_size,
1347  const TIndex data_size,
1348  const float16* input,
1349  const int32_t* indices,
1350  const int* lengths,
1351  const float* weights,
1352  bool normalize_by_lengths,
1353  float* out) {
1354  Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float__avx2_fma<true>(
1355  block_size,
1356  output_size,
1357  index_size,
1358  data_size,
1359  input,
1360  indices,
1361  lengths,
1362  weights,
1363  normalize_by_lengths,
1364  out);
1365 }
1366 
1367 template <bool IS_WEIGHT_POSITIONAL>
1368 static void Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float__avx2_fma(
1369  const TIndex block_size,
1370  const TIndex output_size,
1371  const TIndex index_size,
1372  const TIndex data_size,
1373  const float16* input,
1374  const int64_t* indices,
1375  const int* lengths,
1376  const float* weights,
1377  bool normalize_by_lengths,
1378  float* out) {
1379  const int64_t prefdist_T0 = 16;
1380  const int64_t fused_block_size = block_size + 4;
1381  if (block_size == 128) {
1382  // unrolling 16 times
1383  int64_t dataInd = 0;
1384  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1385  float* op = &out[rangeIndex * block_size];
1386  __m256 vop0 = _mm256_setzero_ps();
1387  __m256 vop8 = _mm256_setzero_ps();
1388  __m256 vop16 = _mm256_setzero_ps();
1389  __m256 vop24 = _mm256_setzero_ps();
1390  __m256 vop32 = _mm256_setzero_ps();
1391  __m256 vop40 = _mm256_setzero_ps();
1392  __m256 vop48 = _mm256_setzero_ps();
1393  __m256 vop56 = _mm256_setzero_ps();
1394  __m256 vop64 = _mm256_setzero_ps();
1395  __m256 vop72 = _mm256_setzero_ps();
1396  __m256 vop80 = _mm256_setzero_ps();
1397  __m256 vop88 = _mm256_setzero_ps();
1398  __m256 vop96 = _mm256_setzero_ps();
1399  __m256 vop104 = _mm256_setzero_ps();
1400  __m256 vop112 = _mm256_setzero_ps();
1401  __m256 vop120 = _mm256_setzero_ps();
1402  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1403  ++dataInd) {
1404  const int64_t idx = indices[dataInd];
1405  CAFFE_ENFORCE(
1406  idx >= 0 && idx < data_size,
1407  "Index ",
1408  dataInd,
1409  " is out of bounds: ",
1410  idx,
1411  ", range 0 to ",
1412  data_size);
1413  float wgt = 1.f;
1414  if (weights) {
1415  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1416  }
1417  __m256 vwgt = _mm256_set1_ps(wgt);
1418  const float16* ip = &input[idx * fused_block_size];
1419  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1420  ? (dataInd + prefdist_T0)
1421  : dataInd;
1422  const int64_t idx_pref_T0 = indices[next_T0];
1423  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1424  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1425  vop0 = _mm256_fmadd_ps(
1426  vwgt,
1427  _mm256_cvtph_ps(
1428  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1429  vop0);
1430  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1431  vop8 = _mm256_fmadd_ps(
1432  vwgt,
1433  _mm256_cvtph_ps(
1434  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1435  vop8);
1436  // skip unnecessary prefetch of (&ip_next_T0[8])
1437  vop16 = _mm256_fmadd_ps(
1438  vwgt,
1439  _mm256_cvtph_ps(
1440  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1441  vop16);
1442  // skip unnecessary prefetch of (&ip_next_T0[16])
1443  vop24 = _mm256_fmadd_ps(
1444  vwgt,
1445  _mm256_cvtph_ps(
1446  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1447  vop24);
1448  // skip unnecessary prefetch of (&ip_next_T0[24])
1449  vop32 = _mm256_fmadd_ps(
1450  vwgt,
1451  _mm256_cvtph_ps(
1452  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1453  vop32);
1454  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1455  vop40 = _mm256_fmadd_ps(
1456  vwgt,
1457  _mm256_cvtph_ps(
1458  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1459  vop40);
1460  // skip unnecessary prefetch of (&ip_next_T0[40])
1461  vop48 = _mm256_fmadd_ps(
1462  vwgt,
1463  _mm256_cvtph_ps(
1464  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1465  vop48);
1466  // skip unnecessary prefetch of (&ip_next_T0[48])
1467  vop56 = _mm256_fmadd_ps(
1468  vwgt,
1469  _mm256_cvtph_ps(
1470  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1471  vop56);
1472  // skip unnecessary prefetch of (&ip_next_T0[56])
1473  vop64 = _mm256_fmadd_ps(
1474  vwgt,
1475  _mm256_cvtph_ps(
1476  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1477  vop64);
1478  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
1479  vop72 = _mm256_fmadd_ps(
1480  vwgt,
1481  _mm256_cvtph_ps(
1482  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1483  vop72);
1484  // skip unnecessary prefetch of (&ip_next_T0[72])
1485  vop80 = _mm256_fmadd_ps(
1486  vwgt,
1487  _mm256_cvtph_ps(
1488  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1489  vop80);
1490  // skip unnecessary prefetch of (&ip_next_T0[80])
1491  vop88 = _mm256_fmadd_ps(
1492  vwgt,
1493  _mm256_cvtph_ps(
1494  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1495  vop88);
1496  // skip unnecessary prefetch of (&ip_next_T0[88])
1497  vop96 = _mm256_fmadd_ps(
1498  vwgt,
1499  _mm256_cvtph_ps(
1500  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1501  vop96);
1502  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
1503  vop104 = _mm256_fmadd_ps(
1504  vwgt,
1505  _mm256_cvtph_ps(
1506  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1507  vop104);
1508  // skip unnecessary prefetch of (&ip_next_T0[104])
1509  vop112 = _mm256_fmadd_ps(
1510  vwgt,
1511  _mm256_cvtph_ps(
1512  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1513  vop112);
1514  // skip unnecessary prefetch of (&ip_next_T0[112])
1515  vop120 = _mm256_fmadd_ps(
1516  vwgt,
1517  _mm256_cvtph_ps(
1518  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1519  vop120);
1520  // skip unnecessary prefetch of (&ip_next_T0[120])
1521  }
1522  if (normalize_by_lengths == false) {
1523  _mm256_storeu_ps(&op[0], vop0);
1524  _mm256_storeu_ps(&op[8], vop8);
1525  _mm256_storeu_ps(&op[16], vop16);
1526  _mm256_storeu_ps(&op[24], vop24);
1527  _mm256_storeu_ps(&op[32], vop32);
1528  _mm256_storeu_ps(&op[40], vop40);
1529  _mm256_storeu_ps(&op[48], vop48);
1530  _mm256_storeu_ps(&op[56], vop56);
1531  _mm256_storeu_ps(&op[64], vop64);
1532  _mm256_storeu_ps(&op[72], vop72);
1533  _mm256_storeu_ps(&op[80], vop80);
1534  _mm256_storeu_ps(&op[88], vop88);
1535  _mm256_storeu_ps(&op[96], vop96);
1536  _mm256_storeu_ps(&op[104], vop104);
1537  _mm256_storeu_ps(&op[112], vop112);
1538  _mm256_storeu_ps(&op[120], vop120);
1539  } else if (lengths[rangeIndex]) {
1540  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1541  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1542  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1543  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1544  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1545  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1546  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1547  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1548  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1549  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1550  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1551  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1552  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1553  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1554  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1555  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1556  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1557  }
1558  }
1559  } else if (block_size == 64) {
1560  // unrolling 8 times
1561  int64_t dataInd = 0;
1562  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1563  float* op = &out[rangeIndex * block_size];
1564  __m256 vop0 = _mm256_setzero_ps();
1565  __m256 vop8 = _mm256_setzero_ps();
1566  __m256 vop16 = _mm256_setzero_ps();
1567  __m256 vop24 = _mm256_setzero_ps();
1568  __m256 vop32 = _mm256_setzero_ps();
1569  __m256 vop40 = _mm256_setzero_ps();
1570  __m256 vop48 = _mm256_setzero_ps();
1571  __m256 vop56 = _mm256_setzero_ps();
1572  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1573  ++dataInd) {
1574  const int64_t idx = indices[dataInd];
1575  CAFFE_ENFORCE(
1576  idx >= 0 && idx < data_size,
1577  "Index ",
1578  dataInd,
1579  " is out of bounds: ",
1580  idx,
1581  ", range 0 to ",
1582  data_size);
1583  float wgt = 1.f;
1584  if (weights) {
1585  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1586  }
1587  __m256 vwgt = _mm256_set1_ps(wgt);
1588  const float16* ip = &input[idx * fused_block_size];
1589  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1590  ? (dataInd + prefdist_T0)
1591  : dataInd;
1592  const int64_t idx_pref_T0 = indices[next_T0];
1593  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1594  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1595  vop0 = _mm256_fmadd_ps(
1596  vwgt,
1597  _mm256_cvtph_ps(
1598  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1599  vop0);
1600  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1601  vop8 = _mm256_fmadd_ps(
1602  vwgt,
1603  _mm256_cvtph_ps(
1604  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1605  vop8);
1606  // skip unnecessary prefetch of (&ip_next_T0[8])
1607  vop16 = _mm256_fmadd_ps(
1608  vwgt,
1609  _mm256_cvtph_ps(
1610  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1611  vop16);
1612  // skip unnecessary prefetch of (&ip_next_T0[16])
1613  vop24 = _mm256_fmadd_ps(
1614  vwgt,
1615  _mm256_cvtph_ps(
1616  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1617  vop24);
1618  // skip unnecessary prefetch of (&ip_next_T0[24])
1619  vop32 = _mm256_fmadd_ps(
1620  vwgt,
1621  _mm256_cvtph_ps(
1622  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1623  vop32);
1624  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1625  vop40 = _mm256_fmadd_ps(
1626  vwgt,
1627  _mm256_cvtph_ps(
1628  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1629  vop40);
1630  // skip unnecessary prefetch of (&ip_next_T0[40])
1631  vop48 = _mm256_fmadd_ps(
1632  vwgt,
1633  _mm256_cvtph_ps(
1634  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1635  vop48);
1636  // skip unnecessary prefetch of (&ip_next_T0[48])
1637  vop56 = _mm256_fmadd_ps(
1638  vwgt,
1639  _mm256_cvtph_ps(
1640  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1641  vop56);
1642  // skip unnecessary prefetch of (&ip_next_T0[56])
1643  }
1644  if (normalize_by_lengths == false) {
1645  _mm256_storeu_ps(&op[0], vop0);
1646  _mm256_storeu_ps(&op[8], vop8);
1647  _mm256_storeu_ps(&op[16], vop16);
1648  _mm256_storeu_ps(&op[24], vop24);
1649  _mm256_storeu_ps(&op[32], vop32);
1650  _mm256_storeu_ps(&op[40], vop40);
1651  _mm256_storeu_ps(&op[48], vop48);
1652  _mm256_storeu_ps(&op[56], vop56);
1653  } else if (lengths[rangeIndex]) {
1654  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1655  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1656  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1657  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1658  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1659  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1660  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1661  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1662  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1663  }
1664  }
1665  } else if (block_size == 32) {
1666  // unrolling 4 times
1667  int64_t dataInd = 0;
1668  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1669  float* op = &out[rangeIndex * block_size];
1670  __m256 vop0 = _mm256_setzero_ps();
1671  __m256 vop8 = _mm256_setzero_ps();
1672  __m256 vop16 = _mm256_setzero_ps();
1673  __m256 vop24 = _mm256_setzero_ps();
1674  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1675  ++dataInd) {
1676  const int64_t idx = indices[dataInd];
1677  CAFFE_ENFORCE(
1678  idx >= 0 && idx < data_size,
1679  "Index ",
1680  dataInd,
1681  " is out of bounds: ",
1682  idx,
1683  ", range 0 to ",
1684  data_size);
1685  float wgt = 1.f;
1686  if (weights) {
1687  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1688  }
1689  __m256 vwgt = _mm256_set1_ps(wgt);
1690  const float16* ip = &input[idx * fused_block_size];
1691  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1692  ? (dataInd + prefdist_T0)
1693  : dataInd;
1694  const int64_t idx_pref_T0 = indices[next_T0];
1695  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1696  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1697  vop0 = _mm256_fmadd_ps(
1698  vwgt,
1699  _mm256_cvtph_ps(
1700  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1701  vop0);
1702  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1703  vop8 = _mm256_fmadd_ps(
1704  vwgt,
1705  _mm256_cvtph_ps(
1706  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1707  vop8);
1708  // skip unnecessary prefetch of (&ip_next_T0[8])
1709  vop16 = _mm256_fmadd_ps(
1710  vwgt,
1711  _mm256_cvtph_ps(
1712  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1713  vop16);
1714  // skip unnecessary prefetch of (&ip_next_T0[16])
1715  vop24 = _mm256_fmadd_ps(
1716  vwgt,
1717  _mm256_cvtph_ps(
1718  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1719  vop24);
1720  // skip unnecessary prefetch of (&ip_next_T0[24])
1721  }
1722  if (normalize_by_lengths == false) {
1723  _mm256_storeu_ps(&op[0], vop0);
1724  _mm256_storeu_ps(&op[8], vop8);
1725  _mm256_storeu_ps(&op[16], vop16);
1726  _mm256_storeu_ps(&op[24], vop24);
1727  } else if (lengths[rangeIndex]) {
1728  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1729  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1730  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1731  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1732  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1733  }
1734  }
1735  } else if (block_size == 16) {
1736  // unrolling 2 times
1737  int64_t dataInd = 0;
1738  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1739  float* op = &out[rangeIndex * block_size];
1740  __m256 vop0 = _mm256_setzero_ps();
1741  __m256 vop8 = _mm256_setzero_ps();
1742  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1743  ++dataInd) {
1744  const int64_t idx = indices[dataInd];
1745  CAFFE_ENFORCE(
1746  idx >= 0 && idx < data_size,
1747  "Index ",
1748  dataInd,
1749  " is out of bounds: ",
1750  idx,
1751  ", range 0 to ",
1752  data_size);
1753  float wgt = 1.f;
1754  if (weights) {
1755  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1756  }
1757  __m256 vwgt = _mm256_set1_ps(wgt);
1758  const float16* ip = &input[idx * fused_block_size];
1759  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1760  ? (dataInd + prefdist_T0)
1761  : dataInd;
1762  const int64_t idx_pref_T0 = indices[next_T0];
1763  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1764  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1765  vop0 = _mm256_fmadd_ps(
1766  vwgt,
1767  _mm256_cvtph_ps(
1768  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1769  vop0);
1770  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1771  vop8 = _mm256_fmadd_ps(
1772  vwgt,
1773  _mm256_cvtph_ps(
1774  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1775  vop8);
1776  // skip unnecessary prefetch of (&ip_next_T0[8])
1777  }
1778  if (normalize_by_lengths == false) {
1779  _mm256_storeu_ps(&op[0], vop0);
1780  _mm256_storeu_ps(&op[8], vop8);
1781  } else if (lengths[rangeIndex]) {
1782  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1783  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1784  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1785  }
1786  }
1787  } else {
1788  // generic code
1789  int64_t dataInd = 0;
1790  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1791  float* op = &out[rangeIndex * block_size];
1792  TIndex j = 0;
1793  for (; j + 8 <= block_size; j += 8) {
1794  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1795  }
1796  for (; j < block_size; j++) {
1797  op[j] = 0.0f;
1798  }
1799  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1800  ++dataInd) {
1801  const int64_t idx = indices[dataInd];
1802  CAFFE_ENFORCE(
1803  idx >= 0 && idx < data_size,
1804  "Index ",
1805  dataInd,
1806  " is out of bounds: ",
1807  idx,
1808  ", range 0 to ",
1809  data_size);
1810  float wgt = 1.f;
1811  if (weights) {
1812  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1813  }
1814  __m256 vwgt = _mm256_set1_ps(wgt);
1815  const float16* ip = &input[idx * fused_block_size];
1816  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1817  ? (dataInd + prefdist_T0)
1818  : dataInd;
1819  const int64_t idx_pref_T0 = indices[next_T0];
1820  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1821  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1822  j = 0;
1823  for (; j + 8 <= block_size; j += 8) {
1824  _mm256_storeu_ps(
1825  &op[j],
1826  _mm256_fmadd_ps(
1827  vwgt,
1828  _mm256_cvtph_ps(_mm_loadu_si128(
1829  reinterpret_cast<const __m128i*>(&ip[j]))),
1830  _mm256_loadu_ps(&op[j])));
1831  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
1832  }
1833  float16 vtmp1[8] CAFFE2_ALIGNED(64);
1834  for (; j < block_size; j++) {
1835  vtmp1[0] = ip[j];
1836  __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1837  op[j] += wgt * ((float*)(&vtmp2))[0];
1838  }
1839  }
1840  if (normalize_by_lengths && lengths[rangeIndex]) {
1841  float len_inv = 1.0f / lengths[rangeIndex];
1842  __m256 vlen_inv = _mm256_set1_ps(len_inv);
1843  j = 0;
1844  for (; j + 8 <= block_size; j += 8) {
1845  _mm256_storeu_ps(
1846  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1847  }
1848  for (; j < block_size; j++) {
1849  op[j] = len_inv * op[j];
1850  }
1851  }
1852  }
1853  }
1854 }
1855 void Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float_false__avx2_fma(
1856  const TIndex block_size,
1857  const TIndex output_size,
1858  const TIndex index_size,
1859  const TIndex data_size,
1860  const float16* input,
1861  const int64_t* indices,
1862  const int* lengths,
1863  const float* weights,
1864  bool normalize_by_lengths,
1865  float* out) {
1866  Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float__avx2_fma<false>(
1867  block_size,
1868  output_size,
1869  index_size,
1870  data_size,
1871  input,
1872  indices,
1873  lengths,
1874  weights,
1875  normalize_by_lengths,
1876  out);
1877 }
1878 void Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float_true__avx2_fma(
1879  const TIndex block_size,
1880  const TIndex output_size,
1881  const TIndex index_size,
1882  const TIndex data_size,
1883  const float16* input,
1884  const int64_t* indices,
1885  const int* lengths,
1886  const float* weights,
1887  bool normalize_by_lengths,
1888  float* out) {
1889  Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float__avx2_fma<true>(
1890  block_size,
1891  output_size,
1892  index_size,
1893  data_size,
1894  input,
1895  indices,
1896  lengths,
1897  weights,
1898  normalize_by_lengths,
1899  out);
1900 }
1901 
1902 template <bool IS_WEIGHT_POSITIONAL>
1903 static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
1904  const TIndex block_size,
1905  const TIndex output_size,
1906  const TIndex index_size,
1907  const TIndex data_size,
1908  const uint8_t* input,
1909  const int32_t* indices,
1910  const int* lengths,
1911  const float* weights,
1912  bool normalize_by_lengths,
1913  float* out) {
1914  const int32_t prefdist_T0 = 16;
1915  const int32_t fused_block_size = block_size + 8;
1916  if (block_size == 128) {
1917  // unrolling 16 times
1918  int32_t dataInd = 0;
1919  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1920  float* op = &out[rangeIndex * block_size];
1921  __m256 vop0 = _mm256_setzero_ps();
1922  __m256 vop8 = _mm256_setzero_ps();
1923  __m256 vop16 = _mm256_setzero_ps();
1924  __m256 vop24 = _mm256_setzero_ps();
1925  __m256 vop32 = _mm256_setzero_ps();
1926  __m256 vop40 = _mm256_setzero_ps();
1927  __m256 vop48 = _mm256_setzero_ps();
1928  __m256 vop56 = _mm256_setzero_ps();
1929  __m256 vop64 = _mm256_setzero_ps();
1930  __m256 vop72 = _mm256_setzero_ps();
1931  __m256 vop80 = _mm256_setzero_ps();
1932  __m256 vop88 = _mm256_setzero_ps();
1933  __m256 vop96 = _mm256_setzero_ps();
1934  __m256 vop104 = _mm256_setzero_ps();
1935  __m256 vop112 = _mm256_setzero_ps();
1936  __m256 vop120 = _mm256_setzero_ps();
1937  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1938  ++dataInd) {
1939  const int32_t idx = indices[dataInd];
1940  CAFFE_ENFORCE(
1941  idx >= 0 && idx < data_size,
1942  "Index ",
1943  dataInd,
1944  " is out of bounds: ",
1945  idx,
1946  ", range 0 to ",
1947  data_size);
1948  float wgt = 1.f;
1949  float bio;
1950  if (weights) {
1951  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1952  }
1953  const float* scale_bias = reinterpret_cast<const float*>(
1954  &input[idx * fused_block_size + block_size]);
1955  bio = wgt * scale_bias[1];
1956  wgt = wgt * scale_bias[0];
1957  __m256 vbio = _mm256_set1_ps(bio);
1958  __m256 vwgt = _mm256_set1_ps(wgt);
1959  const uint8_t* ip = &input[idx * fused_block_size];
1960  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1961  ? (dataInd + prefdist_T0)
1962  : dataInd;
1963  const int32_t idx_pref_T0 = indices[next_T0];
1964  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1965  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1966  vop0 = _mm256_fmadd_ps(
1967  vwgt,
1968  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1969  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
1970  _mm256_add_ps(vop0, vbio));
1971  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1972  vop8 = _mm256_fmadd_ps(
1973  vwgt,
1974  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1975  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
1976  _mm256_add_ps(vop8, vbio));
1977  // skip unnecessary prefetch of (&ip_next_T0[8])
1978  vop16 = _mm256_fmadd_ps(
1979  vwgt,
1980  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1981  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
1982  _mm256_add_ps(vop16, vbio));
1983  // skip unnecessary prefetch of (&ip_next_T0[16])
1984  vop24 = _mm256_fmadd_ps(
1985  vwgt,
1986  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1987  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
1988  _mm256_add_ps(vop24, vbio));
1989  // skip unnecessary prefetch of (&ip_next_T0[24])
1990  vop32 = _mm256_fmadd_ps(
1991  vwgt,
1992  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1993  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
1994  _mm256_add_ps(vop32, vbio));
1995  // skip unnecessary prefetch of (&ip_next_T0[32])
1996  vop40 = _mm256_fmadd_ps(
1997  vwgt,
1998  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1999  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2000  _mm256_add_ps(vop40, vbio));
2001  // skip unnecessary prefetch of (&ip_next_T0[40])
2002  vop48 = _mm256_fmadd_ps(
2003  vwgt,
2004  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2005  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2006  _mm256_add_ps(vop48, vbio));
2007  // skip unnecessary prefetch of (&ip_next_T0[48])
2008  vop56 = _mm256_fmadd_ps(
2009  vwgt,
2010  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2011  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2012  _mm256_add_ps(vop56, vbio));
2013  // skip unnecessary prefetch of (&ip_next_T0[56])
2014  vop64 = _mm256_fmadd_ps(
2015  vwgt,
2016  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2017  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2018  _mm256_add_ps(vop64, vbio));
2019  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
2020  vop72 = _mm256_fmadd_ps(
2021  vwgt,
2022  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2023  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2024  _mm256_add_ps(vop72, vbio));
2025  // skip unnecessary prefetch of (&ip_next_T0[72])
2026  vop80 = _mm256_fmadd_ps(
2027  vwgt,
2028  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2029  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2030  _mm256_add_ps(vop80, vbio));
2031  // skip unnecessary prefetch of (&ip_next_T0[80])
2032  vop88 = _mm256_fmadd_ps(
2033  vwgt,
2034  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2035  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2036  _mm256_add_ps(vop88, vbio));
2037  // skip unnecessary prefetch of (&ip_next_T0[88])
2038  vop96 = _mm256_fmadd_ps(
2039  vwgt,
2040  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2041  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2042  _mm256_add_ps(vop96, vbio));
2043  // skip unnecessary prefetch of (&ip_next_T0[96])
2044  vop104 = _mm256_fmadd_ps(
2045  vwgt,
2046  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2047  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2048  _mm256_add_ps(vop104, vbio));
2049  // skip unnecessary prefetch of (&ip_next_T0[104])
2050  vop112 = _mm256_fmadd_ps(
2051  vwgt,
2052  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2053  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2054  _mm256_add_ps(vop112, vbio));
2055  // skip unnecessary prefetch of (&ip_next_T0[112])
2056  vop120 = _mm256_fmadd_ps(
2057  vwgt,
2058  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2059  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2060  _mm256_add_ps(vop120, vbio));
2061  // skip unnecessary prefetch of (&ip_next_T0[120])
2062  }
2063  if (normalize_by_lengths == false) {
2064  _mm256_storeu_ps(&op[0], vop0);
2065  _mm256_storeu_ps(&op[8], vop8);
2066  _mm256_storeu_ps(&op[16], vop16);
2067  _mm256_storeu_ps(&op[24], vop24);
2068  _mm256_storeu_ps(&op[32], vop32);
2069  _mm256_storeu_ps(&op[40], vop40);
2070  _mm256_storeu_ps(&op[48], vop48);
2071  _mm256_storeu_ps(&op[56], vop56);
2072  _mm256_storeu_ps(&op[64], vop64);
2073  _mm256_storeu_ps(&op[72], vop72);
2074  _mm256_storeu_ps(&op[80], vop80);
2075  _mm256_storeu_ps(&op[88], vop88);
2076  _mm256_storeu_ps(&op[96], vop96);
2077  _mm256_storeu_ps(&op[104], vop104);
2078  _mm256_storeu_ps(&op[112], vop112);
2079  _mm256_storeu_ps(&op[120], vop120);
2080  } else if (lengths[rangeIndex]) {
2081  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2082  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2083  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2084  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2085  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2086  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2087  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2088  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2089  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2090  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2091  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2092  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2093  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2094  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2095  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2096  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2097  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2098  }
2099  }
2100  } else if (block_size == 64) {
2101  // unrolling 8 times
2102  int32_t dataInd = 0;
2103  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2104  float* op = &out[rangeIndex * block_size];
2105  __m256 vop0 = _mm256_setzero_ps();
2106  __m256 vop8 = _mm256_setzero_ps();
2107  __m256 vop16 = _mm256_setzero_ps();
2108  __m256 vop24 = _mm256_setzero_ps();
2109  __m256 vop32 = _mm256_setzero_ps();
2110  __m256 vop40 = _mm256_setzero_ps();
2111  __m256 vop48 = _mm256_setzero_ps();
2112  __m256 vop56 = _mm256_setzero_ps();
2113  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2114  ++dataInd) {
2115  const int32_t idx = indices[dataInd];
2116  CAFFE_ENFORCE(
2117  idx >= 0 && idx < data_size,
2118  "Index ",
2119  dataInd,
2120  " is out of bounds: ",
2121  idx,
2122  ", range 0 to ",
2123  data_size);
2124  float wgt = 1.f;
2125  float bio;
2126  if (weights) {
2127  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2128  }
2129  const float* scale_bias = reinterpret_cast<const float*>(
2130  &input[idx * fused_block_size + block_size]);
2131  bio = wgt * scale_bias[1];
2132  wgt = wgt * scale_bias[0];
2133  __m256 vbio = _mm256_set1_ps(bio);
2134  __m256 vwgt = _mm256_set1_ps(wgt);
2135  const uint8_t* ip = &input[idx * fused_block_size];
2136  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2137  ? (dataInd + prefdist_T0)
2138  : dataInd;
2139  const int32_t idx_pref_T0 = indices[next_T0];
2140  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2141  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2142  vop0 = _mm256_fmadd_ps(
2143  vwgt,
2144  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2145  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2146  _mm256_add_ps(vop0, vbio));
2147  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2148  vop8 = _mm256_fmadd_ps(
2149  vwgt,
2150  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2151  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2152  _mm256_add_ps(vop8, vbio));
2153  // skip unnecessary prefetch of (&ip_next_T0[8])
2154  vop16 = _mm256_fmadd_ps(
2155  vwgt,
2156  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2157  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2158  _mm256_add_ps(vop16, vbio));
2159  // skip unnecessary prefetch of (&ip_next_T0[16])
2160  vop24 = _mm256_fmadd_ps(
2161  vwgt,
2162  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2163  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2164  _mm256_add_ps(vop24, vbio));
2165  // skip unnecessary prefetch of (&ip_next_T0[24])
2166  vop32 = _mm256_fmadd_ps(
2167  vwgt,
2168  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2169  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2170  _mm256_add_ps(vop32, vbio));
2171  // skip unnecessary prefetch of (&ip_next_T0[32])
2172  vop40 = _mm256_fmadd_ps(
2173  vwgt,
2174  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2175  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2176  _mm256_add_ps(vop40, vbio));
2177  // skip unnecessary prefetch of (&ip_next_T0[40])
2178  vop48 = _mm256_fmadd_ps(
2179  vwgt,
2180  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2181  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2182  _mm256_add_ps(vop48, vbio));
2183  // skip unnecessary prefetch of (&ip_next_T0[48])
2184  vop56 = _mm256_fmadd_ps(
2185  vwgt,
2186  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2187  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2188  _mm256_add_ps(vop56, vbio));
2189  // skip unnecessary prefetch of (&ip_next_T0[56])
2190  }
2191  if (normalize_by_lengths == false) {
2192  _mm256_storeu_ps(&op[0], vop0);
2193  _mm256_storeu_ps(&op[8], vop8);
2194  _mm256_storeu_ps(&op[16], vop16);
2195  _mm256_storeu_ps(&op[24], vop24);
2196  _mm256_storeu_ps(&op[32], vop32);
2197  _mm256_storeu_ps(&op[40], vop40);
2198  _mm256_storeu_ps(&op[48], vop48);
2199  _mm256_storeu_ps(&op[56], vop56);
2200  } else if (lengths[rangeIndex]) {
2201  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2202  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2203  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2204  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2205  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2206  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2207  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2208  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2209  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2210  }
2211  }
2212  } else if (block_size == 32) {
2213  // unrolling 4 times
2214  int32_t dataInd = 0;
2215  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2216  float* op = &out[rangeIndex * block_size];
2217  __m256 vop0 = _mm256_setzero_ps();
2218  __m256 vop8 = _mm256_setzero_ps();
2219  __m256 vop16 = _mm256_setzero_ps();
2220  __m256 vop24 = _mm256_setzero_ps();
2221  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2222  ++dataInd) {
2223  const int32_t idx = indices[dataInd];
2224  CAFFE_ENFORCE(
2225  idx >= 0 && idx < data_size,
2226  "Index ",
2227  dataInd,
2228  " is out of bounds: ",
2229  idx,
2230  ", range 0 to ",
2231  data_size);
2232  float wgt = 1.f;
2233  float bio;
2234  if (weights) {
2235  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2236  }
2237  const float* scale_bias = reinterpret_cast<const float*>(
2238  &input[idx * fused_block_size + block_size]);
2239  bio = wgt * scale_bias[1];
2240  wgt = wgt * scale_bias[0];
2241  __m256 vbio = _mm256_set1_ps(bio);
2242  __m256 vwgt = _mm256_set1_ps(wgt);
2243  const uint8_t* ip = &input[idx * fused_block_size];
2244  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2245  ? (dataInd + prefdist_T0)
2246  : dataInd;
2247  const int32_t idx_pref_T0 = indices[next_T0];
2248  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2249  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2250  vop0 = _mm256_fmadd_ps(
2251  vwgt,
2252  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2253  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2254  _mm256_add_ps(vop0, vbio));
2255  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2256  vop8 = _mm256_fmadd_ps(
2257  vwgt,
2258  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2259  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2260  _mm256_add_ps(vop8, vbio));
2261  // skip unnecessary prefetch of (&ip_next_T0[8])
2262  vop16 = _mm256_fmadd_ps(
2263  vwgt,
2264  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2265  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2266  _mm256_add_ps(vop16, vbio));
2267  // skip unnecessary prefetch of (&ip_next_T0[16])
2268  vop24 = _mm256_fmadd_ps(
2269  vwgt,
2270  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2271  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2272  _mm256_add_ps(vop24, vbio));
2273  // skip unnecessary prefetch of (&ip_next_T0[24])
2274  }
2275  if (normalize_by_lengths == false) {
2276  _mm256_storeu_ps(&op[0], vop0);
2277  _mm256_storeu_ps(&op[8], vop8);
2278  _mm256_storeu_ps(&op[16], vop16);
2279  _mm256_storeu_ps(&op[24], vop24);
2280  } else if (lengths[rangeIndex]) {
2281  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2282  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2283  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2284  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2285  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2286  }
2287  }
2288  } else if (block_size == 16) {
2289  // unrolling 2 times
2290  int32_t dataInd = 0;
2291  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2292  float* op = &out[rangeIndex * block_size];
2293  __m256 vop0 = _mm256_setzero_ps();
2294  __m256 vop8 = _mm256_setzero_ps();
2295  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2296  ++dataInd) {
2297  const int32_t idx = indices[dataInd];
2298  CAFFE_ENFORCE(
2299  idx >= 0 && idx < data_size,
2300  "Index ",
2301  dataInd,
2302  " is out of bounds: ",
2303  idx,
2304  ", range 0 to ",
2305  data_size);
2306  float wgt = 1.f;
2307  float bio;
2308  if (weights) {
2309  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2310  }
2311  const float* scale_bias = reinterpret_cast<const float*>(
2312  &input[idx * fused_block_size + block_size]);
2313  bio = wgt * scale_bias[1];
2314  wgt = wgt * scale_bias[0];
2315  __m256 vbio = _mm256_set1_ps(bio);
2316  __m256 vwgt = _mm256_set1_ps(wgt);
2317  const uint8_t* ip = &input[idx * fused_block_size];
2318  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2319  ? (dataInd + prefdist_T0)
2320  : dataInd;
2321  const int32_t idx_pref_T0 = indices[next_T0];
2322  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2323  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2324  vop0 = _mm256_fmadd_ps(
2325  vwgt,
2326  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2327  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2328  _mm256_add_ps(vop0, vbio));
2329  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2330  vop8 = _mm256_fmadd_ps(
2331  vwgt,
2332  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2333  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2334  _mm256_add_ps(vop8, vbio));
2335  // skip unnecessary prefetch of (&ip_next_T0[8])
2336  }
2337  if (normalize_by_lengths == false) {
2338  _mm256_storeu_ps(&op[0], vop0);
2339  _mm256_storeu_ps(&op[8], vop8);
2340  } else if (lengths[rangeIndex]) {
2341  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2342  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2343  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2344  }
2345  }
2346  } else {
2347  // generic code
2348  int32_t dataInd = 0;
2349  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2350  float* op = &out[rangeIndex * block_size];
2351  TIndex j = 0;
2352  for (; j + 8 <= block_size; j += 8) {
2353  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2354  }
2355  for (; j < block_size; j++) {
2356  op[j] = 0.0f;
2357  }
2358  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2359  ++dataInd) {
2360  const int32_t idx = indices[dataInd];
2361  CAFFE_ENFORCE(
2362  idx >= 0 && idx < data_size,
2363  "Index ",
2364  dataInd,
2365  " is out of bounds: ",
2366  idx,
2367  ", range 0 to ",
2368  data_size);
2369  float wgt = 1.f;
2370  float bio;
2371  if (weights) {
2372  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2373  }
2374  const float* scale_bias = reinterpret_cast<const float*>(
2375  &input[idx * fused_block_size + block_size]);
2376  bio = wgt * scale_bias[1];
2377  wgt = wgt * scale_bias[0];
2378  __m256 vbio = _mm256_set1_ps(bio);
2379  __m256 vwgt = _mm256_set1_ps(wgt);
2380  const uint8_t* ip = &input[idx * fused_block_size];
2381  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2382  ? (dataInd + prefdist_T0)
2383  : dataInd;
2384  const int32_t idx_pref_T0 = indices[next_T0];
2385  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2386  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2387  j = 0;
2388  for (; j + 8 <= block_size; j += 8) {
2389  _mm256_storeu_ps(
2390  &op[j],
2391  _mm256_fmadd_ps(
2392  vwgt,
2393  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2394  reinterpret_cast<const __m128i*>(&ip[j])))),
2395  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2396  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
2397  }
2398  for (; j < block_size; j++) {
2399  op[j] += wgt * ((float)ip[j]) + bio;
2400  }
2401  }
2402  if (normalize_by_lengths && lengths[rangeIndex]) {
2403  float len_inv = 1.0f / lengths[rangeIndex];
2404  __m256 vlen_inv = _mm256_set1_ps(len_inv);
2405  j = 0;
2406  for (; j + 8 <= block_size; j += 8) {
2407  _mm256_storeu_ps(
2408  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2409  }
2410  for (; j < block_size; j++) {
2411  op[j] = len_inv * op[j];
2412  }
2413  }
2414  }
2415  }
2416 }
2417 void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma(
2418  const TIndex block_size,
2419  const TIndex output_size,
2420  const TIndex index_size,
2421  const TIndex data_size,
2422  const uint8_t* input,
2423  const int32_t* indices,
2424  const int* lengths,
2425  const float* weights,
2426  bool normalize_by_lengths,
2427  float* out) {
2428  Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma<false>(
2429  block_size,
2430  output_size,
2431  index_size,
2432  data_size,
2433  input,
2434  indices,
2435  lengths,
2436  weights,
2437  normalize_by_lengths,
2438  out);
2439 }
2440 void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma(
2441  const TIndex block_size,
2442  const TIndex output_size,
2443  const TIndex index_size,
2444  const TIndex data_size,
2445  const uint8_t* input,
2446  const int32_t* indices,
2447  const int* lengths,
2448  const float* weights,
2449  bool normalize_by_lengths,
2450  float* out) {
2451  Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma<true>(
2452  block_size,
2453  output_size,
2454  index_size,
2455  data_size,
2456  input,
2457  indices,
2458  lengths,
2459  weights,
2460  normalize_by_lengths,
2461  out);
2462 }
2463 
2464 template <bool IS_WEIGHT_POSITIONAL>
2465 static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
2466  const TIndex block_size,
2467  const TIndex output_size,
2468  const TIndex index_size,
2469  const TIndex data_size,
2470  const uint8_t* input,
2471  const int64_t* indices,
2472  const int* lengths,
2473  const float* weights,
2474  bool normalize_by_lengths,
2475  float* out) {
2476  const int64_t prefdist_T0 = 16;
2477  const int64_t fused_block_size = block_size + 8;
2478  if (block_size == 128) {
2479  // unrolling 16 times
2480  int64_t dataInd = 0;
2481  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2482  float* op = &out[rangeIndex * block_size];
2483  __m256 vop0 = _mm256_setzero_ps();
2484  __m256 vop8 = _mm256_setzero_ps();
2485  __m256 vop16 = _mm256_setzero_ps();
2486  __m256 vop24 = _mm256_setzero_ps();
2487  __m256 vop32 = _mm256_setzero_ps();
2488  __m256 vop40 = _mm256_setzero_ps();
2489  __m256 vop48 = _mm256_setzero_ps();
2490  __m256 vop56 = _mm256_setzero_ps();
2491  __m256 vop64 = _mm256_setzero_ps();
2492  __m256 vop72 = _mm256_setzero_ps();
2493  __m256 vop80 = _mm256_setzero_ps();
2494  __m256 vop88 = _mm256_setzero_ps();
2495  __m256 vop96 = _mm256_setzero_ps();
2496  __m256 vop104 = _mm256_setzero_ps();
2497  __m256 vop112 = _mm256_setzero_ps();
2498  __m256 vop120 = _mm256_setzero_ps();
2499  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2500  ++dataInd) {
2501  const int64_t idx = indices[dataInd];
2502  CAFFE_ENFORCE(
2503  idx >= 0 && idx < data_size,
2504  "Index ",
2505  dataInd,
2506  " is out of bounds: ",
2507  idx,
2508  ", range 0 to ",
2509  data_size);
2510  float wgt = 1.f;
2511  float bio;
2512  if (weights) {
2513  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2514  }
2515  const float* scale_bias = reinterpret_cast<const float*>(
2516  &input[idx * fused_block_size + block_size]);
2517  bio = wgt * scale_bias[1];
2518  wgt = wgt * scale_bias[0];
2519  __m256 vbio = _mm256_set1_ps(bio);
2520  __m256 vwgt = _mm256_set1_ps(wgt);
2521  const uint8_t* ip = &input[idx * fused_block_size];
2522  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2523  ? (dataInd + prefdist_T0)
2524  : dataInd;
2525  const int64_t idx_pref_T0 = indices[next_T0];
2526  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2527  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2528  vop0 = _mm256_fmadd_ps(
2529  vwgt,
2530  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2531  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2532  _mm256_add_ps(vop0, vbio));
2533  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2534  vop8 = _mm256_fmadd_ps(
2535  vwgt,
2536  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2537  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2538  _mm256_add_ps(vop8, vbio));
2539  // skip unnecessary prefetch of (&ip_next_T0[8])
2540  vop16 = _mm256_fmadd_ps(
2541  vwgt,
2542  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2543  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2544  _mm256_add_ps(vop16, vbio));
2545  // skip unnecessary prefetch of (&ip_next_T0[16])
2546  vop24 = _mm256_fmadd_ps(
2547  vwgt,
2548  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2549  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2550  _mm256_add_ps(vop24, vbio));
2551  // skip unnecessary prefetch of (&ip_next_T0[24])
2552  vop32 = _mm256_fmadd_ps(
2553  vwgt,
2554  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2555  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2556  _mm256_add_ps(vop32, vbio));
2557  // skip unnecessary prefetch of (&ip_next_T0[32])
2558  vop40 = _mm256_fmadd_ps(
2559  vwgt,
2560  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2561  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2562  _mm256_add_ps(vop40, vbio));
2563  // skip unnecessary prefetch of (&ip_next_T0[40])
2564  vop48 = _mm256_fmadd_ps(
2565  vwgt,
2566  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2567  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2568  _mm256_add_ps(vop48, vbio));
2569  // skip unnecessary prefetch of (&ip_next_T0[48])
2570  vop56 = _mm256_fmadd_ps(
2571  vwgt,
2572  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2573  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2574  _mm256_add_ps(vop56, vbio));
2575  // skip unnecessary prefetch of (&ip_next_T0[56])
2576  vop64 = _mm256_fmadd_ps(
2577  vwgt,
2578  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2579  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2580  _mm256_add_ps(vop64, vbio));
2581  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
2582  vop72 = _mm256_fmadd_ps(
2583  vwgt,
2584  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2585  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2586  _mm256_add_ps(vop72, vbio));
2587  // skip unnecessary prefetch of (&ip_next_T0[72])
2588  vop80 = _mm256_fmadd_ps(
2589  vwgt,
2590  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2591  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2592  _mm256_add_ps(vop80, vbio));
2593  // skip unnecessary prefetch of (&ip_next_T0[80])
2594  vop88 = _mm256_fmadd_ps(
2595  vwgt,
2596  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2597  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2598  _mm256_add_ps(vop88, vbio));
2599  // skip unnecessary prefetch of (&ip_next_T0[88])
2600  vop96 = _mm256_fmadd_ps(
2601  vwgt,
2602  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2603  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2604  _mm256_add_ps(vop96, vbio));
2605  // skip unnecessary prefetch of (&ip_next_T0[96])
2606  vop104 = _mm256_fmadd_ps(
2607  vwgt,
2608  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2609  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2610  _mm256_add_ps(vop104, vbio));
2611  // skip unnecessary prefetch of (&ip_next_T0[104])
2612  vop112 = _mm256_fmadd_ps(
2613  vwgt,
2614  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2615  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2616  _mm256_add_ps(vop112, vbio));
2617  // skip unnecessary prefetch of (&ip_next_T0[112])
2618  vop120 = _mm256_fmadd_ps(
2619  vwgt,
2620  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2621  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2622  _mm256_add_ps(vop120, vbio));
2623  // skip unnecessary prefetch of (&ip_next_T0[120])
2624  }
2625  if (normalize_by_lengths == false) {
2626  _mm256_storeu_ps(&op[0], vop0);
2627  _mm256_storeu_ps(&op[8], vop8);
2628  _mm256_storeu_ps(&op[16], vop16);
2629  _mm256_storeu_ps(&op[24], vop24);
2630  _mm256_storeu_ps(&op[32], vop32);
2631  _mm256_storeu_ps(&op[40], vop40);
2632  _mm256_storeu_ps(&op[48], vop48);
2633  _mm256_storeu_ps(&op[56], vop56);
2634  _mm256_storeu_ps(&op[64], vop64);
2635  _mm256_storeu_ps(&op[72], vop72);
2636  _mm256_storeu_ps(&op[80], vop80);
2637  _mm256_storeu_ps(&op[88], vop88);
2638  _mm256_storeu_ps(&op[96], vop96);
2639  _mm256_storeu_ps(&op[104], vop104);
2640  _mm256_storeu_ps(&op[112], vop112);
2641  _mm256_storeu_ps(&op[120], vop120);
2642  } else if (lengths[rangeIndex]) {
2643  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2644  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2645  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2646  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2647  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2648  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2649  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2650  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2651  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2652  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2653  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2654  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2655  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2656  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2657  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2658  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2659  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2660  }
2661  }
2662  } else if (block_size == 64) {
2663  // unrolling 8 times
2664  int64_t dataInd = 0;
2665  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2666  float* op = &out[rangeIndex * block_size];
2667  __m256 vop0 = _mm256_setzero_ps();
2668  __m256 vop8 = _mm256_setzero_ps();
2669  __m256 vop16 = _mm256_setzero_ps();
2670  __m256 vop24 = _mm256_setzero_ps();
2671  __m256 vop32 = _mm256_setzero_ps();
2672  __m256 vop40 = _mm256_setzero_ps();
2673  __m256 vop48 = _mm256_setzero_ps();
2674  __m256 vop56 = _mm256_setzero_ps();
2675  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2676  ++dataInd) {
2677  const int64_t idx = indices[dataInd];
2678  CAFFE_ENFORCE(
2679  idx >= 0 && idx < data_size,
2680  "Index ",
2681  dataInd,
2682  " is out of bounds: ",
2683  idx,
2684  ", range 0 to ",
2685  data_size);
2686  float wgt = 1.f;
2687  float bio;
2688  if (weights) {
2689  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2690  }
2691  const float* scale_bias = reinterpret_cast<const float*>(
2692  &input[idx * fused_block_size + block_size]);
2693  bio = wgt * scale_bias[1];
2694  wgt = wgt * scale_bias[0];
2695  __m256 vbio = _mm256_set1_ps(bio);
2696  __m256 vwgt = _mm256_set1_ps(wgt);
2697  const uint8_t* ip = &input[idx * fused_block_size];
2698  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2699  ? (dataInd + prefdist_T0)
2700  : dataInd;
2701  const int64_t idx_pref_T0 = indices[next_T0];
2702  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2703  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2704  vop0 = _mm256_fmadd_ps(
2705  vwgt,
2706  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2707  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2708  _mm256_add_ps(vop0, vbio));
2709  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2710  vop8 = _mm256_fmadd_ps(
2711  vwgt,
2712  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2713  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2714  _mm256_add_ps(vop8, vbio));
2715  // skip unnecessary prefetch of (&ip_next_T0[8])
2716  vop16 = _mm256_fmadd_ps(
2717  vwgt,
2718  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2719  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2720  _mm256_add_ps(vop16, vbio));
2721  // skip unnecessary prefetch of (&ip_next_T0[16])
2722  vop24 = _mm256_fmadd_ps(
2723  vwgt,
2724  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2725  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2726  _mm256_add_ps(vop24, vbio));
2727  // skip unnecessary prefetch of (&ip_next_T0[24])
2728  vop32 = _mm256_fmadd_ps(
2729  vwgt,
2730  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2731  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2732  _mm256_add_ps(vop32, vbio));
2733  // skip unnecessary prefetch of (&ip_next_T0[32])
2734  vop40 = _mm256_fmadd_ps(
2735  vwgt,
2736  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2737  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2738  _mm256_add_ps(vop40, vbio));
2739  // skip unnecessary prefetch of (&ip_next_T0[40])
2740  vop48 = _mm256_fmadd_ps(
2741  vwgt,
2742  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2743  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2744  _mm256_add_ps(vop48, vbio));
2745  // skip unnecessary prefetch of (&ip_next_T0[48])
2746  vop56 = _mm256_fmadd_ps(
2747  vwgt,
2748  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2749  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2750  _mm256_add_ps(vop56, vbio));
2751  // skip unnecessary prefetch of (&ip_next_T0[56])
2752  }
2753  if (normalize_by_lengths == false) {
2754  _mm256_storeu_ps(&op[0], vop0);
2755  _mm256_storeu_ps(&op[8], vop8);
2756  _mm256_storeu_ps(&op[16], vop16);
2757  _mm256_storeu_ps(&op[24], vop24);
2758  _mm256_storeu_ps(&op[32], vop32);
2759  _mm256_storeu_ps(&op[40], vop40);
2760  _mm256_storeu_ps(&op[48], vop48);
2761  _mm256_storeu_ps(&op[56], vop56);
2762  } else if (lengths[rangeIndex]) {
2763  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2764  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2765  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2766  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2767  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2768  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2769  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2770  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2771  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2772  }
2773  }
2774  } else if (block_size == 32) {
2775  // unrolling 4 times
2776  int64_t dataInd = 0;
2777  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2778  float* op = &out[rangeIndex * block_size];
2779  __m256 vop0 = _mm256_setzero_ps();
2780  __m256 vop8 = _mm256_setzero_ps();
2781  __m256 vop16 = _mm256_setzero_ps();
2782  __m256 vop24 = _mm256_setzero_ps();
2783  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2784  ++dataInd) {
2785  const int64_t idx = indices[dataInd];
2786  CAFFE_ENFORCE(
2787  idx >= 0 && idx < data_size,
2788  "Index ",
2789  dataInd,
2790  " is out of bounds: ",
2791  idx,
2792  ", range 0 to ",
2793  data_size);
2794  float wgt = 1.f;
2795  float bio;
2796  if (weights) {
2797  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2798  }
2799  const float* scale_bias = reinterpret_cast<const float*>(
2800  &input[idx * fused_block_size + block_size]);
2801  bio = wgt * scale_bias[1];
2802  wgt = wgt * scale_bias[0];
2803  __m256 vbio = _mm256_set1_ps(bio);
2804  __m256 vwgt = _mm256_set1_ps(wgt);
2805  const uint8_t* ip = &input[idx * fused_block_size];
2806  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2807  ? (dataInd + prefdist_T0)
2808  : dataInd;
2809  const int64_t idx_pref_T0 = indices[next_T0];
2810  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2811  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2812  vop0 = _mm256_fmadd_ps(
2813  vwgt,
2814  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2815  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2816  _mm256_add_ps(vop0, vbio));
2817  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2818  vop8 = _mm256_fmadd_ps(
2819  vwgt,
2820  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2821  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2822  _mm256_add_ps(vop8, vbio));
2823  // skip unnecessary prefetch of (&ip_next_T0[8])
2824  vop16 = _mm256_fmadd_ps(
2825  vwgt,
2826  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2827  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2828  _mm256_add_ps(vop16, vbio));
2829  // skip unnecessary prefetch of (&ip_next_T0[16])
2830  vop24 = _mm256_fmadd_ps(
2831  vwgt,
2832  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2833  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2834  _mm256_add_ps(vop24, vbio));
2835  // skip unnecessary prefetch of (&ip_next_T0[24])
2836  }
2837  if (normalize_by_lengths == false) {
2838  _mm256_storeu_ps(&op[0], vop0);
2839  _mm256_storeu_ps(&op[8], vop8);
2840  _mm256_storeu_ps(&op[16], vop16);
2841  _mm256_storeu_ps(&op[24], vop24);
2842  } else if (lengths[rangeIndex]) {
2843  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2844  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2845  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2846  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2847  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2848  }
2849  }
2850  } else if (block_size == 16) {
2851  // unrolling 2 times
2852  int64_t dataInd = 0;
2853  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2854  float* op = &out[rangeIndex * block_size];
2855  __m256 vop0 = _mm256_setzero_ps();
2856  __m256 vop8 = _mm256_setzero_ps();
2857  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2858  ++dataInd) {
2859  const int64_t idx = indices[dataInd];
2860  CAFFE_ENFORCE(
2861  idx >= 0 && idx < data_size,
2862  "Index ",
2863  dataInd,
2864  " is out of bounds: ",
2865  idx,
2866  ", range 0 to ",
2867  data_size);
2868  float wgt = 1.f;
2869  float bio;
2870  if (weights) {
2871  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2872  }
2873  const float* scale_bias = reinterpret_cast<const float*>(
2874  &input[idx * fused_block_size + block_size]);
2875  bio = wgt * scale_bias[1];
2876  wgt = wgt * scale_bias[0];
2877  __m256 vbio = _mm256_set1_ps(bio);
2878  __m256 vwgt = _mm256_set1_ps(wgt);
2879  const uint8_t* ip = &input[idx * fused_block_size];
2880  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2881  ? (dataInd + prefdist_T0)
2882  : dataInd;
2883  const int64_t idx_pref_T0 = indices[next_T0];
2884  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2885  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2886  vop0 = _mm256_fmadd_ps(
2887  vwgt,
2888  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2889  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2890  _mm256_add_ps(vop0, vbio));
2891  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2892  vop8 = _mm256_fmadd_ps(
2893  vwgt,
2894  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2895  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2896  _mm256_add_ps(vop8, vbio));
2897  // skip unnecessary prefetch of (&ip_next_T0[8])
2898  }
2899  if (normalize_by_lengths == false) {
2900  _mm256_storeu_ps(&op[0], vop0);
2901  _mm256_storeu_ps(&op[8], vop8);
2902  } else if (lengths[rangeIndex]) {
2903  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2904  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2905  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2906  }
2907  }
2908  } else {
2909  // generic code
2910  int64_t dataInd = 0;
2911  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2912  float* op = &out[rangeIndex * block_size];
2913  TIndex j = 0;
2914  for (; j + 8 <= block_size; j += 8) {
2915  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2916  }
2917  for (; j < block_size; j++) {
2918  op[j] = 0.0f;
2919  }
2920  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2921  ++dataInd) {
2922  const int64_t idx = indices[dataInd];
2923  CAFFE_ENFORCE(
2924  idx >= 0 && idx < data_size,
2925  "Index ",
2926  dataInd,
2927  " is out of bounds: ",
2928  idx,
2929  ", range 0 to ",
2930  data_size);
2931  float wgt = 1.f;
2932  float bio;
2933  if (weights) {
2934  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2935  }
2936  const float* scale_bias = reinterpret_cast<const float*>(
2937  &input[idx * fused_block_size + block_size]);
2938  bio = wgt * scale_bias[1];
2939  wgt = wgt * scale_bias[0];
2940  __m256 vbio = _mm256_set1_ps(bio);
2941  __m256 vwgt = _mm256_set1_ps(wgt);
2942  const uint8_t* ip = &input[idx * fused_block_size];
2943  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2944  ? (dataInd + prefdist_T0)
2945  : dataInd;
2946  const int64_t idx_pref_T0 = indices[next_T0];
2947  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2948  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2949  j = 0;
2950  for (; j + 8 <= block_size; j += 8) {
2951  _mm256_storeu_ps(
2952  &op[j],
2953  _mm256_fmadd_ps(
2954  vwgt,
2955  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2956  reinterpret_cast<const __m128i*>(&ip[j])))),
2957  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2958  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
2959  }
2960  for (; j < block_size; j++) {
2961  op[j] += wgt * ((float)ip[j]) + bio;
2962  }
2963  }
2964  if (normalize_by_lengths && lengths[rangeIndex]) {
2965  float len_inv = 1.0f / lengths[rangeIndex];
2966  __m256 vlen_inv = _mm256_set1_ps(len_inv);
2967  j = 0;
2968  for (; j + 8 <= block_size; j += 8) {
2969  _mm256_storeu_ps(
2970  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2971  }
2972  for (; j < block_size; j++) {
2973  op[j] = len_inv * op[j];
2974  }
2975  }
2976  }
2977  }
2978 }
2979 void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma(
2980  const TIndex block_size,
2981  const TIndex output_size,
2982  const TIndex index_size,
2983  const TIndex data_size,
2984  const uint8_t* input,
2985  const int64_t* indices,
2986  const int* lengths,
2987  const float* weights,
2988  bool normalize_by_lengths,
2989  float* out) {
2990  Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma<false>(
2991  block_size,
2992  output_size,
2993  index_size,
2994  data_size,
2995  input,
2996  indices,
2997  lengths,
2998  weights,
2999  normalize_by_lengths,
3000  out);
3001 }
3002 void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma(
3003  const TIndex block_size,
3004  const TIndex output_size,
3005  const TIndex index_size,
3006  const TIndex data_size,
3007  const uint8_t* input,
3008  const int64_t* indices,
3009  const int* lengths,
3010  const float* weights,
3011  bool normalize_by_lengths,
3012  float* out) {
3013  Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma<true>(
3014  block_size,
3015  output_size,
3016  index_size,
3017  data_size,
3018  input,
3019  indices,
3020  lengths,
3021  weights,
3022  normalize_by_lengths,
3023  out);
3024 }
3025 
3026 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...