2 #include "caffe2/core/context.h" 3 #include "caffe2/core/operator.h" 4 #include "caffe2/perfkernels/embedding_lookup.h" 14 bool USE_POSITIONAL_WEIGHT = 0
24 !(USE_WEIGHT & USE_MEAN),
"Cannot both specify weight and mean.");
32 bool RunOnDevice()
override {
36 template <
typename InputType>
37 bool DoRunWithType() {
39 this, Input(INDICES));
42 template <
typename InputType,
typename IndexType>
43 bool DoRunWithType2() {
44 auto& dataInput = Input(DATA);
45 auto& indicesInput = Input(INDICES);
46 auto& lengthsInput = Input(LENGTHS);
48 CAFFE_ENFORCE_EQ(1, indicesInput.ndim(),
"INDICES must be a vector");
49 CAFFE_ENFORCE_EQ(1, lengthsInput.ndim(),
"LENGTHS must be a vector");
50 const TIndex N = dataInput.dim(0);
51 const int D = dataInput.size_from_dim(1);
52 const TIndex M = lengthsInput.dim(0);
53 const TIndex indices_size = indicesInput.size();
55 auto* output = Output(0);
56 auto shape = dataInput.dims();
58 output->Resize(shape);
59 T* out_data = output->template mutable_data<T>();
61 const InputType* in_data = dataInput.template data<InputType>();
62 const IndexType* indices = indicesInput.template data<IndexType>();
63 const int* lengths = lengthsInput.template data<int>();
64 const T* in_weight =
nullptr;
68 auto& weightInput = Input(WEIGHT);
69 CAFFE_ENFORCE_EQ(1, weightInput.ndim(),
"WEIGHT must be a vector");
70 if (!USE_POSITIONAL_WEIGHT) {
74 "Weight should have the same length as indices.");
76 in_weight = weightInput.template data<T>();
80 EmbeddingLookup<IndexType, InputType, T, USE_POSITIONAL_WEIGHT>(
99 INDICES = 1 + USE_WEIGHT,
101 LENGTHS = 2 + USE_WEIGHT,
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...