Caffe2 - C++ API
A deep learning, cross platform ML framework
embedding_lookup_avx2.cc
1 
8 #include <caffe2/core/common.h>
9 #include <caffe2/core/types.h>
10 #include <immintrin.h>
11 
12 namespace caffe2 {
13 
14 template <bool IS_WEIGHT_POSITIONAL>
15 static void EmbeddingLookup_int32_t_float_float__avx2_fma(
16  const TIndex block_size,
17  const TIndex output_size,
18  const TIndex index_size,
19  const TIndex data_size,
20  const float* input,
21  const int32_t* indices,
22  const int* lengths,
23  const float* weights,
24  const float* scale_bias,
25  bool normalize_by_lengths,
26  float* out) {
27  const int32_t prefdist_T0 = 16;
28  const int32_t fused_block_size = block_size + 0;
29  CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr");
30  if (block_size == 128) {
31  // unrolling 16 times
32  int32_t dataInd = 0;
33  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
34  float* op = &out[rangeIndex * block_size];
35  __m256 vop0 = _mm256_setzero_ps();
36  __m256 vop8 = _mm256_setzero_ps();
37  __m256 vop16 = _mm256_setzero_ps();
38  __m256 vop24 = _mm256_setzero_ps();
39  __m256 vop32 = _mm256_setzero_ps();
40  __m256 vop40 = _mm256_setzero_ps();
41  __m256 vop48 = _mm256_setzero_ps();
42  __m256 vop56 = _mm256_setzero_ps();
43  __m256 vop64 = _mm256_setzero_ps();
44  __m256 vop72 = _mm256_setzero_ps();
45  __m256 vop80 = _mm256_setzero_ps();
46  __m256 vop88 = _mm256_setzero_ps();
47  __m256 vop96 = _mm256_setzero_ps();
48  __m256 vop104 = _mm256_setzero_ps();
49  __m256 vop112 = _mm256_setzero_ps();
50  __m256 vop120 = _mm256_setzero_ps();
51  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
52  ++dataInd) {
53  const int32_t idx = indices[dataInd];
54  CAFFE_ENFORCE(
55  idx >= 0 && idx < data_size,
56  "Index ",
57  dataInd,
58  " is out of bounds: ",
59  idx,
60  ", range 0 to ",
61  data_size);
62  float wgt = 1.f;
63  if (weights) {
64  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
65  }
66  __m256 vwgt = _mm256_set1_ps(wgt);
67  const float* ip = &input[idx * fused_block_size];
68  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
69  ? (dataInd + prefdist_T0)
70  : dataInd;
71  const int32_t idx_pref_T0 = indices[next_T0];
72  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
73  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
74  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
75  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
76  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
77  // skip unnecessary prefetch of (&ip_next_T0[8])
78  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
79  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
80  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
81  // skip unnecessary prefetch of (&ip_next_T0[24])
82  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
83  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
84  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
85  // skip unnecessary prefetch of (&ip_next_T0[40])
86  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
87  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
88  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
89  // skip unnecessary prefetch of (&ip_next_T0[56])
90  vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
91  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
92  vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
93  // skip unnecessary prefetch of (&ip_next_T0[72])
94  vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
95  _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
96  vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
97  // skip unnecessary prefetch of (&ip_next_T0[88])
98  vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
99  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
100  vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
101  // skip unnecessary prefetch of (&ip_next_T0[104])
102  vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
103  _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
104  vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
105  // skip unnecessary prefetch of (&ip_next_T0[120])
106  }
107  if (normalize_by_lengths == false) {
108  _mm256_storeu_ps(&op[0], vop0);
109  _mm256_storeu_ps(&op[8], vop8);
110  _mm256_storeu_ps(&op[16], vop16);
111  _mm256_storeu_ps(&op[24], vop24);
112  _mm256_storeu_ps(&op[32], vop32);
113  _mm256_storeu_ps(&op[40], vop40);
114  _mm256_storeu_ps(&op[48], vop48);
115  _mm256_storeu_ps(&op[56], vop56);
116  _mm256_storeu_ps(&op[64], vop64);
117  _mm256_storeu_ps(&op[72], vop72);
118  _mm256_storeu_ps(&op[80], vop80);
119  _mm256_storeu_ps(&op[88], vop88);
120  _mm256_storeu_ps(&op[96], vop96);
121  _mm256_storeu_ps(&op[104], vop104);
122  _mm256_storeu_ps(&op[112], vop112);
123  _mm256_storeu_ps(&op[120], vop120);
124  } else if (lengths[rangeIndex]) {
125  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
126  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
127  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
128  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
129  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
130  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
131  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
132  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
133  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
134  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
135  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
136  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
137  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
138  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
139  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
140  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
141  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
142  }
143  }
144  } else if (block_size == 64) {
145  // unrolling 8 times
146  int32_t dataInd = 0;
147  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
148  float* op = &out[rangeIndex * block_size];
149  __m256 vop0 = _mm256_setzero_ps();
150  __m256 vop8 = _mm256_setzero_ps();
151  __m256 vop16 = _mm256_setzero_ps();
152  __m256 vop24 = _mm256_setzero_ps();
153  __m256 vop32 = _mm256_setzero_ps();
154  __m256 vop40 = _mm256_setzero_ps();
155  __m256 vop48 = _mm256_setzero_ps();
156  __m256 vop56 = _mm256_setzero_ps();
157  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
158  ++dataInd) {
159  const int32_t idx = indices[dataInd];
160  CAFFE_ENFORCE(
161  idx >= 0 && idx < data_size,
162  "Index ",
163  dataInd,
164  " is out of bounds: ",
165  idx,
166  ", range 0 to ",
167  data_size);
168  float wgt = 1.f;
169  if (weights) {
170  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
171  }
172  __m256 vwgt = _mm256_set1_ps(wgt);
173  const float* ip = &input[idx * fused_block_size];
174  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
175  ? (dataInd + prefdist_T0)
176  : dataInd;
177  const int32_t idx_pref_T0 = indices[next_T0];
178  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
179  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
180  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
181  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
182  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
183  // skip unnecessary prefetch of (&ip_next_T0[8])
184  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
185  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
186  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
187  // skip unnecessary prefetch of (&ip_next_T0[24])
188  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
189  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
190  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
191  // skip unnecessary prefetch of (&ip_next_T0[40])
192  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
193  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
194  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
195  // skip unnecessary prefetch of (&ip_next_T0[56])
196  }
197  if (normalize_by_lengths == false) {
198  _mm256_storeu_ps(&op[0], vop0);
199  _mm256_storeu_ps(&op[8], vop8);
200  _mm256_storeu_ps(&op[16], vop16);
201  _mm256_storeu_ps(&op[24], vop24);
202  _mm256_storeu_ps(&op[32], vop32);
203  _mm256_storeu_ps(&op[40], vop40);
204  _mm256_storeu_ps(&op[48], vop48);
205  _mm256_storeu_ps(&op[56], vop56);
206  } else if (lengths[rangeIndex]) {
207  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
208  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
209  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
210  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
211  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
212  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
213  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
214  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
215  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
216  }
217  }
218  } else if (block_size == 32) {
219  // unrolling 4 times
220  int32_t dataInd = 0;
221  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
222  float* op = &out[rangeIndex * block_size];
223  __m256 vop0 = _mm256_setzero_ps();
224  __m256 vop8 = _mm256_setzero_ps();
225  __m256 vop16 = _mm256_setzero_ps();
226  __m256 vop24 = _mm256_setzero_ps();
227  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
228  ++dataInd) {
229  const int32_t idx = indices[dataInd];
230  CAFFE_ENFORCE(
231  idx >= 0 && idx < data_size,
232  "Index ",
233  dataInd,
234  " is out of bounds: ",
235  idx,
236  ", range 0 to ",
237  data_size);
238  float wgt = 1.f;
239  if (weights) {
240  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
241  }
242  __m256 vwgt = _mm256_set1_ps(wgt);
243  const float* ip = &input[idx * fused_block_size];
244  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
245  ? (dataInd + prefdist_T0)
246  : dataInd;
247  const int32_t idx_pref_T0 = indices[next_T0];
248  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
249  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
250  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
251  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
252  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
253  // skip unnecessary prefetch of (&ip_next_T0[8])
254  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
255  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
256  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
257  // skip unnecessary prefetch of (&ip_next_T0[24])
258  }
259  if (normalize_by_lengths == false) {
260  _mm256_storeu_ps(&op[0], vop0);
261  _mm256_storeu_ps(&op[8], vop8);
262  _mm256_storeu_ps(&op[16], vop16);
263  _mm256_storeu_ps(&op[24], vop24);
264  } else if (lengths[rangeIndex]) {
265  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
266  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
267  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
268  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
269  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
270  }
271  }
272  } else if (block_size == 16) {
273  // unrolling 2 times
274  int32_t dataInd = 0;
275  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
276  float* op = &out[rangeIndex * block_size];
277  __m256 vop0 = _mm256_setzero_ps();
278  __m256 vop8 = _mm256_setzero_ps();
279  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
280  ++dataInd) {
281  const int32_t idx = indices[dataInd];
282  CAFFE_ENFORCE(
283  idx >= 0 && idx < data_size,
284  "Index ",
285  dataInd,
286  " is out of bounds: ",
287  idx,
288  ", range 0 to ",
289  data_size);
290  float wgt = 1.f;
291  if (weights) {
292  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
293  }
294  __m256 vwgt = _mm256_set1_ps(wgt);
295  const float* ip = &input[idx * fused_block_size];
296  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
297  ? (dataInd + prefdist_T0)
298  : dataInd;
299  const int32_t idx_pref_T0 = indices[next_T0];
300  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
301  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
302  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
303  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
304  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
305  // skip unnecessary prefetch of (&ip_next_T0[8])
306  }
307  if (normalize_by_lengths == false) {
308  _mm256_storeu_ps(&op[0], vop0);
309  _mm256_storeu_ps(&op[8], vop8);
310  } else if (lengths[rangeIndex]) {
311  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
312  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
313  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
314  }
315  }
316  } else {
317  // generic code
318  int32_t dataInd = 0;
319  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
320  float* op = &out[rangeIndex * block_size];
321  TIndex j = 0;
322  for (; j + 8 <= block_size; j += 8) {
323  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
324  }
325  for (; j < block_size; j++) {
326  op[j] = 0.0f;
327  }
328  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
329  ++dataInd) {
330  const int32_t idx = indices[dataInd];
331  CAFFE_ENFORCE(
332  idx >= 0 && idx < data_size,
333  "Index ",
334  dataInd,
335  " is out of bounds: ",
336  idx,
337  ", range 0 to ",
338  data_size);
339  float wgt = 1.f;
340  if (weights) {
341  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
342  }
343  __m256 vwgt = _mm256_set1_ps(wgt);
344  const float* ip = &input[idx * fused_block_size];
345  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
346  ? (dataInd + prefdist_T0)
347  : dataInd;
348  const int32_t idx_pref_T0 = indices[next_T0];
349  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
350  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
351  j = 0;
352  for (; j + 8 <= block_size; j += 8) {
353  _mm256_storeu_ps(
354  &op[j],
355  _mm256_fmadd_ps(
356  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
357  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
358  }
359  for (; j < block_size; j++) {
360  op[j] += wgt * ip[j];
361  }
362  }
363  if (normalize_by_lengths && lengths[rangeIndex]) {
364  float len_inv = 1.0f / lengths[rangeIndex];
365  __m256 vlen_inv = _mm256_set1_ps(len_inv);
366  j = 0;
367  for (; j + 8 <= block_size; j += 8) {
368  _mm256_storeu_ps(
369  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
370  }
371  for (; j < block_size; j++) {
372  op[j] = len_inv * op[j];
373  }
374  }
375  }
376  }
377 }
378 void EmbeddingLookup_int32_t_float_float_false__avx2_fma(
379  const TIndex block_size,
380  const TIndex output_size,
381  const TIndex index_size,
382  const TIndex data_size,
383  const float* input,
384  const int32_t* indices,
385  const int* lengths,
386  const float* weights,
387  const float* scale_bias,
388  bool normalize_by_lengths,
389  float* out) {
390  EmbeddingLookup_int32_t_float_float__avx2_fma<false>(
391  block_size,
392  output_size,
393  index_size,
394  data_size,
395  input,
396  indices,
397  lengths,
398  weights,
399  scale_bias,
400  normalize_by_lengths,
401  out);
402 }
403 void EmbeddingLookup_int32_t_float_float_true__avx2_fma(
404  const TIndex block_size,
405  const TIndex output_size,
406  const TIndex index_size,
407  const TIndex data_size,
408  const float* input,
409  const int32_t* indices,
410  const int* lengths,
411  const float* weights,
412  const float* scale_bias,
413  bool normalize_by_lengths,
414  float* out) {
415  EmbeddingLookup_int32_t_float_float__avx2_fma<true>(
416  block_size,
417  output_size,
418  index_size,
419  data_size,
420  input,
421  indices,
422  lengths,
423  weights,
424  scale_bias,
425  normalize_by_lengths,
426  out);
427 }
428 
429 template <bool IS_WEIGHT_POSITIONAL>
430 static void EmbeddingLookup_int64_t_float_float__avx2_fma(
431  const TIndex block_size,
432  const TIndex output_size,
433  const TIndex index_size,
434  const TIndex data_size,
435  const float* input,
436  const int64_t* indices,
437  const int* lengths,
438  const float* weights,
439  const float* scale_bias,
440  bool normalize_by_lengths,
441  float* out) {
442  const int64_t prefdist_T0 = 16;
443  const int64_t fused_block_size = block_size + 0;
444  CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr");
445  if (block_size == 128) {
446  // unrolling 16 times
447  int64_t dataInd = 0;
448  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
449  float* op = &out[rangeIndex * block_size];
450  __m256 vop0 = _mm256_setzero_ps();
451  __m256 vop8 = _mm256_setzero_ps();
452  __m256 vop16 = _mm256_setzero_ps();
453  __m256 vop24 = _mm256_setzero_ps();
454  __m256 vop32 = _mm256_setzero_ps();
455  __m256 vop40 = _mm256_setzero_ps();
456  __m256 vop48 = _mm256_setzero_ps();
457  __m256 vop56 = _mm256_setzero_ps();
458  __m256 vop64 = _mm256_setzero_ps();
459  __m256 vop72 = _mm256_setzero_ps();
460  __m256 vop80 = _mm256_setzero_ps();
461  __m256 vop88 = _mm256_setzero_ps();
462  __m256 vop96 = _mm256_setzero_ps();
463  __m256 vop104 = _mm256_setzero_ps();
464  __m256 vop112 = _mm256_setzero_ps();
465  __m256 vop120 = _mm256_setzero_ps();
466  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
467  ++dataInd) {
468  const int64_t idx = indices[dataInd];
469  CAFFE_ENFORCE(
470  idx >= 0 && idx < data_size,
471  "Index ",
472  dataInd,
473  " is out of bounds: ",
474  idx,
475  ", range 0 to ",
476  data_size);
477  float wgt = 1.f;
478  if (weights) {
479  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
480  }
481  __m256 vwgt = _mm256_set1_ps(wgt);
482  const float* ip = &input[idx * fused_block_size];
483  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
484  ? (dataInd + prefdist_T0)
485  : dataInd;
486  const int64_t idx_pref_T0 = indices[next_T0];
487  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
488  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
489  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
490  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
491  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
492  // skip unnecessary prefetch of (&ip_next_T0[8])
493  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
494  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
495  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
496  // skip unnecessary prefetch of (&ip_next_T0[24])
497  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
498  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
499  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
500  // skip unnecessary prefetch of (&ip_next_T0[40])
501  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
502  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
503  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
504  // skip unnecessary prefetch of (&ip_next_T0[56])
505  vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
506  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
507  vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
508  // skip unnecessary prefetch of (&ip_next_T0[72])
509  vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
510  _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
511  vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
512  // skip unnecessary prefetch of (&ip_next_T0[88])
513  vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
514  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
515  vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
516  // skip unnecessary prefetch of (&ip_next_T0[104])
517  vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
518  _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
519  vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
520  // skip unnecessary prefetch of (&ip_next_T0[120])
521  }
522  if (normalize_by_lengths == false) {
523  _mm256_storeu_ps(&op[0], vop0);
524  _mm256_storeu_ps(&op[8], vop8);
525  _mm256_storeu_ps(&op[16], vop16);
526  _mm256_storeu_ps(&op[24], vop24);
527  _mm256_storeu_ps(&op[32], vop32);
528  _mm256_storeu_ps(&op[40], vop40);
529  _mm256_storeu_ps(&op[48], vop48);
530  _mm256_storeu_ps(&op[56], vop56);
531  _mm256_storeu_ps(&op[64], vop64);
532  _mm256_storeu_ps(&op[72], vop72);
533  _mm256_storeu_ps(&op[80], vop80);
534  _mm256_storeu_ps(&op[88], vop88);
535  _mm256_storeu_ps(&op[96], vop96);
536  _mm256_storeu_ps(&op[104], vop104);
537  _mm256_storeu_ps(&op[112], vop112);
538  _mm256_storeu_ps(&op[120], vop120);
539  } else if (lengths[rangeIndex]) {
540  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
541  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
542  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
543  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
544  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
545  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
546  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
547  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
548  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
549  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
550  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
551  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
552  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
553  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
554  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
555  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
556  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
557  }
558  }
559  } else if (block_size == 64) {
560  // unrolling 8 times
561  int64_t dataInd = 0;
562  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
563  float* op = &out[rangeIndex * block_size];
564  __m256 vop0 = _mm256_setzero_ps();
565  __m256 vop8 = _mm256_setzero_ps();
566  __m256 vop16 = _mm256_setzero_ps();
567  __m256 vop24 = _mm256_setzero_ps();
568  __m256 vop32 = _mm256_setzero_ps();
569  __m256 vop40 = _mm256_setzero_ps();
570  __m256 vop48 = _mm256_setzero_ps();
571  __m256 vop56 = _mm256_setzero_ps();
572  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
573  ++dataInd) {
574  const int64_t idx = indices[dataInd];
575  CAFFE_ENFORCE(
576  idx >= 0 && idx < data_size,
577  "Index ",
578  dataInd,
579  " is out of bounds: ",
580  idx,
581  ", range 0 to ",
582  data_size);
583  float wgt = 1.f;
584  if (weights) {
585  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
586  }
587  __m256 vwgt = _mm256_set1_ps(wgt);
588  const float* ip = &input[idx * fused_block_size];
589  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
590  ? (dataInd + prefdist_T0)
591  : dataInd;
592  const int64_t idx_pref_T0 = indices[next_T0];
593  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
594  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
595  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
596  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
597  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
598  // skip unnecessary prefetch of (&ip_next_T0[8])
599  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
600  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
601  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
602  // skip unnecessary prefetch of (&ip_next_T0[24])
603  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
604  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
605  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
606  // skip unnecessary prefetch of (&ip_next_T0[40])
607  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
608  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
609  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
610  // skip unnecessary prefetch of (&ip_next_T0[56])
611  }
612  if (normalize_by_lengths == false) {
613  _mm256_storeu_ps(&op[0], vop0);
614  _mm256_storeu_ps(&op[8], vop8);
615  _mm256_storeu_ps(&op[16], vop16);
616  _mm256_storeu_ps(&op[24], vop24);
617  _mm256_storeu_ps(&op[32], vop32);
618  _mm256_storeu_ps(&op[40], vop40);
619  _mm256_storeu_ps(&op[48], vop48);
620  _mm256_storeu_ps(&op[56], vop56);
621  } else if (lengths[rangeIndex]) {
622  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
623  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
624  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
625  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
626  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
627  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
628  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
629  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
630  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
631  }
632  }
633  } else if (block_size == 32) {
634  // unrolling 4 times
635  int64_t dataInd = 0;
636  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
637  float* op = &out[rangeIndex * block_size];
638  __m256 vop0 = _mm256_setzero_ps();
639  __m256 vop8 = _mm256_setzero_ps();
640  __m256 vop16 = _mm256_setzero_ps();
641  __m256 vop24 = _mm256_setzero_ps();
642  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
643  ++dataInd) {
644  const int64_t idx = indices[dataInd];
645  CAFFE_ENFORCE(
646  idx >= 0 && idx < data_size,
647  "Index ",
648  dataInd,
649  " is out of bounds: ",
650  idx,
651  ", range 0 to ",
652  data_size);
653  float wgt = 1.f;
654  if (weights) {
655  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
656  }
657  __m256 vwgt = _mm256_set1_ps(wgt);
658  const float* ip = &input[idx * fused_block_size];
659  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
660  ? (dataInd + prefdist_T0)
661  : dataInd;
662  const int64_t idx_pref_T0 = indices[next_T0];
663  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
664  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
665  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
666  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
667  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
668  // skip unnecessary prefetch of (&ip_next_T0[8])
669  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
670  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
671  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
672  // skip unnecessary prefetch of (&ip_next_T0[24])
673  }
674  if (normalize_by_lengths == false) {
675  _mm256_storeu_ps(&op[0], vop0);
676  _mm256_storeu_ps(&op[8], vop8);
677  _mm256_storeu_ps(&op[16], vop16);
678  _mm256_storeu_ps(&op[24], vop24);
679  } else if (lengths[rangeIndex]) {
680  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
681  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
682  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
683  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
684  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
685  }
686  }
687  } else if (block_size == 16) {
688  // unrolling 2 times
689  int64_t dataInd = 0;
690  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
691  float* op = &out[rangeIndex * block_size];
692  __m256 vop0 = _mm256_setzero_ps();
693  __m256 vop8 = _mm256_setzero_ps();
694  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
695  ++dataInd) {
696  const int64_t idx = indices[dataInd];
697  CAFFE_ENFORCE(
698  idx >= 0 && idx < data_size,
699  "Index ",
700  dataInd,
701  " is out of bounds: ",
702  idx,
703  ", range 0 to ",
704  data_size);
705  float wgt = 1.f;
706  if (weights) {
707  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
708  }
709  __m256 vwgt = _mm256_set1_ps(wgt);
710  const float* ip = &input[idx * fused_block_size];
711  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
712  ? (dataInd + prefdist_T0)
713  : dataInd;
714  const int64_t idx_pref_T0 = indices[next_T0];
715  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
716  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
717  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
718  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
719  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
720  // skip unnecessary prefetch of (&ip_next_T0[8])
721  }
722  if (normalize_by_lengths == false) {
723  _mm256_storeu_ps(&op[0], vop0);
724  _mm256_storeu_ps(&op[8], vop8);
725  } else if (lengths[rangeIndex]) {
726  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
727  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
728  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
729  }
730  }
731  } else {
732  // generic code
733  int64_t dataInd = 0;
734  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
735  float* op = &out[rangeIndex * block_size];
736  TIndex j = 0;
737  for (; j + 8 <= block_size; j += 8) {
738  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
739  }
740  for (; j < block_size; j++) {
741  op[j] = 0.0f;
742  }
743  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
744  ++dataInd) {
745  const int64_t idx = indices[dataInd];
746  CAFFE_ENFORCE(
747  idx >= 0 && idx < data_size,
748  "Index ",
749  dataInd,
750  " is out of bounds: ",
751  idx,
752  ", range 0 to ",
753  data_size);
754  float wgt = 1.f;
755  if (weights) {
756  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
757  }
758  __m256 vwgt = _mm256_set1_ps(wgt);
759  const float* ip = &input[idx * fused_block_size];
760  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
761  ? (dataInd + prefdist_T0)
762  : dataInd;
763  const int64_t idx_pref_T0 = indices[next_T0];
764  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
765  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
766  j = 0;
767  for (; j + 8 <= block_size; j += 8) {
768  _mm256_storeu_ps(
769  &op[j],
770  _mm256_fmadd_ps(
771  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
772  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
773  }
774  for (; j < block_size; j++) {
775  op[j] += wgt * ip[j];
776  }
777  }
778  if (normalize_by_lengths && lengths[rangeIndex]) {
779  float len_inv = 1.0f / lengths[rangeIndex];
780  __m256 vlen_inv = _mm256_set1_ps(len_inv);
781  j = 0;
782  for (; j + 8 <= block_size; j += 8) {
783  _mm256_storeu_ps(
784  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
785  }
786  for (; j < block_size; j++) {
787  op[j] = len_inv * op[j];
788  }
789  }
790  }
791  }
792 }
793 void EmbeddingLookup_int64_t_float_float_false__avx2_fma(
794  const TIndex block_size,
795  const TIndex output_size,
796  const TIndex index_size,
797  const TIndex data_size,
798  const float* input,
799  const int64_t* indices,
800  const int* lengths,
801  const float* weights,
802  const float* scale_bias,
803  bool normalize_by_lengths,
804  float* out) {
805  EmbeddingLookup_int64_t_float_float__avx2_fma<false>(
806  block_size,
807  output_size,
808  index_size,
809  data_size,
810  input,
811  indices,
812  lengths,
813  weights,
814  scale_bias,
815  normalize_by_lengths,
816  out);
817 }
818 void EmbeddingLookup_int64_t_float_float_true__avx2_fma(
819  const TIndex block_size,
820  const TIndex output_size,
821  const TIndex index_size,
822  const TIndex data_size,
823  const float* input,
824  const int64_t* indices,
825  const int* lengths,
826  const float* weights,
827  const float* scale_bias,
828  bool normalize_by_lengths,
829  float* out) {
830  EmbeddingLookup_int64_t_float_float__avx2_fma<true>(
831  block_size,
832  output_size,
833  index_size,
834  data_size,
835  input,
836  indices,
837  lengths,
838  weights,
839  scale_bias,
840  normalize_by_lengths,
841  out);
842 }
843 
844 template <bool IS_WEIGHT_POSITIONAL>
845 static void EmbeddingLookup_int32_t_float16_float__avx2_fma(
846  const TIndex block_size,
847  const TIndex output_size,
848  const TIndex index_size,
849  const TIndex data_size,
850  const float16* input,
851  const int32_t* indices,
852  const int* lengths,
853  const float* weights,
854  const float* scale_bias,
855  bool normalize_by_lengths,
856  float* out) {
857  const int32_t prefdist_T0 = 16;
858  const int32_t fused_block_size = block_size + 0;
859  CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr");
860  if (block_size == 128) {
861  // unrolling 16 times
862  int32_t dataInd = 0;
863  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
864  float* op = &out[rangeIndex * block_size];
865  __m256 vop0 = _mm256_setzero_ps();
866  __m256 vop8 = _mm256_setzero_ps();
867  __m256 vop16 = _mm256_setzero_ps();
868  __m256 vop24 = _mm256_setzero_ps();
869  __m256 vop32 = _mm256_setzero_ps();
870  __m256 vop40 = _mm256_setzero_ps();
871  __m256 vop48 = _mm256_setzero_ps();
872  __m256 vop56 = _mm256_setzero_ps();
873  __m256 vop64 = _mm256_setzero_ps();
874  __m256 vop72 = _mm256_setzero_ps();
875  __m256 vop80 = _mm256_setzero_ps();
876  __m256 vop88 = _mm256_setzero_ps();
877  __m256 vop96 = _mm256_setzero_ps();
878  __m256 vop104 = _mm256_setzero_ps();
879  __m256 vop112 = _mm256_setzero_ps();
880  __m256 vop120 = _mm256_setzero_ps();
881  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
882  ++dataInd) {
883  const int32_t idx = indices[dataInd];
884  CAFFE_ENFORCE(
885  idx >= 0 && idx < data_size,
886  "Index ",
887  dataInd,
888  " is out of bounds: ",
889  idx,
890  ", range 0 to ",
891  data_size);
892  float wgt = 1.f;
893  if (weights) {
894  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
895  }
896  __m256 vwgt = _mm256_set1_ps(wgt);
897  const float16* ip = &input[idx * fused_block_size];
898  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
899  ? (dataInd + prefdist_T0)
900  : dataInd;
901  const int32_t idx_pref_T0 = indices[next_T0];
902  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
903  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
904  vop0 = _mm256_fmadd_ps(
905  vwgt,
906  _mm256_cvtph_ps(
907  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
908  vop0);
909  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
910  vop8 = _mm256_fmadd_ps(
911  vwgt,
912  _mm256_cvtph_ps(
913  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
914  vop8);
915  // skip unnecessary prefetch of (&ip_next_T0[8])
916  vop16 = _mm256_fmadd_ps(
917  vwgt,
918  _mm256_cvtph_ps(
919  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
920  vop16);
921  // skip unnecessary prefetch of (&ip_next_T0[16])
922  vop24 = _mm256_fmadd_ps(
923  vwgt,
924  _mm256_cvtph_ps(
925  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
926  vop24);
927  // skip unnecessary prefetch of (&ip_next_T0[24])
928  vop32 = _mm256_fmadd_ps(
929  vwgt,
930  _mm256_cvtph_ps(
931  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
932  vop32);
933  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
934  vop40 = _mm256_fmadd_ps(
935  vwgt,
936  _mm256_cvtph_ps(
937  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
938  vop40);
939  // skip unnecessary prefetch of (&ip_next_T0[40])
940  vop48 = _mm256_fmadd_ps(
941  vwgt,
942  _mm256_cvtph_ps(
943  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
944  vop48);
945  // skip unnecessary prefetch of (&ip_next_T0[48])
946  vop56 = _mm256_fmadd_ps(
947  vwgt,
948  _mm256_cvtph_ps(
949  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
950  vop56);
951  // skip unnecessary prefetch of (&ip_next_T0[56])
952  vop64 = _mm256_fmadd_ps(
953  vwgt,
954  _mm256_cvtph_ps(
955  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
956  vop64);
957  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
958  vop72 = _mm256_fmadd_ps(
959  vwgt,
960  _mm256_cvtph_ps(
961  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
962  vop72);
963  // skip unnecessary prefetch of (&ip_next_T0[72])
964  vop80 = _mm256_fmadd_ps(
965  vwgt,
966  _mm256_cvtph_ps(
967  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
968  vop80);
969  // skip unnecessary prefetch of (&ip_next_T0[80])
970  vop88 = _mm256_fmadd_ps(
971  vwgt,
972  _mm256_cvtph_ps(
973  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
974  vop88);
975  // skip unnecessary prefetch of (&ip_next_T0[88])
976  vop96 = _mm256_fmadd_ps(
977  vwgt,
978  _mm256_cvtph_ps(
979  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
980  vop96);
981  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
982  vop104 = _mm256_fmadd_ps(
983  vwgt,
984  _mm256_cvtph_ps(
985  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
986  vop104);
987  // skip unnecessary prefetch of (&ip_next_T0[104])
988  vop112 = _mm256_fmadd_ps(
989  vwgt,
990  _mm256_cvtph_ps(
991  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
992  vop112);
993  // skip unnecessary prefetch of (&ip_next_T0[112])
994  vop120 = _mm256_fmadd_ps(
995  vwgt,
996  _mm256_cvtph_ps(
997  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
998  vop120);
999  // skip unnecessary prefetch of (&ip_next_T0[120])
1000  }
1001  if (normalize_by_lengths == false) {
1002  _mm256_storeu_ps(&op[0], vop0);
1003  _mm256_storeu_ps(&op[8], vop8);
1004  _mm256_storeu_ps(&op[16], vop16);
1005  _mm256_storeu_ps(&op[24], vop24);
1006  _mm256_storeu_ps(&op[32], vop32);
1007  _mm256_storeu_ps(&op[40], vop40);
1008  _mm256_storeu_ps(&op[48], vop48);
1009  _mm256_storeu_ps(&op[56], vop56);
1010  _mm256_storeu_ps(&op[64], vop64);
1011  _mm256_storeu_ps(&op[72], vop72);
1012  _mm256_storeu_ps(&op[80], vop80);
1013  _mm256_storeu_ps(&op[88], vop88);
1014  _mm256_storeu_ps(&op[96], vop96);
1015  _mm256_storeu_ps(&op[104], vop104);
1016  _mm256_storeu_ps(&op[112], vop112);
1017  _mm256_storeu_ps(&op[120], vop120);
1018  } else if (lengths[rangeIndex]) {
1019  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1020  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1021  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1022  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1023  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1024  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1025  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1026  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1027  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1028  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1029  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1030  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1031  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1032  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1033  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1034  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1035  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1036  }
1037  }
1038  } else if (block_size == 64) {
1039  // unrolling 8 times
1040  int32_t dataInd = 0;
1041  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1042  float* op = &out[rangeIndex * block_size];
1043  __m256 vop0 = _mm256_setzero_ps();
1044  __m256 vop8 = _mm256_setzero_ps();
1045  __m256 vop16 = _mm256_setzero_ps();
1046  __m256 vop24 = _mm256_setzero_ps();
1047  __m256 vop32 = _mm256_setzero_ps();
1048  __m256 vop40 = _mm256_setzero_ps();
1049  __m256 vop48 = _mm256_setzero_ps();
1050  __m256 vop56 = _mm256_setzero_ps();
1051  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1052  ++dataInd) {
1053  const int32_t idx = indices[dataInd];
1054  CAFFE_ENFORCE(
1055  idx >= 0 && idx < data_size,
1056  "Index ",
1057  dataInd,
1058  " is out of bounds: ",
1059  idx,
1060  ", range 0 to ",
1061  data_size);
1062  float wgt = 1.f;
1063  if (weights) {
1064  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1065  }
1066  __m256 vwgt = _mm256_set1_ps(wgt);
1067  const float16* ip = &input[idx * fused_block_size];
1068  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1069  ? (dataInd + prefdist_T0)
1070  : dataInd;
1071  const int32_t idx_pref_T0 = indices[next_T0];
1072  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1073  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1074  vop0 = _mm256_fmadd_ps(
1075  vwgt,
1076  _mm256_cvtph_ps(
1077  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1078  vop0);
1079  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1080  vop8 = _mm256_fmadd_ps(
1081  vwgt,
1082  _mm256_cvtph_ps(
1083  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1084  vop8);
1085  // skip unnecessary prefetch of (&ip_next_T0[8])
1086  vop16 = _mm256_fmadd_ps(
1087  vwgt,
1088  _mm256_cvtph_ps(
1089  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1090  vop16);
1091  // skip unnecessary prefetch of (&ip_next_T0[16])
1092  vop24 = _mm256_fmadd_ps(
1093  vwgt,
1094  _mm256_cvtph_ps(
1095  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1096  vop24);
1097  // skip unnecessary prefetch of (&ip_next_T0[24])
1098  vop32 = _mm256_fmadd_ps(
1099  vwgt,
1100  _mm256_cvtph_ps(
1101  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1102  vop32);
1103  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1104  vop40 = _mm256_fmadd_ps(
1105  vwgt,
1106  _mm256_cvtph_ps(
1107  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1108  vop40);
1109  // skip unnecessary prefetch of (&ip_next_T0[40])
1110  vop48 = _mm256_fmadd_ps(
1111  vwgt,
1112  _mm256_cvtph_ps(
1113  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1114  vop48);
1115  // skip unnecessary prefetch of (&ip_next_T0[48])
1116  vop56 = _mm256_fmadd_ps(
1117  vwgt,
1118  _mm256_cvtph_ps(
1119  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1120  vop56);
1121  // skip unnecessary prefetch of (&ip_next_T0[56])
1122  }
1123  if (normalize_by_lengths == false) {
1124  _mm256_storeu_ps(&op[0], vop0);
1125  _mm256_storeu_ps(&op[8], vop8);
1126  _mm256_storeu_ps(&op[16], vop16);
1127  _mm256_storeu_ps(&op[24], vop24);
1128  _mm256_storeu_ps(&op[32], vop32);
1129  _mm256_storeu_ps(&op[40], vop40);
1130  _mm256_storeu_ps(&op[48], vop48);
1131  _mm256_storeu_ps(&op[56], vop56);
1132  } else if (lengths[rangeIndex]) {
1133  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1134  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1135  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1136  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1137  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1138  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1139  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1140  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1141  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1142  }
1143  }
1144  } else if (block_size == 32) {
1145  // unrolling 4 times
1146  int32_t dataInd = 0;
1147  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1148  float* op = &out[rangeIndex * block_size];
1149  __m256 vop0 = _mm256_setzero_ps();
1150  __m256 vop8 = _mm256_setzero_ps();
1151  __m256 vop16 = _mm256_setzero_ps();
1152  __m256 vop24 = _mm256_setzero_ps();
1153  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1154  ++dataInd) {
1155  const int32_t idx = indices[dataInd];
1156  CAFFE_ENFORCE(
1157  idx >= 0 && idx < data_size,
1158  "Index ",
1159  dataInd,
1160  " is out of bounds: ",
1161  idx,
1162  ", range 0 to ",
1163  data_size);
1164  float wgt = 1.f;
1165  if (weights) {
1166  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1167  }
1168  __m256 vwgt = _mm256_set1_ps(wgt);
1169  const float16* ip = &input[idx * fused_block_size];
1170  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1171  ? (dataInd + prefdist_T0)
1172  : dataInd;
1173  const int32_t idx_pref_T0 = indices[next_T0];
1174  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1175  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1176  vop0 = _mm256_fmadd_ps(
1177  vwgt,
1178  _mm256_cvtph_ps(
1179  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1180  vop0);
1181  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1182  vop8 = _mm256_fmadd_ps(
1183  vwgt,
1184  _mm256_cvtph_ps(
1185  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1186  vop8);
1187  // skip unnecessary prefetch of (&ip_next_T0[8])
1188  vop16 = _mm256_fmadd_ps(
1189  vwgt,
1190  _mm256_cvtph_ps(
1191  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1192  vop16);
1193  // skip unnecessary prefetch of (&ip_next_T0[16])
1194  vop24 = _mm256_fmadd_ps(
1195  vwgt,
1196  _mm256_cvtph_ps(
1197  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1198  vop24);
1199  // skip unnecessary prefetch of (&ip_next_T0[24])
1200  }
1201  if (normalize_by_lengths == false) {
1202  _mm256_storeu_ps(&op[0], vop0);
1203  _mm256_storeu_ps(&op[8], vop8);
1204  _mm256_storeu_ps(&op[16], vop16);
1205  _mm256_storeu_ps(&op[24], vop24);
1206  } else if (lengths[rangeIndex]) {
1207  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1208  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1209  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1210  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1211  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1212  }
1213  }
1214  } else if (block_size == 16) {
1215  // unrolling 2 times
1216  int32_t dataInd = 0;
1217  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1218  float* op = &out[rangeIndex * block_size];
1219  __m256 vop0 = _mm256_setzero_ps();
1220  __m256 vop8 = _mm256_setzero_ps();
1221  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1222  ++dataInd) {
1223  const int32_t idx = indices[dataInd];
1224  CAFFE_ENFORCE(
1225  idx >= 0 && idx < data_size,
1226  "Index ",
1227  dataInd,
1228  " is out of bounds: ",
1229  idx,
1230  ", range 0 to ",
1231  data_size);
1232  float wgt = 1.f;
1233  if (weights) {
1234  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1235  }
1236  __m256 vwgt = _mm256_set1_ps(wgt);
1237  const float16* ip = &input[idx * fused_block_size];
1238  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1239  ? (dataInd + prefdist_T0)
1240  : dataInd;
1241  const int32_t idx_pref_T0 = indices[next_T0];
1242  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1243  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1244  vop0 = _mm256_fmadd_ps(
1245  vwgt,
1246  _mm256_cvtph_ps(
1247  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1248  vop0);
1249  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1250  vop8 = _mm256_fmadd_ps(
1251  vwgt,
1252  _mm256_cvtph_ps(
1253  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1254  vop8);
1255  // skip unnecessary prefetch of (&ip_next_T0[8])
1256  }
1257  if (normalize_by_lengths == false) {
1258  _mm256_storeu_ps(&op[0], vop0);
1259  _mm256_storeu_ps(&op[8], vop8);
1260  } else if (lengths[rangeIndex]) {
1261  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1262  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1263  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1264  }
1265  }
1266  } else {
1267  // generic code
1268  int32_t dataInd = 0;
1269  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1270  float* op = &out[rangeIndex * block_size];
1271  TIndex j = 0;
1272  for (; j + 8 <= block_size; j += 8) {
1273  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1274  }
1275  for (; j < block_size; j++) {
1276  op[j] = 0.0f;
1277  }
1278  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1279  ++dataInd) {
1280  const int32_t idx = indices[dataInd];
1281  CAFFE_ENFORCE(
1282  idx >= 0 && idx < data_size,
1283  "Index ",
1284  dataInd,
1285  " is out of bounds: ",
1286  idx,
1287  ", range 0 to ",
1288  data_size);
1289  float wgt = 1.f;
1290  if (weights) {
1291  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1292  }
1293  __m256 vwgt = _mm256_set1_ps(wgt);
1294  const float16* ip = &input[idx * fused_block_size];
1295  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1296  ? (dataInd + prefdist_T0)
1297  : dataInd;
1298  const int32_t idx_pref_T0 = indices[next_T0];
1299  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1300  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1301  j = 0;
1302  for (; j + 8 <= block_size; j += 8) {
1303  _mm256_storeu_ps(
1304  &op[j],
1305  _mm256_fmadd_ps(
1306  vwgt,
1307  _mm256_cvtph_ps(_mm_loadu_si128(
1308  reinterpret_cast<const __m128i*>(&ip[j]))),
1309  _mm256_loadu_ps(&op[j])));
1310  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
1311  }
1312  float16 vtmp1[8] CAFFE2_ALIGNED(64);
1313  for (; j < block_size; j++) {
1314  vtmp1[0] = ip[j];
1315  __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1316  op[j] += wgt * ((float*)(&vtmp2))[0];
1317  }
1318  }
1319  if (normalize_by_lengths && lengths[rangeIndex]) {
1320  float len_inv = 1.0f / lengths[rangeIndex];
1321  __m256 vlen_inv = _mm256_set1_ps(len_inv);
1322  j = 0;
1323  for (; j + 8 <= block_size; j += 8) {
1324  _mm256_storeu_ps(
1325  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1326  }
1327  for (; j < block_size; j++) {
1328  op[j] = len_inv * op[j];
1329  }
1330  }
1331  }
1332  }
1333 }
1334 void EmbeddingLookup_int32_t_float16_float_false__avx2_fma(
1335  const TIndex block_size,
1336  const TIndex output_size,
1337  const TIndex index_size,
1338  const TIndex data_size,
1339  const float16* input,
1340  const int32_t* indices,
1341  const int* lengths,
1342  const float* weights,
1343  const float* scale_bias,
1344  bool normalize_by_lengths,
1345  float* out) {
1346  EmbeddingLookup_int32_t_float16_float__avx2_fma<false>(
1347  block_size,
1348  output_size,
1349  index_size,
1350  data_size,
1351  input,
1352  indices,
1353  lengths,
1354  weights,
1355  scale_bias,
1356  normalize_by_lengths,
1357  out);
1358 }
1359 void EmbeddingLookup_int32_t_float16_float_true__avx2_fma(
1360  const TIndex block_size,
1361  const TIndex output_size,
1362  const TIndex index_size,
1363  const TIndex data_size,
1364  const float16* input,
1365  const int32_t* indices,
1366  const int* lengths,
1367  const float* weights,
1368  const float* scale_bias,
1369  bool normalize_by_lengths,
1370  float* out) {
1371  EmbeddingLookup_int32_t_float16_float__avx2_fma<true>(
1372  block_size,
1373  output_size,
1374  index_size,
1375  data_size,
1376  input,
1377  indices,
1378  lengths,
1379  weights,
1380  scale_bias,
1381  normalize_by_lengths,
1382  out);
1383 }
1384 
1385 template <bool IS_WEIGHT_POSITIONAL>
1386 static void EmbeddingLookup_int64_t_float16_float__avx2_fma(
1387  const TIndex block_size,
1388  const TIndex output_size,
1389  const TIndex index_size,
1390  const TIndex data_size,
1391  const float16* input,
1392  const int64_t* indices,
1393  const int* lengths,
1394  const float* weights,
1395  const float* scale_bias,
1396  bool normalize_by_lengths,
1397  float* out) {
1398  const int64_t prefdist_T0 = 16;
1399  const int64_t fused_block_size = block_size + 0;
1400  CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr");
1401  if (block_size == 128) {
1402  // unrolling 16 times
1403  int64_t dataInd = 0;
1404  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1405  float* op = &out[rangeIndex * block_size];
1406  __m256 vop0 = _mm256_setzero_ps();
1407  __m256 vop8 = _mm256_setzero_ps();
1408  __m256 vop16 = _mm256_setzero_ps();
1409  __m256 vop24 = _mm256_setzero_ps();
1410  __m256 vop32 = _mm256_setzero_ps();
1411  __m256 vop40 = _mm256_setzero_ps();
1412  __m256 vop48 = _mm256_setzero_ps();
1413  __m256 vop56 = _mm256_setzero_ps();
1414  __m256 vop64 = _mm256_setzero_ps();
1415  __m256 vop72 = _mm256_setzero_ps();
1416  __m256 vop80 = _mm256_setzero_ps();
1417  __m256 vop88 = _mm256_setzero_ps();
1418  __m256 vop96 = _mm256_setzero_ps();
1419  __m256 vop104 = _mm256_setzero_ps();
1420  __m256 vop112 = _mm256_setzero_ps();
1421  __m256 vop120 = _mm256_setzero_ps();
1422  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1423  ++dataInd) {
1424  const int64_t idx = indices[dataInd];
1425  CAFFE_ENFORCE(
1426  idx >= 0 && idx < data_size,
1427  "Index ",
1428  dataInd,
1429  " is out of bounds: ",
1430  idx,
1431  ", range 0 to ",
1432  data_size);
1433  float wgt = 1.f;
1434  if (weights) {
1435  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1436  }
1437  __m256 vwgt = _mm256_set1_ps(wgt);
1438  const float16* ip = &input[idx * fused_block_size];
1439  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1440  ? (dataInd + prefdist_T0)
1441  : dataInd;
1442  const int64_t idx_pref_T0 = indices[next_T0];
1443  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1444  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1445  vop0 = _mm256_fmadd_ps(
1446  vwgt,
1447  _mm256_cvtph_ps(
1448  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1449  vop0);
1450  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1451  vop8 = _mm256_fmadd_ps(
1452  vwgt,
1453  _mm256_cvtph_ps(
1454  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1455  vop8);
1456  // skip unnecessary prefetch of (&ip_next_T0[8])
1457  vop16 = _mm256_fmadd_ps(
1458  vwgt,
1459  _mm256_cvtph_ps(
1460  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1461  vop16);
1462  // skip unnecessary prefetch of (&ip_next_T0[16])
1463  vop24 = _mm256_fmadd_ps(
1464  vwgt,
1465  _mm256_cvtph_ps(
1466  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1467  vop24);
1468  // skip unnecessary prefetch of (&ip_next_T0[24])
1469  vop32 = _mm256_fmadd_ps(
1470  vwgt,
1471  _mm256_cvtph_ps(
1472  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1473  vop32);
1474  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1475  vop40 = _mm256_fmadd_ps(
1476  vwgt,
1477  _mm256_cvtph_ps(
1478  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1479  vop40);
1480  // skip unnecessary prefetch of (&ip_next_T0[40])
1481  vop48 = _mm256_fmadd_ps(
1482  vwgt,
1483  _mm256_cvtph_ps(
1484  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1485  vop48);
1486  // skip unnecessary prefetch of (&ip_next_T0[48])
1487  vop56 = _mm256_fmadd_ps(
1488  vwgt,
1489  _mm256_cvtph_ps(
1490  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1491  vop56);
1492  // skip unnecessary prefetch of (&ip_next_T0[56])
1493  vop64 = _mm256_fmadd_ps(
1494  vwgt,
1495  _mm256_cvtph_ps(
1496  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1497  vop64);
1498  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
1499  vop72 = _mm256_fmadd_ps(
1500  vwgt,
1501  _mm256_cvtph_ps(
1502  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1503  vop72);
1504  // skip unnecessary prefetch of (&ip_next_T0[72])
1505  vop80 = _mm256_fmadd_ps(
1506  vwgt,
1507  _mm256_cvtph_ps(
1508  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1509  vop80);
1510  // skip unnecessary prefetch of (&ip_next_T0[80])
1511  vop88 = _mm256_fmadd_ps(
1512  vwgt,
1513  _mm256_cvtph_ps(
1514  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1515  vop88);
1516  // skip unnecessary prefetch of (&ip_next_T0[88])
1517  vop96 = _mm256_fmadd_ps(
1518  vwgt,
1519  _mm256_cvtph_ps(
1520  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1521  vop96);
1522  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
1523  vop104 = _mm256_fmadd_ps(
1524  vwgt,
1525  _mm256_cvtph_ps(
1526  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1527  vop104);
1528  // skip unnecessary prefetch of (&ip_next_T0[104])
1529  vop112 = _mm256_fmadd_ps(
1530  vwgt,
1531  _mm256_cvtph_ps(
1532  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1533  vop112);
1534  // skip unnecessary prefetch of (&ip_next_T0[112])
1535  vop120 = _mm256_fmadd_ps(
1536  vwgt,
1537  _mm256_cvtph_ps(
1538  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1539  vop120);
1540  // skip unnecessary prefetch of (&ip_next_T0[120])
1541  }
1542  if (normalize_by_lengths == false) {
1543  _mm256_storeu_ps(&op[0], vop0);
1544  _mm256_storeu_ps(&op[8], vop8);
1545  _mm256_storeu_ps(&op[16], vop16);
1546  _mm256_storeu_ps(&op[24], vop24);
1547  _mm256_storeu_ps(&op[32], vop32);
1548  _mm256_storeu_ps(&op[40], vop40);
1549  _mm256_storeu_ps(&op[48], vop48);
1550  _mm256_storeu_ps(&op[56], vop56);
1551  _mm256_storeu_ps(&op[64], vop64);
1552  _mm256_storeu_ps(&op[72], vop72);
1553  _mm256_storeu_ps(&op[80], vop80);
1554  _mm256_storeu_ps(&op[88], vop88);
1555  _mm256_storeu_ps(&op[96], vop96);
1556  _mm256_storeu_ps(&op[104], vop104);
1557  _mm256_storeu_ps(&op[112], vop112);
1558  _mm256_storeu_ps(&op[120], vop120);
1559  } else if (lengths[rangeIndex]) {
1560  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1561  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1562  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1563  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1564  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1565  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1566  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1567  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1568  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1569  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1570  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1571  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1572  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1573  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1574  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1575  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1576  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1577  }
1578  }
1579  } else if (block_size == 64) {
1580  // unrolling 8 times
1581  int64_t dataInd = 0;
1582  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1583  float* op = &out[rangeIndex * block_size];
1584  __m256 vop0 = _mm256_setzero_ps();
1585  __m256 vop8 = _mm256_setzero_ps();
1586  __m256 vop16 = _mm256_setzero_ps();
1587  __m256 vop24 = _mm256_setzero_ps();
1588  __m256 vop32 = _mm256_setzero_ps();
1589  __m256 vop40 = _mm256_setzero_ps();
1590  __m256 vop48 = _mm256_setzero_ps();
1591  __m256 vop56 = _mm256_setzero_ps();
1592  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1593  ++dataInd) {
1594  const int64_t idx = indices[dataInd];
1595  CAFFE_ENFORCE(
1596  idx >= 0 && idx < data_size,
1597  "Index ",
1598  dataInd,
1599  " is out of bounds: ",
1600  idx,
1601  ", range 0 to ",
1602  data_size);
1603  float wgt = 1.f;
1604  if (weights) {
1605  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1606  }
1607  __m256 vwgt = _mm256_set1_ps(wgt);
1608  const float16* ip = &input[idx * fused_block_size];
1609  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1610  ? (dataInd + prefdist_T0)
1611  : dataInd;
1612  const int64_t idx_pref_T0 = indices[next_T0];
1613  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1614  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1615  vop0 = _mm256_fmadd_ps(
1616  vwgt,
1617  _mm256_cvtph_ps(
1618  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1619  vop0);
1620  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1621  vop8 = _mm256_fmadd_ps(
1622  vwgt,
1623  _mm256_cvtph_ps(
1624  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1625  vop8);
1626  // skip unnecessary prefetch of (&ip_next_T0[8])
1627  vop16 = _mm256_fmadd_ps(
1628  vwgt,
1629  _mm256_cvtph_ps(
1630  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1631  vop16);
1632  // skip unnecessary prefetch of (&ip_next_T0[16])
1633  vop24 = _mm256_fmadd_ps(
1634  vwgt,
1635  _mm256_cvtph_ps(
1636  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1637  vop24);
1638  // skip unnecessary prefetch of (&ip_next_T0[24])
1639  vop32 = _mm256_fmadd_ps(
1640  vwgt,
1641  _mm256_cvtph_ps(
1642  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1643  vop32);
1644  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1645  vop40 = _mm256_fmadd_ps(
1646  vwgt,
1647  _mm256_cvtph_ps(
1648  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1649  vop40);
1650  // skip unnecessary prefetch of (&ip_next_T0[40])
1651  vop48 = _mm256_fmadd_ps(
1652  vwgt,
1653  _mm256_cvtph_ps(
1654  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1655  vop48);
1656  // skip unnecessary prefetch of (&ip_next_T0[48])
1657  vop56 = _mm256_fmadd_ps(
1658  vwgt,
1659  _mm256_cvtph_ps(
1660  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1661  vop56);
1662  // skip unnecessary prefetch of (&ip_next_T0[56])
1663  }
1664  if (normalize_by_lengths == false) {
1665  _mm256_storeu_ps(&op[0], vop0);
1666  _mm256_storeu_ps(&op[8], vop8);
1667  _mm256_storeu_ps(&op[16], vop16);
1668  _mm256_storeu_ps(&op[24], vop24);
1669  _mm256_storeu_ps(&op[32], vop32);
1670  _mm256_storeu_ps(&op[40], vop40);
1671  _mm256_storeu_ps(&op[48], vop48);
1672  _mm256_storeu_ps(&op[56], vop56);
1673  } else if (lengths[rangeIndex]) {
1674  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1675  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1676  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1677  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1678  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1679  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1680  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1681  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1682  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1683  }
1684  }
1685  } else if (block_size == 32) {
1686  // unrolling 4 times
1687  int64_t dataInd = 0;
1688  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1689  float* op = &out[rangeIndex * block_size];
1690  __m256 vop0 = _mm256_setzero_ps();
1691  __m256 vop8 = _mm256_setzero_ps();
1692  __m256 vop16 = _mm256_setzero_ps();
1693  __m256 vop24 = _mm256_setzero_ps();
1694  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1695  ++dataInd) {
1696  const int64_t idx = indices[dataInd];
1697  CAFFE_ENFORCE(
1698  idx >= 0 && idx < data_size,
1699  "Index ",
1700  dataInd,
1701  " is out of bounds: ",
1702  idx,
1703  ", range 0 to ",
1704  data_size);
1705  float wgt = 1.f;
1706  if (weights) {
1707  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1708  }
1709  __m256 vwgt = _mm256_set1_ps(wgt);
1710  const float16* ip = &input[idx * fused_block_size];
1711  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1712  ? (dataInd + prefdist_T0)
1713  : dataInd;
1714  const int64_t idx_pref_T0 = indices[next_T0];
1715  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1716  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1717  vop0 = _mm256_fmadd_ps(
1718  vwgt,
1719  _mm256_cvtph_ps(
1720  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1721  vop0);
1722  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1723  vop8 = _mm256_fmadd_ps(
1724  vwgt,
1725  _mm256_cvtph_ps(
1726  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1727  vop8);
1728  // skip unnecessary prefetch of (&ip_next_T0[8])
1729  vop16 = _mm256_fmadd_ps(
1730  vwgt,
1731  _mm256_cvtph_ps(
1732  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1733  vop16);
1734  // skip unnecessary prefetch of (&ip_next_T0[16])
1735  vop24 = _mm256_fmadd_ps(
1736  vwgt,
1737  _mm256_cvtph_ps(
1738  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1739  vop24);
1740  // skip unnecessary prefetch of (&ip_next_T0[24])
1741  }
1742  if (normalize_by_lengths == false) {
1743  _mm256_storeu_ps(&op[0], vop0);
1744  _mm256_storeu_ps(&op[8], vop8);
1745  _mm256_storeu_ps(&op[16], vop16);
1746  _mm256_storeu_ps(&op[24], vop24);
1747  } else if (lengths[rangeIndex]) {
1748  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1749  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1750  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1751  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1752  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1753  }
1754  }
1755  } else if (block_size == 16) {
1756  // unrolling 2 times
1757  int64_t dataInd = 0;
1758  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1759  float* op = &out[rangeIndex * block_size];
1760  __m256 vop0 = _mm256_setzero_ps();
1761  __m256 vop8 = _mm256_setzero_ps();
1762  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1763  ++dataInd) {
1764  const int64_t idx = indices[dataInd];
1765  CAFFE_ENFORCE(
1766  idx >= 0 && idx < data_size,
1767  "Index ",
1768  dataInd,
1769  " is out of bounds: ",
1770  idx,
1771  ", range 0 to ",
1772  data_size);
1773  float wgt = 1.f;
1774  if (weights) {
1775  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1776  }
1777  __m256 vwgt = _mm256_set1_ps(wgt);
1778  const float16* ip = &input[idx * fused_block_size];
1779  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1780  ? (dataInd + prefdist_T0)
1781  : dataInd;
1782  const int64_t idx_pref_T0 = indices[next_T0];
1783  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1784  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1785  vop0 = _mm256_fmadd_ps(
1786  vwgt,
1787  _mm256_cvtph_ps(
1788  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1789  vop0);
1790  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1791  vop8 = _mm256_fmadd_ps(
1792  vwgt,
1793  _mm256_cvtph_ps(
1794  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1795  vop8);
1796  // skip unnecessary prefetch of (&ip_next_T0[8])
1797  }
1798  if (normalize_by_lengths == false) {
1799  _mm256_storeu_ps(&op[0], vop0);
1800  _mm256_storeu_ps(&op[8], vop8);
1801  } else if (lengths[rangeIndex]) {
1802  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1803  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1804  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1805  }
1806  }
1807  } else {
1808  // generic code
1809  int64_t dataInd = 0;
1810  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1811  float* op = &out[rangeIndex * block_size];
1812  TIndex j = 0;
1813  for (; j + 8 <= block_size; j += 8) {
1814  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1815  }
1816  for (; j < block_size; j++) {
1817  op[j] = 0.0f;
1818  }
1819  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1820  ++dataInd) {
1821  const int64_t idx = indices[dataInd];
1822  CAFFE_ENFORCE(
1823  idx >= 0 && idx < data_size,
1824  "Index ",
1825  dataInd,
1826  " is out of bounds: ",
1827  idx,
1828  ", range 0 to ",
1829  data_size);
1830  float wgt = 1.f;
1831  if (weights) {
1832  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1833  }
1834  __m256 vwgt = _mm256_set1_ps(wgt);
1835  const float16* ip = &input[idx * fused_block_size];
1836  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1837  ? (dataInd + prefdist_T0)
1838  : dataInd;
1839  const int64_t idx_pref_T0 = indices[next_T0];
1840  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1841  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1842  j = 0;
1843  for (; j + 8 <= block_size; j += 8) {
1844  _mm256_storeu_ps(
1845  &op[j],
1846  _mm256_fmadd_ps(
1847  vwgt,
1848  _mm256_cvtph_ps(_mm_loadu_si128(
1849  reinterpret_cast<const __m128i*>(&ip[j]))),
1850  _mm256_loadu_ps(&op[j])));
1851  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
1852  }
1853  float16 vtmp1[8] CAFFE2_ALIGNED(64);
1854  for (; j < block_size; j++) {
1855  vtmp1[0] = ip[j];
1856  __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1857  op[j] += wgt * ((float*)(&vtmp2))[0];
1858  }
1859  }
1860  if (normalize_by_lengths && lengths[rangeIndex]) {
1861  float len_inv = 1.0f / lengths[rangeIndex];
1862  __m256 vlen_inv = _mm256_set1_ps(len_inv);
1863  j = 0;
1864  for (; j + 8 <= block_size; j += 8) {
1865  _mm256_storeu_ps(
1866  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1867  }
1868  for (; j < block_size; j++) {
1869  op[j] = len_inv * op[j];
1870  }
1871  }
1872  }
1873  }
1874 }
1875 void EmbeddingLookup_int64_t_float16_float_false__avx2_fma(
1876  const TIndex block_size,
1877  const TIndex output_size,
1878  const TIndex index_size,
1879  const TIndex data_size,
1880  const float16* input,
1881  const int64_t* indices,
1882  const int* lengths,
1883  const float* weights,
1884  const float* scale_bias,
1885  bool normalize_by_lengths,
1886  float* out) {
1887  EmbeddingLookup_int64_t_float16_float__avx2_fma<false>(
1888  block_size,
1889  output_size,
1890  index_size,
1891  data_size,
1892  input,
1893  indices,
1894  lengths,
1895  weights,
1896  scale_bias,
1897  normalize_by_lengths,
1898  out);
1899 }
1900 void EmbeddingLookup_int64_t_float16_float_true__avx2_fma(
1901  const TIndex block_size,
1902  const TIndex output_size,
1903  const TIndex index_size,
1904  const TIndex data_size,
1905  const float16* input,
1906  const int64_t* indices,
1907  const int* lengths,
1908  const float* weights,
1909  const float* scale_bias,
1910  bool normalize_by_lengths,
1911  float* out) {
1912  EmbeddingLookup_int64_t_float16_float__avx2_fma<true>(
1913  block_size,
1914  output_size,
1915  index_size,
1916  data_size,
1917  input,
1918  indices,
1919  lengths,
1920  weights,
1921  scale_bias,
1922  normalize_by_lengths,
1923  out);
1924 }
1925 
1926 template <bool IS_WEIGHT_POSITIONAL>
1927 static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
1928  const TIndex block_size,
1929  const TIndex output_size,
1930  const TIndex index_size,
1931  const TIndex data_size,
1932  const uint8_t* input,
1933  const int32_t* indices,
1934  const int* lengths,
1935  const float* weights,
1936  const float* scale_bias,
1937  bool normalize_by_lengths,
1938  float* out) {
1939  const int32_t prefdist_T0 = 16;
1940  const int32_t fused_block_size = block_size + 0;
1941  CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr");
1942  if (block_size == 128) {
1943  // unrolling 16 times
1944  int32_t dataInd = 0;
1945  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1946  float* op = &out[rangeIndex * block_size];
1947  __m256 vop0 = _mm256_setzero_ps();
1948  __m256 vop8 = _mm256_setzero_ps();
1949  __m256 vop16 = _mm256_setzero_ps();
1950  __m256 vop24 = _mm256_setzero_ps();
1951  __m256 vop32 = _mm256_setzero_ps();
1952  __m256 vop40 = _mm256_setzero_ps();
1953  __m256 vop48 = _mm256_setzero_ps();
1954  __m256 vop56 = _mm256_setzero_ps();
1955  __m256 vop64 = _mm256_setzero_ps();
1956  __m256 vop72 = _mm256_setzero_ps();
1957  __m256 vop80 = _mm256_setzero_ps();
1958  __m256 vop88 = _mm256_setzero_ps();
1959  __m256 vop96 = _mm256_setzero_ps();
1960  __m256 vop104 = _mm256_setzero_ps();
1961  __m256 vop112 = _mm256_setzero_ps();
1962  __m256 vop120 = _mm256_setzero_ps();
1963  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1964  ++dataInd) {
1965  const int32_t idx = indices[dataInd];
1966  CAFFE_ENFORCE(
1967  idx >= 0 && idx < data_size,
1968  "Index ",
1969  dataInd,
1970  " is out of bounds: ",
1971  idx,
1972  ", range 0 to ",
1973  data_size);
1974  float wgt = 1.f;
1975  float bio;
1976  if (weights) {
1977  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1978  }
1979  bio = wgt * scale_bias[2 * idx + 1];
1980  wgt = wgt * scale_bias[2 * idx];
1981  __m256 vbio = _mm256_set1_ps(bio);
1982  __m256 vwgt = _mm256_set1_ps(wgt);
1983  const uint8_t* ip = &input[idx * fused_block_size];
1984  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1985  ? (dataInd + prefdist_T0)
1986  : dataInd;
1987  const int32_t idx_pref_T0 = indices[next_T0];
1988  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1989  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1990  vop0 = _mm256_fmadd_ps(
1991  vwgt,
1992  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1993  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
1994  _mm256_add_ps(vop0, vbio));
1995  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1996  vop8 = _mm256_fmadd_ps(
1997  vwgt,
1998  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1999  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2000  _mm256_add_ps(vop8, vbio));
2001  // skip unnecessary prefetch of (&ip_next_T0[8])
2002  vop16 = _mm256_fmadd_ps(
2003  vwgt,
2004  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2005  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2006  _mm256_add_ps(vop16, vbio));
2007  // skip unnecessary prefetch of (&ip_next_T0[16])
2008  vop24 = _mm256_fmadd_ps(
2009  vwgt,
2010  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2011  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2012  _mm256_add_ps(vop24, vbio));
2013  // skip unnecessary prefetch of (&ip_next_T0[24])
2014  vop32 = _mm256_fmadd_ps(
2015  vwgt,
2016  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2017  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2018  _mm256_add_ps(vop32, vbio));
2019  // skip unnecessary prefetch of (&ip_next_T0[32])
2020  vop40 = _mm256_fmadd_ps(
2021  vwgt,
2022  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2023  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2024  _mm256_add_ps(vop40, vbio));
2025  // skip unnecessary prefetch of (&ip_next_T0[40])
2026  vop48 = _mm256_fmadd_ps(
2027  vwgt,
2028  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2029  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2030  _mm256_add_ps(vop48, vbio));
2031  // skip unnecessary prefetch of (&ip_next_T0[48])
2032  vop56 = _mm256_fmadd_ps(
2033  vwgt,
2034  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2035  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2036  _mm256_add_ps(vop56, vbio));
2037  // skip unnecessary prefetch of (&ip_next_T0[56])
2038  vop64 = _mm256_fmadd_ps(
2039  vwgt,
2040  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2041  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2042  _mm256_add_ps(vop64, vbio));
2043  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
2044  vop72 = _mm256_fmadd_ps(
2045  vwgt,
2046  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2047  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2048  _mm256_add_ps(vop72, vbio));
2049  // skip unnecessary prefetch of (&ip_next_T0[72])
2050  vop80 = _mm256_fmadd_ps(
2051  vwgt,
2052  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2053  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2054  _mm256_add_ps(vop80, vbio));
2055  // skip unnecessary prefetch of (&ip_next_T0[80])
2056  vop88 = _mm256_fmadd_ps(
2057  vwgt,
2058  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2059  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2060  _mm256_add_ps(vop88, vbio));
2061  // skip unnecessary prefetch of (&ip_next_T0[88])
2062  vop96 = _mm256_fmadd_ps(
2063  vwgt,
2064  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2065  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2066  _mm256_add_ps(vop96, vbio));
2067  // skip unnecessary prefetch of (&ip_next_T0[96])
2068  vop104 = _mm256_fmadd_ps(
2069  vwgt,
2070  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2071  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2072  _mm256_add_ps(vop104, vbio));
2073  // skip unnecessary prefetch of (&ip_next_T0[104])
2074  vop112 = _mm256_fmadd_ps(
2075  vwgt,
2076  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2077  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2078  _mm256_add_ps(vop112, vbio));
2079  // skip unnecessary prefetch of (&ip_next_T0[112])
2080  vop120 = _mm256_fmadd_ps(
2081  vwgt,
2082  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2083  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2084  _mm256_add_ps(vop120, vbio));
2085  // skip unnecessary prefetch of (&ip_next_T0[120])
2086  }
2087  if (normalize_by_lengths == false) {
2088  _mm256_storeu_ps(&op[0], vop0);
2089  _mm256_storeu_ps(&op[8], vop8);
2090  _mm256_storeu_ps(&op[16], vop16);
2091  _mm256_storeu_ps(&op[24], vop24);
2092  _mm256_storeu_ps(&op[32], vop32);
2093  _mm256_storeu_ps(&op[40], vop40);
2094  _mm256_storeu_ps(&op[48], vop48);
2095  _mm256_storeu_ps(&op[56], vop56);
2096  _mm256_storeu_ps(&op[64], vop64);
2097  _mm256_storeu_ps(&op[72], vop72);
2098  _mm256_storeu_ps(&op[80], vop80);
2099  _mm256_storeu_ps(&op[88], vop88);
2100  _mm256_storeu_ps(&op[96], vop96);
2101  _mm256_storeu_ps(&op[104], vop104);
2102  _mm256_storeu_ps(&op[112], vop112);
2103  _mm256_storeu_ps(&op[120], vop120);
2104  } else if (lengths[rangeIndex]) {
2105  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2106  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2107  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2108  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2109  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2110  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2111  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2112  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2113  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2114  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2115  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2116  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2117  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2118  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2119  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2120  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2121  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2122  }
2123  }
2124  } else if (block_size == 64) {
2125  // unrolling 8 times
2126  int32_t dataInd = 0;
2127  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2128  float* op = &out[rangeIndex * block_size];
2129  __m256 vop0 = _mm256_setzero_ps();
2130  __m256 vop8 = _mm256_setzero_ps();
2131  __m256 vop16 = _mm256_setzero_ps();
2132  __m256 vop24 = _mm256_setzero_ps();
2133  __m256 vop32 = _mm256_setzero_ps();
2134  __m256 vop40 = _mm256_setzero_ps();
2135  __m256 vop48 = _mm256_setzero_ps();
2136  __m256 vop56 = _mm256_setzero_ps();
2137  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2138  ++dataInd) {
2139  const int32_t idx = indices[dataInd];
2140  CAFFE_ENFORCE(
2141  idx >= 0 && idx < data_size,
2142  "Index ",
2143  dataInd,
2144  " is out of bounds: ",
2145  idx,
2146  ", range 0 to ",
2147  data_size);
2148  float wgt = 1.f;
2149  float bio;
2150  if (weights) {
2151  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2152  }
2153  bio = wgt * scale_bias[2 * idx + 1];
2154  wgt = wgt * scale_bias[2 * idx];
2155  __m256 vbio = _mm256_set1_ps(bio);
2156  __m256 vwgt = _mm256_set1_ps(wgt);
2157  const uint8_t* ip = &input[idx * fused_block_size];
2158  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2159  ? (dataInd + prefdist_T0)
2160  : dataInd;
2161  const int32_t idx_pref_T0 = indices[next_T0];
2162  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2163  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2164  vop0 = _mm256_fmadd_ps(
2165  vwgt,
2166  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2167  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2168  _mm256_add_ps(vop0, vbio));
2169  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2170  vop8 = _mm256_fmadd_ps(
2171  vwgt,
2172  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2173  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2174  _mm256_add_ps(vop8, vbio));
2175  // skip unnecessary prefetch of (&ip_next_T0[8])
2176  vop16 = _mm256_fmadd_ps(
2177  vwgt,
2178  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2179  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2180  _mm256_add_ps(vop16, vbio));
2181  // skip unnecessary prefetch of (&ip_next_T0[16])
2182  vop24 = _mm256_fmadd_ps(
2183  vwgt,
2184  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2185  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2186  _mm256_add_ps(vop24, vbio));
2187  // skip unnecessary prefetch of (&ip_next_T0[24])
2188  vop32 = _mm256_fmadd_ps(
2189  vwgt,
2190  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2191  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2192  _mm256_add_ps(vop32, vbio));
2193  // skip unnecessary prefetch of (&ip_next_T0[32])
2194  vop40 = _mm256_fmadd_ps(
2195  vwgt,
2196  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2197  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2198  _mm256_add_ps(vop40, vbio));
2199  // skip unnecessary prefetch of (&ip_next_T0[40])
2200  vop48 = _mm256_fmadd_ps(
2201  vwgt,
2202  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2203  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2204  _mm256_add_ps(vop48, vbio));
2205  // skip unnecessary prefetch of (&ip_next_T0[48])
2206  vop56 = _mm256_fmadd_ps(
2207  vwgt,
2208  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2209  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2210  _mm256_add_ps(vop56, vbio));
2211  // skip unnecessary prefetch of (&ip_next_T0[56])
2212  }
2213  if (normalize_by_lengths == false) {
2214  _mm256_storeu_ps(&op[0], vop0);
2215  _mm256_storeu_ps(&op[8], vop8);
2216  _mm256_storeu_ps(&op[16], vop16);
2217  _mm256_storeu_ps(&op[24], vop24);
2218  _mm256_storeu_ps(&op[32], vop32);
2219  _mm256_storeu_ps(&op[40], vop40);
2220  _mm256_storeu_ps(&op[48], vop48);
2221  _mm256_storeu_ps(&op[56], vop56);
2222  } else if (lengths[rangeIndex]) {
2223  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2224  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2225  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2226  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2227  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2228  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2229  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2230  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2231  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2232  }
2233  }
2234  } else if (block_size == 32) {
2235  // unrolling 4 times
2236  int32_t dataInd = 0;
2237  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2238  float* op = &out[rangeIndex * block_size];
2239  __m256 vop0 = _mm256_setzero_ps();
2240  __m256 vop8 = _mm256_setzero_ps();
2241  __m256 vop16 = _mm256_setzero_ps();
2242  __m256 vop24 = _mm256_setzero_ps();
2243  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2244  ++dataInd) {
2245  const int32_t idx = indices[dataInd];
2246  CAFFE_ENFORCE(
2247  idx >= 0 && idx < data_size,
2248  "Index ",
2249  dataInd,
2250  " is out of bounds: ",
2251  idx,
2252  ", range 0 to ",
2253  data_size);
2254  float wgt = 1.f;
2255  float bio;
2256  if (weights) {
2257  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2258  }
2259  bio = wgt * scale_bias[2 * idx + 1];
2260  wgt = wgt * scale_bias[2 * idx];
2261  __m256 vbio = _mm256_set1_ps(bio);
2262  __m256 vwgt = _mm256_set1_ps(wgt);
2263  const uint8_t* ip = &input[idx * fused_block_size];
2264  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2265  ? (dataInd + prefdist_T0)
2266  : dataInd;
2267  const int32_t idx_pref_T0 = indices[next_T0];
2268  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2269  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2270  vop0 = _mm256_fmadd_ps(
2271  vwgt,
2272  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2273  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2274  _mm256_add_ps(vop0, vbio));
2275  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2276  vop8 = _mm256_fmadd_ps(
2277  vwgt,
2278  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2279  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2280  _mm256_add_ps(vop8, vbio));
2281  // skip unnecessary prefetch of (&ip_next_T0[8])
2282  vop16 = _mm256_fmadd_ps(
2283  vwgt,
2284  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2285  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2286  _mm256_add_ps(vop16, vbio));
2287  // skip unnecessary prefetch of (&ip_next_T0[16])
2288  vop24 = _mm256_fmadd_ps(
2289  vwgt,
2290  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2291  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2292  _mm256_add_ps(vop24, vbio));
2293  // skip unnecessary prefetch of (&ip_next_T0[24])
2294  }
2295  if (normalize_by_lengths == false) {
2296  _mm256_storeu_ps(&op[0], vop0);
2297  _mm256_storeu_ps(&op[8], vop8);
2298  _mm256_storeu_ps(&op[16], vop16);
2299  _mm256_storeu_ps(&op[24], vop24);
2300  } else if (lengths[rangeIndex]) {
2301  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2302  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2303  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2304  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2305  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2306  }
2307  }
2308  } else if (block_size == 16) {
2309  // unrolling 2 times
2310  int32_t dataInd = 0;
2311  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2312  float* op = &out[rangeIndex * block_size];
2313  __m256 vop0 = _mm256_setzero_ps();
2314  __m256 vop8 = _mm256_setzero_ps();
2315  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2316  ++dataInd) {
2317  const int32_t idx = indices[dataInd];
2318  CAFFE_ENFORCE(
2319  idx >= 0 && idx < data_size,
2320  "Index ",
2321  dataInd,
2322  " is out of bounds: ",
2323  idx,
2324  ", range 0 to ",
2325  data_size);
2326  float wgt = 1.f;
2327  float bio;
2328  if (weights) {
2329  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2330  }
2331  bio = wgt * scale_bias[2 * idx + 1];
2332  wgt = wgt * scale_bias[2 * idx];
2333  __m256 vbio = _mm256_set1_ps(bio);
2334  __m256 vwgt = _mm256_set1_ps(wgt);
2335  const uint8_t* ip = &input[idx * fused_block_size];
2336  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2337  ? (dataInd + prefdist_T0)
2338  : dataInd;
2339  const int32_t idx_pref_T0 = indices[next_T0];
2340  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2341  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2342  vop0 = _mm256_fmadd_ps(
2343  vwgt,
2344  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2345  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2346  _mm256_add_ps(vop0, vbio));
2347  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2348  vop8 = _mm256_fmadd_ps(
2349  vwgt,
2350  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2351  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2352  _mm256_add_ps(vop8, vbio));
2353  // skip unnecessary prefetch of (&ip_next_T0[8])
2354  }
2355  if (normalize_by_lengths == false) {
2356  _mm256_storeu_ps(&op[0], vop0);
2357  _mm256_storeu_ps(&op[8], vop8);
2358  } else if (lengths[rangeIndex]) {
2359  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2360  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2361  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2362  }
2363  }
2364  } else {
2365  // generic code
2366  int32_t dataInd = 0;
2367  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2368  float* op = &out[rangeIndex * block_size];
2369  TIndex j = 0;
2370  for (; j + 8 <= block_size; j += 8) {
2371  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2372  }
2373  for (; j < block_size; j++) {
2374  op[j] = 0.0f;
2375  }
2376  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2377  ++dataInd) {
2378  const int32_t idx = indices[dataInd];
2379  CAFFE_ENFORCE(
2380  idx >= 0 && idx < data_size,
2381  "Index ",
2382  dataInd,
2383  " is out of bounds: ",
2384  idx,
2385  ", range 0 to ",
2386  data_size);
2387  float wgt = 1.f;
2388  float bio;
2389  if (weights) {
2390  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2391  }
2392  assert(scale_bias);
2393  bio = wgt * scale_bias[2 * idx + 1];
2394  wgt = wgt * scale_bias[2 * idx];
2395  __m256 vbio = _mm256_set1_ps(bio);
2396  __m256 vwgt = _mm256_set1_ps(wgt);
2397  const uint8_t* ip = &input[idx * fused_block_size];
2398  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2399  ? (dataInd + prefdist_T0)
2400  : dataInd;
2401  const int32_t idx_pref_T0 = indices[next_T0];
2402  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2403  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2404  j = 0;
2405  for (; j + 8 <= block_size; j += 8) {
2406  _mm256_storeu_ps(
2407  &op[j],
2408  _mm256_fmadd_ps(
2409  vwgt,
2410  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2411  reinterpret_cast<const __m128i*>(&ip[j])))),
2412  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2413  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
2414  }
2415  for (; j < block_size; j++) {
2416  op[j] += wgt * ((float)ip[j]) + bio;
2417  }
2418  }
2419  if (normalize_by_lengths && lengths[rangeIndex]) {
2420  float len_inv = 1.0f / lengths[rangeIndex];
2421  __m256 vlen_inv = _mm256_set1_ps(len_inv);
2422  j = 0;
2423  for (; j + 8 <= block_size; j += 8) {
2424  _mm256_storeu_ps(
2425  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2426  }
2427  for (; j < block_size; j++) {
2428  op[j] = len_inv * op[j];
2429  }
2430  }
2431  }
2432  }
2433 }
2434 void EmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma(
2435  const TIndex block_size,
2436  const TIndex output_size,
2437  const TIndex index_size,
2438  const TIndex data_size,
2439  const uint8_t* input,
2440  const int32_t* indices,
2441  const int* lengths,
2442  const float* weights,
2443  const float* scale_bias,
2444  bool normalize_by_lengths,
2445  float* out) {
2446  EmbeddingLookup_int32_t_uint8_t_float__avx2_fma<false>(
2447  block_size,
2448  output_size,
2449  index_size,
2450  data_size,
2451  input,
2452  indices,
2453  lengths,
2454  weights,
2455  scale_bias,
2456  normalize_by_lengths,
2457  out);
2458 }
2459 void EmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma(
2460  const TIndex block_size,
2461  const TIndex output_size,
2462  const TIndex index_size,
2463  const TIndex data_size,
2464  const uint8_t* input,
2465  const int32_t* indices,
2466  const int* lengths,
2467  const float* weights,
2468  const float* scale_bias,
2469  bool normalize_by_lengths,
2470  float* out) {
2471  EmbeddingLookup_int32_t_uint8_t_float__avx2_fma<true>(
2472  block_size,
2473  output_size,
2474  index_size,
2475  data_size,
2476  input,
2477  indices,
2478  lengths,
2479  weights,
2480  scale_bias,
2481  normalize_by_lengths,
2482  out);
2483 }
2484 
2485 template <bool IS_WEIGHT_POSITIONAL>
2486 static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
2487  const TIndex block_size,
2488  const TIndex output_size,
2489  const TIndex index_size,
2490  const TIndex data_size,
2491  const uint8_t* input,
2492  const int64_t* indices,
2493  const int* lengths,
2494  const float* weights,
2495  const float* scale_bias,
2496  bool normalize_by_lengths,
2497  float* out) {
2498  const int64_t prefdist_T0 = 16;
2499  const int64_t fused_block_size = block_size + 0;
2500  CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr");
2501  if (block_size == 128) {
2502  // unrolling 16 times
2503  int64_t dataInd = 0;
2504  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2505  float* op = &out[rangeIndex * block_size];
2506  __m256 vop0 = _mm256_setzero_ps();
2507  __m256 vop8 = _mm256_setzero_ps();
2508  __m256 vop16 = _mm256_setzero_ps();
2509  __m256 vop24 = _mm256_setzero_ps();
2510  __m256 vop32 = _mm256_setzero_ps();
2511  __m256 vop40 = _mm256_setzero_ps();
2512  __m256 vop48 = _mm256_setzero_ps();
2513  __m256 vop56 = _mm256_setzero_ps();
2514  __m256 vop64 = _mm256_setzero_ps();
2515  __m256 vop72 = _mm256_setzero_ps();
2516  __m256 vop80 = _mm256_setzero_ps();
2517  __m256 vop88 = _mm256_setzero_ps();
2518  __m256 vop96 = _mm256_setzero_ps();
2519  __m256 vop104 = _mm256_setzero_ps();
2520  __m256 vop112 = _mm256_setzero_ps();
2521  __m256 vop120 = _mm256_setzero_ps();
2522  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2523  ++dataInd) {
2524  const int64_t idx = indices[dataInd];
2525  CAFFE_ENFORCE(
2526  idx >= 0 && idx < data_size,
2527  "Index ",
2528  dataInd,
2529  " is out of bounds: ",
2530  idx,
2531  ", range 0 to ",
2532  data_size);
2533  float wgt = 1.f;
2534  float bio;
2535  if (weights) {
2536  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2537  }
2538  bio = wgt * scale_bias[2 * idx + 1];
2539  wgt = wgt * scale_bias[2 * idx];
2540  __m256 vbio = _mm256_set1_ps(bio);
2541  __m256 vwgt = _mm256_set1_ps(wgt);
2542  const uint8_t* ip = &input[idx * fused_block_size];
2543  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2544  ? (dataInd + prefdist_T0)
2545  : dataInd;
2546  const int64_t idx_pref_T0 = indices[next_T0];
2547  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2548  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2549  vop0 = _mm256_fmadd_ps(
2550  vwgt,
2551  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2552  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2553  _mm256_add_ps(vop0, vbio));
2554  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2555  vop8 = _mm256_fmadd_ps(
2556  vwgt,
2557  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2558  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2559  _mm256_add_ps(vop8, vbio));
2560  // skip unnecessary prefetch of (&ip_next_T0[8])
2561  vop16 = _mm256_fmadd_ps(
2562  vwgt,
2563  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2564  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2565  _mm256_add_ps(vop16, vbio));
2566  // skip unnecessary prefetch of (&ip_next_T0[16])
2567  vop24 = _mm256_fmadd_ps(
2568  vwgt,
2569  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2570  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2571  _mm256_add_ps(vop24, vbio));
2572  // skip unnecessary prefetch of (&ip_next_T0[24])
2573  vop32 = _mm256_fmadd_ps(
2574  vwgt,
2575  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2576  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2577  _mm256_add_ps(vop32, vbio));
2578  // skip unnecessary prefetch of (&ip_next_T0[32])
2579  vop40 = _mm256_fmadd_ps(
2580  vwgt,
2581  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2582  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2583  _mm256_add_ps(vop40, vbio));
2584  // skip unnecessary prefetch of (&ip_next_T0[40])
2585  vop48 = _mm256_fmadd_ps(
2586  vwgt,
2587  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2588  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2589  _mm256_add_ps(vop48, vbio));
2590  // skip unnecessary prefetch of (&ip_next_T0[48])
2591  vop56 = _mm256_fmadd_ps(
2592  vwgt,
2593  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2594  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2595  _mm256_add_ps(vop56, vbio));
2596  // skip unnecessary prefetch of (&ip_next_T0[56])
2597  vop64 = _mm256_fmadd_ps(
2598  vwgt,
2599  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2600  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2601  _mm256_add_ps(vop64, vbio));
2602  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
2603  vop72 = _mm256_fmadd_ps(
2604  vwgt,
2605  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2606  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2607  _mm256_add_ps(vop72, vbio));
2608  // skip unnecessary prefetch of (&ip_next_T0[72])
2609  vop80 = _mm256_fmadd_ps(
2610  vwgt,
2611  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2612  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2613  _mm256_add_ps(vop80, vbio));
2614  // skip unnecessary prefetch of (&ip_next_T0[80])
2615  vop88 = _mm256_fmadd_ps(
2616  vwgt,
2617  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2618  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2619  _mm256_add_ps(vop88, vbio));
2620  // skip unnecessary prefetch of (&ip_next_T0[88])
2621  vop96 = _mm256_fmadd_ps(
2622  vwgt,
2623  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2624  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2625  _mm256_add_ps(vop96, vbio));
2626  // skip unnecessary prefetch of (&ip_next_T0[96])
2627  vop104 = _mm256_fmadd_ps(
2628  vwgt,
2629  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2630  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2631  _mm256_add_ps(vop104, vbio));
2632  // skip unnecessary prefetch of (&ip_next_T0[104])
2633  vop112 = _mm256_fmadd_ps(
2634  vwgt,
2635  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2636  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2637  _mm256_add_ps(vop112, vbio));
2638  // skip unnecessary prefetch of (&ip_next_T0[112])
2639  vop120 = _mm256_fmadd_ps(
2640  vwgt,
2641  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2642  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2643  _mm256_add_ps(vop120, vbio));
2644  // skip unnecessary prefetch of (&ip_next_T0[120])
2645  }
2646  if (normalize_by_lengths == false) {
2647  _mm256_storeu_ps(&op[0], vop0);
2648  _mm256_storeu_ps(&op[8], vop8);
2649  _mm256_storeu_ps(&op[16], vop16);
2650  _mm256_storeu_ps(&op[24], vop24);
2651  _mm256_storeu_ps(&op[32], vop32);
2652  _mm256_storeu_ps(&op[40], vop40);
2653  _mm256_storeu_ps(&op[48], vop48);
2654  _mm256_storeu_ps(&op[56], vop56);
2655  _mm256_storeu_ps(&op[64], vop64);
2656  _mm256_storeu_ps(&op[72], vop72);
2657  _mm256_storeu_ps(&op[80], vop80);
2658  _mm256_storeu_ps(&op[88], vop88);
2659  _mm256_storeu_ps(&op[96], vop96);
2660  _mm256_storeu_ps(&op[104], vop104);
2661  _mm256_storeu_ps(&op[112], vop112);
2662  _mm256_storeu_ps(&op[120], vop120);
2663  } else if (lengths[rangeIndex]) {
2664  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2665  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2666  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2667  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2668  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2669  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2670  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2671  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2672  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2673  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2674  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2675  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2676  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2677  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2678  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2679  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2680  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2681  }
2682  }
2683  } else if (block_size == 64) {
2684  // unrolling 8 times
2685  int64_t dataInd = 0;
2686  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2687  float* op = &out[rangeIndex * block_size];
2688  __m256 vop0 = _mm256_setzero_ps();
2689  __m256 vop8 = _mm256_setzero_ps();
2690  __m256 vop16 = _mm256_setzero_ps();
2691  __m256 vop24 = _mm256_setzero_ps();
2692  __m256 vop32 = _mm256_setzero_ps();
2693  __m256 vop40 = _mm256_setzero_ps();
2694  __m256 vop48 = _mm256_setzero_ps();
2695  __m256 vop56 = _mm256_setzero_ps();
2696  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2697  ++dataInd) {
2698  const int64_t idx = indices[dataInd];
2699  CAFFE_ENFORCE(
2700  idx >= 0 && idx < data_size,
2701  "Index ",
2702  dataInd,
2703  " is out of bounds: ",
2704  idx,
2705  ", range 0 to ",
2706  data_size);
2707  float wgt = 1.f;
2708  float bio;
2709  if (weights) {
2710  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2711  }
2712  bio = wgt * scale_bias[2 * idx + 1];
2713  wgt = wgt * scale_bias[2 * idx];
2714  __m256 vbio = _mm256_set1_ps(bio);
2715  __m256 vwgt = _mm256_set1_ps(wgt);
2716  const uint8_t* ip = &input[idx * fused_block_size];
2717  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2718  ? (dataInd + prefdist_T0)
2719  : dataInd;
2720  const int64_t idx_pref_T0 = indices[next_T0];
2721  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2722  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2723  vop0 = _mm256_fmadd_ps(
2724  vwgt,
2725  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2726  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2727  _mm256_add_ps(vop0, vbio));
2728  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2729  vop8 = _mm256_fmadd_ps(
2730  vwgt,
2731  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2732  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2733  _mm256_add_ps(vop8, vbio));
2734  // skip unnecessary prefetch of (&ip_next_T0[8])
2735  vop16 = _mm256_fmadd_ps(
2736  vwgt,
2737  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2738  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2739  _mm256_add_ps(vop16, vbio));
2740  // skip unnecessary prefetch of (&ip_next_T0[16])
2741  vop24 = _mm256_fmadd_ps(
2742  vwgt,
2743  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2744  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2745  _mm256_add_ps(vop24, vbio));
2746  // skip unnecessary prefetch of (&ip_next_T0[24])
2747  vop32 = _mm256_fmadd_ps(
2748  vwgt,
2749  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2750  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2751  _mm256_add_ps(vop32, vbio));
2752  // skip unnecessary prefetch of (&ip_next_T0[32])
2753  vop40 = _mm256_fmadd_ps(
2754  vwgt,
2755  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2756  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2757  _mm256_add_ps(vop40, vbio));
2758  // skip unnecessary prefetch of (&ip_next_T0[40])
2759  vop48 = _mm256_fmadd_ps(
2760  vwgt,
2761  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2762  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2763  _mm256_add_ps(vop48, vbio));
2764  // skip unnecessary prefetch of (&ip_next_T0[48])
2765  vop56 = _mm256_fmadd_ps(
2766  vwgt,
2767  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2768  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2769  _mm256_add_ps(vop56, vbio));
2770  // skip unnecessary prefetch of (&ip_next_T0[56])
2771  }
2772  if (normalize_by_lengths == false) {
2773  _mm256_storeu_ps(&op[0], vop0);
2774  _mm256_storeu_ps(&op[8], vop8);
2775  _mm256_storeu_ps(&op[16], vop16);
2776  _mm256_storeu_ps(&op[24], vop24);
2777  _mm256_storeu_ps(&op[32], vop32);
2778  _mm256_storeu_ps(&op[40], vop40);
2779  _mm256_storeu_ps(&op[48], vop48);
2780  _mm256_storeu_ps(&op[56], vop56);
2781  } else if (lengths[rangeIndex]) {
2782  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2783  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2784  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2785  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2786  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2787  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2788  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2789  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2790  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2791  }
2792  }
2793  } else if (block_size == 32) {
2794  // unrolling 4 times
2795  int64_t dataInd = 0;
2796  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2797  float* op = &out[rangeIndex * block_size];
2798  __m256 vop0 = _mm256_setzero_ps();
2799  __m256 vop8 = _mm256_setzero_ps();
2800  __m256 vop16 = _mm256_setzero_ps();
2801  __m256 vop24 = _mm256_setzero_ps();
2802  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2803  ++dataInd) {
2804  const int64_t idx = indices[dataInd];
2805  CAFFE_ENFORCE(
2806  idx >= 0 && idx < data_size,
2807  "Index ",
2808  dataInd,
2809  " is out of bounds: ",
2810  idx,
2811  ", range 0 to ",
2812  data_size);
2813  float wgt = 1.f;
2814  float bio;
2815  if (weights) {
2816  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2817  }
2818  bio = wgt * scale_bias[2 * idx + 1];
2819  wgt = wgt * scale_bias[2 * idx];
2820  __m256 vbio = _mm256_set1_ps(bio);
2821  __m256 vwgt = _mm256_set1_ps(wgt);
2822  const uint8_t* ip = &input[idx * fused_block_size];
2823  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2824  ? (dataInd + prefdist_T0)
2825  : dataInd;
2826  const int64_t idx_pref_T0 = indices[next_T0];
2827  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2828  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2829  vop0 = _mm256_fmadd_ps(
2830  vwgt,
2831  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2832  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2833  _mm256_add_ps(vop0, vbio));
2834  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2835  vop8 = _mm256_fmadd_ps(
2836  vwgt,
2837  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2838  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2839  _mm256_add_ps(vop8, vbio));
2840  // skip unnecessary prefetch of (&ip_next_T0[8])
2841  vop16 = _mm256_fmadd_ps(
2842  vwgt,
2843  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2844  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2845  _mm256_add_ps(vop16, vbio));
2846  // skip unnecessary prefetch of (&ip_next_T0[16])
2847  vop24 = _mm256_fmadd_ps(
2848  vwgt,
2849  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2850  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2851  _mm256_add_ps(vop24, vbio));
2852  // skip unnecessary prefetch of (&ip_next_T0[24])
2853  }
2854  if (normalize_by_lengths == false) {
2855  _mm256_storeu_ps(&op[0], vop0);
2856  _mm256_storeu_ps(&op[8], vop8);
2857  _mm256_storeu_ps(&op[16], vop16);
2858  _mm256_storeu_ps(&op[24], vop24);
2859  } else if (lengths[rangeIndex]) {
2860  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2861  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2862  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2863  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2864  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2865  }
2866  }
2867  } else if (block_size == 16) {
2868  // unrolling 2 times
2869  int64_t dataInd = 0;
2870  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2871  float* op = &out[rangeIndex * block_size];
2872  __m256 vop0 = _mm256_setzero_ps();
2873  __m256 vop8 = _mm256_setzero_ps();
2874  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2875  ++dataInd) {
2876  const int64_t idx = indices[dataInd];
2877  CAFFE_ENFORCE(
2878  idx >= 0 && idx < data_size,
2879  "Index ",
2880  dataInd,
2881  " is out of bounds: ",
2882  idx,
2883  ", range 0 to ",
2884  data_size);
2885  float wgt = 1.f;
2886  float bio;
2887  if (weights) {
2888  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2889  }
2890  bio = wgt * scale_bias[2 * idx + 1];
2891  wgt = wgt * scale_bias[2 * idx];
2892  __m256 vbio = _mm256_set1_ps(bio);
2893  __m256 vwgt = _mm256_set1_ps(wgt);
2894  const uint8_t* ip = &input[idx * fused_block_size];
2895  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2896  ? (dataInd + prefdist_T0)
2897  : dataInd;
2898  const int64_t idx_pref_T0 = indices[next_T0];
2899  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2900  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2901  vop0 = _mm256_fmadd_ps(
2902  vwgt,
2903  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2904  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2905  _mm256_add_ps(vop0, vbio));
2906  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2907  vop8 = _mm256_fmadd_ps(
2908  vwgt,
2909  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2910  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2911  _mm256_add_ps(vop8, vbio));
2912  // skip unnecessary prefetch of (&ip_next_T0[8])
2913  }
2914  if (normalize_by_lengths == false) {
2915  _mm256_storeu_ps(&op[0], vop0);
2916  _mm256_storeu_ps(&op[8], vop8);
2917  } else if (lengths[rangeIndex]) {
2918  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2919  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2920  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2921  }
2922  }
2923  } else {
2924  // generic code
2925  int64_t dataInd = 0;
2926  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2927  float* op = &out[rangeIndex * block_size];
2928  TIndex j = 0;
2929  for (; j + 8 <= block_size; j += 8) {
2930  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2931  }
2932  for (; j < block_size; j++) {
2933  op[j] = 0.0f;
2934  }
2935  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2936  ++dataInd) {
2937  const int64_t idx = indices[dataInd];
2938  CAFFE_ENFORCE(
2939  idx >= 0 && idx < data_size,
2940  "Index ",
2941  dataInd,
2942  " is out of bounds: ",
2943  idx,
2944  ", range 0 to ",
2945  data_size);
2946  float wgt = 1.f;
2947  float bio;
2948  if (weights) {
2949  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2950  }
2951  assert(scale_bias);
2952  bio = wgt * scale_bias[2 * idx + 1];
2953  wgt = wgt * scale_bias[2 * idx];
2954  __m256 vbio = _mm256_set1_ps(bio);
2955  __m256 vwgt = _mm256_set1_ps(wgt);
2956  const uint8_t* ip = &input[idx * fused_block_size];
2957  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2958  ? (dataInd + prefdist_T0)
2959  : dataInd;
2960  const int64_t idx_pref_T0 = indices[next_T0];
2961  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2962  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2963  j = 0;
2964  for (; j + 8 <= block_size; j += 8) {
2965  _mm256_storeu_ps(
2966  &op[j],
2967  _mm256_fmadd_ps(
2968  vwgt,
2969  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2970  reinterpret_cast<const __m128i*>(&ip[j])))),
2971  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2972  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
2973  }
2974  for (; j < block_size; j++) {
2975  op[j] += wgt * ((float)ip[j]) + bio;
2976  }
2977  }
2978  if (normalize_by_lengths && lengths[rangeIndex]) {
2979  float len_inv = 1.0f / lengths[rangeIndex];
2980  __m256 vlen_inv = _mm256_set1_ps(len_inv);
2981  j = 0;
2982  for (; j + 8 <= block_size; j += 8) {
2983  _mm256_storeu_ps(
2984  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2985  }
2986  for (; j < block_size; j++) {
2987  op[j] = len_inv * op[j];
2988  }
2989  }
2990  }
2991  }
2992 }
2993 void EmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma(
2994  const TIndex block_size,
2995  const TIndex output_size,
2996  const TIndex index_size,
2997  const TIndex data_size,
2998  const uint8_t* input,
2999  const int64_t* indices,
3000  const int* lengths,
3001  const float* weights,
3002  const float* scale_bias,
3003  bool normalize_by_lengths,
3004  float* out) {
3005  EmbeddingLookup_int64_t_uint8_t_float__avx2_fma<false>(
3006  block_size,
3007  output_size,
3008  index_size,
3009  data_size,
3010  input,
3011  indices,
3012  lengths,
3013  weights,
3014  scale_bias,
3015  normalize_by_lengths,
3016  out);
3017 }
3018 void EmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma(
3019  const TIndex block_size,
3020  const TIndex output_size,
3021  const TIndex index_size,
3022  const TIndex data_size,
3023  const uint8_t* input,
3024  const int64_t* indices,
3025  const int* lengths,
3026  const float* weights,
3027  const float* scale_bias,
3028  bool normalize_by_lengths,
3029  float* out) {
3030  EmbeddingLookup_int64_t_uint8_t_float__avx2_fma<true>(
3031  block_size,
3032  output_size,
3033  index_size,
3034  data_size,
3035  input,
3036  indices,
3037  lengths,
3038  weights,
3039  scale_bias,
3040  normalize_by_lengths,
3041  out);
3042 }
3043 
3044 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...