Caffe2 - C++ API
A deep learning, cross platform ML framework
qtensor_serialization.h
1 #ifndef CAFFE2_CORE_QTENSOR_SERIALIZATION_H_
2 #define CAFFE2_CORE_QTENSOR_SERIALIZATION_H_
3 
4 #include "caffe2/core/blob_serialization.h"
5 #include "caffe2/core/qtensor.h"
6 
7 namespace caffe2 {
8 
9 constexpr auto kQTensorBlobQType = "QTensor";
10 
11 template <class Context>
13  public:
14  QTensorSerializer() : context_() {}
15  ~QTensorSerializer() {}
19  void Serialize(
20  const Blob& blob,
21  const string& name,
22  SerializationAcceptor acceptor) override;
23 
24  private:
25  Context context_;
26 };
27 
28 template <class Context>
30  public:
31  void Deserialize(const BlobProto& proto, Blob* blob) override;
32  void Deserialize(const QTensorProto& proto, QTensor<Context>* tensor);
33 };
34 
35 template <class Context>
37  const Blob& blob,
38  const string& name,
39  BlobSerializerBase::SerializationAcceptor acceptor) {
40  const auto& qtensor = blob.template Get<QTensor<Context>>();
41  BlobProto blob_proto;
42  blob_proto.set_name(name);
43  blob_proto.set_type(kQTensorBlobQType);
44  QTensorProto& proto = *blob_proto.mutable_qtensor();
45  proto.set_name(name);
46  for (int i = 0; i < qtensor.ndim(); ++i) {
47  proto.add_dims(qtensor.dim32(i));
48  }
49  proto.set_precision(qtensor.precision());
50  proto.set_scale(qtensor.scale());
51  proto.set_bias(qtensor.bias());
52  proto.set_is_signed(qtensor.is_signed());
53  detail::CopyToProtoWithCast(
54  qtensor.nbytes(), qtensor.data(), proto.mutable_data(), &this->context_);
55  acceptor(name, blob_proto.SerializeAsString());
56 }
57 
58 template <class Context>
60  const BlobProto& blob_proto,
61  Blob* blob) {
62  Deserialize(blob_proto.qtensor(), blob->GetMutable<QTensor<Context>>());
63 }
64 
65 template <class Context>
67  const QTensorProto& proto,
68  QTensor<Context>* qtensor) {
69  Context context{};
70  vector<int> dims;
71  for (const int d : proto.dims()) {
72  dims.push_back(d);
73  }
74  qtensor->Resize(dims);
75  qtensor->SetPrecision(proto.precision());
76  qtensor->SetScale(proto.scale());
77  qtensor->SetBias(proto.bias());
78  qtensor->SetSigned(proto.is_signed());
79 
80  detail::CopyFromProtoWithCast(
81  qtensor->nbytes(), proto.data(), qtensor->mutable_data(), &context);
82 }
83 
84 } // namespace caffe2
85 
86 #endif // CAFFE2_CORE_QTENSOR_SERIALIZATION_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
T * GetMutable(bool *is_new_object=nullptr)
Gets a mutable pointer to the stored object.
Definition: blob.h:101
void Serialize(const Blob &blob, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
BlobSerializerBase is an abstract class that serializes a blob to a string.