1 #include "caffe2/core/blob.h" 2 #include "caffe2/core/blob_serialization.h" 3 #include "caffe2/mkl/mkl_utils.h" 5 #ifdef CAFFE2_HAS_MKL_DNN 15 class MKLMemorySerializer :
public BlobSerializerBase {
17 MKLMemorySerializer() {}
18 ~MKLMemorySerializer() {}
23 SerializationAcceptor acceptor)
override {
25 blob_proto.set_name(name);
26 blob_proto.set_type(kTensorBlobType);
27 TensorProto* proto = blob_proto.mutable_tensor();
28 auto* device_detail = proto->mutable_device_detail();
29 device_detail->set_device_type(MKLDNN);
30 proto->set_name(name);
31 if (blob.IsType<MKLMemory<float>>()) {
32 const MKLMemory<float>& src = blob.Get<MKLMemory<float>>();
34 src.buffer(),
"Cannot serialize an empty MKLMemory object.");
36 for (
int i = 0; i < src.dims().size(); ++i) {
37 proto->add_dims(src.dims()[i]);
38 total *= src.dims()[i];
40 proto->mutable_float_data()->Reserve(total);
42 proto->add_float_data(0);
44 src.CopyTo(proto->mutable_float_data()->mutable_data());
45 }
else if (blob.IsType<MKLMemory<double>>()) {
46 const MKLMemory<double>& src = blob.Get<MKLMemory<double>>();
48 src.buffer(),
"Cannot serialize an empty MKLMemory object.");
50 for (
int i = 0; i < src.dims().size(); ++i) {
51 proto->add_dims(src.dims()[i]);
52 total *= src.dims()[i];
54 proto->mutable_double_data()->Reserve(total);
56 proto->add_double_data(0);
58 src.CopyTo(proto->mutable_double_data()->mutable_data());
61 "MKLMemory could only be either float or double. " 62 "Encountered unsupported type.");
64 acceptor(name, blob_proto.SerializeAsString());
77 class MKLMemoryDeserializer :
public BlobDeserializerBase {
79 void Deserialize(
const BlobProto& blob_proto, Blob* blob)
override {
80 const TensorProto& proto = blob_proto.tensor();
82 proto.data_type() == TensorProto_DataType_FLOAT ||
83 proto.data_type() == TensorProto_DataType_DOUBLE,
84 "MKLMemory only supports either float or double formats.");
86 !proto.has_segment(),
"MKLMemory does not support segment right now.");
88 for (
const TIndex d : proto.dims()) {
93 switch (proto.data_type()) {
94 case TensorProto_DataType_FLOAT: {
95 auto dst = make_unique<MKLMemory<float>>(dims);
96 dst->CopyFrom(proto.float_data().data());
97 blob->Reset(dst.release());
100 case TensorProto_DataType_DOUBLE: {
101 auto dst = make_unique<MKLMemory<double>>(dims);
102 dst->CopyFrom(proto.double_data().data());
103 blob->Reset(dst.release());
107 CAFFE_THROW(
"This should not happen, we guarded things above already.");
114 REGISTER_BLOB_SERIALIZER(
116 mkl::MKLMemorySerializer);
117 REGISTER_BLOB_SERIALIZER(
119 mkl::MKLMemorySerializer);
120 REGISTER_BLOB_DESERIALIZER(TensorMKLDNN, mkl::MKLMemoryDeserializer);
123 #endif // CAFFE2_HAS_MKL_DNN
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...