Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor.cc
1 #include "caffe2/core/tensor.h"
2 
3 #include "caffe2/core/blob_stats.h"
4 #include "caffe2/core/flags.h"
5 
6 CAFFE2_DEFINE_bool(
7  caffe2_keep_on_shrink,
8  true,
9  "If set, keeps memory when a tensor is shrinking its size.");
10 
11 CAFFE2_DEFINE_int64(
12  caffe2_max_keep_on_shrink_memory,
13  LLONG_MAX,
14  "The maximum memory in bytes to keep on shrink, if the difference between "
15  "tensor sizes is bigger than this then tensor will be reset.");
16 
17 namespace caffe2 {
18 // declaring it here instead of context.cc because tensor.h includes context.h
19 CAFFE_KNOWN_TYPE(Tensor<CPUContext>);
20 
21 TensorPrinter::TensorPrinter(
22  const std::string& tensor_name,
23  const std::string& file_name,
24  int limit)
25  : to_file_(!file_name.empty()),
26  limit_(limit ? limit : k_limit_default_),
27  tensor_name_(tensor_name) {
28  if (to_file_) {
29  // We will output to file instead of printing on screen.
30  // We will write each individual tensor to its individual file.
31  log_file_.reset(new std::ofstream(
32  file_name, std::ofstream::out | std::ofstream::trunc));
33  CAFFE_ENFORCE(
34  log_file_->good(),
35  "Failed to open TensorPrinter file ",
36  file_name,
37  ". rdstate() = ",
38  log_file_->rdstate());
39  }
40 }
41 
42 TensorPrinter::~TensorPrinter() {
43  if (log_file_.get()) {
44  log_file_->close();
45  }
46 }
47 
48 static CaffeMap<CaffeTypeId, TypeCall> type_call_registry_ {
49  {TypeMeta::Id<Tensor<CPUContext>>(), GetTensorType<CPUContext>}
50 };
51 
52 TypeCall GetTypeCallFunction(CaffeTypeId id) {
53  auto f = type_call_registry_.find(id);
54  if (f == type_call_registry_.end()) {
55  return nullptr;
56  }
57  return f->second;
58 }
59 
60 void RegisterTypeCallFunction(CaffeTypeId id, TypeCall c) {
61  type_call_registry_[id] = c;
62 }
63 
64 static CaffeMap<CaffeTypeId, TensorInfoCall> tensor_info_call_registry_{
65  {TypeMeta::Id<Tensor<CPUContext>>(), GetTensorInfo<CPUContext>}};
66 
67 TensorInfoCall GetTensorInfoFunction(CaffeTypeId id) {
68  auto f = tensor_info_call_registry_.find(id);
69  if (f == tensor_info_call_registry_.end()) {
70  return nullptr;
71  }
72  return f->second;
73 }
74 
75 void RegisterTensorInfoFunction(CaffeTypeId id, TensorInfoCall c) {
76  tensor_info_call_registry_[id] = c;
77 }
78 
79 namespace {
80 
81 struct TensorCPUStatGetter : BlobStatGetter {
82  size_t sizeBytes(const Blob& blob) const override {
83  const auto& tensor = blob.Get<TensorCPU>();
84  auto nbytes = tensor.nbytes();
85  if (nbytes > 0 && tensor.IsType<std::string>()) {
86  const auto* data = tensor.data<std::string>();
87  for (size_t i = 0; i < tensor.size(); ++i) {
88  nbytes += data[i].size();
89  }
90  }
91  return nbytes;
92  }
93 };
94 REGISTER_BLOB_STAT_GETTER(TensorCPU, TensorCPUStatGetter);
95 }
96 
97 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Commandline flags support for Caffe2.