Caffe2 - C++ API
A deep learning, cross platform ML framework
funhash_op.h
1 #ifndef CAFFE2_OPERATORS_FUNHASH_OP_H_
2 #define CAFFE2_OPERATORS_FUNHASH_OP_H_
3 
4 #include <xxhash.h>
5 #include <array>
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 #define SIGN_MAGIC 0x9e3779b97f4a7c15
11 #define INDEX_MAGIC 0xf39cc0605cedc834
12 
13 #define USE_SIGN
14 
15 namespace caffe2 {
16 
17 template <typename T, class Context>
18 class FunHashOp : public Operator<Context> {
19  public:
20  USE_OPERATOR_CONTEXT_FUNCTIONS;
21  FunHashOp(const OperatorDef& operator_def, Workspace* ws)
22  : Operator<Context>(operator_def, ws),
23  num_outputs_(
24  OperatorBase::GetSingleArgument<TIndex>("num_outputs", -1)),
25  num_segments_(
26  OperatorBase::GetSingleArgument<TIndex>("num_segments", -1)),
27  seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
28  CAFFE_ENFORCE(
29  OperatorBase::HasArgument("num_outputs"),
30  "Argument `num_outputs` is missing.");
31  // If alpha is provided, use adaptive hashing parameterized by alpha.
32  adaptive_ = (InputSize() == 5);
33  }
34 
35  bool RunOnDevice() override {
36  const auto& val = Input(0);
37  const auto& key = Input(1);
38  const auto& seg = Input(2);
39  const auto& weight = Input(3);
40 
41  TIndex num_alpha = 1;
42  if (adaptive_) {
43  const auto& alpha = Input(4);
44  num_alpha = alpha.dim(0);
45  }
46 
47  const auto* seg_data = seg.template data<int>();
48 
49  TIndex num_weight = weight.dim(0);
50  TIndex num_nz_ent = seg.dim(0);
51 
52  TIndex n_segments = num_segments_;
53  if (num_segments_ == -1) {
54  for (TIndex i = 0; i < num_nz_ent; ++i) {
55  if (seg_data[i] > n_segments) {
56  n_segments = seg_data[i];
57  }
58  }
59  ++n_segments;
60  }
61 
62  auto* output = Output(0);
63  output->Resize(n_segments, num_outputs_);
64 
65  T* output_data = output->template mutable_data<T>();
66 
67  memset(output_data, 0, sizeof(T) * n_segments * num_outputs_);
68 
69  const auto* weight_data = weight.template data<T>();
70  const auto* alpha_data = adaptive_ ? Input(4).template data<T>() : 0;
71  const auto* val_data = val.template data<T>();
72  const auto* key_data = key.template data<TIndex>();
73 
74  for (TIndex j = 0; j < num_nz_ent; ++j) {
75  TIndex cur_seg = seg_data[j];
76  TIndex cur_key = key_data[j];
77  T cur_val = val_data[j];
78  TIndex output_stride = cur_seg * num_outputs_;
79  for (TIndex i = 0; i < num_outputs_; ++i) {
80  T sum = 0;
81  for (TIndex k = 0; k < num_alpha; ++k) {
82  uint64_t hash;
83  // The hash function takes as input four integers:
84  // 1. feature index
85  // 2. output index
86  // 3. alpha index
87  // 4. magic number: SIGN_MAGIC for sign (-1/+1)
88  // INDEX_MAGIC for weight index
89  hash_data[0] = cur_key;
90  hash_data[1] = i;
91  hash_data[2] = k;
92 
93  hash_data[3] = INDEX_MAGIC;
94  hash = XXH64(hash_data.data(), hash_data.size(), seed_);
95  TIndex index = hash % num_weight;
96 
97  T cur_weight = weight_data[index];
98 #ifdef USE_SIGN
99  hash_data[3] = SIGN_MAGIC;
100  hash = XXH64(hash_data.data(), hash_data.size(), seed_);
101  if (hash % 2) {
102  cur_weight = -cur_weight;
103  }
104 #endif // USE_SIGN
105 
106  if (adaptive_) {
107  sum += cur_weight * alpha_data[k];
108  } else {
109  sum += cur_weight;
110  }
111  }
112  output_data[output_stride + i] += sum * cur_val;
113  }
114  }
115 
116  return true;
117  }
118 
119  protected:
120  TIndex num_outputs_;
121  TIndex num_segments_;
122  uint64_t seed_;
123  std::array<uint64_t, 4> hash_data;
124  bool adaptive_;
125 };
126 
127 template <typename T, class Context>
128 class FunHashGradientOp : public Operator<Context> {
129  public:
130  USE_OPERATOR_CONTEXT_FUNCTIONS;
131  FunHashGradientOp(const OperatorDef& operator_def, Workspace* ws)
132  : Operator<Context>(operator_def, ws),
133  num_outputs_(
134  OperatorBase::GetSingleArgument<TIndex>("num_outputs", -1)),
135  seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
136  adaptive_ = (InputSize() == 6);
137  }
138 
139  bool RunOnDevice() override {
140  const auto& grad_out = Input(0);
141  const auto& val = Input(1);
142  const auto& key = Input(2);
143  const auto& seg = Input(3);
144  const auto& weight = Input(4);
145 
146  TIndex num_alpha = 1;
147  T* grad_alpha_data = 0;
148 
149  if (adaptive_) {
150  const auto& alpha = Input(5);
151  num_alpha = alpha.dim(0);
152  auto* grad_alpha = Output(1);
153  grad_alpha->ResizeLike(alpha);
154  grad_alpha_data = grad_alpha->template mutable_data<T>();
155  memset(grad_alpha_data, 0, sizeof(T) * num_alpha);
156  }
157 
158  const auto* seg_data = seg.template data<int>();
159 
160  TIndex num_weight = weight.dim(0);
161  TIndex num_nz_ent = seg.dim(0);
162 
163  auto* grad_weight = Output(0);
164  grad_weight->ResizeLike(weight);
165  T* grad_weight_data = grad_weight->template mutable_data<T>();
166 
167  const auto* grad_out_data = grad_out.template data<T>();
168  const auto* weight_data = weight.template data<T>();
169  const auto* alpha_data = adaptive_ ? Input(5).template data<T>() : 0;
170  const auto* val_data = val.template data<T>();
171  const auto* key_data = key.template data<TIndex>();
172 
173  memset(grad_weight_data, 0, sizeof(T) * num_weight);
174 
175  for (TIndex j = 0; j < num_nz_ent; ++j) {
176  TIndex cur_seg = seg_data[j];
177  TIndex cur_key = key_data[j];
178  T cur_val = val_data[j];
179  TIndex grad_out_stride = cur_seg * num_outputs_;
180  for (TIndex i = 0; i < num_outputs_; ++i) {
181  T grad_out_scale = grad_out_data[grad_out_stride + i] * cur_val;
182  for (TIndex k = 0; k < num_alpha; ++k) {
183  uint64_t hash;
184  hash_data[0] = cur_key;
185  hash_data[1] = i;
186  hash_data[2] = k;
187 
188  hash_data[3] = INDEX_MAGIC;
189  hash = XXH64(hash_data.data(), hash_data.size(), seed_);
190  TIndex index = hash % num_weight;
191 
192  T cur_grad_out_scale = grad_out_scale;
193 #ifdef USE_SIGN
194  hash_data[3] = SIGN_MAGIC;
195  hash = XXH64(hash_data.data(), hash_data.size(), seed_);
196  if (hash % 2) {
197  cur_grad_out_scale = -cur_grad_out_scale;
198  }
199 #endif // USE_SIGN
200 
201  if (adaptive_) {
202  grad_alpha_data[k] += cur_grad_out_scale * weight_data[index];
203  grad_weight_data[index] += alpha_data[k] * cur_grad_out_scale;
204  } else {
205  grad_weight_data[index] += cur_grad_out_scale;
206  }
207  }
208  }
209  }
210  return true;
211  }
212 
213  protected:
214  TIndex num_outputs_;
215  uint64_t seed_;
216  std::array<uint64_t, 4> hash_data;
217  bool adaptive_;
218 };
219 
220 } // namespace caffe2
221 
222 #endif // CAFFE2_OPERATORS_FUNHASH_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37