Caffe2 - C++ API
A deep learning, cross platform ML framework
workspace.cc
1 #include "caffe2/core/workspace.h"
2 
3 #include <algorithm>
4 #include <ctime>
5 #include <mutex>
6 
7 #include "caffe2/core/logging.h"
8 #include "caffe2/core/net.h"
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/plan_executor.h"
11 #include "caffe2/core/tensor.h"
12 #include "caffe2/proto/caffe2.pb.h"
13 
14 CAFFE2_DEFINE_bool(
15  caffe2_print_blob_sizes_at_exit,
16  false,
17  "If true, workspace destructor will print all blob shapes");
18 
19 namespace caffe2 {
20 
21 void Workspace::PrintBlobSizes() {
22  vector<string> blobs = LocalBlobs();
23  size_t cumtotal = 0;
24 
25  // First get total sizes and sort
26  vector<std::pair<size_t, std::string>> blob_sizes;
27  for (const auto& s : blobs) {
28  Blob* b = this->GetBlob(s);
29  TensorInfoCall shape_fun = GetTensorInfoFunction(b->meta().id());
30  if (shape_fun) {
31  bool shares_data = false;
32  size_t capacity;
33  DeviceOption _device;
34  auto shape = shape_fun(b->GetRaw(), &shares_data, &capacity, &_device);
35  if (shares_data) {
36  // Blobs sharing data do not actually take any memory
37  capacity = 0;
38  }
39  cumtotal += capacity;
40  blob_sizes.push_back(make_pair(capacity, s));
41  }
42  }
43  std::sort(
44  blob_sizes.begin(),
45  blob_sizes.end(),
46  [](const std::pair<size_t, std::string>& a,
47  const std::pair<size_t, std::string>& b) {
48  return b.first < a.first;
49  });
50 
51  // Then print in descending order
52  LOG(INFO) << "---- Workspace blobs: ---- ";
53  LOG(INFO) << "name;current shape;capacity bytes;percentage";
54  for (const auto& sb : blob_sizes) {
55  Blob* b = this->GetBlob(sb.second);
56  TensorInfoCall shape_fun = GetTensorInfoFunction(b->meta().id());
57  CHECK(shape_fun != nullptr);
58  bool _shares_data = false;
59  size_t capacity;
60  DeviceOption _device;
61 
62  auto shape = shape_fun(b->GetRaw(), &_shares_data, &capacity, &_device);
63  std::stringstream ss;
64  ss << sb.second << ";";
65  for (const auto d : shape) {
66  ss << d << ",";
67  }
68  LOG(INFO) << ss.str() << ";" << sb.first << ";" << std::setprecision(3)
69  << (cumtotal > 0 ? 100.0 * double(sb.first) / cumtotal : 0.0)
70  << "%";
71  }
72  LOG(INFO) << "Total;;" << cumtotal << ";100%";
73 }
74 
75 vector<string> Workspace::LocalBlobs() const {
76  vector<string> names;
77  names.reserve(blob_map_.size());
78  for (auto& entry : blob_map_) {
79  names.push_back(entry.first);
80  }
81  return names;
82 }
83 
84 vector<string> Workspace::Blobs() const {
85  vector<string> names;
86  names.reserve(blob_map_.size());
87  for (auto& entry : blob_map_) {
88  names.push_back(entry.first);
89  }
90  for (const auto& forwarded : forwarded_blobs_) {
91  const auto parent_ws = forwarded.second.first;
92  const auto& parent_name = forwarded.second.second;
93  if (parent_ws->HasBlob(parent_name)) {
94  names.push_back(forwarded.first);
95  }
96  }
97  if (shared_) {
98  const auto& shared_blobs = shared_->Blobs();
99  names.insert(names.end(), shared_blobs.begin(), shared_blobs.end());
100  }
101  return names;
102 }
103 
104 Blob* Workspace::CreateBlob(const string& name) {
105  if (HasBlob(name)) {
106  VLOG(1) << "Blob " << name << " already exists. Skipping.";
107  } else if (forwarded_blobs_.count(name)) {
108  // possible if parent workspace deletes forwarded blob
109  VLOG(1) << "Blob " << name << " is already forwarded from parent workspace "
110  << "(blob " << forwarded_blobs_[name].second << "). Skipping.";
111  } else {
112  VLOG(1) << "Creating blob " << name;
113  blob_map_[name] = unique_ptr<Blob>(new Blob());
114  }
115  return GetBlob(name);
116 }
117 
118 Blob* Workspace::CreateLocalBlob(const string& name) {
119  if (blob_map_.count(name)) {
120  VLOG(1) << "Blob " << name << " already exists. Skipping.";
121  } else {
122  VLOG(1) << "Creating blob " << name;
123  blob_map_[name] = unique_ptr<Blob>(new Blob());
124  }
125  return GetBlob(name);
126 }
127 
128 Blob* Workspace::RenameBlob(const string& old_name, const string& new_name) {
129  // We allow renaming only local blobs for API clarity purpose
130  auto it = blob_map_.find(old_name);
131  CAFFE_ENFORCE(
132  it != blob_map_.end(),
133  "Blob ",
134  old_name,
135  " is not in the local blob list");
136 
137  // New blob can't be in any parent either, otherwise it will hide a parent
138  // blob
139  CAFFE_ENFORCE(
140  !HasBlob(new_name), "Blob ", new_name, "is already in the workspace");
141 
142  // First delete the old record
143  auto value = std::move(it->second);
144  blob_map_.erase(it);
145 
146  auto* raw_ptr = value.get();
147  blob_map_[new_name] = std::move(value);
148  return raw_ptr;
149 }
150 
151 bool Workspace::RemoveBlob(const string& name) {
152  auto it = blob_map_.find(name);
153  if (it != blob_map_.end()) {
154  VLOG(1) << "Removing blob " << name << " from this workspace.";
155  blob_map_.erase(it);
156  return true;
157  }
158 
159  // won't go into shared_ here
160  VLOG(1) << "Blob " << name << " not exists. Skipping.";
161  return false;
162 }
163 
164 const Blob* Workspace::GetBlob(const string& name) const {
165  if (blob_map_.count(name)) {
166  return blob_map_.at(name).get();
167  } else if (forwarded_blobs_.count(name)) {
168  const auto parent_ws = forwarded_blobs_.at(name).first;
169  const auto& parent_name = forwarded_blobs_.at(name).second;
170  return parent_ws->GetBlob(parent_name);
171  } else if (shared_ && shared_->HasBlob(name)) {
172  return shared_->GetBlob(name);
173  }
174  LOG(WARNING) << "Blob " << name << " not in the workspace.";
175  // TODO(Yangqing): do we want to always print out the list of blobs here?
176  // LOG(WARNING) << "Current blobs:";
177  // for (const auto& entry : blob_map_) {
178  // LOG(WARNING) << entry.first;
179  // }
180  return nullptr;
181 }
182 
184  const Workspace* parent,
185  const std::unordered_map<string, string>& forwarded_blobs,
186  bool skip_defined_blobs) {
187  CAFFE_ENFORCE(parent, "Parent workspace must be specified");
188  for (const auto& forwarded : forwarded_blobs) {
189  CAFFE_ENFORCE(
190  parent->HasBlob(forwarded.second),
191  "Invalid parent workspace blob " + forwarded.second);
192  if (forwarded_blobs_.count(forwarded.first)) {
193  const auto& ws_blob = forwarded_blobs_[forwarded.first];
194  CAFFE_ENFORCE_EQ(
195  ws_blob.first, parent, "Redefinition of blob " + forwarded.first);
196  CAFFE_ENFORCE_EQ(
197  ws_blob.second,
198  forwarded.second,
199  "Redefinition of blob " + forwarded.first);
200  } else {
201  if (skip_defined_blobs && HasBlob(forwarded.first)) {
202  continue;
203  }
204  CAFFE_ENFORCE(
205  !HasBlob(forwarded.first), "Redefinition of blob " + forwarded.first);
206  // Lazy blob resolution - store the parent workspace and
207  // blob name, blob value might change in the parent workspace
208  forwarded_blobs_[forwarded.first] =
209  std::make_pair(parent, forwarded.second);
210  }
211  }
212 }
213 
214 Blob* Workspace::GetBlob(const string& name) {
215  return const_cast<Blob*>(static_cast<const Workspace*>(this)->GetBlob(name));
216 }
217 
218 NetBase* Workspace::CreateNet(const NetDef& net_def, bool overwrite) {
219  std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
220  return CreateNet(tmp_net_def, overwrite);
221 }
222 
224  const std::shared_ptr<const NetDef>& net_def,
225  bool overwrite) {
226  CAFFE_ENFORCE(net_def->has_name(), "Net definition should have a name.");
227  if (net_map_.count(net_def->name()) > 0) {
228  if (!overwrite) {
229  CAFFE_THROW(
230  "I respectfully refuse to overwrite an existing net of the same "
231  "name \"",
232  net_def->name(),
233  "\", unless you explicitly specify overwrite=true.");
234  }
235  VLOG(1) << "Deleting existing network of the same name.";
236  // Note(Yangqing): Why do we explicitly erase it here? Some components of
237  // the old network, such as an opened LevelDB, may prevent us from creating
238  // a new network before the old one is deleted. Thus we will need to first
239  // erase the old one before the new one can be constructed.
240  net_map_.erase(net_def->name());
241  }
242  // Create a new net with its name.
243  VLOG(1) << "Initializing network " << net_def->name();
244  net_map_[net_def->name()] =
245  unique_ptr<NetBase>(caffe2::CreateNet(net_def, this));
246  if (net_map_[net_def->name()].get() == nullptr) {
247  LOG(ERROR) << "Error when creating the network."
248  << "Maybe net type: [" << net_def->type() << "] does not exist";
249  net_map_.erase(net_def->name());
250  return nullptr;
251  }
252  return net_map_[net_def->name()].get();
253 }
254 
255 NetBase* Workspace::GetNet(const string& name) {
256  if (!net_map_.count(name)) {
257  return nullptr;
258  } else {
259  return net_map_[name].get();
260  }
261 }
262 
263 void Workspace::DeleteNet(const string& name) {
264  if (net_map_.count(name)) {
265  net_map_.erase(name);
266  }
267 }
268 
269 bool Workspace::RunNet(const string& name) {
270  if (!net_map_.count(name)) {
271  LOG(ERROR) << "Network " << name << " does not exist yet.";
272  return false;
273  }
274  return net_map_[name]->Run();
275 }
276 
277 bool Workspace::RunOperatorOnce(const OperatorDef& op_def) {
278  std::unique_ptr<OperatorBase> op(CreateOperator(op_def, this));
279  if (op.get() == nullptr) {
280  LOG(ERROR) << "Cannot create operator of type " << op_def.type();
281  return false;
282  }
283  if (!op->Run()) {
284  LOG(ERROR) << "Error when running operator " << op_def.type();
285  return false;
286  }
287  return true;
288 }
289 bool Workspace::RunNetOnce(const NetDef& net_def) {
290  std::unique_ptr<NetBase> net(caffe2::CreateNet(net_def, this));
291  if (net == nullptr) {
292  CAFFE_THROW(
293  "Could not create net: " + net_def.name() + " of type " +
294  net_def.type());
295  }
296  if (!net->Run()) {
297  LOG(ERROR) << "Error when running network " << net_def.name();
298  return false;
299  }
300  return true;
301 }
302 
303 bool Workspace::RunPlan(const PlanDef& plan, ShouldContinue shouldContinue) {
304  return RunPlanOnWorkspace(this, plan, shouldContinue);
305 }
306 
307 ThreadPool* Workspace::GetThreadPool() {
308  std::lock_guard<std::mutex> guard(thread_pool_creation_mutex_);
309  if (!thread_pool_) {
310  thread_pool_ = ThreadPool::defaultThreadPool();
311  }
312  return thread_pool_.get();
313 }
314 
315 } // namespace caffe2
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:104
void DeleteNet(const string &net_name)
Deletes the instantiated network with the given name.
Definition: workspace.cc:263
bool RunPlan(const PlanDef &plan_def, ShouldContinue should_continue=StopOnSignal{})
Runs a plan that has multiple nets and execution steps.
Definition: workspace.cc:303
Blob * CreateLocalBlob(const string &name)
Similar to CreateBlob(), but it creates a blob in the local workspace even if another blob with the s...
Definition: workspace.cc:118
bool RemoveBlob(const string &name)
Remove the blob of the given name.
Definition: workspace.cc:151
vector< string > LocalBlobs() const
Return list of blobs owned by this Workspace, not including blobs shared from parent workspace...
Definition: workspace.cc:75
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:164
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasBlob(const string &name) const
Checks if a blob with the given name is present in the current workspace.
Definition: workspace.h:171
void AddBlobMapping(const Workspace *parent, const std::unordered_map< string, string > &forwarded_blobs, bool skip_defined_blobs=false)
Adds blob mappings from workspace to the blobs from parent workspace.
Definition: workspace.cc:183
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:103
Blob * RenameBlob(const string &old_name, const string &new_name)
Renames a local workspace blob.
Definition: workspace.cc:128
bool RunNet(const string &net_name)
Finds and runs the instantiated network with the given name.
Definition: workspace.cc:269
NetBase * GetNet(const string &net_name)
Gets the pointer to a created net.
Definition: workspace.cc:255
vector< string > Blobs() const
Return a list of blob names.
Definition: workspace.cc:84
NetBase * CreateNet(const NetDef &net_def, bool overwrite=false)
Creates a network with the given NetDef, and returns the pointer to the network.
Definition: workspace.cc:218