8 #include <caffe2/core/common.h> 9 #include <caffe2/core/types.h> 10 #include <immintrin.h> 14 template <
bool IS_WEIGHT_POSITIONAL>
15 static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
16 const TIndex block_size,
17 const TIndex output_size,
18 const TIndex index_size,
19 const TIndex data_size,
21 const int32_t* indices,
24 bool normalize_by_lengths,
26 const int32_t prefdist_T0 = 16;
27 const int32_t fused_block_size = block_size + 2;
28 if (block_size == 128) {
31 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
32 float* op = &out[rangeIndex * block_size];
33 __m256 vop0 = _mm256_setzero_ps();
34 __m256 vop8 = _mm256_setzero_ps();
35 __m256 vop16 = _mm256_setzero_ps();
36 __m256 vop24 = _mm256_setzero_ps();
37 __m256 vop32 = _mm256_setzero_ps();
38 __m256 vop40 = _mm256_setzero_ps();
39 __m256 vop48 = _mm256_setzero_ps();
40 __m256 vop56 = _mm256_setzero_ps();
41 __m256 vop64 = _mm256_setzero_ps();
42 __m256 vop72 = _mm256_setzero_ps();
43 __m256 vop80 = _mm256_setzero_ps();
44 __m256 vop88 = _mm256_setzero_ps();
45 __m256 vop96 = _mm256_setzero_ps();
46 __m256 vop104 = _mm256_setzero_ps();
47 __m256 vop112 = _mm256_setzero_ps();
48 __m256 vop120 = _mm256_setzero_ps();
49 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
51 const int32_t idx = indices[dataInd];
53 idx >= 0 && idx < data_size,
56 " is out of bounds: ",
62 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
64 __m256 vwgt = _mm256_set1_ps(wgt);
65 const float* ip = &input[idx * fused_block_size];
66 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
67 ? (dataInd + prefdist_T0)
69 const int32_t idx_pref_T0 = indices[next_T0];
70 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
71 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
72 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
73 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
74 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
76 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
77 _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
78 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
80 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
81 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
82 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
84 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
85 _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
86 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
88 vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
89 _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
90 vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
92 vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
93 _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
94 vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
96 vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
97 _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
98 vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
100 vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
101 _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
102 vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
105 if (normalize_by_lengths ==
false) {
106 _mm256_storeu_ps(&op[0], vop0);
107 _mm256_storeu_ps(&op[8], vop8);
108 _mm256_storeu_ps(&op[16], vop16);
109 _mm256_storeu_ps(&op[24], vop24);
110 _mm256_storeu_ps(&op[32], vop32);
111 _mm256_storeu_ps(&op[40], vop40);
112 _mm256_storeu_ps(&op[48], vop48);
113 _mm256_storeu_ps(&op[56], vop56);
114 _mm256_storeu_ps(&op[64], vop64);
115 _mm256_storeu_ps(&op[72], vop72);
116 _mm256_storeu_ps(&op[80], vop80);
117 _mm256_storeu_ps(&op[88], vop88);
118 _mm256_storeu_ps(&op[96], vop96);
119 _mm256_storeu_ps(&op[104], vop104);
120 _mm256_storeu_ps(&op[112], vop112);
121 _mm256_storeu_ps(&op[120], vop120);
122 }
else if (lengths[rangeIndex]) {
123 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
124 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
125 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
126 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
127 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
128 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
129 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
130 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
131 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
132 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
133 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
134 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
135 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
136 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
137 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
138 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
139 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
142 }
else if (block_size == 64) {
145 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
146 float* op = &out[rangeIndex * block_size];
147 __m256 vop0 = _mm256_setzero_ps();
148 __m256 vop8 = _mm256_setzero_ps();
149 __m256 vop16 = _mm256_setzero_ps();
150 __m256 vop24 = _mm256_setzero_ps();
151 __m256 vop32 = _mm256_setzero_ps();
152 __m256 vop40 = _mm256_setzero_ps();
153 __m256 vop48 = _mm256_setzero_ps();
154 __m256 vop56 = _mm256_setzero_ps();
155 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
157 const int32_t idx = indices[dataInd];
159 idx >= 0 && idx < data_size,
162 " is out of bounds: ",
168 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
170 __m256 vwgt = _mm256_set1_ps(wgt);
171 const float* ip = &input[idx * fused_block_size];
172 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
173 ? (dataInd + prefdist_T0)
175 const int32_t idx_pref_T0 = indices[next_T0];
176 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
177 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
178 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
179 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
180 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
182 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
183 _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
184 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
186 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
187 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
188 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
190 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
191 _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
192 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
195 if (normalize_by_lengths ==
false) {
196 _mm256_storeu_ps(&op[0], vop0);
197 _mm256_storeu_ps(&op[8], vop8);
198 _mm256_storeu_ps(&op[16], vop16);
199 _mm256_storeu_ps(&op[24], vop24);
200 _mm256_storeu_ps(&op[32], vop32);
201 _mm256_storeu_ps(&op[40], vop40);
202 _mm256_storeu_ps(&op[48], vop48);
203 _mm256_storeu_ps(&op[56], vop56);
204 }
else if (lengths[rangeIndex]) {
205 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
206 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
207 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
208 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
209 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
210 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
211 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
212 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
213 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
216 }
else if (block_size == 32) {
219 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
220 float* op = &out[rangeIndex * block_size];
221 __m256 vop0 = _mm256_setzero_ps();
222 __m256 vop8 = _mm256_setzero_ps();
223 __m256 vop16 = _mm256_setzero_ps();
224 __m256 vop24 = _mm256_setzero_ps();
225 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
227 const int32_t idx = indices[dataInd];
229 idx >= 0 && idx < data_size,
232 " is out of bounds: ",
238 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
240 __m256 vwgt = _mm256_set1_ps(wgt);
241 const float* ip = &input[idx * fused_block_size];
242 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
243 ? (dataInd + prefdist_T0)
245 const int32_t idx_pref_T0 = indices[next_T0];
246 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
247 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
248 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
249 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
250 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
252 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
253 _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
254 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
257 if (normalize_by_lengths ==
false) {
258 _mm256_storeu_ps(&op[0], vop0);
259 _mm256_storeu_ps(&op[8], vop8);
260 _mm256_storeu_ps(&op[16], vop16);
261 _mm256_storeu_ps(&op[24], vop24);
262 }
else if (lengths[rangeIndex]) {
263 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
264 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
265 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
266 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
267 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
270 }
else if (block_size == 16) {
273 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
274 float* op = &out[rangeIndex * block_size];
275 __m256 vop0 = _mm256_setzero_ps();
276 __m256 vop8 = _mm256_setzero_ps();
277 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
279 const int32_t idx = indices[dataInd];
281 idx >= 0 && idx < data_size,
284 " is out of bounds: ",
290 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
292 __m256 vwgt = _mm256_set1_ps(wgt);
293 const float* ip = &input[idx * fused_block_size];
294 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
295 ? (dataInd + prefdist_T0)
297 const int32_t idx_pref_T0 = indices[next_T0];
298 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
299 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
300 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
301 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
302 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
305 if (normalize_by_lengths ==
false) {
306 _mm256_storeu_ps(&op[0], vop0);
307 _mm256_storeu_ps(&op[8], vop8);
308 }
else if (lengths[rangeIndex]) {
309 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
310 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
311 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
317 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
318 float* op = &out[rangeIndex * block_size];
320 for (; j + 8 <= block_size; j += 8) {
321 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
323 for (; j < block_size; j++) {
326 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
328 const int32_t idx = indices[dataInd];
330 idx >= 0 && idx < data_size,
333 " is out of bounds: ",
339 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
341 __m256 vwgt = _mm256_set1_ps(wgt);
342 const float* ip = &input[idx * fused_block_size];
343 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
344 ? (dataInd + prefdist_T0)
346 const int32_t idx_pref_T0 = indices[next_T0];
347 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
348 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
350 for (; j + 8 <= block_size; j += 8) {
354 vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
355 _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
357 for (; j < block_size; j++) {
358 op[j] += wgt * ip[j];
361 if (normalize_by_lengths && lengths[rangeIndex]) {
362 float len_inv = 1.0f / lengths[rangeIndex];
363 __m256 vlen_inv = _mm256_set1_ps(len_inv);
365 for (; j + 8 <= block_size; j += 8) {
367 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
369 for (; j < block_size; j++) {
370 op[j] = len_inv * op[j];
376 void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_false__avx2_fma(
377 const TIndex block_size,
378 const TIndex output_size,
379 const TIndex index_size,
380 const TIndex data_size,
382 const int32_t* indices,
384 const float* weights,
385 bool normalize_by_lengths,
387 Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma<false>(
396 normalize_by_lengths,
399 void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_true__avx2_fma(
400 const TIndex block_size,
401 const TIndex output_size,
402 const TIndex index_size,
403 const TIndex data_size,
405 const int32_t* indices,
407 const float* weights,
408 bool normalize_by_lengths,
410 Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma<true>(
419 normalize_by_lengths,
423 template <
bool IS_WEIGHT_POSITIONAL>
424 static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
425 const TIndex block_size,
426 const TIndex output_size,
427 const TIndex index_size,
428 const TIndex data_size,
430 const int64_t* indices,
432 const float* weights,
433 bool normalize_by_lengths,
435 const int64_t prefdist_T0 = 16;
436 const int64_t fused_block_size = block_size + 2;
437 if (block_size == 128) {
440 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
441 float* op = &out[rangeIndex * block_size];
442 __m256 vop0 = _mm256_setzero_ps();
443 __m256 vop8 = _mm256_setzero_ps();
444 __m256 vop16 = _mm256_setzero_ps();
445 __m256 vop24 = _mm256_setzero_ps();
446 __m256 vop32 = _mm256_setzero_ps();
447 __m256 vop40 = _mm256_setzero_ps();
448 __m256 vop48 = _mm256_setzero_ps();
449 __m256 vop56 = _mm256_setzero_ps();
450 __m256 vop64 = _mm256_setzero_ps();
451 __m256 vop72 = _mm256_setzero_ps();
452 __m256 vop80 = _mm256_setzero_ps();
453 __m256 vop88 = _mm256_setzero_ps();
454 __m256 vop96 = _mm256_setzero_ps();
455 __m256 vop104 = _mm256_setzero_ps();
456 __m256 vop112 = _mm256_setzero_ps();
457 __m256 vop120 = _mm256_setzero_ps();
458 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
460 const int64_t idx = indices[dataInd];
462 idx >= 0 && idx < data_size,
465 " is out of bounds: ",
471 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
473 __m256 vwgt = _mm256_set1_ps(wgt);
474 const float* ip = &input[idx * fused_block_size];
475 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
476 ? (dataInd + prefdist_T0)
478 const int64_t idx_pref_T0 = indices[next_T0];
479 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
480 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
481 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
482 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
483 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
485 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
486 _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
487 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
489 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
490 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
491 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
493 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
494 _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
495 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
497 vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
498 _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
499 vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
501 vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
502 _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
503 vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
505 vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
506 _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
507 vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
509 vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
510 _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
511 vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
514 if (normalize_by_lengths ==
false) {
515 _mm256_storeu_ps(&op[0], vop0);
516 _mm256_storeu_ps(&op[8], vop8);
517 _mm256_storeu_ps(&op[16], vop16);
518 _mm256_storeu_ps(&op[24], vop24);
519 _mm256_storeu_ps(&op[32], vop32);
520 _mm256_storeu_ps(&op[40], vop40);
521 _mm256_storeu_ps(&op[48], vop48);
522 _mm256_storeu_ps(&op[56], vop56);
523 _mm256_storeu_ps(&op[64], vop64);
524 _mm256_storeu_ps(&op[72], vop72);
525 _mm256_storeu_ps(&op[80], vop80);
526 _mm256_storeu_ps(&op[88], vop88);
527 _mm256_storeu_ps(&op[96], vop96);
528 _mm256_storeu_ps(&op[104], vop104);
529 _mm256_storeu_ps(&op[112], vop112);
530 _mm256_storeu_ps(&op[120], vop120);
531 }
else if (lengths[rangeIndex]) {
532 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
533 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
534 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
535 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
536 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
537 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
538 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
539 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
540 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
541 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
542 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
543 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
544 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
545 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
546 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
547 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
548 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
551 }
else if (block_size == 64) {
554 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
555 float* op = &out[rangeIndex * block_size];
556 __m256 vop0 = _mm256_setzero_ps();
557 __m256 vop8 = _mm256_setzero_ps();
558 __m256 vop16 = _mm256_setzero_ps();
559 __m256 vop24 = _mm256_setzero_ps();
560 __m256 vop32 = _mm256_setzero_ps();
561 __m256 vop40 = _mm256_setzero_ps();
562 __m256 vop48 = _mm256_setzero_ps();
563 __m256 vop56 = _mm256_setzero_ps();
564 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
566 const int64_t idx = indices[dataInd];
568 idx >= 0 && idx < data_size,
571 " is out of bounds: ",
577 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
579 __m256 vwgt = _mm256_set1_ps(wgt);
580 const float* ip = &input[idx * fused_block_size];
581 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
582 ? (dataInd + prefdist_T0)
584 const int64_t idx_pref_T0 = indices[next_T0];
585 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
586 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
587 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
588 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
589 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
591 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
592 _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
593 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
595 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
596 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
597 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
599 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
600 _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
601 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
604 if (normalize_by_lengths ==
false) {
605 _mm256_storeu_ps(&op[0], vop0);
606 _mm256_storeu_ps(&op[8], vop8);
607 _mm256_storeu_ps(&op[16], vop16);
608 _mm256_storeu_ps(&op[24], vop24);
609 _mm256_storeu_ps(&op[32], vop32);
610 _mm256_storeu_ps(&op[40], vop40);
611 _mm256_storeu_ps(&op[48], vop48);
612 _mm256_storeu_ps(&op[56], vop56);
613 }
else if (lengths[rangeIndex]) {
614 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
615 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
616 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
617 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
618 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
619 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
620 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
621 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
622 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
625 }
else if (block_size == 32) {
628 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
629 float* op = &out[rangeIndex * block_size];
630 __m256 vop0 = _mm256_setzero_ps();
631 __m256 vop8 = _mm256_setzero_ps();
632 __m256 vop16 = _mm256_setzero_ps();
633 __m256 vop24 = _mm256_setzero_ps();
634 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
636 const int64_t idx = indices[dataInd];
638 idx >= 0 && idx < data_size,
641 " is out of bounds: ",
647 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
649 __m256 vwgt = _mm256_set1_ps(wgt);
650 const float* ip = &input[idx * fused_block_size];
651 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
652 ? (dataInd + prefdist_T0)
654 const int64_t idx_pref_T0 = indices[next_T0];
655 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
656 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
657 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
658 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
659 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
661 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
662 _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
663 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
666 if (normalize_by_lengths ==
false) {
667 _mm256_storeu_ps(&op[0], vop0);
668 _mm256_storeu_ps(&op[8], vop8);
669 _mm256_storeu_ps(&op[16], vop16);
670 _mm256_storeu_ps(&op[24], vop24);
671 }
else if (lengths[rangeIndex]) {
672 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
673 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
674 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
675 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
676 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
679 }
else if (block_size == 16) {
682 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
683 float* op = &out[rangeIndex * block_size];
684 __m256 vop0 = _mm256_setzero_ps();
685 __m256 vop8 = _mm256_setzero_ps();
686 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
688 const int64_t idx = indices[dataInd];
690 idx >= 0 && idx < data_size,
693 " is out of bounds: ",
699 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
701 __m256 vwgt = _mm256_set1_ps(wgt);
702 const float* ip = &input[idx * fused_block_size];
703 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
704 ? (dataInd + prefdist_T0)
706 const int64_t idx_pref_T0 = indices[next_T0];
707 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
708 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
709 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
710 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
711 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
714 if (normalize_by_lengths ==
false) {
715 _mm256_storeu_ps(&op[0], vop0);
716 _mm256_storeu_ps(&op[8], vop8);
717 }
else if (lengths[rangeIndex]) {
718 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
719 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
720 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
726 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
727 float* op = &out[rangeIndex * block_size];
729 for (; j + 8 <= block_size; j += 8) {
730 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
732 for (; j < block_size; j++) {
735 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
737 const int64_t idx = indices[dataInd];
739 idx >= 0 && idx < data_size,
742 " is out of bounds: ",
748 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
750 __m256 vwgt = _mm256_set1_ps(wgt);
751 const float* ip = &input[idx * fused_block_size];
752 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
753 ? (dataInd + prefdist_T0)
755 const int64_t idx_pref_T0 = indices[next_T0];
756 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
757 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
759 for (; j + 8 <= block_size; j += 8) {
763 vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
764 _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
766 for (; j < block_size; j++) {
767 op[j] += wgt * ip[j];
770 if (normalize_by_lengths && lengths[rangeIndex]) {
771 float len_inv = 1.0f / lengths[rangeIndex];
772 __m256 vlen_inv = _mm256_set1_ps(len_inv);
774 for (; j + 8 <= block_size; j += 8) {
776 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
778 for (; j < block_size; j++) {
779 op[j] = len_inv * op[j];
785 void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_false__avx2_fma(
786 const TIndex block_size,
787 const TIndex output_size,
788 const TIndex index_size,
789 const TIndex data_size,
791 const int64_t* indices,
793 const float* weights,
794 bool normalize_by_lengths,
796 Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma<false>(
805 normalize_by_lengths,
808 void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_true__avx2_fma(
809 const TIndex block_size,
810 const TIndex output_size,
811 const TIndex index_size,
812 const TIndex data_size,
814 const int64_t* indices,
816 const float* weights,
817 bool normalize_by_lengths,
819 Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma<true>(
828 normalize_by_lengths,
832 template <
bool IS_WEIGHT_POSITIONAL>
833 static void Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float__avx2_fma(
834 const TIndex block_size,
835 const TIndex output_size,
836 const TIndex index_size,
837 const TIndex data_size,
838 const float16* input,
839 const int32_t* indices,
841 const float* weights,
842 bool normalize_by_lengths,
844 const int32_t prefdist_T0 = 16;
845 const int32_t fused_block_size = block_size + 4;
846 if (block_size == 128) {
849 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
850 float* op = &out[rangeIndex * block_size];
851 __m256 vop0 = _mm256_setzero_ps();
852 __m256 vop8 = _mm256_setzero_ps();
853 __m256 vop16 = _mm256_setzero_ps();
854 __m256 vop24 = _mm256_setzero_ps();
855 __m256 vop32 = _mm256_setzero_ps();
856 __m256 vop40 = _mm256_setzero_ps();
857 __m256 vop48 = _mm256_setzero_ps();
858 __m256 vop56 = _mm256_setzero_ps();
859 __m256 vop64 = _mm256_setzero_ps();
860 __m256 vop72 = _mm256_setzero_ps();
861 __m256 vop80 = _mm256_setzero_ps();
862 __m256 vop88 = _mm256_setzero_ps();
863 __m256 vop96 = _mm256_setzero_ps();
864 __m256 vop104 = _mm256_setzero_ps();
865 __m256 vop112 = _mm256_setzero_ps();
866 __m256 vop120 = _mm256_setzero_ps();
867 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
869 const int32_t idx = indices[dataInd];
871 idx >= 0 && idx < data_size,
874 " is out of bounds: ",
880 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
882 __m256 vwgt = _mm256_set1_ps(wgt);
883 const float16* ip = &input[idx * fused_block_size];
884 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
885 ? (dataInd + prefdist_T0)
887 const int32_t idx_pref_T0 = indices[next_T0];
888 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
889 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
890 vop0 = _mm256_fmadd_ps(
893 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
895 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
896 vop8 = _mm256_fmadd_ps(
899 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
902 vop16 = _mm256_fmadd_ps(
905 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
908 vop24 = _mm256_fmadd_ps(
911 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
914 vop32 = _mm256_fmadd_ps(
917 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
919 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
920 vop40 = _mm256_fmadd_ps(
923 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
926 vop48 = _mm256_fmadd_ps(
929 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
932 vop56 = _mm256_fmadd_ps(
935 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
938 vop64 = _mm256_fmadd_ps(
941 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
943 _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
944 vop72 = _mm256_fmadd_ps(
947 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
950 vop80 = _mm256_fmadd_ps(
953 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
956 vop88 = _mm256_fmadd_ps(
959 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
962 vop96 = _mm256_fmadd_ps(
965 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
967 _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
968 vop104 = _mm256_fmadd_ps(
971 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
974 vop112 = _mm256_fmadd_ps(
977 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
980 vop120 = _mm256_fmadd_ps(
983 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
987 if (normalize_by_lengths ==
false) {
988 _mm256_storeu_ps(&op[0], vop0);
989 _mm256_storeu_ps(&op[8], vop8);
990 _mm256_storeu_ps(&op[16], vop16);
991 _mm256_storeu_ps(&op[24], vop24);
992 _mm256_storeu_ps(&op[32], vop32);
993 _mm256_storeu_ps(&op[40], vop40);
994 _mm256_storeu_ps(&op[48], vop48);
995 _mm256_storeu_ps(&op[56], vop56);
996 _mm256_storeu_ps(&op[64], vop64);
997 _mm256_storeu_ps(&op[72], vop72);
998 _mm256_storeu_ps(&op[80], vop80);
999 _mm256_storeu_ps(&op[88], vop88);
1000 _mm256_storeu_ps(&op[96], vop96);
1001 _mm256_storeu_ps(&op[104], vop104);
1002 _mm256_storeu_ps(&op[112], vop112);
1003 _mm256_storeu_ps(&op[120], vop120);
1004 }
else if (lengths[rangeIndex]) {
1005 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1006 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1007 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1008 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1009 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1010 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1011 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1012 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1013 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1014 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1015 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1016 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1017 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1018 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1019 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1020 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1021 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1024 }
else if (block_size == 64) {
1026 int32_t dataInd = 0;
1027 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1028 float* op = &out[rangeIndex * block_size];
1029 __m256 vop0 = _mm256_setzero_ps();
1030 __m256 vop8 = _mm256_setzero_ps();
1031 __m256 vop16 = _mm256_setzero_ps();
1032 __m256 vop24 = _mm256_setzero_ps();
1033 __m256 vop32 = _mm256_setzero_ps();
1034 __m256 vop40 = _mm256_setzero_ps();
1035 __m256 vop48 = _mm256_setzero_ps();
1036 __m256 vop56 = _mm256_setzero_ps();
1037 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1039 const int32_t idx = indices[dataInd];
1041 idx >= 0 && idx < data_size,
1044 " is out of bounds: ",
1050 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1052 __m256 vwgt = _mm256_set1_ps(wgt);
1053 const float16* ip = &input[idx * fused_block_size];
1054 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1055 ? (dataInd + prefdist_T0)
1057 const int32_t idx_pref_T0 = indices[next_T0];
1058 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1059 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1060 vop0 = _mm256_fmadd_ps(
1063 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1065 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1066 vop8 = _mm256_fmadd_ps(
1069 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1072 vop16 = _mm256_fmadd_ps(
1075 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1078 vop24 = _mm256_fmadd_ps(
1081 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1084 vop32 = _mm256_fmadd_ps(
1087 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1089 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1090 vop40 = _mm256_fmadd_ps(
1093 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1096 vop48 = _mm256_fmadd_ps(
1099 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1102 vop56 = _mm256_fmadd_ps(
1105 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1109 if (normalize_by_lengths ==
false) {
1110 _mm256_storeu_ps(&op[0], vop0);
1111 _mm256_storeu_ps(&op[8], vop8);
1112 _mm256_storeu_ps(&op[16], vop16);
1113 _mm256_storeu_ps(&op[24], vop24);
1114 _mm256_storeu_ps(&op[32], vop32);
1115 _mm256_storeu_ps(&op[40], vop40);
1116 _mm256_storeu_ps(&op[48], vop48);
1117 _mm256_storeu_ps(&op[56], vop56);
1118 }
else if (lengths[rangeIndex]) {
1119 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1120 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1121 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1122 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1123 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1124 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1125 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1126 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1127 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1130 }
else if (block_size == 32) {
1132 int32_t dataInd = 0;
1133 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1134 float* op = &out[rangeIndex * block_size];
1135 __m256 vop0 = _mm256_setzero_ps();
1136 __m256 vop8 = _mm256_setzero_ps();
1137 __m256 vop16 = _mm256_setzero_ps();
1138 __m256 vop24 = _mm256_setzero_ps();
1139 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1141 const int32_t idx = indices[dataInd];
1143 idx >= 0 && idx < data_size,
1146 " is out of bounds: ",
1152 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1154 __m256 vwgt = _mm256_set1_ps(wgt);
1155 const float16* ip = &input[idx * fused_block_size];
1156 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1157 ? (dataInd + prefdist_T0)
1159 const int32_t idx_pref_T0 = indices[next_T0];
1160 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1161 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1162 vop0 = _mm256_fmadd_ps(
1165 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1167 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1168 vop8 = _mm256_fmadd_ps(
1171 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1174 vop16 = _mm256_fmadd_ps(
1177 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1180 vop24 = _mm256_fmadd_ps(
1183 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1187 if (normalize_by_lengths ==
false) {
1188 _mm256_storeu_ps(&op[0], vop0);
1189 _mm256_storeu_ps(&op[8], vop8);
1190 _mm256_storeu_ps(&op[16], vop16);
1191 _mm256_storeu_ps(&op[24], vop24);
1192 }
else if (lengths[rangeIndex]) {
1193 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1194 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1195 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1196 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1197 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1200 }
else if (block_size == 16) {
1202 int32_t dataInd = 0;
1203 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1204 float* op = &out[rangeIndex * block_size];
1205 __m256 vop0 = _mm256_setzero_ps();
1206 __m256 vop8 = _mm256_setzero_ps();
1207 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1209 const int32_t idx = indices[dataInd];
1211 idx >= 0 && idx < data_size,
1214 " is out of bounds: ",
1220 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1222 __m256 vwgt = _mm256_set1_ps(wgt);
1223 const float16* ip = &input[idx * fused_block_size];
1224 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1225 ? (dataInd + prefdist_T0)
1227 const int32_t idx_pref_T0 = indices[next_T0];
1228 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1229 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1230 vop0 = _mm256_fmadd_ps(
1233 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1235 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1236 vop8 = _mm256_fmadd_ps(
1239 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1243 if (normalize_by_lengths ==
false) {
1244 _mm256_storeu_ps(&op[0], vop0);
1245 _mm256_storeu_ps(&op[8], vop8);
1246 }
else if (lengths[rangeIndex]) {
1247 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1248 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1249 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1254 int32_t dataInd = 0;
1255 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1256 float* op = &out[rangeIndex * block_size];
1258 for (; j + 8 <= block_size; j += 8) {
1259 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1261 for (; j < block_size; j++) {
1264 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1266 const int32_t idx = indices[dataInd];
1268 idx >= 0 && idx < data_size,
1271 " is out of bounds: ",
1277 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1279 __m256 vwgt = _mm256_set1_ps(wgt);
1280 const float16* ip = &input[idx * fused_block_size];
1281 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1282 ? (dataInd + prefdist_T0)
1284 const int32_t idx_pref_T0 = indices[next_T0];
1285 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1286 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1288 for (; j + 8 <= block_size; j += 8) {
1293 _mm256_cvtph_ps(_mm_loadu_si128(
1294 reinterpret_cast<const __m128i*>(&ip[j]))),
1295 _mm256_loadu_ps(&op[j])));
1296 _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
1298 float16 vtmp1[8] CAFFE2_ALIGNED(64);
1299 for (; j < block_size; j++) {
1301 __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1302 op[j] += wgt * ((
float*)(&vtmp2))[0];
1305 if (normalize_by_lengths && lengths[rangeIndex]) {
1306 float len_inv = 1.0f / lengths[rangeIndex];
1307 __m256 vlen_inv = _mm256_set1_ps(len_inv);
1309 for (; j + 8 <= block_size; j += 8) {
1311 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1313 for (; j < block_size; j++) {
1314 op[j] = len_inv * op[j];
1320 void Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float_false__avx2_fma(
1321 const TIndex block_size,
1322 const TIndex output_size,
1323 const TIndex index_size,
1324 const TIndex data_size,
1325 const float16* input,
1326 const int32_t* indices,
1328 const float* weights,
1329 bool normalize_by_lengths,
1331 Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float__avx2_fma<false>(
1340 normalize_by_lengths,
1343 void Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float_true__avx2_fma(
1344 const TIndex block_size,
1345 const TIndex output_size,
1346 const TIndex index_size,
1347 const TIndex data_size,
1348 const float16* input,
1349 const int32_t* indices,
1351 const float* weights,
1352 bool normalize_by_lengths,
1354 Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float__avx2_fma<true>(
1363 normalize_by_lengths,
1367 template <
bool IS_WEIGHT_POSITIONAL>
1368 static void Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float__avx2_fma(
1369 const TIndex block_size,
1370 const TIndex output_size,
1371 const TIndex index_size,
1372 const TIndex data_size,
1373 const float16* input,
1374 const int64_t* indices,
1376 const float* weights,
1377 bool normalize_by_lengths,
1379 const int64_t prefdist_T0 = 16;
1380 const int64_t fused_block_size = block_size + 4;
1381 if (block_size == 128) {
1383 int64_t dataInd = 0;
1384 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1385 float* op = &out[rangeIndex * block_size];
1386 __m256 vop0 = _mm256_setzero_ps();
1387 __m256 vop8 = _mm256_setzero_ps();
1388 __m256 vop16 = _mm256_setzero_ps();
1389 __m256 vop24 = _mm256_setzero_ps();
1390 __m256 vop32 = _mm256_setzero_ps();
1391 __m256 vop40 = _mm256_setzero_ps();
1392 __m256 vop48 = _mm256_setzero_ps();
1393 __m256 vop56 = _mm256_setzero_ps();
1394 __m256 vop64 = _mm256_setzero_ps();
1395 __m256 vop72 = _mm256_setzero_ps();
1396 __m256 vop80 = _mm256_setzero_ps();
1397 __m256 vop88 = _mm256_setzero_ps();
1398 __m256 vop96 = _mm256_setzero_ps();
1399 __m256 vop104 = _mm256_setzero_ps();
1400 __m256 vop112 = _mm256_setzero_ps();
1401 __m256 vop120 = _mm256_setzero_ps();
1402 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1404 const int64_t idx = indices[dataInd];
1406 idx >= 0 && idx < data_size,
1409 " is out of bounds: ",
1415 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1417 __m256 vwgt = _mm256_set1_ps(wgt);
1418 const float16* ip = &input[idx * fused_block_size];
1419 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1420 ? (dataInd + prefdist_T0)
1422 const int64_t idx_pref_T0 = indices[next_T0];
1423 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1424 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1425 vop0 = _mm256_fmadd_ps(
1428 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1430 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1431 vop8 = _mm256_fmadd_ps(
1434 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1437 vop16 = _mm256_fmadd_ps(
1440 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1443 vop24 = _mm256_fmadd_ps(
1446 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1449 vop32 = _mm256_fmadd_ps(
1452 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1454 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1455 vop40 = _mm256_fmadd_ps(
1458 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1461 vop48 = _mm256_fmadd_ps(
1464 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1467 vop56 = _mm256_fmadd_ps(
1470 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1473 vop64 = _mm256_fmadd_ps(
1476 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1478 _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
1479 vop72 = _mm256_fmadd_ps(
1482 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1485 vop80 = _mm256_fmadd_ps(
1488 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1491 vop88 = _mm256_fmadd_ps(
1494 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1497 vop96 = _mm256_fmadd_ps(
1500 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1502 _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
1503 vop104 = _mm256_fmadd_ps(
1506 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1509 vop112 = _mm256_fmadd_ps(
1512 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1515 vop120 = _mm256_fmadd_ps(
1518 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1522 if (normalize_by_lengths ==
false) {
1523 _mm256_storeu_ps(&op[0], vop0);
1524 _mm256_storeu_ps(&op[8], vop8);
1525 _mm256_storeu_ps(&op[16], vop16);
1526 _mm256_storeu_ps(&op[24], vop24);
1527 _mm256_storeu_ps(&op[32], vop32);
1528 _mm256_storeu_ps(&op[40], vop40);
1529 _mm256_storeu_ps(&op[48], vop48);
1530 _mm256_storeu_ps(&op[56], vop56);
1531 _mm256_storeu_ps(&op[64], vop64);
1532 _mm256_storeu_ps(&op[72], vop72);
1533 _mm256_storeu_ps(&op[80], vop80);
1534 _mm256_storeu_ps(&op[88], vop88);
1535 _mm256_storeu_ps(&op[96], vop96);
1536 _mm256_storeu_ps(&op[104], vop104);
1537 _mm256_storeu_ps(&op[112], vop112);
1538 _mm256_storeu_ps(&op[120], vop120);
1539 }
else if (lengths[rangeIndex]) {
1540 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1541 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1542 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1543 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1544 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1545 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1546 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1547 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1548 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1549 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1550 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1551 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1552 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1553 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1554 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1555 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1556 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1559 }
else if (block_size == 64) {
1561 int64_t dataInd = 0;
1562 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1563 float* op = &out[rangeIndex * block_size];
1564 __m256 vop0 = _mm256_setzero_ps();
1565 __m256 vop8 = _mm256_setzero_ps();
1566 __m256 vop16 = _mm256_setzero_ps();
1567 __m256 vop24 = _mm256_setzero_ps();
1568 __m256 vop32 = _mm256_setzero_ps();
1569 __m256 vop40 = _mm256_setzero_ps();
1570 __m256 vop48 = _mm256_setzero_ps();
1571 __m256 vop56 = _mm256_setzero_ps();
1572 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1574 const int64_t idx = indices[dataInd];
1576 idx >= 0 && idx < data_size,
1579 " is out of bounds: ",
1585 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1587 __m256 vwgt = _mm256_set1_ps(wgt);
1588 const float16* ip = &input[idx * fused_block_size];
1589 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1590 ? (dataInd + prefdist_T0)
1592 const int64_t idx_pref_T0 = indices[next_T0];
1593 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1594 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1595 vop0 = _mm256_fmadd_ps(
1598 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1600 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1601 vop8 = _mm256_fmadd_ps(
1604 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1607 vop16 = _mm256_fmadd_ps(
1610 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1613 vop24 = _mm256_fmadd_ps(
1616 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1619 vop32 = _mm256_fmadd_ps(
1622 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1624 _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1625 vop40 = _mm256_fmadd_ps(
1628 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1631 vop48 = _mm256_fmadd_ps(
1634 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1637 vop56 = _mm256_fmadd_ps(
1640 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1644 if (normalize_by_lengths ==
false) {
1645 _mm256_storeu_ps(&op[0], vop0);
1646 _mm256_storeu_ps(&op[8], vop8);
1647 _mm256_storeu_ps(&op[16], vop16);
1648 _mm256_storeu_ps(&op[24], vop24);
1649 _mm256_storeu_ps(&op[32], vop32);
1650 _mm256_storeu_ps(&op[40], vop40);
1651 _mm256_storeu_ps(&op[48], vop48);
1652 _mm256_storeu_ps(&op[56], vop56);
1653 }
else if (lengths[rangeIndex]) {
1654 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1655 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1656 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1657 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1658 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1659 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1660 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1661 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1662 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1665 }
else if (block_size == 32) {
1667 int64_t dataInd = 0;
1668 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1669 float* op = &out[rangeIndex * block_size];
1670 __m256 vop0 = _mm256_setzero_ps();
1671 __m256 vop8 = _mm256_setzero_ps();
1672 __m256 vop16 = _mm256_setzero_ps();
1673 __m256 vop24 = _mm256_setzero_ps();
1674 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1676 const int64_t idx = indices[dataInd];
1678 idx >= 0 && idx < data_size,
1681 " is out of bounds: ",
1687 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1689 __m256 vwgt = _mm256_set1_ps(wgt);
1690 const float16* ip = &input[idx * fused_block_size];
1691 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1692 ? (dataInd + prefdist_T0)
1694 const int64_t idx_pref_T0 = indices[next_T0];
1695 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1696 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1697 vop0 = _mm256_fmadd_ps(
1700 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1702 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1703 vop8 = _mm256_fmadd_ps(
1706 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1709 vop16 = _mm256_fmadd_ps(
1712 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1715 vop24 = _mm256_fmadd_ps(
1718 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1722 if (normalize_by_lengths ==
false) {
1723 _mm256_storeu_ps(&op[0], vop0);
1724 _mm256_storeu_ps(&op[8], vop8);
1725 _mm256_storeu_ps(&op[16], vop16);
1726 _mm256_storeu_ps(&op[24], vop24);
1727 }
else if (lengths[rangeIndex]) {
1728 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1729 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1730 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1731 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1732 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1735 }
else if (block_size == 16) {
1737 int64_t dataInd = 0;
1738 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1739 float* op = &out[rangeIndex * block_size];
1740 __m256 vop0 = _mm256_setzero_ps();
1741 __m256 vop8 = _mm256_setzero_ps();
1742 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1744 const int64_t idx = indices[dataInd];
1746 idx >= 0 && idx < data_size,
1749 " is out of bounds: ",
1755 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1757 __m256 vwgt = _mm256_set1_ps(wgt);
1758 const float16* ip = &input[idx * fused_block_size];
1759 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1760 ? (dataInd + prefdist_T0)
1762 const int64_t idx_pref_T0 = indices[next_T0];
1763 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1764 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1765 vop0 = _mm256_fmadd_ps(
1768 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1770 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1771 vop8 = _mm256_fmadd_ps(
1774 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1778 if (normalize_by_lengths ==
false) {
1779 _mm256_storeu_ps(&op[0], vop0);
1780 _mm256_storeu_ps(&op[8], vop8);
1781 }
else if (lengths[rangeIndex]) {
1782 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1783 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1784 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1789 int64_t dataInd = 0;
1790 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1791 float* op = &out[rangeIndex * block_size];
1793 for (; j + 8 <= block_size; j += 8) {
1794 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1796 for (; j < block_size; j++) {
1799 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1801 const int64_t idx = indices[dataInd];
1803 idx >= 0 && idx < data_size,
1806 " is out of bounds: ",
1812 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1814 __m256 vwgt = _mm256_set1_ps(wgt);
1815 const float16* ip = &input[idx * fused_block_size];
1816 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1817 ? (dataInd + prefdist_T0)
1819 const int64_t idx_pref_T0 = indices[next_T0];
1820 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1821 const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1823 for (; j + 8 <= block_size; j += 8) {
1828 _mm256_cvtph_ps(_mm_loadu_si128(
1829 reinterpret_cast<const __m128i*>(&ip[j]))),
1830 _mm256_loadu_ps(&op[j])));
1831 _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
1833 float16 vtmp1[8] CAFFE2_ALIGNED(64);
1834 for (; j < block_size; j++) {
1836 __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1837 op[j] += wgt * ((
float*)(&vtmp2))[0];
1840 if (normalize_by_lengths && lengths[rangeIndex]) {
1841 float len_inv = 1.0f / lengths[rangeIndex];
1842 __m256 vlen_inv = _mm256_set1_ps(len_inv);
1844 for (; j + 8 <= block_size; j += 8) {
1846 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1848 for (; j < block_size; j++) {
1849 op[j] = len_inv * op[j];
1855 void Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float_false__avx2_fma(
1856 const TIndex block_size,
1857 const TIndex output_size,
1858 const TIndex index_size,
1859 const TIndex data_size,
1860 const float16* input,
1861 const int64_t* indices,
1863 const float* weights,
1864 bool normalize_by_lengths,
1866 Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float__avx2_fma<false>(
1875 normalize_by_lengths,
1878 void Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float_true__avx2_fma(
1879 const TIndex block_size,
1880 const TIndex output_size,
1881 const TIndex index_size,
1882 const TIndex data_size,
1883 const float16* input,
1884 const int64_t* indices,
1886 const float* weights,
1887 bool normalize_by_lengths,
1889 Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float__avx2_fma<true>(
1898 normalize_by_lengths,
1902 template <
bool IS_WEIGHT_POSITIONAL>
1903 static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
1904 const TIndex block_size,
1905 const TIndex output_size,
1906 const TIndex index_size,
1907 const TIndex data_size,
1908 const uint8_t* input,
1909 const int32_t* indices,
1911 const float* weights,
1912 bool normalize_by_lengths,
1914 const int32_t prefdist_T0 = 16;
1915 const int32_t fused_block_size = block_size + 8;
1916 if (block_size == 128) {
1918 int32_t dataInd = 0;
1919 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1920 float* op = &out[rangeIndex * block_size];
1921 __m256 vop0 = _mm256_setzero_ps();
1922 __m256 vop8 = _mm256_setzero_ps();
1923 __m256 vop16 = _mm256_setzero_ps();
1924 __m256 vop24 = _mm256_setzero_ps();
1925 __m256 vop32 = _mm256_setzero_ps();
1926 __m256 vop40 = _mm256_setzero_ps();
1927 __m256 vop48 = _mm256_setzero_ps();
1928 __m256 vop56 = _mm256_setzero_ps();
1929 __m256 vop64 = _mm256_setzero_ps();
1930 __m256 vop72 = _mm256_setzero_ps();
1931 __m256 vop80 = _mm256_setzero_ps();
1932 __m256 vop88 = _mm256_setzero_ps();
1933 __m256 vop96 = _mm256_setzero_ps();
1934 __m256 vop104 = _mm256_setzero_ps();
1935 __m256 vop112 = _mm256_setzero_ps();
1936 __m256 vop120 = _mm256_setzero_ps();
1937 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1939 const int32_t idx = indices[dataInd];
1941 idx >= 0 && idx < data_size,
1944 " is out of bounds: ",
1951 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1953 const float* scale_bias =
reinterpret_cast<const float*
>(
1954 &input[idx * fused_block_size + block_size]);
1955 bio = wgt * scale_bias[1];
1956 wgt = wgt * scale_bias[0];
1957 __m256 vbio = _mm256_set1_ps(bio);
1958 __m256 vwgt = _mm256_set1_ps(wgt);
1959 const uint8_t* ip = &input[idx * fused_block_size];
1960 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1961 ? (dataInd + prefdist_T0)
1963 const int32_t idx_pref_T0 = indices[next_T0];
1964 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1965 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1966 vop0 = _mm256_fmadd_ps(
1968 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1969 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
1970 _mm256_add_ps(vop0, vbio));
1971 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1972 vop8 = _mm256_fmadd_ps(
1974 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1975 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
1976 _mm256_add_ps(vop8, vbio));
1978 vop16 = _mm256_fmadd_ps(
1980 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1981 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
1982 _mm256_add_ps(vop16, vbio));
1984 vop24 = _mm256_fmadd_ps(
1986 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1987 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
1988 _mm256_add_ps(vop24, vbio));
1990 vop32 = _mm256_fmadd_ps(
1992 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1993 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
1994 _mm256_add_ps(vop32, vbio));
1996 vop40 = _mm256_fmadd_ps(
1998 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1999 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2000 _mm256_add_ps(vop40, vbio));
2002 vop48 = _mm256_fmadd_ps(
2004 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2005 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2006 _mm256_add_ps(vop48, vbio));
2008 vop56 = _mm256_fmadd_ps(
2010 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2011 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2012 _mm256_add_ps(vop56, vbio));
2014 vop64 = _mm256_fmadd_ps(
2016 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2017 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2018 _mm256_add_ps(vop64, vbio));
2019 _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
2020 vop72 = _mm256_fmadd_ps(
2022 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2023 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2024 _mm256_add_ps(vop72, vbio));
2026 vop80 = _mm256_fmadd_ps(
2028 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2029 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2030 _mm256_add_ps(vop80, vbio));
2032 vop88 = _mm256_fmadd_ps(
2034 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2035 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2036 _mm256_add_ps(vop88, vbio));
2038 vop96 = _mm256_fmadd_ps(
2040 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2041 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2042 _mm256_add_ps(vop96, vbio));
2044 vop104 = _mm256_fmadd_ps(
2046 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2047 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2048 _mm256_add_ps(vop104, vbio));
2050 vop112 = _mm256_fmadd_ps(
2052 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2053 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2054 _mm256_add_ps(vop112, vbio));
2056 vop120 = _mm256_fmadd_ps(
2058 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2059 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2060 _mm256_add_ps(vop120, vbio));
2063 if (normalize_by_lengths ==
false) {
2064 _mm256_storeu_ps(&op[0], vop0);
2065 _mm256_storeu_ps(&op[8], vop8);
2066 _mm256_storeu_ps(&op[16], vop16);
2067 _mm256_storeu_ps(&op[24], vop24);
2068 _mm256_storeu_ps(&op[32], vop32);
2069 _mm256_storeu_ps(&op[40], vop40);
2070 _mm256_storeu_ps(&op[48], vop48);
2071 _mm256_storeu_ps(&op[56], vop56);
2072 _mm256_storeu_ps(&op[64], vop64);
2073 _mm256_storeu_ps(&op[72], vop72);
2074 _mm256_storeu_ps(&op[80], vop80);
2075 _mm256_storeu_ps(&op[88], vop88);
2076 _mm256_storeu_ps(&op[96], vop96);
2077 _mm256_storeu_ps(&op[104], vop104);
2078 _mm256_storeu_ps(&op[112], vop112);
2079 _mm256_storeu_ps(&op[120], vop120);
2080 }
else if (lengths[rangeIndex]) {
2081 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2082 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2083 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2084 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2085 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2086 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2087 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2088 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2089 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2090 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2091 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2092 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2093 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2094 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2095 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2096 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2097 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2100 }
else if (block_size == 64) {
2102 int32_t dataInd = 0;
2103 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2104 float* op = &out[rangeIndex * block_size];
2105 __m256 vop0 = _mm256_setzero_ps();
2106 __m256 vop8 = _mm256_setzero_ps();
2107 __m256 vop16 = _mm256_setzero_ps();
2108 __m256 vop24 = _mm256_setzero_ps();
2109 __m256 vop32 = _mm256_setzero_ps();
2110 __m256 vop40 = _mm256_setzero_ps();
2111 __m256 vop48 = _mm256_setzero_ps();
2112 __m256 vop56 = _mm256_setzero_ps();
2113 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2115 const int32_t idx = indices[dataInd];
2117 idx >= 0 && idx < data_size,
2120 " is out of bounds: ",
2127 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2129 const float* scale_bias =
reinterpret_cast<const float*
>(
2130 &input[idx * fused_block_size + block_size]);
2131 bio = wgt * scale_bias[1];
2132 wgt = wgt * scale_bias[0];
2133 __m256 vbio = _mm256_set1_ps(bio);
2134 __m256 vwgt = _mm256_set1_ps(wgt);
2135 const uint8_t* ip = &input[idx * fused_block_size];
2136 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2137 ? (dataInd + prefdist_T0)
2139 const int32_t idx_pref_T0 = indices[next_T0];
2140 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2141 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2142 vop0 = _mm256_fmadd_ps(
2144 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2145 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2146 _mm256_add_ps(vop0, vbio));
2147 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2148 vop8 = _mm256_fmadd_ps(
2150 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2151 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2152 _mm256_add_ps(vop8, vbio));
2154 vop16 = _mm256_fmadd_ps(
2156 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2157 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2158 _mm256_add_ps(vop16, vbio));
2160 vop24 = _mm256_fmadd_ps(
2162 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2163 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2164 _mm256_add_ps(vop24, vbio));
2166 vop32 = _mm256_fmadd_ps(
2168 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2169 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2170 _mm256_add_ps(vop32, vbio));
2172 vop40 = _mm256_fmadd_ps(
2174 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2175 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2176 _mm256_add_ps(vop40, vbio));
2178 vop48 = _mm256_fmadd_ps(
2180 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2181 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2182 _mm256_add_ps(vop48, vbio));
2184 vop56 = _mm256_fmadd_ps(
2186 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2187 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2188 _mm256_add_ps(vop56, vbio));
2191 if (normalize_by_lengths ==
false) {
2192 _mm256_storeu_ps(&op[0], vop0);
2193 _mm256_storeu_ps(&op[8], vop8);
2194 _mm256_storeu_ps(&op[16], vop16);
2195 _mm256_storeu_ps(&op[24], vop24);
2196 _mm256_storeu_ps(&op[32], vop32);
2197 _mm256_storeu_ps(&op[40], vop40);
2198 _mm256_storeu_ps(&op[48], vop48);
2199 _mm256_storeu_ps(&op[56], vop56);
2200 }
else if (lengths[rangeIndex]) {
2201 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2202 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2203 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2204 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2205 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2206 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2207 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2208 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2209 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2212 }
else if (block_size == 32) {
2214 int32_t dataInd = 0;
2215 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2216 float* op = &out[rangeIndex * block_size];
2217 __m256 vop0 = _mm256_setzero_ps();
2218 __m256 vop8 = _mm256_setzero_ps();
2219 __m256 vop16 = _mm256_setzero_ps();
2220 __m256 vop24 = _mm256_setzero_ps();
2221 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2223 const int32_t idx = indices[dataInd];
2225 idx >= 0 && idx < data_size,
2228 " is out of bounds: ",
2235 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2237 const float* scale_bias =
reinterpret_cast<const float*
>(
2238 &input[idx * fused_block_size + block_size]);
2239 bio = wgt * scale_bias[1];
2240 wgt = wgt * scale_bias[0];
2241 __m256 vbio = _mm256_set1_ps(bio);
2242 __m256 vwgt = _mm256_set1_ps(wgt);
2243 const uint8_t* ip = &input[idx * fused_block_size];
2244 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2245 ? (dataInd + prefdist_T0)
2247 const int32_t idx_pref_T0 = indices[next_T0];
2248 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2249 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2250 vop0 = _mm256_fmadd_ps(
2252 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2253 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2254 _mm256_add_ps(vop0, vbio));
2255 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2256 vop8 = _mm256_fmadd_ps(
2258 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2259 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2260 _mm256_add_ps(vop8, vbio));
2262 vop16 = _mm256_fmadd_ps(
2264 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2265 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2266 _mm256_add_ps(vop16, vbio));
2268 vop24 = _mm256_fmadd_ps(
2270 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2271 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2272 _mm256_add_ps(vop24, vbio));
2275 if (normalize_by_lengths ==
false) {
2276 _mm256_storeu_ps(&op[0], vop0);
2277 _mm256_storeu_ps(&op[8], vop8);
2278 _mm256_storeu_ps(&op[16], vop16);
2279 _mm256_storeu_ps(&op[24], vop24);
2280 }
else if (lengths[rangeIndex]) {
2281 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2282 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2283 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2284 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2285 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2288 }
else if (block_size == 16) {
2290 int32_t dataInd = 0;
2291 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2292 float* op = &out[rangeIndex * block_size];
2293 __m256 vop0 = _mm256_setzero_ps();
2294 __m256 vop8 = _mm256_setzero_ps();
2295 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2297 const int32_t idx = indices[dataInd];
2299 idx >= 0 && idx < data_size,
2302 " is out of bounds: ",
2309 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2311 const float* scale_bias =
reinterpret_cast<const float*
>(
2312 &input[idx * fused_block_size + block_size]);
2313 bio = wgt * scale_bias[1];
2314 wgt = wgt * scale_bias[0];
2315 __m256 vbio = _mm256_set1_ps(bio);
2316 __m256 vwgt = _mm256_set1_ps(wgt);
2317 const uint8_t* ip = &input[idx * fused_block_size];
2318 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2319 ? (dataInd + prefdist_T0)
2321 const int32_t idx_pref_T0 = indices[next_T0];
2322 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2323 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2324 vop0 = _mm256_fmadd_ps(
2326 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2327 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2328 _mm256_add_ps(vop0, vbio));
2329 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2330 vop8 = _mm256_fmadd_ps(
2332 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2333 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2334 _mm256_add_ps(vop8, vbio));
2337 if (normalize_by_lengths ==
false) {
2338 _mm256_storeu_ps(&op[0], vop0);
2339 _mm256_storeu_ps(&op[8], vop8);
2340 }
else if (lengths[rangeIndex]) {
2341 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2342 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2343 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2348 int32_t dataInd = 0;
2349 for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2350 float* op = &out[rangeIndex * block_size];
2352 for (; j + 8 <= block_size; j += 8) {
2353 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2355 for (; j < block_size; j++) {
2358 for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2360 const int32_t idx = indices[dataInd];
2362 idx >= 0 && idx < data_size,
2365 " is out of bounds: ",
2372 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2374 const float* scale_bias =
reinterpret_cast<const float*
>(
2375 &input[idx * fused_block_size + block_size]);
2376 bio = wgt * scale_bias[1];
2377 wgt = wgt * scale_bias[0];
2378 __m256 vbio = _mm256_set1_ps(bio);
2379 __m256 vwgt = _mm256_set1_ps(wgt);
2380 const uint8_t* ip = &input[idx * fused_block_size];
2381 const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2382 ? (dataInd + prefdist_T0)
2384 const int32_t idx_pref_T0 = indices[next_T0];
2385 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2386 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2388 for (; j + 8 <= block_size; j += 8) {
2393 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2394 reinterpret_cast<const __m128i*>(&ip[j])))),
2395 _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2396 _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
2398 for (; j < block_size; j++) {
2399 op[j] += wgt * ((float)ip[j]) + bio;
2402 if (normalize_by_lengths && lengths[rangeIndex]) {
2403 float len_inv = 1.0f / lengths[rangeIndex];
2404 __m256 vlen_inv = _mm256_set1_ps(len_inv);
2406 for (; j + 8 <= block_size; j += 8) {
2408 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2410 for (; j < block_size; j++) {
2411 op[j] = len_inv * op[j];
2417 void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma(
2418 const TIndex block_size,
2419 const TIndex output_size,
2420 const TIndex index_size,
2421 const TIndex data_size,
2422 const uint8_t* input,
2423 const int32_t* indices,
2425 const float* weights,
2426 bool normalize_by_lengths,
2428 Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma<false>(
2437 normalize_by_lengths,
2440 void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma(
2441 const TIndex block_size,
2442 const TIndex output_size,
2443 const TIndex index_size,
2444 const TIndex data_size,
2445 const uint8_t* input,
2446 const int32_t* indices,
2448 const float* weights,
2449 bool normalize_by_lengths,
2451 Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma<true>(
2460 normalize_by_lengths,
2464 template <
bool IS_WEIGHT_POSITIONAL>
2465 static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
2466 const TIndex block_size,
2467 const TIndex output_size,
2468 const TIndex index_size,
2469 const TIndex data_size,
2470 const uint8_t* input,
2471 const int64_t* indices,
2473 const float* weights,
2474 bool normalize_by_lengths,
2476 const int64_t prefdist_T0 = 16;
2477 const int64_t fused_block_size = block_size + 8;
2478 if (block_size == 128) {
2480 int64_t dataInd = 0;
2481 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2482 float* op = &out[rangeIndex * block_size];
2483 __m256 vop0 = _mm256_setzero_ps();
2484 __m256 vop8 = _mm256_setzero_ps();
2485 __m256 vop16 = _mm256_setzero_ps();
2486 __m256 vop24 = _mm256_setzero_ps();
2487 __m256 vop32 = _mm256_setzero_ps();
2488 __m256 vop40 = _mm256_setzero_ps();
2489 __m256 vop48 = _mm256_setzero_ps();
2490 __m256 vop56 = _mm256_setzero_ps();
2491 __m256 vop64 = _mm256_setzero_ps();
2492 __m256 vop72 = _mm256_setzero_ps();
2493 __m256 vop80 = _mm256_setzero_ps();
2494 __m256 vop88 = _mm256_setzero_ps();
2495 __m256 vop96 = _mm256_setzero_ps();
2496 __m256 vop104 = _mm256_setzero_ps();
2497 __m256 vop112 = _mm256_setzero_ps();
2498 __m256 vop120 = _mm256_setzero_ps();
2499 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2501 const int64_t idx = indices[dataInd];
2503 idx >= 0 && idx < data_size,
2506 " is out of bounds: ",
2513 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2515 const float* scale_bias =
reinterpret_cast<const float*
>(
2516 &input[idx * fused_block_size + block_size]);
2517 bio = wgt * scale_bias[1];
2518 wgt = wgt * scale_bias[0];
2519 __m256 vbio = _mm256_set1_ps(bio);
2520 __m256 vwgt = _mm256_set1_ps(wgt);
2521 const uint8_t* ip = &input[idx * fused_block_size];
2522 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2523 ? (dataInd + prefdist_T0)
2525 const int64_t idx_pref_T0 = indices[next_T0];
2526 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2527 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2528 vop0 = _mm256_fmadd_ps(
2530 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2531 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2532 _mm256_add_ps(vop0, vbio));
2533 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2534 vop8 = _mm256_fmadd_ps(
2536 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2537 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2538 _mm256_add_ps(vop8, vbio));
2540 vop16 = _mm256_fmadd_ps(
2542 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2543 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2544 _mm256_add_ps(vop16, vbio));
2546 vop24 = _mm256_fmadd_ps(
2548 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2549 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2550 _mm256_add_ps(vop24, vbio));
2552 vop32 = _mm256_fmadd_ps(
2554 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2555 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2556 _mm256_add_ps(vop32, vbio));
2558 vop40 = _mm256_fmadd_ps(
2560 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2561 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2562 _mm256_add_ps(vop40, vbio));
2564 vop48 = _mm256_fmadd_ps(
2566 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2567 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2568 _mm256_add_ps(vop48, vbio));
2570 vop56 = _mm256_fmadd_ps(
2572 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2573 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2574 _mm256_add_ps(vop56, vbio));
2576 vop64 = _mm256_fmadd_ps(
2578 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2579 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2580 _mm256_add_ps(vop64, vbio));
2581 _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
2582 vop72 = _mm256_fmadd_ps(
2584 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2585 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2586 _mm256_add_ps(vop72, vbio));
2588 vop80 = _mm256_fmadd_ps(
2590 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2591 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2592 _mm256_add_ps(vop80, vbio));
2594 vop88 = _mm256_fmadd_ps(
2596 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2597 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2598 _mm256_add_ps(vop88, vbio));
2600 vop96 = _mm256_fmadd_ps(
2602 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2603 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2604 _mm256_add_ps(vop96, vbio));
2606 vop104 = _mm256_fmadd_ps(
2608 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2609 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2610 _mm256_add_ps(vop104, vbio));
2612 vop112 = _mm256_fmadd_ps(
2614 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2615 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2616 _mm256_add_ps(vop112, vbio));
2618 vop120 = _mm256_fmadd_ps(
2620 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2621 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2622 _mm256_add_ps(vop120, vbio));
2625 if (normalize_by_lengths ==
false) {
2626 _mm256_storeu_ps(&op[0], vop0);
2627 _mm256_storeu_ps(&op[8], vop8);
2628 _mm256_storeu_ps(&op[16], vop16);
2629 _mm256_storeu_ps(&op[24], vop24);
2630 _mm256_storeu_ps(&op[32], vop32);
2631 _mm256_storeu_ps(&op[40], vop40);
2632 _mm256_storeu_ps(&op[48], vop48);
2633 _mm256_storeu_ps(&op[56], vop56);
2634 _mm256_storeu_ps(&op[64], vop64);
2635 _mm256_storeu_ps(&op[72], vop72);
2636 _mm256_storeu_ps(&op[80], vop80);
2637 _mm256_storeu_ps(&op[88], vop88);
2638 _mm256_storeu_ps(&op[96], vop96);
2639 _mm256_storeu_ps(&op[104], vop104);
2640 _mm256_storeu_ps(&op[112], vop112);
2641 _mm256_storeu_ps(&op[120], vop120);
2642 }
else if (lengths[rangeIndex]) {
2643 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2644 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2645 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2646 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2647 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2648 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2649 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2650 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2651 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2652 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2653 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2654 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2655 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2656 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2657 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2658 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2659 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2662 }
else if (block_size == 64) {
2664 int64_t dataInd = 0;
2665 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2666 float* op = &out[rangeIndex * block_size];
2667 __m256 vop0 = _mm256_setzero_ps();
2668 __m256 vop8 = _mm256_setzero_ps();
2669 __m256 vop16 = _mm256_setzero_ps();
2670 __m256 vop24 = _mm256_setzero_ps();
2671 __m256 vop32 = _mm256_setzero_ps();
2672 __m256 vop40 = _mm256_setzero_ps();
2673 __m256 vop48 = _mm256_setzero_ps();
2674 __m256 vop56 = _mm256_setzero_ps();
2675 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2677 const int64_t idx = indices[dataInd];
2679 idx >= 0 && idx < data_size,
2682 " is out of bounds: ",
2689 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2691 const float* scale_bias =
reinterpret_cast<const float*
>(
2692 &input[idx * fused_block_size + block_size]);
2693 bio = wgt * scale_bias[1];
2694 wgt = wgt * scale_bias[0];
2695 __m256 vbio = _mm256_set1_ps(bio);
2696 __m256 vwgt = _mm256_set1_ps(wgt);
2697 const uint8_t* ip = &input[idx * fused_block_size];
2698 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2699 ? (dataInd + prefdist_T0)
2701 const int64_t idx_pref_T0 = indices[next_T0];
2702 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2703 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2704 vop0 = _mm256_fmadd_ps(
2706 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2707 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2708 _mm256_add_ps(vop0, vbio));
2709 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2710 vop8 = _mm256_fmadd_ps(
2712 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2713 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2714 _mm256_add_ps(vop8, vbio));
2716 vop16 = _mm256_fmadd_ps(
2718 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2719 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2720 _mm256_add_ps(vop16, vbio));
2722 vop24 = _mm256_fmadd_ps(
2724 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2725 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2726 _mm256_add_ps(vop24, vbio));
2728 vop32 = _mm256_fmadd_ps(
2730 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2731 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2732 _mm256_add_ps(vop32, vbio));
2734 vop40 = _mm256_fmadd_ps(
2736 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2737 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2738 _mm256_add_ps(vop40, vbio));
2740 vop48 = _mm256_fmadd_ps(
2742 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2743 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2744 _mm256_add_ps(vop48, vbio));
2746 vop56 = _mm256_fmadd_ps(
2748 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2749 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2750 _mm256_add_ps(vop56, vbio));
2753 if (normalize_by_lengths ==
false) {
2754 _mm256_storeu_ps(&op[0], vop0);
2755 _mm256_storeu_ps(&op[8], vop8);
2756 _mm256_storeu_ps(&op[16], vop16);
2757 _mm256_storeu_ps(&op[24], vop24);
2758 _mm256_storeu_ps(&op[32], vop32);
2759 _mm256_storeu_ps(&op[40], vop40);
2760 _mm256_storeu_ps(&op[48], vop48);
2761 _mm256_storeu_ps(&op[56], vop56);
2762 }
else if (lengths[rangeIndex]) {
2763 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2764 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2765 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2766 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2767 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2768 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2769 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2770 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2771 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2774 }
else if (block_size == 32) {
2776 int64_t dataInd = 0;
2777 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2778 float* op = &out[rangeIndex * block_size];
2779 __m256 vop0 = _mm256_setzero_ps();
2780 __m256 vop8 = _mm256_setzero_ps();
2781 __m256 vop16 = _mm256_setzero_ps();
2782 __m256 vop24 = _mm256_setzero_ps();
2783 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2785 const int64_t idx = indices[dataInd];
2787 idx >= 0 && idx < data_size,
2790 " is out of bounds: ",
2797 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2799 const float* scale_bias =
reinterpret_cast<const float*
>(
2800 &input[idx * fused_block_size + block_size]);
2801 bio = wgt * scale_bias[1];
2802 wgt = wgt * scale_bias[0];
2803 __m256 vbio = _mm256_set1_ps(bio);
2804 __m256 vwgt = _mm256_set1_ps(wgt);
2805 const uint8_t* ip = &input[idx * fused_block_size];
2806 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2807 ? (dataInd + prefdist_T0)
2809 const int64_t idx_pref_T0 = indices[next_T0];
2810 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2811 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2812 vop0 = _mm256_fmadd_ps(
2814 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2815 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2816 _mm256_add_ps(vop0, vbio));
2817 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2818 vop8 = _mm256_fmadd_ps(
2820 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2821 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2822 _mm256_add_ps(vop8, vbio));
2824 vop16 = _mm256_fmadd_ps(
2826 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2827 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2828 _mm256_add_ps(vop16, vbio));
2830 vop24 = _mm256_fmadd_ps(
2832 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2833 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2834 _mm256_add_ps(vop24, vbio));
2837 if (normalize_by_lengths ==
false) {
2838 _mm256_storeu_ps(&op[0], vop0);
2839 _mm256_storeu_ps(&op[8], vop8);
2840 _mm256_storeu_ps(&op[16], vop16);
2841 _mm256_storeu_ps(&op[24], vop24);
2842 }
else if (lengths[rangeIndex]) {
2843 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2844 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2845 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2846 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2847 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2850 }
else if (block_size == 16) {
2852 int64_t dataInd = 0;
2853 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2854 float* op = &out[rangeIndex * block_size];
2855 __m256 vop0 = _mm256_setzero_ps();
2856 __m256 vop8 = _mm256_setzero_ps();
2857 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2859 const int64_t idx = indices[dataInd];
2861 idx >= 0 && idx < data_size,
2864 " is out of bounds: ",
2871 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2873 const float* scale_bias =
reinterpret_cast<const float*
>(
2874 &input[idx * fused_block_size + block_size]);
2875 bio = wgt * scale_bias[1];
2876 wgt = wgt * scale_bias[0];
2877 __m256 vbio = _mm256_set1_ps(bio);
2878 __m256 vwgt = _mm256_set1_ps(wgt);
2879 const uint8_t* ip = &input[idx * fused_block_size];
2880 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2881 ? (dataInd + prefdist_T0)
2883 const int64_t idx_pref_T0 = indices[next_T0];
2884 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2885 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2886 vop0 = _mm256_fmadd_ps(
2888 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2889 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2890 _mm256_add_ps(vop0, vbio));
2891 _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2892 vop8 = _mm256_fmadd_ps(
2894 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2895 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2896 _mm256_add_ps(vop8, vbio));
2899 if (normalize_by_lengths ==
false) {
2900 _mm256_storeu_ps(&op[0], vop0);
2901 _mm256_storeu_ps(&op[8], vop8);
2902 }
else if (lengths[rangeIndex]) {
2903 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2904 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2905 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2910 int64_t dataInd = 0;
2911 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2912 float* op = &out[rangeIndex * block_size];
2914 for (; j + 8 <= block_size; j += 8) {
2915 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2917 for (; j < block_size; j++) {
2920 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2922 const int64_t idx = indices[dataInd];
2924 idx >= 0 && idx < data_size,
2927 " is out of bounds: ",
2934 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2936 const float* scale_bias =
reinterpret_cast<const float*
>(
2937 &input[idx * fused_block_size + block_size]);
2938 bio = wgt * scale_bias[1];
2939 wgt = wgt * scale_bias[0];
2940 __m256 vbio = _mm256_set1_ps(bio);
2941 __m256 vwgt = _mm256_set1_ps(wgt);
2942 const uint8_t* ip = &input[idx * fused_block_size];
2943 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2944 ? (dataInd + prefdist_T0)
2946 const int64_t idx_pref_T0 = indices[next_T0];
2947 CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2948 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2950 for (; j + 8 <= block_size; j += 8) {
2955 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2956 reinterpret_cast<const __m128i*>(&ip[j])))),
2957 _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2958 _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
2960 for (; j < block_size; j++) {
2961 op[j] += wgt * ((float)ip[j]) + bio;
2964 if (normalize_by_lengths && lengths[rangeIndex]) {
2965 float len_inv = 1.0f / lengths[rangeIndex];
2966 __m256 vlen_inv = _mm256_set1_ps(len_inv);
2968 for (; j + 8 <= block_size; j += 8) {
2970 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2972 for (; j < block_size; j++) {
2973 op[j] = len_inv * op[j];
2979 void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma(
2980 const TIndex block_size,
2981 const TIndex output_size,
2982 const TIndex index_size,
2983 const TIndex data_size,
2984 const uint8_t* input,
2985 const int64_t* indices,
2987 const float* weights,
2988 bool normalize_by_lengths,
2990 Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma<false>(
2999 normalize_by_lengths,
3002 void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma(
3003 const TIndex block_size,
3004 const TIndex output_size,
3005 const TIndex index_size,
3006 const TIndex data_size,
3007 const uint8_t* input,
3008 const int64_t* indices,
3010 const float* weights,
3011 bool normalize_by_lengths,
3013 Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma<true>(
3022 normalize_by_lengths,
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...