Caffe2 - C++ API
A deep learning, cross platform ML framework
quant_decode_op.h
1 #ifndef QUANT_DECODE_OP_H_
2 #define QUANT_DECODE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/core/tensor.h"
7 #include "caffe2/core/typeid.h"
8 
9 namespace caffe2 {
10 
11 namespace {
12 
13 template <class CodebookT, class CodeT>
14 void Decode(
15  const TensorCPU& codebook,
16  const TensorCPU& codes,
17  /* optional */ const TensorCPU* const decoded_grad,
18  TensorCPU* const output,
19  bool resizeOnly) {
20  CAFFE_ENFORCE(codebook.IsType<CodebookT>());
21 
22  auto* cb_ptr = codebook.data<CodebookT>();
23  int cb_size = codebook.size();
24 
25  CAFFE_ENFORCE(codes.IsType<CodeT>());
26  auto* code_ptr = codes.data<CodeT>();
27 
28  if (decoded_grad == nullptr) {
29  // Forward pass: decode and store codebook values in output.
30  output->ResizeLike(codes);
31  auto* out_ptr = output->mutable_data<CodebookT>();
32  if (resizeOnly) {
33  return;
34  }
35 
36  int sz = output->size();
37  for (int i = 0; i < sz; i++) {
38  DCHECK_LE(*code_ptr, cb_size);
39  *out_ptr++ = cb_ptr[*code_ptr++];
40  }
41  } else {
42  // Backward pass: decode and accumulate gradient w.r.t. codebook values.
43  CAFFE_ENFORCE_EQ(codes.size(), decoded_grad->size());
44  auto* gradient_ptr = decoded_grad->data<CodebookT>();
45  auto* const gradient_end = gradient_ptr + decoded_grad->size();
46 
47  CAFFE_ENFORCE_EQ(cb_size, output->size());
48  auto* out_ptr = output->mutable_data<CodebookT>();
49  while (gradient_ptr < gradient_end) {
50  DCHECK_LE(*code_ptr, cb_size);
51  out_ptr[*code_ptr++] += *gradient_ptr++;
52  }
53  }
54 }
55 
56 #define REGISTER_DECODER(codebookType, codesType) \
57  { \
58  {TypeMeta::Id<codebookType>(), TypeMeta::Id<codesType>()}, \
59  [](const TensorCPU& codebook_, \
60  const TensorCPU& codes_, \
61  const TensorCPU* gradient_, \
62  TensorCPU* outDecoded_, \
63  bool resizeOnly_) { \
64  Decode<codebookType, codesType>( \
65  codebook_, codes_, gradient_, outDecoded_, resizeOnly_); \
66  } \
67  }
68 
69 inline void DecodeGeneral(
70  const TensorCPU& codebook,
71  const TensorCPU& codes,
72  const TensorCPU* gradient,
73  TensorCPU* outDecoded,
74  bool resizeOnly) {
75  const static std::map<
76  std::pair<CaffeTypeId, CaffeTypeId>,
77  std::function<void(
78  const TensorCPU& codebook,
79  const TensorCPU& codes,
80  const TensorCPU* gradient,
81  TensorCPU* outDecoded,
82  bool resizeOnly)>>
83  gDecoderMapper = {REGISTER_DECODER(float, uint8_t),
84  REGISTER_DECODER(float, uint16_t),
85  REGISTER_DECODER(float, int32_t)};
86 
87  gDecoderMapper.at({codebook.meta().id(), codes.meta().id()})(
88  codebook, codes, gradient, outDecoded, resizeOnly);
89 }
90 
91 } // namespace
92 
93 // Decode tensors based on given codebook,
94 // The codebook is generated by model_quantize.py
95 
96 enum class QuantDecodeRunTy {
97  RUN_ALWAYS,
98  RUN_ONCE,
99 };
100 
101 template <QuantDecodeRunTy QuantDecodeRun>
102 class QuantDecodeOp final : public Operator<CPUContext> {
103  public:
104  USE_OPERATOR_FUNCTIONS(CPUContext);
105  QuantDecodeOp(const OperatorDef& operator_def, Workspace* ws)
106  : Operator<CPUContext>(operator_def, ws) {}
107 
108  ~QuantDecodeOp() {}
109 
110  bool RunOnDevice() override {
111  CAFFE_ENFORCE_GT(InputSize(), 1);
112  // first input is the codebook
113  CAFFE_ENFORCE_EQ(InputSize(), OutputSize() + 1);
114 
115  const auto& codebook = Input(0);
116  CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.meta().name());
117 
118  for (int i = 0; i < OutputSize(); i++) {
119  auto& ci = Input(i + 1);
120  auto* co = Output(i);
121 
122  DecodeGeneral(
123  codebook,
124  ci,
125  nullptr,
126  co,
127  /*resizeOnly=*/QuantDecodeRun == QuantDecodeRunTy::RUN_ONCE &&
128  hasRun_);
129  }
130  hasRun_ = true;
131  return true;
132  }
133 
134  private:
135  bool hasRun_{false};
136 };
137 
138 class QuantDecodeGradientOp final : public Operator<CPUContext> {
139  public:
140  USE_OPERATOR_FUNCTIONS(CPUContext);
141  QuantDecodeGradientOp(const OperatorDef& operator_def, Workspace* ws)
142  : Operator<CPUContext>(operator_def, ws) {}
144 
145  bool RunOnDevice() override {
146  // Inputs: 1 codebook, n tensors of codes, and n corresponding gradients.
147  CAFFE_ENFORCE(InputSize() >= 3 && InputSize() % 2 == 1);
148  const int num_code_tensors = (InputSize() - 1) / 2;
149  CAFFE_ENFORCE_EQ(OutputSize(), 1);
150 
151  const auto& codebook = Input(0);
152  CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.meta().name());
153 
154  auto* gradient = Output(0);
155  gradient->ResizeLike(codebook);
156  auto* gradient_ptr = gradient->mutable_data<float>();
157  std::fill(gradient_ptr, gradient_ptr + gradient->size(), 0);
158 
159  for (int i = 0; i < num_code_tensors; i++) {
160  auto& codes_i = Input(i + 1);
161  auto& output_gradient_i = Input(i + num_code_tensors + 1);
162  DecodeGeneral(codebook, codes_i, &output_gradient_i, gradient, false);
163  }
164  return true;
165  }
166 };
167 
168 } // namespace caffe2
169 #endif // QUANT_DECODE_OP_H_
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:66
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...