8 #include <caffe2/core/common.h> 9 #include <caffe2/core/types.h> 10 #include <immintrin.h> 14 template <
bool IS_WEIGHT_POSITIONAL>
15 static void EmbeddingLookup_int32_t_float_float__avx2_fma(
16 const TIndex block_size,
17 const TIndex output_size,
18 const TIndex index_size,
19 const TIndex data_size,
21 const int32_t* indices,
24 const float* scale_bias,
25 bool normalize_by_lengths,
27 const int32_t prefdist_T0 = 16;
28 const int32_t fused_block_size = block_size + 0;
29 CAFFE_ENFORCE(scale_bias ==
nullptr,
"scale_bias must be nullptr");
30 if (block_size == 128) {
33 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
34 float* op = &out[rangeIndex * block_size];
35 __m256 vop0 = _mm256_setzero_ps();
36 __m256 vop8 = _mm256_setzero_ps();
37 __m256 vop16 = _mm256_setzero_ps();
38 __m256 vop24 = _mm256_setzero_ps();
39 __m256 vop32 = _mm256_setzero_ps();
40 __m256 vop40 = _mm256_setzero_ps();
41 __m256 vop48 = _mm256_setzero_ps();
42 __m256 vop56 = _mm256_setzero_ps();
43 __m256 vop64 = _mm256_setzero_ps();
44 __m256 vop72 = _mm256_setzero_ps();
45 __m256 vop80 = _mm256_setzero_ps();
46 __m256 vop88 = _mm256_setzero_ps();
47 __m256 vop96 = _mm256_setzero_ps();
48 __m256 vop104 = _mm256_setzero_ps();
49 __m256 vop112 = _mm256_setzero_ps();
50 __m256 vop120 = _mm256_setzero_ps();
51 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
53 const int32_t idx = indices[dataInd];
55 idx >= 0 && idx < data_size,
58 " is out of bounds: ",
64 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
66 __m256 vwgt = _mm256_set1_ps(wgt);
67 const float* ip = &input[idx * fused_block_size];
68 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
69 ? (dataInd + prefdist_T0)
71 const int32_t idx_pref_T0 = indices[next_T0];
72 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
73 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
74 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
75 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
76 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
78 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
79 _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
80 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
82 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
83 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
84 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
86 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
87 _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
88 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
90 vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
91 _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
92 vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
94 vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
95 _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
96 vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
98 vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
99 _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
100 vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
102 vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
103 _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
104 vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
107 if (normalize_by_lengths ==
false) {
108 _mm256_storeu_ps(&op[0], vop0);
109 _mm256_storeu_ps(&op[8], vop8);
110 _mm256_storeu_ps(&op[16], vop16);
111 _mm256_storeu_ps(&op[24], vop24);
112 _mm256_storeu_ps(&op[32], vop32);
113 _mm256_storeu_ps(&op[40], vop40);
114 _mm256_storeu_ps(&op[48], vop48);
115 _mm256_storeu_ps(&op[56], vop56);
116 _mm256_storeu_ps(&op[64], vop64);
117 _mm256_storeu_ps(&op[72], vop72);
118 _mm256_storeu_ps(&op[80], vop80);
119 _mm256_storeu_ps(&op[88], vop88);
120 _mm256_storeu_ps(&op[96], vop96);
121 _mm256_storeu_ps(&op[104], vop104);
122 _mm256_storeu_ps(&op[112], vop112);
123 _mm256_storeu_ps(&op[120], vop120);
124 }
else if (lengths[rangeIndex]) {
125 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
126 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
127 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
128 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
129 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
130 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
131 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
132 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
133 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
134 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
135 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
136 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
137 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
138 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
139 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
140 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
141 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
144 }
else if (block_size == 64) {
147 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
148 float* op = &out[rangeIndex * block_size];
149 __m256 vop0 = _mm256_setzero_ps();
150 __m256 vop8 = _mm256_setzero_ps();
151 __m256 vop16 = _mm256_setzero_ps();
152 __m256 vop24 = _mm256_setzero_ps();
153 __m256 vop32 = _mm256_setzero_ps();
154 __m256 vop40 = _mm256_setzero_ps();
155 __m256 vop48 = _mm256_setzero_ps();
156 __m256 vop56 = _mm256_setzero_ps();
157 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
159 const int32_t idx = indices[dataInd];
161 idx >= 0 && idx < data_size,
164 " is out of bounds: ",
170 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
172 __m256 vwgt = _mm256_set1_ps(wgt);
173 const float* ip = &input[idx * fused_block_size];
174 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
175 ? (dataInd + prefdist_T0)
177 const int32_t idx_pref_T0 = indices[next_T0];
178 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
179 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
180 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
181 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
182 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
184 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
185 _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
186 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
188 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
189 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
190 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
192 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
193 _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
194 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
197 if (normalize_by_lengths ==
false) {
198 _mm256_storeu_ps(&op[0], vop0);
199 _mm256_storeu_ps(&op[8], vop8);
200 _mm256_storeu_ps(&op[16], vop16);
201 _mm256_storeu_ps(&op[24], vop24);
202 _mm256_storeu_ps(&op[32], vop32);
203 _mm256_storeu_ps(&op[40], vop40);
204 _mm256_storeu_ps(&op[48], vop48);
205 _mm256_storeu_ps(&op[56], vop56);
206 }
else if (lengths[rangeIndex]) {
207 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
208 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
209 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
210 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
211 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
212 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
213 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
214 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
215 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
218 }
else if (block_size == 32) {
221 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
222 float* op = &out[rangeIndex * block_size];
223 __m256 vop0 = _mm256_setzero_ps();
224 __m256 vop8 = _mm256_setzero_ps();
225 __m256 vop16 = _mm256_setzero_ps();
226 __m256 vop24 = _mm256_setzero_ps();
227 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
229 const int32_t idx = indices[dataInd];
231 idx >= 0 && idx < data_size,
234 " is out of bounds: ",
240 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
242 __m256 vwgt = _mm256_set1_ps(wgt);
243 const float* ip = &input[idx * fused_block_size];
244 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
245 ? (dataInd + prefdist_T0)
247 const int32_t idx_pref_T0 = indices[next_T0];
248 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
249 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
250 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
251 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
252 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
254 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
255 _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
256 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
259 if (normalize_by_lengths ==
false) {
260 _mm256_storeu_ps(&op[0], vop0);
261 _mm256_storeu_ps(&op[8], vop8);
262 _mm256_storeu_ps(&op[16], vop16);
263 _mm256_storeu_ps(&op[24], vop24);
264 }
else if (lengths[rangeIndex]) {
265 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
266 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
267 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
268 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
269 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
272 }
else if (block_size == 16) {
275 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
276 float* op = &out[rangeIndex * block_size];
277 __m256 vop0 = _mm256_setzero_ps();
278 __m256 vop8 = _mm256_setzero_ps();
279 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
281 const int32_t idx = indices[dataInd];
283 idx >= 0 && idx < data_size,
286 " is out of bounds: ",
292 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
294 __m256 vwgt = _mm256_set1_ps(wgt);
295 const float* ip = &input[idx * fused_block_size];
296 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
297 ? (dataInd + prefdist_T0)
299 const int32_t idx_pref_T0 = indices[next_T0];
300 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
301 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
302 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
303 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
304 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
307 if (normalize_by_lengths ==
false) {
308 _mm256_storeu_ps(&op[0], vop0);
309 _mm256_storeu_ps(&op[8], vop8);
310 }
else if (lengths[rangeIndex]) {
311 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
312 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
313 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
319 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
320 float* op = &out[rangeIndex * block_size];
322 for (; j + 8 <= block_size; j += 8) {
323 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
325 for (; j < block_size; j++) {
328 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
330 const int32_t idx = indices[dataInd];
332 idx >= 0 && idx < data_size,
335 " is out of bounds: ",
341 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
343 __m256 vwgt = _mm256_set1_ps(wgt);
344 const float* ip = &input[idx * fused_block_size];
345 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
346 ? (dataInd + prefdist_T0)
348 const int32_t idx_pref_T0 = indices[next_T0];
349 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
350 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
352 for (; j + 8 <= block_size; j += 8) {
356 vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
357 _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
359 for (; j < block_size; j++) {
360 op[j] += wgt * ip[j];
363 if (normalize_by_lengths && lengths[rangeIndex]) {
364 float len_inv = 1.0f / lengths[rangeIndex];
365 __m256 vlen_inv = _mm256_set1_ps(len_inv);
367 for (; j + 8 <= block_size; j += 8) {
369 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
371 for (; j < block_size; j++) {
372 op[j] = len_inv * op[j];
378 void EmbeddingLookup_int32_t_float_float_false__avx2_fma(
379 const TIndex block_size,
380 const TIndex output_size,
381 const TIndex index_size,
382 const TIndex data_size,
384 const int32_t* indices,
386 const float* weights,
387 const float* scale_bias,
388 bool normalize_by_lengths,
390 EmbeddingLookup_int32_t_float_float__avx2_fma<false>(
400 normalize_by_lengths,
403 void EmbeddingLookup_int32_t_float_float_true__avx2_fma(
404 const TIndex block_size,
405 const TIndex output_size,
406 const TIndex index_size,
407 const TIndex data_size,
409 const int32_t* indices,
411 const float* weights,
412 const float* scale_bias,
413 bool normalize_by_lengths,
415 EmbeddingLookup_int32_t_float_float__avx2_fma<true>(
425 normalize_by_lengths,
429 template <
bool IS_WEIGHT_POSITIONAL>
430 static void EmbeddingLookup_int64_t_float_float__avx2_fma(
431 const TIndex block_size,
432 const TIndex output_size,
433 const TIndex index_size,
434 const TIndex data_size,
436 const int64_t* indices,
438 const float* weights,
439 const float* scale_bias,
440 bool normalize_by_lengths,
442 const int64_t prefdist_T0 = 16;
443 const int64_t fused_block_size = block_size + 0;
444 CAFFE_ENFORCE(scale_bias ==
nullptr,
"scale_bias must be nullptr");
445 if (block_size == 128) {
448 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
449 float* op = &out[rangeIndex * block_size];
450 __m256 vop0 = _mm256_setzero_ps();
451 __m256 vop8 = _mm256_setzero_ps();
452 __m256 vop16 = _mm256_setzero_ps();
453 __m256 vop24 = _mm256_setzero_ps();
454 __m256 vop32 = _mm256_setzero_ps();
455 __m256 vop40 = _mm256_setzero_ps();
456 __m256 vop48 = _mm256_setzero_ps();
457 __m256 vop56 = _mm256_setzero_ps();
458 __m256 vop64 = _mm256_setzero_ps();
459 __m256 vop72 = _mm256_setzero_ps();
460 __m256 vop80 = _mm256_setzero_ps();
461 __m256 vop88 = _mm256_setzero_ps();
462 __m256 vop96 = _mm256_setzero_ps();
463 __m256 vop104 = _mm256_setzero_ps();
464 __m256 vop112 = _mm256_setzero_ps();
465 __m256 vop120 = _mm256_setzero_ps();
466 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
468 const int64_t idx = indices[dataInd];
470 idx >= 0 && idx < data_size,
473 " is out of bounds: ",
479 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
481 __m256 vwgt = _mm256_set1_ps(wgt);
482 const float* ip = &input[idx * fused_block_size];
483 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
484 ? (dataInd + prefdist_T0)
486 const int64_t idx_pref_T0 = indices[next_T0];
487 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
488 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
489 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
490 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
491 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
493 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
494 _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
495 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
497 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
498 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
499 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
501 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
502 _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
503 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
505 vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
506 _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
507 vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
509 vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
510 _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
511 vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
513 vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
514 _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
515 vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
517 vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
518 _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
519 vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
522 if (normalize_by_lengths ==
false) {
523 _mm256_storeu_ps(&op[0], vop0);
524 _mm256_storeu_ps(&op[8], vop8);
525 _mm256_storeu_ps(&op[16], vop16);
526 _mm256_storeu_ps(&op[24], vop24);
527 _mm256_storeu_ps(&op[32], vop32);
528 _mm256_storeu_ps(&op[40], vop40);
529 _mm256_storeu_ps(&op[48], vop48);
530 _mm256_storeu_ps(&op[56], vop56);
531 _mm256_storeu_ps(&op[64], vop64);
532 _mm256_storeu_ps(&op[72], vop72);
533 _mm256_storeu_ps(&op[80], vop80);
534 _mm256_storeu_ps(&op[88], vop88);
535 _mm256_storeu_ps(&op[96], vop96);
536 _mm256_storeu_ps(&op[104], vop104);
537 _mm256_storeu_ps(&op[112], vop112);
538 _mm256_storeu_ps(&op[120], vop120);
539 }
else if (lengths[rangeIndex]) {
540 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
541 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
542 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
543 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
544 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
545 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
546 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
547 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
548 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
549 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
550 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
551 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
552 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
553 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
554 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
555 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
556 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
559 }
else if (block_size == 64) {
562 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
563 float* op = &out[rangeIndex * block_size];
564 __m256 vop0 = _mm256_setzero_ps();
565 __m256 vop8 = _mm256_setzero_ps();
566 __m256 vop16 = _mm256_setzero_ps();
567 __m256 vop24 = _mm256_setzero_ps();
568 __m256 vop32 = _mm256_setzero_ps();
569 __m256 vop40 = _mm256_setzero_ps();
570 __m256 vop48 = _mm256_setzero_ps();
571 __m256 vop56 = _mm256_setzero_ps();
572 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
574 const int64_t idx = indices[dataInd];
576 idx >= 0 && idx < data_size,
579 " is out of bounds: ",
585 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
587 __m256 vwgt = _mm256_set1_ps(wgt);
588 const float* ip = &input[idx * fused_block_size];
589 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
590 ? (dataInd + prefdist_T0)
592 const int64_t idx_pref_T0 = indices[next_T0];
593 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
594 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
595 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
596 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
597 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
599 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
600 _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
601 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
603 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
604 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
605 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
607 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
608 _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
609 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
612 if (normalize_by_lengths ==
false) {
613 _mm256_storeu_ps(&op[0], vop0);
614 _mm256_storeu_ps(&op[8], vop8);
615 _mm256_storeu_ps(&op[16], vop16);
616 _mm256_storeu_ps(&op[24], vop24);
617 _mm256_storeu_ps(&op[32], vop32);
618 _mm256_storeu_ps(&op[40], vop40);
619 _mm256_storeu_ps(&op[48], vop48);
620 _mm256_storeu_ps(&op[56], vop56);
621 }
else if (lengths[rangeIndex]) {
622 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
623 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
624 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
625 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
626 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
627 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
628 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
629 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
630 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
633 }
else if (block_size == 32) {
636 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
637 float* op = &out[rangeIndex * block_size];
638 __m256 vop0 = _mm256_setzero_ps();
639 __m256 vop8 = _mm256_setzero_ps();
640 __m256 vop16 = _mm256_setzero_ps();
641 __m256 vop24 = _mm256_setzero_ps();
642 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
644 const int64_t idx = indices[dataInd];
646 idx >= 0 && idx < data_size,
649 " is out of bounds: ",
655 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
657 __m256 vwgt = _mm256_set1_ps(wgt);
658 const float* ip = &input[idx * fused_block_size];
659 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
660 ? (dataInd + prefdist_T0)
662 const int64_t idx_pref_T0 = indices[next_T0];
663 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
664 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
665 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
666 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
667 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
669 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
670 _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
671 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
674 if (normalize_by_lengths ==
false) {
675 _mm256_storeu_ps(&op[0], vop0);
676 _mm256_storeu_ps(&op[8], vop8);
677 _mm256_storeu_ps(&op[16], vop16);
678 _mm256_storeu_ps(&op[24], vop24);
679 }
else if (lengths[rangeIndex]) {
680 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
681 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
682 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
683 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
684 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
687 }
else if (block_size == 16) {
690 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
691 float* op = &out[rangeIndex * block_size];
692 __m256 vop0 = _mm256_setzero_ps();
693 __m256 vop8 = _mm256_setzero_ps();
694 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
696 const int64_t idx = indices[dataInd];
698 idx >= 0 && idx < data_size,
701 " is out of bounds: ",
707 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
709 __m256 vwgt = _mm256_set1_ps(wgt);
710 const float* ip = &input[idx * fused_block_size];
711 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
712 ? (dataInd + prefdist_T0)
714 const int64_t idx_pref_T0 = indices[next_T0];
715 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
716 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
717 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
718 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
719 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
722 if (normalize_by_lengths ==
false) {
723 _mm256_storeu_ps(&op[0], vop0);
724 _mm256_storeu_ps(&op[8], vop8);
725 }
else if (lengths[rangeIndex]) {
726 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
727 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
728 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
734 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
735 float* op = &out[rangeIndex * block_size];
737 for (; j + 8 <= block_size; j += 8) {
738 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
740 for (; j < block_size; j++) {
743 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
745 const int64_t idx = indices[dataInd];
747 idx >= 0 && idx < data_size,
750 " is out of bounds: ",
756 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
758 __m256 vwgt = _mm256_set1_ps(wgt);
759 const float* ip = &input[idx * fused_block_size];
760 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
761 ? (dataInd + prefdist_T0)
763 const int64_t idx_pref_T0 = indices[next_T0];
764 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
765 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
767 for (; j + 8 <= block_size; j += 8) {
771 vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
772 _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
774 for (; j < block_size; j++) {
775 op[j] += wgt * ip[j];
778 if (normalize_by_lengths && lengths[rangeIndex]) {
779 float len_inv = 1.0f / lengths[rangeIndex];
780 __m256 vlen_inv = _mm256_set1_ps(len_inv);
782 for (; j + 8 <= block_size; j += 8) {
784 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
786 for (; j < block_size; j++) {
787 op[j] = len_inv * op[j];
793 void EmbeddingLookup_int64_t_float_float_false__avx2_fma(
794 const TIndex block_size,
795 const TIndex output_size,
796 const TIndex index_size,
797 const TIndex data_size,
799 const int64_t* indices,
801 const float* weights,
802 const float* scale_bias,
803 bool normalize_by_lengths,
805 EmbeddingLookup_int64_t_float_float__avx2_fma<false>(
815 normalize_by_lengths,
818 void EmbeddingLookup_int64_t_float_float_true__avx2_fma(
819 const TIndex block_size,
820 const TIndex output_size,
821 const TIndex index_size,
822 const TIndex data_size,
824 const int64_t* indices,
826 const float* weights,
827 const float* scale_bias,
828 bool normalize_by_lengths,
830 EmbeddingLookup_int64_t_float_float__avx2_fma<true>(
840 normalize_by_lengths,
844 template <
bool IS_WEIGHT_POSITIONAL>
845 static void EmbeddingLookup_int32_t_float16_float__avx2_fma(
846 const TIndex block_size,
847 const TIndex output_size,
848 const TIndex index_size,
849 const TIndex data_size,
850 const float16* input,
851 const int32_t* indices,
853 const float* weights,
854 const float* scale_bias,
855 bool normalize_by_lengths,
857 const int32_t prefdist_T0 = 16;
858 const int32_t fused_block_size = block_size + 0;
859 CAFFE_ENFORCE(scale_bias ==
nullptr,
"scale_bias must be nullptr");
860 if (block_size == 128) {
863 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
864 float* op = &out[rangeIndex * block_size];
865 __m256 vop0 = _mm256_setzero_ps();
866 __m256 vop8 = _mm256_setzero_ps();
867 __m256 vop16 = _mm256_setzero_ps();
868 __m256 vop24 = _mm256_setzero_ps();
869 __m256 vop32 = _mm256_setzero_ps();
870 __m256 vop40 = _mm256_setzero_ps();
871 __m256 vop48 = _mm256_setzero_ps();
872 __m256 vop56 = _mm256_setzero_ps();
873 __m256 vop64 = _mm256_setzero_ps();
874 __m256 vop72 = _mm256_setzero_ps();
875 __m256 vop80 = _mm256_setzero_ps();
876 __m256 vop88 = _mm256_setzero_ps();
877 __m256 vop96 = _mm256_setzero_ps();
878 __m256 vop104 = _mm256_setzero_ps();
879 __m256 vop112 = _mm256_setzero_ps();
880 __m256 vop120 = _mm256_setzero_ps();
881 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
883 const int32_t idx = indices[dataInd];
885 idx >= 0 && idx < data_size,
888 " is out of bounds: ",
894 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
896 __m256 vwgt = _mm256_set1_ps(wgt);
897 const float16* ip = &input[idx * fused_block_size];
898 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
899 ? (dataInd + prefdist_T0)
901 const int32_t idx_pref_T0 = indices[next_T0];
902 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
903 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
904 vop0 = _mm256_fmadd_ps(
907 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
909 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
910 vop8 = _mm256_fmadd_ps(
913 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
916 vop16 = _mm256_fmadd_ps(
919 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
922 vop24 = _mm256_fmadd_ps(
925 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
928 vop32 = _mm256_fmadd_ps(
931 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
933 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
934 vop40 = _mm256_fmadd_ps(
937 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
940 vop48 = _mm256_fmadd_ps(
943 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
946 vop56 = _mm256_fmadd_ps(
949 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
952 vop64 = _mm256_fmadd_ps(
955 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
957 _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
958 vop72 = _mm256_fmadd_ps(
961 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
964 vop80 = _mm256_fmadd_ps(
967 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
970 vop88 = _mm256_fmadd_ps(
973 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
976 vop96 = _mm256_fmadd_ps(
979 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
981 _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
982 vop104 = _mm256_fmadd_ps(
985 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
988 vop112 = _mm256_fmadd_ps(
991 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
994 vop120 = _mm256_fmadd_ps(
997 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1001 if (normalize_by_lengths ==
false) {
1002 _mm256_storeu_ps(&op[0], vop0);
1003 _mm256_storeu_ps(&op[8], vop8);
1004 _mm256_storeu_ps(&op[16], vop16);
1005 _mm256_storeu_ps(&op[24], vop24);
1006 _mm256_storeu_ps(&op[32], vop32);
1007 _mm256_storeu_ps(&op[40], vop40);
1008 _mm256_storeu_ps(&op[48], vop48);
1009 _mm256_storeu_ps(&op[56], vop56);
1010 _mm256_storeu_ps(&op[64], vop64);
1011 _mm256_storeu_ps(&op[72], vop72);
1012 _mm256_storeu_ps(&op[80], vop80);
1013 _mm256_storeu_ps(&op[88], vop88);
1014 _mm256_storeu_ps(&op[96], vop96);
1015 _mm256_storeu_ps(&op[104], vop104);
1016 _mm256_storeu_ps(&op[112], vop112);
1017 _mm256_storeu_ps(&op[120], vop120);
1018 }
else if (lengths[rangeIndex]) {
1019 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1020 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1021 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1022 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1023 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1024 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1025 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1026 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1027 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1028 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1029 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1030 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1031 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1032 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1033 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1034 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1035 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1038 }
else if (block_size == 64) {
1040 int32_t dataInd = 0;
1041 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1042 float* op = &out[rangeIndex * block_size];
1043 __m256 vop0 = _mm256_setzero_ps();
1044 __m256 vop8 = _mm256_setzero_ps();
1045 __m256 vop16 = _mm256_setzero_ps();
1046 __m256 vop24 = _mm256_setzero_ps();
1047 __m256 vop32 = _mm256_setzero_ps();
1048 __m256 vop40 = _mm256_setzero_ps();
1049 __m256 vop48 = _mm256_setzero_ps();
1050 __m256 vop56 = _mm256_setzero_ps();
1051 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1053 const int32_t idx = indices[dataInd];
1055 idx >= 0 && idx < data_size,
1058 " is out of bounds: ",
1064 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1066 __m256 vwgt = _mm256_set1_ps(wgt);
1067 const float16* ip = &input[idx * fused_block_size];
1068 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1069 ? (dataInd + prefdist_T0)
1071 const int32_t idx_pref_T0 = indices[next_T0];
1072 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1073 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1074 vop0 = _mm256_fmadd_ps(
1077 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1079 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1080 vop8 = _mm256_fmadd_ps(
1083 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1086 vop16 = _mm256_fmadd_ps(
1089 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1092 vop24 = _mm256_fmadd_ps(
1095 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1098 vop32 = _mm256_fmadd_ps(
1101 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1103 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1104 vop40 = _mm256_fmadd_ps(
1107 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1110 vop48 = _mm256_fmadd_ps(
1113 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1116 vop56 = _mm256_fmadd_ps(
1119 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1123 if (normalize_by_lengths ==
false) {
1124 _mm256_storeu_ps(&op[0], vop0);
1125 _mm256_storeu_ps(&op[8], vop8);
1126 _mm256_storeu_ps(&op[16], vop16);
1127 _mm256_storeu_ps(&op[24], vop24);
1128 _mm256_storeu_ps(&op[32], vop32);
1129 _mm256_storeu_ps(&op[40], vop40);
1130 _mm256_storeu_ps(&op[48], vop48);
1131 _mm256_storeu_ps(&op[56], vop56);
1132 }
else if (lengths[rangeIndex]) {
1133 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1134 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1135 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1136 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1137 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1138 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1139 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1140 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1141 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1144 }
else if (block_size == 32) {
1146 int32_t dataInd = 0;
1147 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1148 float* op = &out[rangeIndex * block_size];
1149 __m256 vop0 = _mm256_setzero_ps();
1150 __m256 vop8 = _mm256_setzero_ps();
1151 __m256 vop16 = _mm256_setzero_ps();
1152 __m256 vop24 = _mm256_setzero_ps();
1153 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1155 const int32_t idx = indices[dataInd];
1157 idx >= 0 && idx < data_size,
1160 " is out of bounds: ",
1166 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1168 __m256 vwgt = _mm256_set1_ps(wgt);
1169 const float16* ip = &input[idx * fused_block_size];
1170 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1171 ? (dataInd + prefdist_T0)
1173 const int32_t idx_pref_T0 = indices[next_T0];
1174 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1175 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1176 vop0 = _mm256_fmadd_ps(
1179 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1181 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1182 vop8 = _mm256_fmadd_ps(
1185 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1188 vop16 = _mm256_fmadd_ps(
1191 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1194 vop24 = _mm256_fmadd_ps(
1197 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1201 if (normalize_by_lengths ==
false) {
1202 _mm256_storeu_ps(&op[0], vop0);
1203 _mm256_storeu_ps(&op[8], vop8);
1204 _mm256_storeu_ps(&op[16], vop16);
1205 _mm256_storeu_ps(&op[24], vop24);
1206 }
else if (lengths[rangeIndex]) {
1207 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1208 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1209 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1210 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1211 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1214 }
else if (block_size == 16) {
1216 int32_t dataInd = 0;
1217 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1218 float* op = &out[rangeIndex * block_size];
1219 __m256 vop0 = _mm256_setzero_ps();
1220 __m256 vop8 = _mm256_setzero_ps();
1221 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1223 const int32_t idx = indices[dataInd];
1225 idx >= 0 && idx < data_size,
1228 " is out of bounds: ",
1234 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1236 __m256 vwgt = _mm256_set1_ps(wgt);
1237 const float16* ip = &input[idx * fused_block_size];
1238 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1239 ? (dataInd + prefdist_T0)
1241 const int32_t idx_pref_T0 = indices[next_T0];
1242 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1243 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1244 vop0 = _mm256_fmadd_ps(
1247 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1249 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1250 vop8 = _mm256_fmadd_ps(
1253 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1257 if (normalize_by_lengths ==
false) {
1258 _mm256_storeu_ps(&op[0], vop0);
1259 _mm256_storeu_ps(&op[8], vop8);
1260 }
else if (lengths[rangeIndex]) {
1261 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1262 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1263 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1268 int32_t dataInd = 0;
1269 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1270 float* op = &out[rangeIndex * block_size];
1272 for (; j + 8 <= block_size; j += 8) {
1273 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1275 for (; j < block_size; j++) {
1278 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1280 const int32_t idx = indices[dataInd];
1282 idx >= 0 && idx < data_size,
1285 " is out of bounds: ",
1291 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1293 __m256 vwgt = _mm256_set1_ps(wgt);
1294 const float16* ip = &input[idx * fused_block_size];
1295 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1296 ? (dataInd + prefdist_T0)
1298 const int32_t idx_pref_T0 = indices[next_T0];
1299 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1300 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1302 for (; j + 8 <= block_size; j += 8) {
1307 _mm256_cvtph_ps(_mm_loadu_si128(
1308 reinterpret_cast<const __m128i*>(&ip[j]))),
1309 _mm256_loadu_ps(&op[j])));
1310 _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
1312 float16 vtmp1[8] CAFFE2_ALIGNED(64);
1313 for (; j < block_size; j++) {
1315 __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1316 op[j] += wgt * ((
float*)(&vtmp2))[0];
1319 if (normalize_by_lengths && lengths[rangeIndex]) {
1320 float len_inv = 1.0f / lengths[rangeIndex];
1321 __m256 vlen_inv = _mm256_set1_ps(len_inv);
1323 for (; j + 8 <= block_size; j += 8) {
1325 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1327 for (; j < block_size; j++) {
1328 op[j] = len_inv * op[j];
1334 void EmbeddingLookup_int32_t_float16_float_false__avx2_fma(
1335 const TIndex block_size,
1336 const TIndex output_size,
1337 const TIndex index_size,
1338 const TIndex data_size,
1339 const float16* input,
1340 const int32_t* indices,
1342 const float* weights,
1343 const float* scale_bias,
1344 bool normalize_by_lengths,
1346 EmbeddingLookup_int32_t_float16_float__avx2_fma<false>(
1356 normalize_by_lengths,
1359 void EmbeddingLookup_int32_t_float16_float_true__avx2_fma(
1360 const TIndex block_size,
1361 const TIndex output_size,
1362 const TIndex index_size,
1363 const TIndex data_size,
1364 const float16* input,
1365 const int32_t* indices,
1367 const float* weights,
1368 const float* scale_bias,
1369 bool normalize_by_lengths,
1371 EmbeddingLookup_int32_t_float16_float__avx2_fma<true>(
1381 normalize_by_lengths,
1385 template <
bool IS_WEIGHT_POSITIONAL>
1386 static void EmbeddingLookup_int64_t_float16_float__avx2_fma(
1387 const TIndex block_size,
1388 const TIndex output_size,
1389 const TIndex index_size,
1390 const TIndex data_size,
1391 const float16* input,
1392 const int64_t* indices,
1394 const float* weights,
1395 const float* scale_bias,
1396 bool normalize_by_lengths,
1398 const int64_t prefdist_T0 = 16;
1399 const int64_t fused_block_size = block_size + 0;
1400 CAFFE_ENFORCE(scale_bias ==
nullptr,
"scale_bias must be nullptr");
1401 if (block_size == 128) {
1403 int64_t dataInd = 0;
1404 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1405 float* op = &out[rangeIndex * block_size];
1406 __m256 vop0 = _mm256_setzero_ps();
1407 __m256 vop8 = _mm256_setzero_ps();
1408 __m256 vop16 = _mm256_setzero_ps();
1409 __m256 vop24 = _mm256_setzero_ps();
1410 __m256 vop32 = _mm256_setzero_ps();
1411 __m256 vop40 = _mm256_setzero_ps();
1412 __m256 vop48 = _mm256_setzero_ps();
1413 __m256 vop56 = _mm256_setzero_ps();
1414 __m256 vop64 = _mm256_setzero_ps();
1415 __m256 vop72 = _mm256_setzero_ps();
1416 __m256 vop80 = _mm256_setzero_ps();
1417 __m256 vop88 = _mm256_setzero_ps();
1418 __m256 vop96 = _mm256_setzero_ps();
1419 __m256 vop104 = _mm256_setzero_ps();
1420 __m256 vop112 = _mm256_setzero_ps();
1421 __m256 vop120 = _mm256_setzero_ps();
1422 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1424 const int64_t idx = indices[dataInd];
1426 idx >= 0 && idx < data_size,
1429 " is out of bounds: ",
1435 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1437 __m256 vwgt = _mm256_set1_ps(wgt);
1438 const float16* ip = &input[idx * fused_block_size];
1439 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1440 ? (dataInd + prefdist_T0)
1442 const int64_t idx_pref_T0 = indices[next_T0];
1443 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1444 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1445 vop0 = _mm256_fmadd_ps(
1448 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1450 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1451 vop8 = _mm256_fmadd_ps(
1454 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1457 vop16 = _mm256_fmadd_ps(
1460 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1463 vop24 = _mm256_fmadd_ps(
1466 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1469 vop32 = _mm256_fmadd_ps(
1472 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1474 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1475 vop40 = _mm256_fmadd_ps(
1478 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1481 vop48 = _mm256_fmadd_ps(
1484 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1487 vop56 = _mm256_fmadd_ps(
1490 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1493 vop64 = _mm256_fmadd_ps(
1496 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1498 _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
1499 vop72 = _mm256_fmadd_ps(
1502 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1505 vop80 = _mm256_fmadd_ps(
1508 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1511 vop88 = _mm256_fmadd_ps(
1514 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1517 vop96 = _mm256_fmadd_ps(
1520 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1522 _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
1523 vop104 = _mm256_fmadd_ps(
1526 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1529 vop112 = _mm256_fmadd_ps(
1532 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1535 vop120 = _mm256_fmadd_ps(
1538 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1542 if (normalize_by_lengths ==
false) {
1543 _mm256_storeu_ps(&op[0], vop0);
1544 _mm256_storeu_ps(&op[8], vop8);
1545 _mm256_storeu_ps(&op[16], vop16);
1546 _mm256_storeu_ps(&op[24], vop24);
1547 _mm256_storeu_ps(&op[32], vop32);
1548 _mm256_storeu_ps(&op[40], vop40);
1549 _mm256_storeu_ps(&op[48], vop48);
1550 _mm256_storeu_ps(&op[56], vop56);
1551 _mm256_storeu_ps(&op[64], vop64);
1552 _mm256_storeu_ps(&op[72], vop72);
1553 _mm256_storeu_ps(&op[80], vop80);
1554 _mm256_storeu_ps(&op[88], vop88);
1555 _mm256_storeu_ps(&op[96], vop96);
1556 _mm256_storeu_ps(&op[104], vop104);
1557 _mm256_storeu_ps(&op[112], vop112);
1558 _mm256_storeu_ps(&op[120], vop120);
1559 }
else if (lengths[rangeIndex]) {
1560 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1561 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1562 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1563 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1564 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1565 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1566 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1567 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1568 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1569 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1570 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1571 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1572 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1573 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1574 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1575 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1576 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1579 }
else if (block_size == 64) {
1581 int64_t dataInd = 0;
1582 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1583 float* op = &out[rangeIndex * block_size];
1584 __m256 vop0 = _mm256_setzero_ps();
1585 __m256 vop8 = _mm256_setzero_ps();
1586 __m256 vop16 = _mm256_setzero_ps();
1587 __m256 vop24 = _mm256_setzero_ps();
1588 __m256 vop32 = _mm256_setzero_ps();
1589 __m256 vop40 = _mm256_setzero_ps();
1590 __m256 vop48 = _mm256_setzero_ps();
1591 __m256 vop56 = _mm256_setzero_ps();
1592 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1594 const int64_t idx = indices[dataInd];
1596 idx >= 0 && idx < data_size,
1599 " is out of bounds: ",
1605 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1607 __m256 vwgt = _mm256_set1_ps(wgt);
1608 const float16* ip = &input[idx * fused_block_size];
1609 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1610 ? (dataInd + prefdist_T0)
1612 const int64_t idx_pref_T0 = indices[next_T0];
1613 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1614 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1615 vop0 = _mm256_fmadd_ps(
1618 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1620 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1621 vop8 = _mm256_fmadd_ps(
1624 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1627 vop16 = _mm256_fmadd_ps(
1630 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1633 vop24 = _mm256_fmadd_ps(
1636 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1639 vop32 = _mm256_fmadd_ps(
1642 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1644 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1645 vop40 = _mm256_fmadd_ps(
1648 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1651 vop48 = _mm256_fmadd_ps(
1654 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1657 vop56 = _mm256_fmadd_ps(
1660 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1664 if (normalize_by_lengths ==
false) {
1665 _mm256_storeu_ps(&op[0], vop0);
1666 _mm256_storeu_ps(&op[8], vop8);
1667 _mm256_storeu_ps(&op[16], vop16);
1668 _mm256_storeu_ps(&op[24], vop24);
1669 _mm256_storeu_ps(&op[32], vop32);
1670 _mm256_storeu_ps(&op[40], vop40);
1671 _mm256_storeu_ps(&op[48], vop48);
1672 _mm256_storeu_ps(&op[56], vop56);
1673 }
else if (lengths[rangeIndex]) {
1674 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1675 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1676 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1677 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1678 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1679 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1680 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1681 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1682 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1685 }
else if (block_size == 32) {
1687 int64_t dataInd = 0;
1688 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1689 float* op = &out[rangeIndex * block_size];
1690 __m256 vop0 = _mm256_setzero_ps();
1691 __m256 vop8 = _mm256_setzero_ps();
1692 __m256 vop16 = _mm256_setzero_ps();
1693 __m256 vop24 = _mm256_setzero_ps();
1694 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1696 const int64_t idx = indices[dataInd];
1698 idx >= 0 && idx < data_size,
1701 " is out of bounds: ",
1707 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1709 __m256 vwgt = _mm256_set1_ps(wgt);
1710 const float16* ip = &input[idx * fused_block_size];
1711 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1712 ? (dataInd + prefdist_T0)
1714 const int64_t idx_pref_T0 = indices[next_T0];
1715 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1716 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1717 vop0 = _mm256_fmadd_ps(
1720 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1722 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1723 vop8 = _mm256_fmadd_ps(
1726 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1729 vop16 = _mm256_fmadd_ps(
1732 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1735 vop24 = _mm256_fmadd_ps(
1738 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1742 if (normalize_by_lengths ==
false) {
1743 _mm256_storeu_ps(&op[0], vop0);
1744 _mm256_storeu_ps(&op[8], vop8);
1745 _mm256_storeu_ps(&op[16], vop16);
1746 _mm256_storeu_ps(&op[24], vop24);
1747 }
else if (lengths[rangeIndex]) {
1748 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1749 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1750 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1751 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1752 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1755 }
else if (block_size == 16) {
1757 int64_t dataInd = 0;
1758 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1759 float* op = &out[rangeIndex * block_size];
1760 __m256 vop0 = _mm256_setzero_ps();
1761 __m256 vop8 = _mm256_setzero_ps();
1762 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1764 const int64_t idx = indices[dataInd];
1766 idx >= 0 && idx < data_size,
1769 " is out of bounds: ",
1775 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1777 __m256 vwgt = _mm256_set1_ps(wgt);
1778 const float16* ip = &input[idx * fused_block_size];
1779 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1780 ? (dataInd + prefdist_T0)
1782 const int64_t idx_pref_T0 = indices[next_T0];
1783 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1784 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1785 vop0 = _mm256_fmadd_ps(
1788 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1790 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1791 vop8 = _mm256_fmadd_ps(
1794 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1798 if (normalize_by_lengths ==
false) {
1799 _mm256_storeu_ps(&op[0], vop0);
1800 _mm256_storeu_ps(&op[8], vop8);
1801 }
else if (lengths[rangeIndex]) {
1802 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1803 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1804 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1809 int64_t dataInd = 0;
1810 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1811 float* op = &out[rangeIndex * block_size];
1813 for (; j + 8 <= block_size; j += 8) {
1814 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1816 for (; j < block_size; j++) {
1819 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1821 const int64_t idx = indices[dataInd];
1823 idx >= 0 && idx < data_size,
1826 " is out of bounds: ",
1832 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1834 __m256 vwgt = _mm256_set1_ps(wgt);
1835 const float16* ip = &input[idx * fused_block_size];
1836 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1837 ? (dataInd + prefdist_T0)
1839 const int64_t idx_pref_T0 = indices[next_T0];
1840 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1841 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1843 for (; j + 8 <= block_size; j += 8) {
1848 _mm256_cvtph_ps(_mm_loadu_si128(
1849 reinterpret_cast<const __m128i*>(&ip[j]))),
1850 _mm256_loadu_ps(&op[j])));
1851 _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
1853 float16 vtmp1[8] CAFFE2_ALIGNED(64);
1854 for (; j < block_size; j++) {
1856 __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1857 op[j] += wgt * ((
float*)(&vtmp2))[0];
1860 if (normalize_by_lengths && lengths[rangeIndex]) {
1861 float len_inv = 1.0f / lengths[rangeIndex];
1862 __m256 vlen_inv = _mm256_set1_ps(len_inv);
1864 for (; j + 8 <= block_size; j += 8) {
1866 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1868 for (; j < block_size; j++) {
1869 op[j] = len_inv * op[j];
1875 void EmbeddingLookup_int64_t_float16_float_false__avx2_fma(
1876 const TIndex block_size,
1877 const TIndex output_size,
1878 const TIndex index_size,
1879 const TIndex data_size,
1880 const float16* input,
1881 const int64_t* indices,
1883 const float* weights,
1884 const float* scale_bias,
1885 bool normalize_by_lengths,
1887 EmbeddingLookup_int64_t_float16_float__avx2_fma<false>(
1897 normalize_by_lengths,
1900 void EmbeddingLookup_int64_t_float16_float_true__avx2_fma(
1901 const TIndex block_size,
1902 const TIndex output_size,
1903 const TIndex index_size,
1904 const TIndex data_size,
1905 const float16* input,
1906 const int64_t* indices,
1908 const float* weights,
1909 const float* scale_bias,
1910 bool normalize_by_lengths,
1912 EmbeddingLookup_int64_t_float16_float__avx2_fma<true>(
1922 normalize_by_lengths,
1926 template <
bool IS_WEIGHT_POSITIONAL>
1927 static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
1928 const TIndex block_size,
1929 const TIndex output_size,
1930 const TIndex index_size,
1931 const TIndex data_size,
1932 const uint8_t* input,
1933 const int32_t* indices,
1935 const float* weights,
1936 const float* scale_bias,
1937 bool normalize_by_lengths,
1939 const int32_t prefdist_T0 = 16;
1940 const int32_t fused_block_size = block_size + 0;
1941 CAFFE_ENFORCE(scale_bias !=
nullptr,
"scale_bias must not be nullptr");
1942 if (block_size == 128) {
1944 int32_t dataInd = 0;
1945 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1946 float* op = &out[rangeIndex * block_size];
1947 __m256 vop0 = _mm256_setzero_ps();
1948 __m256 vop8 = _mm256_setzero_ps();
1949 __m256 vop16 = _mm256_setzero_ps();
1950 __m256 vop24 = _mm256_setzero_ps();
1951 __m256 vop32 = _mm256_setzero_ps();
1952 __m256 vop40 = _mm256_setzero_ps();
1953 __m256 vop48 = _mm256_setzero_ps();
1954 __m256 vop56 = _mm256_setzero_ps();
1955 __m256 vop64 = _mm256_setzero_ps();
1956 __m256 vop72 = _mm256_setzero_ps();
1957 __m256 vop80 = _mm256_setzero_ps();
1958 __m256 vop88 = _mm256_setzero_ps();
1959 __m256 vop96 = _mm256_setzero_ps();
1960 __m256 vop104 = _mm256_setzero_ps();
1961 __m256 vop112 = _mm256_setzero_ps();
1962 __m256 vop120 = _mm256_setzero_ps();
1963 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1965 const int32_t idx = indices[dataInd];
1967 idx >= 0 && idx < data_size,
1970 " is out of bounds: ",
1977 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1979 bio = wgt * scale_bias[2 * idx + 1];
1980 wgt = wgt * scale_bias[2 * idx];
1981 __m256 vbio = _mm256_set1_ps(bio);
1982 __m256 vwgt = _mm256_set1_ps(wgt);
1983 const uint8_t* ip = &input[idx * fused_block_size];
1984 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1985 ? (dataInd + prefdist_T0)
1987 const int32_t idx_pref_T0 = indices[next_T0];
1988 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1989 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1990 vop0 = _mm256_fmadd_ps(
1992 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1993 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
1994 _mm256_add_ps(vop0, vbio));
1995 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1996 vop8 = _mm256_fmadd_ps(
1998 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1999 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2000 _mm256_add_ps(vop8, vbio));
2002 vop16 = _mm256_fmadd_ps(
2004 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2005 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2006 _mm256_add_ps(vop16, vbio));
2008 vop24 = _mm256_fmadd_ps(
2010 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2011 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2012 _mm256_add_ps(vop24, vbio));
2014 vop32 = _mm256_fmadd_ps(
2016 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2017 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2018 _mm256_add_ps(vop32, vbio));
2020 vop40 = _mm256_fmadd_ps(
2022 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2023 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2024 _mm256_add_ps(vop40, vbio));
2026 vop48 = _mm256_fmadd_ps(
2028 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2029 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2030 _mm256_add_ps(vop48, vbio));
2032 vop56 = _mm256_fmadd_ps(
2034 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2035 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2036 _mm256_add_ps(vop56, vbio));
2038 vop64 = _mm256_fmadd_ps(
2040 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2041 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2042 _mm256_add_ps(vop64, vbio));
2043 _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
2044 vop72 = _mm256_fmadd_ps(
2046 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2047 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2048 _mm256_add_ps(vop72, vbio));
2050 vop80 = _mm256_fmadd_ps(
2052 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2053 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2054 _mm256_add_ps(vop80, vbio));
2056 vop88 = _mm256_fmadd_ps(
2058 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2059 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2060 _mm256_add_ps(vop88, vbio));
2062 vop96 = _mm256_fmadd_ps(
2064 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2065 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2066 _mm256_add_ps(vop96, vbio));
2068 vop104 = _mm256_fmadd_ps(
2070 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2071 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2072 _mm256_add_ps(vop104, vbio));
2074 vop112 = _mm256_fmadd_ps(
2076 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2077 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2078 _mm256_add_ps(vop112, vbio));
2080 vop120 = _mm256_fmadd_ps(
2082 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2083 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2084 _mm256_add_ps(vop120, vbio));
2087 if (normalize_by_lengths ==
false) {
2088 _mm256_storeu_ps(&op[0], vop0);
2089 _mm256_storeu_ps(&op[8], vop8);
2090 _mm256_storeu_ps(&op[16], vop16);
2091 _mm256_storeu_ps(&op[24], vop24);
2092 _mm256_storeu_ps(&op[32], vop32);
2093 _mm256_storeu_ps(&op[40], vop40);
2094 _mm256_storeu_ps(&op[48], vop48);
2095 _mm256_storeu_ps(&op[56], vop56);
2096 _mm256_storeu_ps(&op[64], vop64);
2097 _mm256_storeu_ps(&op[72], vop72);
2098 _mm256_storeu_ps(&op[80], vop80);
2099 _mm256_storeu_ps(&op[88], vop88);
2100 _mm256_storeu_ps(&op[96], vop96);
2101 _mm256_storeu_ps(&op[104], vop104);
2102 _mm256_storeu_ps(&op[112], vop112);
2103 _mm256_storeu_ps(&op[120], vop120);
2104 }
else if (lengths[rangeIndex]) {
2105 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2106 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2107 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2108 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2109 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2110 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2111 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2112 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2113 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2114 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2115 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2116 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2117 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2118 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2119 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2120 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2121 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2124 }
else if (block_size == 64) {
2126 int32_t dataInd = 0;
2127 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2128 float* op = &out[rangeIndex * block_size];
2129 __m256 vop0 = _mm256_setzero_ps();
2130 __m256 vop8 = _mm256_setzero_ps();
2131 __m256 vop16 = _mm256_setzero_ps();
2132 __m256 vop24 = _mm256_setzero_ps();
2133 __m256 vop32 = _mm256_setzero_ps();
2134 __m256 vop40 = _mm256_setzero_ps();
2135 __m256 vop48 = _mm256_setzero_ps();
2136 __m256 vop56 = _mm256_setzero_ps();
2137 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2139 const int32_t idx = indices[dataInd];
2141 idx >= 0 && idx < data_size,
2144 " is out of bounds: ",
2151 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2153 bio = wgt * scale_bias[2 * idx + 1];
2154 wgt = wgt * scale_bias[2 * idx];
2155 __m256 vbio = _mm256_set1_ps(bio);
2156 __m256 vwgt = _mm256_set1_ps(wgt);
2157 const uint8_t* ip = &input[idx * fused_block_size];
2158 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2159 ? (dataInd + prefdist_T0)
2161 const int32_t idx_pref_T0 = indices[next_T0];
2162 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2163 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2164 vop0 = _mm256_fmadd_ps(
2166 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2167 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2168 _mm256_add_ps(vop0, vbio));
2169 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2170 vop8 = _mm256_fmadd_ps(
2172 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2173 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2174 _mm256_add_ps(vop8, vbio));
2176 vop16 = _mm256_fmadd_ps(
2178 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2179 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2180 _mm256_add_ps(vop16, vbio));
2182 vop24 = _mm256_fmadd_ps(
2184 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2185 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2186 _mm256_add_ps(vop24, vbio));
2188 vop32 = _mm256_fmadd_ps(
2190 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2191 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2192 _mm256_add_ps(vop32, vbio));
2194 vop40 = _mm256_fmadd_ps(
2196 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2197 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2198 _mm256_add_ps(vop40, vbio));
2200 vop48 = _mm256_fmadd_ps(
2202 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2203 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2204 _mm256_add_ps(vop48, vbio));
2206 vop56 = _mm256_fmadd_ps(
2208 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2209 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2210 _mm256_add_ps(vop56, vbio));
2213 if (normalize_by_lengths ==
false) {
2214 _mm256_storeu_ps(&op[0], vop0);
2215 _mm256_storeu_ps(&op[8], vop8);
2216 _mm256_storeu_ps(&op[16], vop16);
2217 _mm256_storeu_ps(&op[24], vop24);
2218 _mm256_storeu_ps(&op[32], vop32);
2219 _mm256_storeu_ps(&op[40], vop40);
2220 _mm256_storeu_ps(&op[48], vop48);
2221 _mm256_storeu_ps(&op[56], vop56);
2222 }
else if (lengths[rangeIndex]) {
2223 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2224 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2225 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2226 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2227 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2228 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2229 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2230 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2231 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2234 }
else if (block_size == 32) {
2236 int32_t dataInd = 0;
2237 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2238 float* op = &out[rangeIndex * block_size];
2239 __m256 vop0 = _mm256_setzero_ps();
2240 __m256 vop8 = _mm256_setzero_ps();
2241 __m256 vop16 = _mm256_setzero_ps();
2242 __m256 vop24 = _mm256_setzero_ps();
2243 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2245 const int32_t idx = indices[dataInd];
2247 idx >= 0 && idx < data_size,
2250 " is out of bounds: ",
2257 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2259 bio = wgt * scale_bias[2 * idx + 1];
2260 wgt = wgt * scale_bias[2 * idx];
2261 __m256 vbio = _mm256_set1_ps(bio);
2262 __m256 vwgt = _mm256_set1_ps(wgt);
2263 const uint8_t* ip = &input[idx * fused_block_size];
2264 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2265 ? (dataInd + prefdist_T0)
2267 const int32_t idx_pref_T0 = indices[next_T0];
2268 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2269 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2270 vop0 = _mm256_fmadd_ps(
2272 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2273 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2274 _mm256_add_ps(vop0, vbio));
2275 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2276 vop8 = _mm256_fmadd_ps(
2278 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2279 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2280 _mm256_add_ps(vop8, vbio));
2282 vop16 = _mm256_fmadd_ps(
2284 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2285 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2286 _mm256_add_ps(vop16, vbio));
2288 vop24 = _mm256_fmadd_ps(
2290 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2291 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2292 _mm256_add_ps(vop24, vbio));
2295 if (normalize_by_lengths ==
false) {
2296 _mm256_storeu_ps(&op[0], vop0);
2297 _mm256_storeu_ps(&op[8], vop8);
2298 _mm256_storeu_ps(&op[16], vop16);
2299 _mm256_storeu_ps(&op[24], vop24);
2300 }
else if (lengths[rangeIndex]) {
2301 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2302 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2303 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2304 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2305 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2308 }
else if (block_size == 16) {
2310 int32_t dataInd = 0;
2311 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2312 float* op = &out[rangeIndex * block_size];
2313 __m256 vop0 = _mm256_setzero_ps();
2314 __m256 vop8 = _mm256_setzero_ps();
2315 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2317 const int32_t idx = indices[dataInd];
2319 idx >= 0 && idx < data_size,
2322 " is out of bounds: ",
2329 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2331 bio = wgt * scale_bias[2 * idx + 1];
2332 wgt = wgt * scale_bias[2 * idx];
2333 __m256 vbio = _mm256_set1_ps(bio);
2334 __m256 vwgt = _mm256_set1_ps(wgt);
2335 const uint8_t* ip = &input[idx * fused_block_size];
2336 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2337 ? (dataInd + prefdist_T0)
2339 const int32_t idx_pref_T0 = indices[next_T0];
2340 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2341 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2342 vop0 = _mm256_fmadd_ps(
2344 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2345 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2346 _mm256_add_ps(vop0, vbio));
2347 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2348 vop8 = _mm256_fmadd_ps(
2350 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2351 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2352 _mm256_add_ps(vop8, vbio));
2355 if (normalize_by_lengths ==
false) {
2356 _mm256_storeu_ps(&op[0], vop0);
2357 _mm256_storeu_ps(&op[8], vop8);
2358 }
else if (lengths[rangeIndex]) {
2359 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2360 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2361 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2366 int32_t dataInd = 0;
2367 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2368 float* op = &out[rangeIndex * block_size];
2370 for (; j + 8 <= block_size; j += 8) {
2371 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2373 for (; j < block_size; j++) {
2376 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2378 const int32_t idx = indices[dataInd];
2380 idx >= 0 && idx < data_size,
2383 " is out of bounds: ",
2390 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2393 bio = wgt * scale_bias[2 * idx + 1];
2394 wgt = wgt * scale_bias[2 * idx];
2395 __m256 vbio = _mm256_set1_ps(bio);
2396 __m256 vwgt = _mm256_set1_ps(wgt);
2397 const uint8_t* ip = &input[idx * fused_block_size];
2398 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2399 ? (dataInd + prefdist_T0)
2401 const int32_t idx_pref_T0 = indices[next_T0];
2402 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2403 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2405 for (; j + 8 <= block_size; j += 8) {
2410 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2411 reinterpret_cast<const __m128i*>(&ip[j])))),
2412 _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2413 _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
2415 for (; j < block_size; j++) {
2416 op[j] += wgt * ((float)ip[j]) + bio;
2419 if (normalize_by_lengths && lengths[rangeIndex]) {
2420 float len_inv = 1.0f / lengths[rangeIndex];
2421 __m256 vlen_inv = _mm256_set1_ps(len_inv);
2423 for (; j + 8 <= block_size; j += 8) {
2425 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2427 for (; j < block_size; j++) {
2428 op[j] = len_inv * op[j];
2434 void EmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma(
2435 const TIndex block_size,
2436 const TIndex output_size,
2437 const TIndex index_size,
2438 const TIndex data_size,
2439 const uint8_t* input,
2440 const int32_t* indices,
2442 const float* weights,
2443 const float* scale_bias,
2444 bool normalize_by_lengths,
2446 EmbeddingLookup_int32_t_uint8_t_float__avx2_fma<false>(
2456 normalize_by_lengths,
2459 void EmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma(
2460 const TIndex block_size,
2461 const TIndex output_size,
2462 const TIndex index_size,
2463 const TIndex data_size,
2464 const uint8_t* input,
2465 const int32_t* indices,
2467 const float* weights,
2468 const float* scale_bias,
2469 bool normalize_by_lengths,
2471 EmbeddingLookup_int32_t_uint8_t_float__avx2_fma<true>(
2481 normalize_by_lengths,
2485 template <
bool IS_WEIGHT_POSITIONAL>
2486 static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
2487 const TIndex block_size,
2488 const TIndex output_size,
2489 const TIndex index_size,
2490 const TIndex data_size,
2491 const uint8_t* input,
2492 const int64_t* indices,
2494 const float* weights,
2495 const float* scale_bias,
2496 bool normalize_by_lengths,
2498 const int64_t prefdist_T0 = 16;
2499 const int64_t fused_block_size = block_size + 0;
2500 CAFFE_ENFORCE(scale_bias !=
nullptr,
"scale_bias must not be nullptr");
2501 if (block_size == 128) {
2503 int64_t dataInd = 0;
2504 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2505 float* op = &out[rangeIndex * block_size];
2506 __m256 vop0 = _mm256_setzero_ps();
2507 __m256 vop8 = _mm256_setzero_ps();
2508 __m256 vop16 = _mm256_setzero_ps();
2509 __m256 vop24 = _mm256_setzero_ps();
2510 __m256 vop32 = _mm256_setzero_ps();
2511 __m256 vop40 = _mm256_setzero_ps();
2512 __m256 vop48 = _mm256_setzero_ps();
2513 __m256 vop56 = _mm256_setzero_ps();
2514 __m256 vop64 = _mm256_setzero_ps();
2515 __m256 vop72 = _mm256_setzero_ps();
2516 __m256 vop80 = _mm256_setzero_ps();
2517 __m256 vop88 = _mm256_setzero_ps();
2518 __m256 vop96 = _mm256_setzero_ps();
2519 __m256 vop104 = _mm256_setzero_ps();
2520 __m256 vop112 = _mm256_setzero_ps();
2521 __m256 vop120 = _mm256_setzero_ps();
2522 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2524 const int64_t idx = indices[dataInd];
2526 idx >= 0 && idx < data_size,
2529 " is out of bounds: ",
2536 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2538 bio = wgt * scale_bias[2 * idx + 1];
2539 wgt = wgt * scale_bias[2 * idx];
2540 __m256 vbio = _mm256_set1_ps(bio);
2541 __m256 vwgt = _mm256_set1_ps(wgt);
2542 const uint8_t* ip = &input[idx * fused_block_size];
2543 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2544 ? (dataInd + prefdist_T0)
2546 const int64_t idx_pref_T0 = indices[next_T0];
2547 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2548 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2549 vop0 = _mm256_fmadd_ps(
2551 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2552 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2553 _mm256_add_ps(vop0, vbio));
2554 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2555 vop8 = _mm256_fmadd_ps(
2557 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2558 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2559 _mm256_add_ps(vop8, vbio));
2561 vop16 = _mm256_fmadd_ps(
2563 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2564 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2565 _mm256_add_ps(vop16, vbio));
2567 vop24 = _mm256_fmadd_ps(
2569 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2570 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2571 _mm256_add_ps(vop24, vbio));
2573 vop32 = _mm256_fmadd_ps(
2575 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2576 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2577 _mm256_add_ps(vop32, vbio));
2579 vop40 = _mm256_fmadd_ps(
2581 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2582 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2583 _mm256_add_ps(vop40, vbio));
2585 vop48 = _mm256_fmadd_ps(
2587 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2588 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2589 _mm256_add_ps(vop48, vbio));
2591 vop56 = _mm256_fmadd_ps(
2593 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2594 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2595 _mm256_add_ps(vop56, vbio));
2597 vop64 = _mm256_fmadd_ps(
2599 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2600 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2601 _mm256_add_ps(vop64, vbio));
2602 _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
2603 vop72 = _mm256_fmadd_ps(
2605 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2606 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2607 _mm256_add_ps(vop72, vbio));
2609 vop80 = _mm256_fmadd_ps(
2611 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2612 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2613 _mm256_add_ps(vop80, vbio));
2615 vop88 = _mm256_fmadd_ps(
2617 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2618 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2619 _mm256_add_ps(vop88, vbio));
2621 vop96 = _mm256_fmadd_ps(
2623 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2624 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2625 _mm256_add_ps(vop96, vbio));
2627 vop104 = _mm256_fmadd_ps(
2629 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2630 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2631 _mm256_add_ps(vop104, vbio));
2633 vop112 = _mm256_fmadd_ps(
2635 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2636 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2637 _mm256_add_ps(vop112, vbio));
2639 vop120 = _mm256_fmadd_ps(
2641 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2642 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2643 _mm256_add_ps(vop120, vbio));
2646 if (normalize_by_lengths ==
false) {
2647 _mm256_storeu_ps(&op[0], vop0);
2648 _mm256_storeu_ps(&op[8], vop8);
2649 _mm256_storeu_ps(&op[16], vop16);
2650 _mm256_storeu_ps(&op[24], vop24);
2651 _mm256_storeu_ps(&op[32], vop32);
2652 _mm256_storeu_ps(&op[40], vop40);
2653 _mm256_storeu_ps(&op[48], vop48);
2654 _mm256_storeu_ps(&op[56], vop56);
2655 _mm256_storeu_ps(&op[64], vop64);
2656 _mm256_storeu_ps(&op[72], vop72);
2657 _mm256_storeu_ps(&op[80], vop80);
2658 _mm256_storeu_ps(&op[88], vop88);
2659 _mm256_storeu_ps(&op[96], vop96);
2660 _mm256_storeu_ps(&op[104], vop104);
2661 _mm256_storeu_ps(&op[112], vop112);
2662 _mm256_storeu_ps(&op[120], vop120);
2663 }
else if (lengths[rangeIndex]) {
2664 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2665 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2666 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2667 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2668 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2669 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2670 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2671 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2672 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2673 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2674 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2675 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2676 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2677 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2678 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2679 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2680 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2683 }
else if (block_size == 64) {
2685 int64_t dataInd = 0;
2686 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2687 float* op = &out[rangeIndex * block_size];
2688 __m256 vop0 = _mm256_setzero_ps();
2689 __m256 vop8 = _mm256_setzero_ps();
2690 __m256 vop16 = _mm256_setzero_ps();
2691 __m256 vop24 = _mm256_setzero_ps();
2692 __m256 vop32 = _mm256_setzero_ps();
2693 __m256 vop40 = _mm256_setzero_ps();
2694 __m256 vop48 = _mm256_setzero_ps();
2695 __m256 vop56 = _mm256_setzero_ps();
2696 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2698 const int64_t idx = indices[dataInd];
2700 idx >= 0 && idx < data_size,
2703 " is out of bounds: ",
2710 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2712 bio = wgt * scale_bias[2 * idx + 1];
2713 wgt = wgt * scale_bias[2 * idx];
2714 __m256 vbio = _mm256_set1_ps(bio);
2715 __m256 vwgt = _mm256_set1_ps(wgt);
2716 const uint8_t* ip = &input[idx * fused_block_size];
2717 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2718 ? (dataInd + prefdist_T0)
2720 const int64_t idx_pref_T0 = indices[next_T0];
2721 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2722 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2723 vop0 = _mm256_fmadd_ps(
2725 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2726 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2727 _mm256_add_ps(vop0, vbio));
2728 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2729 vop8 = _mm256_fmadd_ps(
2731 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2732 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2733 _mm256_add_ps(vop8, vbio));
2735 vop16 = _mm256_fmadd_ps(
2737 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2738 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2739 _mm256_add_ps(vop16, vbio));
2741 vop24 = _mm256_fmadd_ps(
2743 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2744 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2745 _mm256_add_ps(vop24, vbio));
2747 vop32 = _mm256_fmadd_ps(
2749 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2750 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2751 _mm256_add_ps(vop32, vbio));
2753 vop40 = _mm256_fmadd_ps(
2755 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2756 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2757 _mm256_add_ps(vop40, vbio));
2759 vop48 = _mm256_fmadd_ps(
2761 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2762 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2763 _mm256_add_ps(vop48, vbio));
2765 vop56 = _mm256_fmadd_ps(
2767 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2768 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2769 _mm256_add_ps(vop56, vbio));
2772 if (normalize_by_lengths ==
false) {
2773 _mm256_storeu_ps(&op[0], vop0);
2774 _mm256_storeu_ps(&op[8], vop8);
2775 _mm256_storeu_ps(&op[16], vop16);
2776 _mm256_storeu_ps(&op[24], vop24);
2777 _mm256_storeu_ps(&op[32], vop32);
2778 _mm256_storeu_ps(&op[40], vop40);
2779 _mm256_storeu_ps(&op[48], vop48);
2780 _mm256_storeu_ps(&op[56], vop56);
2781 }
else if (lengths[rangeIndex]) {
2782 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2783 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2784 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2785 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2786 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2787 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2788 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2789 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2790 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2793 }
else if (block_size == 32) {
2795 int64_t dataInd = 0;
2796 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2797 float* op = &out[rangeIndex * block_size];
2798 __m256 vop0 = _mm256_setzero_ps();
2799 __m256 vop8 = _mm256_setzero_ps();
2800 __m256 vop16 = _mm256_setzero_ps();
2801 __m256 vop24 = _mm256_setzero_ps();
2802 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2804 const int64_t idx = indices[dataInd];
2806 idx >= 0 && idx < data_size,
2809 " is out of bounds: ",
2816 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2818 bio = wgt * scale_bias[2 * idx + 1];
2819 wgt = wgt * scale_bias[2 * idx];
2820 __m256 vbio = _mm256_set1_ps(bio);
2821 __m256 vwgt = _mm256_set1_ps(wgt);
2822 const uint8_t* ip = &input[idx * fused_block_size];
2823 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2824 ? (dataInd + prefdist_T0)
2826 const int64_t idx_pref_T0 = indices[next_T0];
2827 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2828 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2829 vop0 = _mm256_fmadd_ps(
2831 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2832 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2833 _mm256_add_ps(vop0, vbio));
2834 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2835 vop8 = _mm256_fmadd_ps(
2837 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2838 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2839 _mm256_add_ps(vop8, vbio));
2841 vop16 = _mm256_fmadd_ps(
2843 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2844 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2845 _mm256_add_ps(vop16, vbio));
2847 vop24 = _mm256_fmadd_ps(
2849 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2850 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2851 _mm256_add_ps(vop24, vbio));
2854 if (normalize_by_lengths ==
false) {
2855 _mm256_storeu_ps(&op[0], vop0);
2856 _mm256_storeu_ps(&op[8], vop8);
2857 _mm256_storeu_ps(&op[16], vop16);
2858 _mm256_storeu_ps(&op[24], vop24);
2859 }
else if (lengths[rangeIndex]) {
2860 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2861 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2862 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2863 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2864 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2867 }
else if (block_size == 16) {
2869 int64_t dataInd = 0;
2870 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2871 float* op = &out[rangeIndex * block_size];
2872 __m256 vop0 = _mm256_setzero_ps();
2873 __m256 vop8 = _mm256_setzero_ps();
2874 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2876 const int64_t idx = indices[dataInd];
2878 idx >= 0 && idx < data_size,
2881 " is out of bounds: ",
2888 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2890 bio = wgt * scale_bias[2 * idx + 1];
2891 wgt = wgt * scale_bias[2 * idx];
2892 __m256 vbio = _mm256_set1_ps(bio);
2893 __m256 vwgt = _mm256_set1_ps(wgt);
2894 const uint8_t* ip = &input[idx * fused_block_size];
2895 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2896 ? (dataInd + prefdist_T0)
2898 const int64_t idx_pref_T0 = indices[next_T0];
2899 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2900 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2901 vop0 = _mm256_fmadd_ps(
2903 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2904 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2905 _mm256_add_ps(vop0, vbio));
2906 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2907 vop8 = _mm256_fmadd_ps(
2909 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2910 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2911 _mm256_add_ps(vop8, vbio));
2914 if (normalize_by_lengths ==
false) {
2915 _mm256_storeu_ps(&op[0], vop0);
2916 _mm256_storeu_ps(&op[8], vop8);
2917 }
else if (lengths[rangeIndex]) {
2918 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2919 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2920 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2925 int64_t dataInd = 0;
2926 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2927 float* op = &out[rangeIndex * block_size];
2929 for (; j + 8 <= block_size; j += 8) {
2930 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2932 for (; j < block_size; j++) {
2935 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2937 const int64_t idx = indices[dataInd];
2939 idx >= 0 && idx < data_size,
2942 " is out of bounds: ",
2949 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2952 bio = wgt * scale_bias[2 * idx + 1];
2953 wgt = wgt * scale_bias[2 * idx];
2954 __m256 vbio = _mm256_set1_ps(bio);
2955 __m256 vwgt = _mm256_set1_ps(wgt);
2956 const uint8_t* ip = &input[idx * fused_block_size];
2957 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2958 ? (dataInd + prefdist_T0)
2960 const int64_t idx_pref_T0 = indices[next_T0];
2961 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2962 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2964 for (; j + 8 <= block_size; j += 8) {
2969 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2970 reinterpret_cast<const __m128i*>(&ip[j])))),
2971 _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2972 _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
2974 for (; j < block_size; j++) {
2975 op[j] += wgt * ((float)ip[j]) + bio;
2978 if (normalize_by_lengths && lengths[rangeIndex]) {
2979 float len_inv = 1.0f / lengths[rangeIndex];
2980 __m256 vlen_inv = _mm256_set1_ps(len_inv);
2982 for (; j + 8 <= block_size; j += 8) {
2984 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2986 for (; j < block_size; j++) {
2987 op[j] = len_inv * op[j];
2993 void EmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma(
2994 const TIndex block_size,
2995 const TIndex output_size,
2996 const TIndex index_size,
2997 const TIndex data_size,
2998 const uint8_t* input,
2999 const int64_t* indices,
3001 const float* weights,
3002 const float* scale_bias,
3003 bool normalize_by_lengths,
3005 EmbeddingLookup_int64_t_uint8_t_float__avx2_fma<false>(
3015 normalize_by_lengths,
3018 void EmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma(
3019 const TIndex block_size,
3020 const TIndex output_size,
3021 const TIndex index_size,
3022 const TIndex data_size,
3023 const uint8_t* input,
3024 const int64_t* indices,
3026 const float* weights,
3027 const float* scale_bias,
3028 bool normalize_by_lengths,
3030 EmbeddingLookup_int64_t_uint8_t_float__avx2_fma<true>(
3040 normalize_by_lengths,
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...