1 #ifndef CAFFE2_OPERATORS_ROW_MUL_H_ 2 #define CAFFE2_OPERATORS_ROW_MUL_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 13 template <
typename T,
class Context>
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 bool RunOnDevice()
override {
22 auto* output = Output(0);
24 output->ResizeLike(mat);
25 T* output_data = output->template mutable_data<T>();
26 const T* mat_data = mat.template data<T>();
27 const T* w_data = w.template data<T>();
33 "Length of w should be equal to the first dim of mat");
35 auto block_size = mat.size_from_dim(1);
36 for (
int i = 0; i < w.size(); i++) {
37 size_t offset = i * block_size;
38 for (
int j = 0; j < block_size; j++) {
39 output_data[offset + j] = mat_data[offset + j] * w_data[i];
48 template <
typename T,
class Context>
51 USE_OPERATOR_CONTEXT_FUNCTIONS;
54 bool RunOnDevice()
override {
56 auto* output = Output(0);
59 int block_size = mat.size_from_dim(1);
62 T* output_data = output->template mutable_data<T>();
63 const T* mat_data = mat.template data<T>();
65 for (
int i = 0; i < N; i++) {
67 size_t offset = i * block_size;
68 for (
int j = 0; j < block_size; j++) {
69 output_data[i] += mat_data[offset + j];
78 #endif // CAFFE2_OPERATORS_ROW_MUL_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...