Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_funhash_op.h
1 #ifndef CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_
2 #define CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_
3 
4 #include <xxhash.h>
5 #include <array>
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 #define HASH_MAGIC 0x9e3779b97f4a7c15
11 
12 #define USE_SIGN
13 
14 namespace caffe2 {
15 
16 template <typename T, class Context>
17 class SparseFunHashOp : public Operator<Context> {
18  public:
19  USE_OPERATOR_CONTEXT_FUNCTIONS;
20  SparseFunHashOp(const OperatorDef& operator_def, Workspace* ws)
21  : Operator<Context>(operator_def, ws),
22  num_outputs_(
23  OperatorBase::GetSingleArgument<TIndex>("num_outputs", -1)),
24  num_segments_(
25  OperatorBase::GetSingleArgument<TIndex>("num_segments", -1)),
26  seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
27  CAFFE_ENFORCE(
28  OperatorBase::HasArgument("num_outputs"),
29  "Argument `num_outputs` is missing.");
30  // If alpha is provided, use adaptive hashing parameterized by alpha.
31  adaptive_ = (InputSize() == 5);
32  }
33 
34  bool RunOnDevice() override {
35  const auto& val = Input(0);
36  const auto& key = Input(1);
37  const auto& seg = Input(2);
38  const auto& weight = Input(3);
39 
40  TIndex num_alpha = 1;
41  if (adaptive_) {
42  const auto& alpha = Input(4);
43  num_alpha = alpha.dim(0);
44  }
45 
46  const auto* seg_data = seg.template data<int>();
47 
48  TIndex num_weight = weight.dim(0);
49  TIndex num_nz_ent = seg.dim(0);
50 
51  TIndex n_segments = num_segments_;
52  if (num_segments_ == -1) {
53  for (TIndex i = 0; i < num_nz_ent; ++i) {
54  if (seg_data[i] > n_segments) {
55  n_segments = seg_data[i];
56  }
57  }
58  ++n_segments;
59  }
60 
61  auto* output = Output(0);
62  output->Resize(n_segments, num_outputs_);
63 
64  T* output_data = output->template mutable_data<T>();
65 
66  memset(output_data, 0, sizeof(T) * n_segments * num_outputs_);
67 
68  const auto* weight_data = weight.template data<T>();
69  const auto* alpha_data = adaptive_ ? Input(4).template data<T>() : 0;
70  const auto* val_data = val.template data<T>();
71  const auto* key_data = key.template data<TIndex>();
72 
73  for (TIndex j = 0; j < num_nz_ent; ++j) {
74  TIndex cur_seg = seg_data[j];
75  TIndex cur_key = key_data[j];
76  T cur_val = val_data[j];
77  TIndex output_stride = cur_seg * num_outputs_;
78  for (TIndex i = 0; i < num_outputs_; ++i) {
79  T sum = 0;
80  for (TIndex k = 0; k < num_alpha; ++k) {
81  // The hash function takes as input three integers:
82  // 1. feature index
83  // 2. output index
84  // 3. alpha index
85  // 4. magic number to improve hashing
86  hash_data[0] = cur_key;
87  hash_data[1] = i;
88  hash_data[2] = k;
89  hash_data[3] = HASH_MAGIC;
90 
91  uint64_t hash = XXH64(hash_data.data(), hash_data.size(), seed_);
92 
93 #ifdef USE_SIGN
94  // Use the least significant bit for sign, the rest for weights.
95  TIndex index = (hash >> 1) % num_weight;
96  T cur_weight = weight_data[index];
97  if (hash & 1) {
98  cur_weight = -cur_weight;
99  }
100 #else
101  TIndex index = hash % num_weight;
102  T cur_weight = weight_data[index];
103 #endif
104 
105  if (adaptive_) {
106  sum += cur_weight * alpha_data[k];
107  } else {
108  sum += cur_weight;
109  }
110  }
111  output_data[output_stride + i] += sum * cur_val;
112  }
113  }
114 
115  return true;
116  }
117 
118  protected:
119  TIndex num_outputs_;
120  TIndex num_segments_;
121  uint64_t seed_;
122  std::array<uint64_t, 4> hash_data;
123  bool adaptive_;
124 };
125 
126 template <typename T, class Context>
127 class SparseFunHashGradientOp : public Operator<Context> {
128  public:
129  USE_OPERATOR_CONTEXT_FUNCTIONS;
130  SparseFunHashGradientOp(const OperatorDef& operator_def, Workspace* ws)
131  : Operator<Context>(operator_def, ws),
132  num_outputs_(
133  OperatorBase::GetSingleArgument<TIndex>("num_outputs", -1)),
134  seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
135  adaptive_ = (InputSize() == 6);
136  }
137 
138  bool RunOnDevice() override {
139  const auto& grad_out = Input(0);
140  const auto& val = Input(1);
141  const auto& key = Input(2);
142  const auto& seg = Input(3);
143  const auto& weight = Input(4);
144 
145  TIndex num_alpha = 1;
146  T* grad_alpha_data = 0;
147 
148  if (adaptive_) {
149  const auto& alpha = Input(5);
150  num_alpha = alpha.dim(0);
151  auto* grad_alpha = Output(2);
152  grad_alpha->ResizeLike(alpha);
153  grad_alpha_data = grad_alpha->template mutable_data<T>();
154  memset(grad_alpha_data, 0, sizeof(T) * num_alpha);
155  }
156 
157  const auto* seg_data = seg.template data<int>();
158 
159  TIndex num_weight = weight.dim(0);
160  TIndex num_nz_ent = seg.dim(0);
161 
162  TIndex grad_weight_size = num_nz_ent * num_outputs_ * num_alpha;
163  auto* grad_weight_val = Output(0);
164  grad_weight_val->Resize(grad_weight_size);
165  T* grad_weight_val_data = grad_weight_val->template mutable_data<T>();
166 
167  auto* grad_weight_ind = Output(1);
168  grad_weight_ind->Resize(grad_weight_size);
169  auto* grad_weight_ind_data =
170  grad_weight_ind->template mutable_data<TIndex>();
171 
172  const auto* grad_out_data = grad_out.template data<T>();
173  const auto* weight_data = weight.template data<T>();
174  const auto* alpha_data = adaptive_ ? Input(5).template data<T>() : 0;
175  const auto* val_data = val.template data<T>();
176  const auto* key_data = key.template data<TIndex>();
177 
178  TIndex w_ind = 0;
179  for (TIndex j = 0; j < num_nz_ent; ++j) {
180  TIndex cur_seg = seg_data[j];
181  TIndex cur_key = key_data[j];
182  T cur_val = val_data[j];
183  TIndex grad_out_stride = cur_seg * num_outputs_;
184  for (TIndex i = 0; i < num_outputs_; ++i) {
185  T grad_out_scale = grad_out_data[grad_out_stride + i] * cur_val;
186  for (TIndex k = 0; k < num_alpha; ++k) {
187  hash_data[0] = cur_key;
188  hash_data[1] = i;
189  hash_data[2] = k;
190  hash_data[3] = HASH_MAGIC;
191 
192  uint64_t hash = XXH64(hash_data.data(), hash_data.size(), seed_);
193 
194  T cur_grad_out_scale = grad_out_scale;
195 #ifdef USE_SIGN
196  TIndex index = (hash >> 1) % num_weight;
197  if (hash & 1) {
198  cur_grad_out_scale = -cur_grad_out_scale;
199  }
200 #else
201  TIndex index = hash % num_weight;
202 #endif
203 
204  if (adaptive_) {
205  grad_alpha_data[k] += cur_grad_out_scale * weight_data[index];
206  grad_weight_val_data[w_ind] = alpha_data[k] * cur_grad_out_scale;
207  } else {
208  grad_weight_val_data[w_ind] = cur_grad_out_scale;
209  }
210  grad_weight_ind_data[w_ind] = index;
211  ++w_ind;
212  }
213  }
214  }
215  return true;
216  }
217 
218  protected:
219  TIndex num_outputs_;
220  uint64_t seed_;
221  std::array<uint64_t, 4> hash_data;
222  bool adaptive_;
223 };
224 
225 } // namespace caffe2
226 
227 #endif // CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37