1 #ifndef CAFFE2_CORE_WORKSPACE_H_ 2 #define CAFFE2_CORE_WORKSPACE_H_ 4 #include "caffe2/core/common.h" 5 #include "caffe2/core/observer.h" 11 #include <unordered_set> 14 #include "caffe2/core/blob.h" 15 #include "caffe2/core/registry.h" 16 #include "caffe2/core/net.h" 17 #include "caffe2/proto/caffe2.pb.h" 18 #include "caffe2/utils/signal_handler.h" 19 #include "caffe2/utils/threadpool/ThreadPool.h" 21 CAFFE2_DECLARE_bool(caffe2_print_blob_sizes_at_exit);
29 : handler_(std::make_shared<SignalHandler>(
30 SignalHandler::Action::STOP,
31 SignalHandler::Action::STOP)) {}
35 bool operator()(
int ) {
36 return handler_->CheckForSignals() != SignalHandler::Action::STOP;
39 std::shared_ptr<SignalHandler> handler_;
49 typedef std::function<bool(int)> ShouldContinue;
50 typedef CaffeMap<string, unique_ptr<Blob> > BlobMap;
51 typedef CaffeMap<string, unique_ptr<NetBase> > NetMap;
55 Workspace() : root_folder_(
"."), shared_(nullptr) {}
65 : root_folder_(root_folder), shared_(nullptr) {}
77 : root_folder_(
"."), shared_(shared) {}
86 const std::unordered_map<string, string>& forwarded_blobs)
87 : root_folder_(
"."), shared_(nullptr) {
88 CAFFE_ENFORCE(shared,
"Parent workspace must be specified");
89 for (
const auto& forwarded : forwarded_blobs) {
91 shared->
HasBlob(forwarded.second),
"Invalid parent workspace blob");
92 forwarded_blobs_[forwarded.first] =
93 std::make_pair(shared, forwarded.second);
101 : root_folder_(root_folder), shared_(shared) {}
104 if (FLAGS_caffe2_print_blob_sizes_at_exit) {
121 const std::unordered_map<string, string>& forwarded_blobs,
122 bool skip_defined_blobs =
false);
128 template <
class Context>
130 for (
const auto& blob : blobs) {
131 if (!forwarded_blobs_.count(blob)) {
134 const auto& ws_blob = forwarded_blobs_[blob];
135 const auto* parent_ws = ws_blob.first;
136 auto* from_blob = parent_ws->GetBlob(ws_blob.second);
137 CAFFE_ENFORCE(from_blob);
140 "Expected blob with tensor value",
142 forwarded_blobs_.erase(blob);
143 auto* to_blob = CreateBlob(blob);
144 CAFFE_ENFORCE(to_blob);
145 const auto& from_tensor = from_blob->template Get<Tensor<Context>>();
146 auto* to_tensor = to_blob->template GetMutable<Tensor<Context>>();
147 to_tensor->CopyFrom(from_tensor);
155 vector<string> LocalBlobs()
const;
162 vector<string> Blobs()
const;
171 inline bool HasBlob(
const string& name)
const {
174 if (blob_map_.count(name)) {
176 }
else if (forwarded_blobs_.count(name)) {
177 const auto parent_ws = forwarded_blobs_.at(name).first;
178 const auto& parent_name = forwarded_blobs_.at(name).second;
179 return parent_ws->HasBlob(parent_name);
180 }
else if (shared_) {
181 return shared_->HasBlob(name);
186 void PrintBlobSizes();
193 Blob* CreateBlob(
const string& name);
201 Blob* CreateLocalBlob(
const string& name);
207 bool RemoveBlob(
const string& name);
212 const Blob* GetBlob(
const string& name)
const;
217 Blob* GetBlob(
const string& name);
224 Blob* RenameBlob(
const string& old_name,
const string& new_name);
237 const std::shared_ptr<const NetDef>& net_def,
238 bool overwrite =
false);
243 NetBase* GetNet(
const string& net_name);
247 void DeleteNet(
const string& net_name);
253 bool RunNet(
const string& net_name);
259 vector<string> names;
260 for (
auto& entry : net_map_) {
261 names.push_back(entry.first);
269 bool RunPlan(
const PlanDef& plan_def,
284 bool RunOperatorOnce(
const OperatorDef& op_def);
285 bool RunNetOnce(
const NetDef& net_def);
288 std::atomic<int> last_failed_op_net_position;
293 const string root_folder_;
295 std::unordered_map<string, std::pair<const Workspace*, string>>
297 std::unique_ptr<ThreadPool> thread_pool_;
298 std::mutex thread_pool_creation_mutex_;
305 #endif // CAFFE2_CORE_WORKSPACE_H_ const string & RootFolder()
Return the root folder of the workspace.
Blob is a general container that hosts a typed pointer.
Workspace(const Workspace *shared, const std::unordered_map< string, string > &forwarded_blobs)
Initializes workspace with parent workspace, blob name remapping (new name -> parent blob name)...
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace(const Workspace *shared)
Initializes a workspace with a shared workspace.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Workspace(const string &root_folder)
Initializes an empty workspace with the given root folder.
bool HasBlob(const string &name) const
Checks if a blob with the given name is present in the current workspace.
void CopyForwardedTensors(const std::unordered_set< std::string > &blobs)
Converts prevously mapped tensor blobs to local blobs, copies values from parent workspace blobs into...
vector< string > Nets() const
Returns a list of names of the currently instantiated networks.
Workspace()
Initializes an empty workspace.
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Workspace(const string &root_folder, Workspace *shared)
Initializes a workspace with a root folder and a shared workspace.