1 #ifndef CAFFE2_OPERATORS_FUNHASH_OP_H_ 2 #define CAFFE2_OPERATORS_FUNHASH_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/utils/math.h" 10 #define SIGN_MAGIC 0x9e3779b97f4a7c15 11 #define INDEX_MAGIC 0xf39cc0605cedc834 17 template <
typename T,
class Context>
20 USE_OPERATOR_CONTEXT_FUNCTIONS;
24 OperatorBase::GetSingleArgument<TIndex>(
"num_outputs", -1)),
26 OperatorBase::GetSingleArgument<TIndex>(
"num_segments", -1)),
27 seed_(OperatorBase::GetSingleArgument<uint64_t>(
"seed", 0)) {
30 "Argument `num_outputs` is missing.");
32 adaptive_ = (InputSize() == 5);
35 bool RunOnDevice()
override {
36 const auto& val = Input(0);
37 const auto& key = Input(1);
38 const auto& seg = Input(2);
39 const auto& weight = Input(3);
43 const auto& alpha = Input(4);
44 num_alpha = alpha.dim(0);
47 const auto* seg_data = seg.template data<int>();
49 TIndex num_weight = weight.dim(0);
50 TIndex num_nz_ent = seg.dim(0);
52 TIndex n_segments = num_segments_;
53 if (num_segments_ == -1) {
54 for (TIndex i = 0; i < num_nz_ent; ++i) {
55 if (seg_data[i] > n_segments) {
56 n_segments = seg_data[i];
62 auto* output = Output(0);
63 output->Resize(n_segments, num_outputs_);
65 T* output_data = output->template mutable_data<T>();
67 memset(output_data, 0,
sizeof(T) * n_segments * num_outputs_);
69 const auto* weight_data = weight.template data<T>();
70 const auto* alpha_data = adaptive_ ? Input(4).template data<T>() : 0;
71 const auto* val_data = val.template data<T>();
72 const auto* key_data = key.template data<TIndex>();
74 for (TIndex j = 0; j < num_nz_ent; ++j) {
75 TIndex cur_seg = seg_data[j];
76 TIndex cur_key = key_data[j];
77 T cur_val = val_data[j];
78 TIndex output_stride = cur_seg * num_outputs_;
79 for (TIndex i = 0; i < num_outputs_; ++i) {
81 for (TIndex k = 0; k < num_alpha; ++k) {
89 hash_data[0] = cur_key;
93 hash_data[3] = INDEX_MAGIC;
94 hash = XXH64(hash_data.data(), hash_data.size(), seed_);
95 TIndex index = hash % num_weight;
97 T cur_weight = weight_data[index];
99 hash_data[3] = SIGN_MAGIC;
100 hash = XXH64(hash_data.data(), hash_data.size(), seed_);
102 cur_weight = -cur_weight;
107 sum += cur_weight * alpha_data[k];
112 output_data[output_stride + i] += sum * cur_val;
121 TIndex num_segments_;
123 std::array<uint64_t, 4> hash_data;
127 template <
typename T,
class Context>
130 USE_OPERATOR_CONTEXT_FUNCTIONS;
134 OperatorBase::GetSingleArgument<TIndex>(
"num_outputs", -1)),
135 seed_(OperatorBase::GetSingleArgument<uint64_t>(
"seed", 0)) {
136 adaptive_ = (InputSize() == 6);
139 bool RunOnDevice()
override {
140 const auto& grad_out = Input(0);
141 const auto& val = Input(1);
142 const auto& key = Input(2);
143 const auto& seg = Input(3);
144 const auto& weight = Input(4);
146 TIndex num_alpha = 1;
147 T* grad_alpha_data = 0;
150 const auto& alpha = Input(5);
151 num_alpha = alpha.dim(0);
152 auto* grad_alpha = Output(1);
153 grad_alpha->ResizeLike(alpha);
154 grad_alpha_data = grad_alpha->template mutable_data<T>();
155 memset(grad_alpha_data, 0,
sizeof(T) * num_alpha);
158 const auto* seg_data = seg.template data<int>();
160 TIndex num_weight = weight.dim(0);
161 TIndex num_nz_ent = seg.dim(0);
163 auto* grad_weight = Output(0);
164 grad_weight->ResizeLike(weight);
165 T* grad_weight_data = grad_weight->template mutable_data<T>();
167 const auto* grad_out_data = grad_out.template data<T>();
168 const auto* weight_data = weight.template data<T>();
169 const auto* alpha_data = adaptive_ ? Input(5).template data<T>() : 0;
170 const auto* val_data = val.template data<T>();
171 const auto* key_data = key.template data<TIndex>();
173 memset(grad_weight_data, 0,
sizeof(T) * num_weight);
175 for (TIndex j = 0; j < num_nz_ent; ++j) {
176 TIndex cur_seg = seg_data[j];
177 TIndex cur_key = key_data[j];
178 T cur_val = val_data[j];
179 TIndex grad_out_stride = cur_seg * num_outputs_;
180 for (TIndex i = 0; i < num_outputs_; ++i) {
181 T grad_out_scale = grad_out_data[grad_out_stride + i] * cur_val;
182 for (TIndex k = 0; k < num_alpha; ++k) {
184 hash_data[0] = cur_key;
188 hash_data[3] = INDEX_MAGIC;
189 hash = XXH64(hash_data.data(), hash_data.size(), seed_);
190 TIndex index = hash % num_weight;
192 T cur_grad_out_scale = grad_out_scale;
194 hash_data[3] = SIGN_MAGIC;
195 hash = XXH64(hash_data.data(), hash_data.size(), seed_);
197 cur_grad_out_scale = -cur_grad_out_scale;
202 grad_alpha_data[k] += cur_grad_out_scale * weight_data[index];
203 grad_weight_data[index] += alpha_data[k] * cur_grad_out_scale;
205 grad_weight_data[index] += cur_grad_out_scale;
216 std::array<uint64_t, 4> hash_data;
222 #endif // CAFFE2_OPERATORS_FUNHASH_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.