1 #ifndef QUANT_DECODE_OP_H_ 2 #define QUANT_DECODE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/core/tensor.h" 7 #include "caffe2/core/typeid.h" 13 template <
class CodebookT,
class CodeT>
15 const TensorCPU& codebook,
16 const TensorCPU& codes,
17 const TensorCPU*
const decoded_grad,
18 TensorCPU*
const output,
20 CAFFE_ENFORCE(codebook.IsType<CodebookT>());
22 auto* cb_ptr = codebook.data<CodebookT>();
23 int cb_size = codebook.size();
25 CAFFE_ENFORCE(codes.IsType<CodeT>());
26 auto* code_ptr = codes.data<CodeT>();
28 if (decoded_grad ==
nullptr) {
30 output->ResizeLike(codes);
31 auto* out_ptr = output->mutable_data<CodebookT>();
36 int sz = output->size();
37 for (
int i = 0; i < sz; i++) {
38 DCHECK_LE(*code_ptr, cb_size);
39 *out_ptr++ = cb_ptr[*code_ptr++];
43 CAFFE_ENFORCE_EQ(codes.size(), decoded_grad->size());
44 auto* gradient_ptr = decoded_grad->data<CodebookT>();
45 auto*
const gradient_end = gradient_ptr + decoded_grad->size();
47 CAFFE_ENFORCE_EQ(cb_size, output->size());
48 auto* out_ptr = output->mutable_data<CodebookT>();
49 while (gradient_ptr < gradient_end) {
50 DCHECK_LE(*code_ptr, cb_size);
51 out_ptr[*code_ptr++] += *gradient_ptr++;
56 #define REGISTER_DECODER(codebookType, codesType) \ 58 {TypeMeta::Id<codebookType>(), TypeMeta::Id<codesType>()}, \ 59 [](const TensorCPU& codebook_, \ 60 const TensorCPU& codes_, \ 61 const TensorCPU* gradient_, \ 62 TensorCPU* outDecoded_, \ 64 Decode<codebookType, codesType>( \ 65 codebook_, codes_, gradient_, outDecoded_, resizeOnly_); \ 69 inline void DecodeGeneral(
70 const TensorCPU& codebook,
71 const TensorCPU& codes,
72 const TensorCPU* gradient,
73 TensorCPU* outDecoded,
75 const static std::map<
76 std::pair<CaffeTypeId, CaffeTypeId>,
78 const TensorCPU& codebook,
79 const TensorCPU& codes,
80 const TensorCPU* gradient,
81 TensorCPU* outDecoded,
83 gDecoderMapper = {REGISTER_DECODER(
float, uint8_t),
84 REGISTER_DECODER(
float, uint16_t),
85 REGISTER_DECODER(
float, int32_t)};
87 gDecoderMapper.at({codebook.meta().id(), codes.meta().id()})(
88 codebook, codes, gradient, outDecoded, resizeOnly);
96 enum class QuantDecodeRunTy {
101 template <QuantDecodeRunTy QuantDecodeRun>
110 bool RunOnDevice()
override {
111 CAFFE_ENFORCE_GT(InputSize(), 1);
113 CAFFE_ENFORCE_EQ(InputSize(), OutputSize() + 1);
115 const auto& codebook = Input(0);
116 CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.meta().name());
118 for (
int i = 0; i < OutputSize(); i++) {
119 auto& ci = Input(i + 1);
120 auto* co = Output(i);
127 QuantDecodeRun == QuantDecodeRunTy::RUN_ONCE &&
145 bool RunOnDevice()
override {
147 CAFFE_ENFORCE(InputSize() >= 3 && InputSize() % 2 == 1);
148 const int num_code_tensors = (InputSize() - 1) / 2;
149 CAFFE_ENFORCE_EQ(OutputSize(), 1);
151 const auto& codebook = Input(0);
152 CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.meta().name());
154 auto* gradient = Output(0);
155 gradient->ResizeLike(codebook);
156 auto* gradient_ptr = gradient->mutable_data<
float>();
157 std::fill(gradient_ptr, gradient_ptr + gradient->size(), 0);
159 for (
int i = 0; i < num_code_tensors; i++) {
160 auto& codes_i = Input(i + 1);
161 auto& output_gradient_i = Input(i + num_code_tensors + 1);
162 DecodeGeneral(codebook, codes_i, &output_gradient_i, gradient,
false);
169 #endif // QUANT_DECODE_OP_H_
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...