Caffe2 - C++ API
A deep learning, cross platform ML framework
mklmemory_serialization.cc
1 #include "caffe2/core/blob.h"
2 #include "caffe2/core/blob_serialization.h"
3 #include "caffe2/mkl/mkl_utils.h"
4 
5 #ifdef CAFFE2_HAS_MKL_DNN
6 
7 namespace caffe2 {
8 namespace mkl {
15 class MKLMemorySerializer : public BlobSerializerBase {
16  public:
17  MKLMemorySerializer() {}
18  ~MKLMemorySerializer() {}
19 
20  void Serialize(
21  const Blob& blob,
22  const string& name,
23  SerializationAcceptor acceptor) override {
24  BlobProto blob_proto;
25  blob_proto.set_name(name);
26  blob_proto.set_type(kTensorBlobType);
27  TensorProto* proto = blob_proto.mutable_tensor();
28  auto* device_detail = proto->mutable_device_detail();
29  device_detail->set_device_type(MKLDNN);
30  proto->set_name(name);
31  if (blob.IsType<MKLMemory<float>>()) {
32  const MKLMemory<float>& src = blob.Get<MKLMemory<float>>();
33  CAFFE_ENFORCE(
34  src.buffer(), "Cannot serialize an empty MKLMemory object.");
35  size_t total = 1;
36  for (int i = 0; i < src.dims().size(); ++i) {
37  proto->add_dims(src.dims()[i]);
38  total *= src.dims()[i];
39  }
40  proto->mutable_float_data()->Reserve(total);
41  while (total--) {
42  proto->add_float_data(0);
43  }
44  src.CopyTo(proto->mutable_float_data()->mutable_data());
45  } else if (blob.IsType<MKLMemory<double>>()) {
46  const MKLMemory<double>& src = blob.Get<MKLMemory<double>>();
47  CAFFE_ENFORCE(
48  src.buffer(), "Cannot serialize an empty MKLMemory object.");
49  size_t total = 1;
50  for (int i = 0; i < src.dims().size(); ++i) {
51  proto->add_dims(src.dims()[i]);
52  total *= src.dims()[i];
53  }
54  proto->mutable_double_data()->Reserve(total);
55  while (total--) {
56  proto->add_double_data(0);
57  }
58  src.CopyTo(proto->mutable_double_data()->mutable_data());
59  } else {
60  CAFFE_THROW(
61  "MKLMemory could only be either float or double. "
62  "Encountered unsupported type.");
63  }
64  acceptor(name, blob_proto.SerializeAsString());
65  }
66 };
67 
77 class MKLMemoryDeserializer : public BlobDeserializerBase {
78  public:
79  void Deserialize(const BlobProto& blob_proto, Blob* blob) override {
80  const TensorProto& proto = blob_proto.tensor();
81  CAFFE_ENFORCE(
82  proto.data_type() == TensorProto_DataType_FLOAT ||
83  proto.data_type() == TensorProto_DataType_DOUBLE,
84  "MKLMemory only supports either float or double formats.");
85  CAFFE_ENFORCE(
86  !proto.has_segment(), "MKLMemory does not support segment right now.");
87  vector<TIndex> dims;
88  for (const TIndex d : proto.dims()) {
89  dims.push_back(d);
90  }
91  // TODO: right now, every time we do a deserializer we create a new MKL
92  // Memory object. Optionally, we can change that.
93  switch (proto.data_type()) {
94  case TensorProto_DataType_FLOAT: {
95  auto dst = make_unique<MKLMemory<float>>(dims);
96  dst->CopyFrom(proto.float_data().data());
97  blob->Reset(dst.release());
98  break;
99  }
100  case TensorProto_DataType_DOUBLE: {
101  auto dst = make_unique<MKLMemory<double>>(dims);
102  dst->CopyFrom(proto.double_data().data());
103  blob->Reset(dst.release());
104  break;
105  }
106  default:
107  CAFFE_THROW("This should not happen, we guarded things above already.");
108  }
109  }
110 };
111 
112 } // namespace mkl
113 
114 REGISTER_BLOB_SERIALIZER(
115  (TypeMeta::Id<mkl::MKLMemory<float>>()),
116  mkl::MKLMemorySerializer);
117 REGISTER_BLOB_SERIALIZER(
118  (TypeMeta::Id<mkl::MKLMemory<double>>()),
119  mkl::MKLMemorySerializer);
120 REGISTER_BLOB_DESERIALIZER(TensorMKLDNN, mkl::MKLMemoryDeserializer);
121 } // namespace caffe2
122 
123 #endif // CAFFE2_HAS_MKL_DNN
static CAFFE2_API CaffeTypeId Id()
Returns the unique id for the given type T.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...