1 #ifndef CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_ 2 #define CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/utils/math.h" 10 #define HASH_MAGIC 0x9e3779b97f4a7c15 16 template <
typename T,
class Context>
19 USE_OPERATOR_CONTEXT_FUNCTIONS;
23 OperatorBase::GetSingleArgument<TIndex>(
"num_outputs", -1)),
25 OperatorBase::GetSingleArgument<TIndex>(
"num_segments", -1)),
26 seed_(OperatorBase::GetSingleArgument<uint64_t>(
"seed", 0)) {
29 "Argument `num_outputs` is missing.");
31 adaptive_ = (InputSize() == 5);
34 bool RunOnDevice()
override {
35 const auto& val = Input(0);
36 const auto& key = Input(1);
37 const auto& seg = Input(2);
38 const auto& weight = Input(3);
42 const auto& alpha = Input(4);
43 num_alpha = alpha.dim(0);
46 const auto* seg_data = seg.template data<int>();
48 TIndex num_weight = weight.dim(0);
49 TIndex num_nz_ent = seg.dim(0);
51 TIndex n_segments = num_segments_;
52 if (num_segments_ == -1) {
53 for (TIndex i = 0; i < num_nz_ent; ++i) {
54 if (seg_data[i] > n_segments) {
55 n_segments = seg_data[i];
61 auto* output = Output(0);
62 output->Resize(n_segments, num_outputs_);
64 T* output_data = output->template mutable_data<T>();
66 memset(output_data, 0,
sizeof(T) * n_segments * num_outputs_);
68 const auto* weight_data = weight.template data<T>();
69 const auto* alpha_data = adaptive_ ? Input(4).template data<T>() : 0;
70 const auto* val_data = val.template data<T>();
71 const auto* key_data = key.template data<TIndex>();
73 for (TIndex j = 0; j < num_nz_ent; ++j) {
74 TIndex cur_seg = seg_data[j];
75 TIndex cur_key = key_data[j];
76 T cur_val = val_data[j];
77 TIndex output_stride = cur_seg * num_outputs_;
78 for (TIndex i = 0; i < num_outputs_; ++i) {
80 for (TIndex k = 0; k < num_alpha; ++k) {
86 hash_data[0] = cur_key;
89 hash_data[3] = HASH_MAGIC;
91 uint64_t hash = XXH64(hash_data.data(), hash_data.size(), seed_);
95 TIndex index = (hash >> 1) % num_weight;
96 T cur_weight = weight_data[index];
98 cur_weight = -cur_weight;
101 TIndex index = hash % num_weight;
102 T cur_weight = weight_data[index];
106 sum += cur_weight * alpha_data[k];
111 output_data[output_stride + i] += sum * cur_val;
120 TIndex num_segments_;
122 std::array<uint64_t, 4> hash_data;
126 template <
typename T,
class Context>
129 USE_OPERATOR_CONTEXT_FUNCTIONS;
133 OperatorBase::GetSingleArgument<TIndex>(
"num_outputs", -1)),
134 seed_(OperatorBase::GetSingleArgument<uint64_t>(
"seed", 0)) {
135 adaptive_ = (InputSize() == 6);
138 bool RunOnDevice()
override {
139 const auto& grad_out = Input(0);
140 const auto& val = Input(1);
141 const auto& key = Input(2);
142 const auto& seg = Input(3);
143 const auto& weight = Input(4);
145 TIndex num_alpha = 1;
146 T* grad_alpha_data = 0;
149 const auto& alpha = Input(5);
150 num_alpha = alpha.dim(0);
151 auto* grad_alpha = Output(2);
152 grad_alpha->ResizeLike(alpha);
153 grad_alpha_data = grad_alpha->template mutable_data<T>();
154 memset(grad_alpha_data, 0,
sizeof(T) * num_alpha);
157 const auto* seg_data = seg.template data<int>();
159 TIndex num_weight = weight.dim(0);
160 TIndex num_nz_ent = seg.dim(0);
162 TIndex grad_weight_size = num_nz_ent * num_outputs_ * num_alpha;
163 auto* grad_weight_val = Output(0);
164 grad_weight_val->Resize(grad_weight_size);
165 T* grad_weight_val_data = grad_weight_val->template mutable_data<T>();
167 auto* grad_weight_ind = Output(1);
168 grad_weight_ind->Resize(grad_weight_size);
169 auto* grad_weight_ind_data =
170 grad_weight_ind->template mutable_data<TIndex>();
172 const auto* grad_out_data = grad_out.template data<T>();
173 const auto* weight_data = weight.template data<T>();
174 const auto* alpha_data = adaptive_ ? Input(5).template data<T>() : 0;
175 const auto* val_data = val.template data<T>();
176 const auto* key_data = key.template data<TIndex>();
179 for (TIndex j = 0; j < num_nz_ent; ++j) {
180 TIndex cur_seg = seg_data[j];
181 TIndex cur_key = key_data[j];
182 T cur_val = val_data[j];
183 TIndex grad_out_stride = cur_seg * num_outputs_;
184 for (TIndex i = 0; i < num_outputs_; ++i) {
185 T grad_out_scale = grad_out_data[grad_out_stride + i] * cur_val;
186 for (TIndex k = 0; k < num_alpha; ++k) {
187 hash_data[0] = cur_key;
190 hash_data[3] = HASH_MAGIC;
192 uint64_t hash = XXH64(hash_data.data(), hash_data.size(), seed_);
194 T cur_grad_out_scale = grad_out_scale;
196 TIndex index = (hash >> 1) % num_weight;
198 cur_grad_out_scale = -cur_grad_out_scale;
201 TIndex index = hash % num_weight;
205 grad_alpha_data[k] += cur_grad_out_scale * weight_data[index];
206 grad_weight_val_data[w_ind] = alpha_data[k] * cur_grad_out_scale;
208 grad_weight_val_data[w_ind] = cur_grad_out_scale;
210 grad_weight_ind_data[w_ind] = index;
221 std::array<uint64_t, 4> hash_data;
227 #endif // CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.