Caffe2 - C++ API
A deep learning, cross platform ML framework
typed_axpy_avx2.cc
1 #include "caffe2/core/types.h"
2 #include "caffe2/perfkernels/cvtsh_ss_bugfix.h"
3 #include "caffe2/perfkernels/typed_axpy.h"
4 #include "caffe2/utils/math.h"
5 
6 #include <emmintrin.h>
7 #include <immintrin.h>
8 
9 namespace caffe2 {
10 
11 void TypedAxpy_float16_float__avx2_fma(
12  int N,
13  const float a,
14  const float16* x,
15  float* y) {
16  // if x does not start at the 16 byte boundary, we will process the first few.
17  // before we get to a real one.
18  while (((unsigned long)x % 16) && N) {
19  *(y++) += _cvtsh_ss((*(x++)).x) * a;
20  --N;
21  }
22 
23  // From now on we can do vectorized additions using __m256, which is 8 floats,
24  // so we will vectorize every 8 element and then resort to cvtsh_ss.
25  __m256 mma = _mm256_set1_ps(a);
26  int current = 0;
27  const int bound = (N % 8) ? N - 8 : N;
28 
29  for (; current < bound; current += 8) {
30  __m128i mmx_16 =
31  _mm_loadu_si128(reinterpret_cast<const __m128i*>(x + current));
32  __m256 mmx_32 = _mm256_cvtph_ps(mmx_16);
33  __m256 mmy = _mm256_loadu_ps(y + current);
34  mmy = _mm256_fmadd_ps(mmx_32, mma, mmy);
35  _mm256_storeu_ps(y + current, mmy);
36  }
37 
38  if (bound != N) {
39  while (current < N) {
40  y[current] += _cvtsh_ss(x[current].x) * a;
41  ++current;
42  }
43  }
44 }
45 
46 void TypedAxpy_uint8_float__avx2_fma(
47  int N,
48  const float a,
49  const std::uint8_t* x,
50  float* y) {
51  // if x does not start at the 16 byte boundary, we will process the first few.
52  // before we get to a real one.
53  while (((unsigned long)x % 16) && N) {
54  *(y++) += (float)(*(x++)) * a;
55  --N;
56  }
57 
58  // From now on we can do vectorized additions using __m256, which is 8 floats,
59  // so we will vectorize every 8 element and then resort to cvtsh_ss.
60  __m256 mma = _mm256_set1_ps(a);
61  int current = 0;
62  const int bound = (N % 8) ? N - 8 : N;
63 
64  for (; current < bound; current += 8) {
65  __m256i mmx_int32 = _mm256_cvtepi8_epi32(
66  _mm_loadu_si128(reinterpret_cast<const __m128i*>(x + current)));
67  __m256 mmx_fp32 = _mm256_cvtepi32_ps(mmx_int32);
68 
69  __m256 mmy = _mm256_loadu_ps(y + current);
70  mmy = _mm256_fmadd_ps(mmx_fp32, mma, mmy);
71  _mm256_storeu_ps(y + current, mmy);
72  }
73 
74  if (bound != N) {
75  while (current < N) {
76  y[current] += (float)(x[current]) * a;
77  ++current;
78  }
79  }
80 }
81 
82 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...