1 #ifndef CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_ 2 #define CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/operators/fused_rowwise_8bit_conversion_ops.h" 8 #include "caffe2/operators/reducer_functors.h" 9 #include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h" 10 #include "caffe2/utils/math.h" 14 template <
class Context,
bool with_weights = 0,
bool is_mean = 0>
18 !(with_weights && is_mean),
19 "Cannot have with_weights and is_mean a the same time");
21 USE_OPERATOR_CONTEXT_FUNCTIONS;
24 bool RunOnDevice()
override {
26 this, Input(INDICES));
29 template <
typename IndexType>
30 bool DoRunWithType() {
31 const auto& data = Input(DATA);
32 const auto& indices = Input(INDICES);
33 const auto& lengths = Input(LENGTHS);
34 auto* output = Output(0);
36 CAFFE_ENFORCE_EQ(indices.ndim(), 1,
"INDICES must be a vector");
37 CAFFE_ENFORCE_EQ(lengths.ndim(), 1,
"LENGTHS must be a vector");
39 const float* weights =
nullptr;
41 const auto& weights_input = Input(WEIGHTS);
42 CAFFE_ENFORCE_EQ(weights_input.ndim(), 1,
"WEIGHTS must be a vector");
46 "WEIGHTS should have the same length as INDICES.");
47 weights = weights_input.template data<float>();
50 CAFFE_ENFORCE_GT(data.dim(1), 8,
"DATA must have more than 8 columns");
53 const std::vector<TIndex> shape = {lengths.dim(0), data.dim(1) - 8};
54 output->Resize(shape);
61 data.template data<uint8_t>(),
62 indices.template data<IndexType>(),
63 lengths.template data<int>(),
66 output->template mutable_data<float>());
75 INDICES = 1 + with_weights,
76 LENGTHS = 2 + with_weights,
82 #endif // CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void Fused8BitRowwiseEmbeddingLookup(const TIndex block_size, const TIndex output_size, const TIndex index_size, const TIndex data_size, const InType *input, const IndexType *indices, const int *lengths, const float *weights, bool normalize_by_lengths, OutType *out)
Embedding lookup with reduction.