Caffe2 - C++ API
A deep learning, cross platform ML framework
workspace.h
1 #ifndef CAFFE2_CORE_WORKSPACE_H_
2 #define CAFFE2_CORE_WORKSPACE_H_
3 
4 #include "caffe2/core/common.h"
5 #include "caffe2/core/observer.h"
6 
7 #include <climits>
8 #include <cstddef>
9 #include <mutex>
10 #include <typeinfo>
11 #include <unordered_set>
12 #include <vector>
13 
14 #include "caffe2/core/blob.h"
15 #include "caffe2/core/registry.h"
16 #include "caffe2/core/net.h"
17 #include "caffe2/proto/caffe2.pb.h"
18 #include "caffe2/utils/signal_handler.h"
19 #include "caffe2/utils/threadpool/ThreadPool.h"
20 
21 CAFFE2_DECLARE_bool(caffe2_print_blob_sizes_at_exit);
22 
23 namespace caffe2 {
24 
25 class NetBase;
26 
27 struct StopOnSignal {
28  StopOnSignal()
29  : handler_(std::make_shared<SignalHandler>(
30  SignalHandler::Action::STOP,
31  SignalHandler::Action::STOP)) {}
32 
33  StopOnSignal(const StopOnSignal& other) : handler_(other.handler_) {}
34 
35  bool operator()(int /*iter*/) {
36  return handler_->CheckForSignals() != SignalHandler::Action::STOP;
37  }
38 
39  std::shared_ptr<SignalHandler> handler_;
40 };
41 
47 class Workspace {
48  public:
49  typedef std::function<bool(int)> ShouldContinue;
50  typedef CaffeMap<string, unique_ptr<Blob> > BlobMap;
51  typedef CaffeMap<string, unique_ptr<NetBase> > NetMap;
55  Workspace() : root_folder_("."), shared_(nullptr) {}
56 
64  explicit Workspace(const string& root_folder)
65  : root_folder_(root_folder), shared_(nullptr) {}
66 
76  explicit Workspace(const Workspace* shared)
77  : root_folder_("."), shared_(shared) {}
78 
85  const Workspace* shared,
86  const std::unordered_map<string, string>& forwarded_blobs)
87  : root_folder_("."), shared_(nullptr) {
88  CAFFE_ENFORCE(shared, "Parent workspace must be specified");
89  for (const auto& forwarded : forwarded_blobs) {
90  CAFFE_ENFORCE(
91  shared->HasBlob(forwarded.second), "Invalid parent workspace blob");
92  forwarded_blobs_[forwarded.first] =
93  std::make_pair(shared, forwarded.second);
94  }
95  }
96 
100  Workspace(const string& root_folder, Workspace* shared)
101  : root_folder_(root_folder), shared_(shared) {}
102 
103  ~Workspace() {
104  if (FLAGS_caffe2_print_blob_sizes_at_exit) {
105  PrintBlobSizes();
106  }
107  }
108 
119  void AddBlobMapping(
120  const Workspace* parent,
121  const std::unordered_map<string, string>& forwarded_blobs,
122  bool skip_defined_blobs = false);
123 
128  template <class Context>
129  void CopyForwardedTensors(const std::unordered_set<std::string>& blobs) {
130  for (const auto& blob : blobs) {
131  if (!forwarded_blobs_.count(blob)) {
132  continue;
133  }
134  const auto& ws_blob = forwarded_blobs_[blob];
135  const auto* parent_ws = ws_blob.first;
136  auto* from_blob = parent_ws->GetBlob(ws_blob.second);
137  CAFFE_ENFORCE(from_blob);
138  CAFFE_ENFORCE(
139  from_blob->template IsType<Tensor<Context>>(),
140  "Expected blob with tensor value",
141  ws_blob.second);
142  forwarded_blobs_.erase(blob);
143  auto* to_blob = CreateBlob(blob);
144  CAFFE_ENFORCE(to_blob);
145  const auto& from_tensor = from_blob->template Get<Tensor<Context>>();
146  auto* to_tensor = to_blob->template GetMutable<Tensor<Context>>();
147  to_tensor->CopyFrom(from_tensor);
148  }
149  }
150 
155  vector<string> LocalBlobs() const;
156 
162  vector<string> Blobs() const;
163 
167  const string& RootFolder() { return root_folder_; }
171  inline bool HasBlob(const string& name) const {
172  // First, check the local workspace,
173  // Then, check the forwarding map, then the parent workspace
174  if (blob_map_.count(name)) {
175  return true;
176  } else if (forwarded_blobs_.count(name)) {
177  const auto parent_ws = forwarded_blobs_.at(name).first;
178  const auto& parent_name = forwarded_blobs_.at(name).second;
179  return parent_ws->HasBlob(parent_name);
180  } else if (shared_) {
181  return shared_->HasBlob(name);
182  }
183  return false;
184  }
185 
186  void PrintBlobSizes();
187 
193  Blob* CreateBlob(const string& name);
201  Blob* CreateLocalBlob(const string& name);
207  bool RemoveBlob(const string& name);
212  const Blob* GetBlob(const string& name) const;
217  Blob* GetBlob(const string& name);
218 
224  Blob* RenameBlob(const string& old_name, const string& new_name);
225 
235  NetBase* CreateNet(const NetDef& net_def, bool overwrite = false);
237  const std::shared_ptr<const NetDef>& net_def,
238  bool overwrite = false);
243  NetBase* GetNet(const string& net_name);
247  void DeleteNet(const string& net_name);
253  bool RunNet(const string& net_name);
254 
258  vector<string> Nets() const {
259  vector<string> names;
260  for (auto& entry : net_map_) {
261  names.push_back(entry.first);
262  }
263  return names;
264  }
265 
269  bool RunPlan(const PlanDef& plan_def,
270  ShouldContinue should_continue = StopOnSignal{});
271 
272  /*
273  * Returns a CPU threadpool instace for parallel execution of
274  * work. The threadpool is created lazily; if no operators use it,
275  * then no threadpool will be created.
276  */
277  ThreadPool* GetThreadPool();
278 
279  // RunOperatorOnce and RunNetOnce runs an operator or net once. The difference
280  // between RunNet and RunNetOnce lies in the fact that RunNet allows you to
281  // have a persistent net object, while RunNetOnce creates a net and discards
282  // it on the fly - this may make things like database read and random number
283  // generators repeat the same thing over multiple calls.
284  bool RunOperatorOnce(const OperatorDef& op_def);
285  bool RunNetOnce(const NetDef& net_def);
286 
287  public:
288  std::atomic<int> last_failed_op_net_position;
289 
290  private:
291  BlobMap blob_map_;
292  NetMap net_map_;
293  const string root_folder_;
294  const Workspace* shared_;
295  std::unordered_map<string, std::pair<const Workspace*, string>>
296  forwarded_blobs_;
297  std::unique_ptr<ThreadPool> thread_pool_;
298  std::mutex thread_pool_creation_mutex_;
299 
300  DISABLE_COPY_AND_ASSIGN(Workspace);
301 };
302 
303 } // namespace caffe2
304 
305 #endif // CAFFE2_CORE_WORKSPACE_H_
const string & RootFolder()
Return the root folder of the workspace.
Definition: workspace.h:167
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
Workspace(const Workspace *shared, const std::unordered_map< string, string > &forwarded_blobs)
Initializes workspace with parent workspace, blob name remapping (new name -> parent blob name)...
Definition: workspace.h:84
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace(const Workspace *shared)
Initializes a workspace with a shared workspace.
Definition: workspace.h:76
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Workspace(const string &root_folder)
Initializes an empty workspace with the given root folder.
Definition: workspace.h:64
bool HasBlob(const string &name) const
Checks if a blob with the given name is present in the current workspace.
Definition: workspace.h:171
void CopyForwardedTensors(const std::unordered_set< std::string > &blobs)
Converts prevously mapped tensor blobs to local blobs, copies values from parent workspace blobs into...
Definition: workspace.h:129
vector< string > Nets() const
Returns a list of names of the currently instantiated networks.
Definition: workspace.h:258
Workspace()
Initializes an empty workspace.
Definition: workspace.h:55
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:103
Workspace(const string &root_folder, Workspace *shared)
Initializes a workspace with a root folder and a shared workspace.
Definition: workspace.h:100