1 #include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h" 3 #include "caffe2/core/types.h" 4 #include "caffe2/perfkernels/common.h" 5 #include "caffe2/perfkernels/typed_axpy.h" 6 #include "caffe2/utils/cpuid.h" 7 #include "caffe2/utils/math.h" 16 bool IS_WEIGHT_POSITIONAL =
false>
17 static void Fused8BitRowwiseEmbeddingLookupGenericSlow(
18 const TIndex block_size,
19 const TIndex output_size,
20 const TIndex index_size,
21 const TIndex data_size,
23 const IndexType* indices,
26 bool normalize_by_lengths,
30 const auto scale_bias_offset = 8 /
sizeof(InType);
31 const TIndex fused_block_size = block_size + scale_bias_offset;
33 for (
int m = 0; m < output_size; ++m) {
34 memset(out, 0,
sizeof(OutType) * block_size);
35 EigenVectorArrayMap<OutType> out_vector(out, block_size);
36 for (
int i = 0; i < lengths[m]; ++i) {
37 CAFFE_ENFORCE_LT(current, index_size);
38 TIndex idx = indices[current];
40 0 <= idx && idx < data_size,
43 " is out of bounds: ",
47 CAFFE_ENFORCE_LT(idx, data_size);
49 if (current + 1 < index_size) {
51 input + fused_block_size * indices[current + 1], 0, 1);
55 const float* scale_bias =
reinterpret_cast<const float*
>(
56 input + fused_block_size * indices[current] + block_size);
60 weight = weights[IS_WEIGHT_POSITIONAL ? i : current];
62 const float scale = weight * scale_bias[0];
63 const float bias = weight * scale_bias[1];
65 TypedAxpy<InType, OutType>(
66 block_size, scale, input + fused_block_size * indices[current], out);
72 if (normalize_by_lengths && lengths[m]) {
74 math::Scale<OutType, CPUContext>(
75 block_size, 1.f / lengths[m], out, out,
nullptr);
82 "Your input seems to be incorrect: the sum of lengths values should be " 83 "the size of the indices tensor, but it appears not.");
87 #define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION( \ 88 IndexType, InType, OutType) \ 90 Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__base( \ 91 const TIndex block_size, \ 92 const TIndex output_size, \ 93 const TIndex index_size, \ 94 const TIndex data_size, \ 95 const InType* input, \ 96 const IndexType* indices, \ 98 const float* weights, \ 99 bool normalize_by_lengths, \ 101 Fused8BitRowwiseEmbeddingLookupGenericSlow< \ 114 normalize_by_lengths, \ 118 void Fused8BitRowwiseEmbeddingLookup<IndexType, InType, OutType, false>( \ 119 const TIndex block_size, \ 120 const TIndex output_size, \ 121 const TIndex index_size, \ 122 const TIndex data_size, \ 123 const InType* input, \ 124 const IndexType* indices, \ 125 const int* lengths, \ 126 const float* weights, \ 127 bool normalize_by_lengths, \ 129 const int32_t one = 1; \ 131 reinterpret_cast<const uint8_t*>(&one)[0], \ 133 "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \ 135 Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false, \ 144 normalize_by_lengths, \ 147 Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false, \ 156 normalize_by_lengths, \ 160 FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, uint8_t,
float);
161 FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t, uint8_t,
float);
163 #undef FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...