Caffe2 - C++ API
A deep learning, cross platform ML framework
DataTransfer.cc
1 
2 #include "DataTransfer.h"
3 #include "GLLogging.h"
4 
5 #include "caffe2/core/common.h"
6 
7 inline uint16x4x4_t vld4_u16_aligned16(const uint16_t* address) {
8  return vld4_u16(static_cast<const uint16_t*>(__builtin_assume_aligned(address, 16)));
9 }
10 
11 inline uint16x4_t vld1_u16_aligned8(const uint16_t* address) {
12  return vld1_u16(static_cast<const uint16_t*>(__builtin_assume_aligned(address, 8)));
13 }
14 
15 inline void vst4_u16_aligned16(uint16_t* address, uint16x4x4_t data) {
16  vst4_u16(static_cast<uint16_t*>(__builtin_assume_aligned(address, 16)), data);
17 }
18 
19 inline void vst1_u16_aligned8(uint16_t* address, uint16x4_t data) {
20  vst1_u16(static_cast<uint16_t*>(__builtin_assume_aligned(address, 8)), data);
21 }
22 
23 template <int input_channels>
24 static void interleaveSlice(
25  void* output, const float* input, size_t width, size_t height, size_t row_stride) {
26  const float* input_r = input;
27  const float* input_g = input_r + height * width;
28  const float* input_b = input_g + height * width;
29  const float* input_a = input_b + height * width;
30  uint16_t* output_f16 = static_cast<uint16_t*>(output);
31  if (width >= 4) {
32  for (size_t y = 0; y < height; y++) {
33  size_t nx = width;
34  while (nx >= 4) {
35  const uint16x4_t r = uint16x4_t(vcvt_f16_f32(vld1q_f32(input_r)));
36  input_r += 4;
37  uint16x4_t g, b, a;
38  g = b = a = vdup_n_u16(0);
39  if (input_channels >= 2) {
40  g = uint16x4_t(vcvt_f16_f32(vld1q_f32(input_g)));
41  input_g += 4;
42  if (input_channels >= 3) {
43  b = uint16x4_t(vcvt_f16_f32(vld1q_f32(input_b)));
44  input_b += 4;
45  if (input_channels >= 4) {
46  a = uint16x4_t(vcvt_f16_f32(vld1q_f32(input_a)));
47  input_a += 4;
48  }
49  }
50  }
51 
52  const uint16x4x4_t rgba = (uint16x4x4_t){{r, g, b, a}};
53  vst4_u16_aligned16(output_f16, rgba);
54  output_f16 += 4 * 4;
55 
56  nx -= 4;
57  }
58  if (nx != 0) {
59  output_f16 -= (4 - nx) * 4;
60  input_r -= 4 - nx;
61  if (input_channels >= 2) {
62  input_g -= 4 - nx;
63  if (input_channels >= 3) {
64  input_b -= 4 - nx;
65  if (input_channels >= 4) {
66  input_a -= 4 - nx;
67  }
68  }
69  }
70 
71  const uint16x4_t r = uint16x4_t(vcvt_f16_f32(vld1q_f32(input_r)));
72  input_r += 4;
73  uint16x4_t g, b, a;
74  g = b = a = vdup_n_u16(0);
75  if (input_channels >= 2) {
76  g = uint16x4_t(vcvt_f16_f32(vld1q_f32(input_g)));
77  input_g += 4;
78  if (input_channels >= 3) {
79  b = uint16x4_t(vcvt_f16_f32(vld1q_f32(input_b)));
80  input_b += 4;
81  if (input_channels >= 4) {
82  a = uint16x4_t(vcvt_f16_f32(vld1q_f32(input_a)));
83  input_a += 4;
84  }
85  }
86  }
87 
88  const uint16x4x4_t rgba = (uint16x4x4_t){{r, g, b, a}};
89  vst4_u16_aligned16(output_f16, rgba);
90  output_f16 += 4 * 4;
91  }
92  output_f16 += (row_stride - width) * 4;
93  }
94  } else {
95  for (size_t y = 0; y < height; y++) {
96  for (size_t x = 0; x < width; x++) {
97  float32x4_t rgba = vld1q_dup_f32(input_r++);
98  if (input_channels >= 2) {
99  rgba = vld1q_lane_f32(input_g++, rgba, 1);
100  if (input_channels >= 3) {
101  rgba = vld1q_lane_f32(input_b++, rgba, 2);
102  if (input_channels >= 4) {
103  rgba = vld1q_lane_f32(input_a++, rgba, 3);
104  }
105  }
106  }
107  vst1_u16_aligned8(output_f16, uint16x4_t(vcvt_f16_f32(rgba)));
108  output_f16 += 4;
109  }
110  output_f16 += (row_stride - width) * 4;
111  }
112  }
113 }
114 
115 void interleaveSlice(void* output,
116  const float* input,
117  size_t width,
118  size_t height,
119  size_t row_stride,
120  uint16_t input_channels) {
121  switch (input_channels) {
122  case 1:
123  interleaveSlice<1>(output, input, width, height, row_stride);
124  break;
125  case 2:
126  interleaveSlice<2>(output, input, width, height, row_stride);
127  break;
128  case 3:
129  interleaveSlice<3>(output, input, width, height, row_stride);
130  break;
131  case 4:
132  interleaveSlice<4>(output, input, width, height, row_stride);
133  break;
134  }
135 }
136 
137 template <int output_channels>
138 static void deInterleaveSlice(
139  float* output, const void* input, size_t width, size_t height, size_t row_stride) {
140  float* output_r = output;
141  float* output_g = output_r + height * width;
142  float* output_b = output_g + height * width;
143  float* output_a = output_b + height * width;
144  const uint16_t* input_f16 = static_cast<const uint16_t*>(input);
145  if (width >= 4) {
146  for (size_t y = 0; y < height; y++) {
147  size_t nx = width;
148  while (nx >= 4) {
149  const uint16x4x4_t rgba = vld4_u16_aligned16(input_f16);
150  input_f16 += 4 * 4;
151  const float32x4_t r = vcvt_f32_f16(float16x4_t(rgba.val[0]));
152  vst1q_f32(output_r, r);
153  output_r += 4;
154  if (output_channels >= 2) {
155  const float32x4_t g = vcvt_f32_f16(float16x4_t(rgba.val[1]));
156  vst1q_f32(output_g, g);
157  output_g += 4;
158  if (output_channels >= 3) {
159  const float32x4_t b = vcvt_f32_f16(float16x4_t(rgba.val[2]));
160  vst1q_f32(output_b, b);
161  output_b += 4;
162  if (output_channels >= 4) {
163  const float32x4_t a = vcvt_f32_f16(float16x4_t(rgba.val[3]));
164  vst1q_f32(output_a, a);
165  output_a += 4;
166  }
167  }
168  }
169 
170  nx -= 4;
171  }
172  if (nx != 0) {
173  input_f16 -= (4 - nx) * 4;
174  output_r -= 4 - nx;
175  if (output_channels >= 2) {
176  output_g -= 4 - nx;
177  if (output_channels >= 3) {
178  output_b -= 4 - nx;
179  if (output_channels >= 4) {
180  output_a -= 4 - nx;
181  }
182  }
183  }
184 
185  const uint16x4x4_t rgba = vld4_u16_aligned16(input_f16);
186  input_f16 += 4 * 4;
187  const float32x4_t r = vcvt_f32_f16(float16x4_t(rgba.val[0]));
188  vst1q_f32(output_r, r);
189  output_r += 4;
190  if (output_channels >= 2) {
191  const float32x4_t g = vcvt_f32_f16(float16x4_t(rgba.val[1]));
192  vst1q_f32(output_g, g);
193  output_g += 4;
194  if (output_channels >= 3) {
195  const float32x4_t b = vcvt_f32_f16(float16x4_t(rgba.val[2]));
196  vst1q_f32(output_b, b);
197  output_b += 4;
198  if (output_channels >= 4) {
199  const float32x4_t a = vcvt_f32_f16(float16x4_t(rgba.val[3]));
200  vst1q_f32(output_a, a);
201  output_a += 4;
202  }
203  }
204  }
205  }
206  input_f16 += (row_stride - width) * 4;
207  }
208  } else {
209  for (size_t y = 0; y < height; y++) {
210  for (size_t x = 0; x < width; x++) {
211  const float32x4_t rgba = vcvt_f32_f16(float16x4_t(vld1_u16_aligned8(input_f16)));
212  input_f16 += 4;
213  vst1q_lane_f32(output_r++, rgba, 0);
214  if (output_channels >= 2) {
215  vst1q_lane_f32(output_g++, rgba, 1);
216  if (output_channels >= 3) {
217  vst1q_lane_f32(output_b++, rgba, 2);
218  if (output_channels >= 4) {
219  vst1q_lane_f32(output_a++, rgba, 3);
220  }
221  }
222  }
223  }
224  input_f16 += (row_stride - width) * 4;
225  }
226  }
227 }
228 
229 void deInterleaveSlice(float* output,
230  const void* input,
231  size_t width,
232  size_t height,
233  size_t row_stride,
234  uint32_t output_channels) {
235  switch (output_channels) {
236  case 1:
237  deInterleaveSlice<1>(output, input, width, height, row_stride);
238  break;
239  case 2:
240  deInterleaveSlice<2>(output, input, width, height, row_stride);
241  break;
242  case 3:
243  deInterleaveSlice<3>(output, input, width, height, row_stride);
244  break;
245  case 4:
246  deInterleaveSlice<4>(output, input, width, height, row_stride);
247  break;
248  }
249 }