1 #include "caffe2/core/workspace.h" 7 #include "caffe2/core/logging.h" 8 #include "caffe2/core/net.h" 9 #include "caffe2/core/operator.h" 10 #include "caffe2/core/plan_executor.h" 11 #include "caffe2/core/tensor.h" 12 #include "caffe2/proto/caffe2.pb.h" 15 caffe2_print_blob_sizes_at_exit,
17 "If true, workspace destructor will print all blob shapes");
21 void Workspace::PrintBlobSizes() {
26 vector<std::pair<size_t, std::string>> blob_sizes;
27 for (
const auto& s : blobs) {
29 TensorInfoCall shape_fun = GetTensorInfoFunction(b->meta().id());
31 bool shares_data =
false;
34 auto shape = shape_fun(b->GetRaw(), &shares_data, &capacity, &_device);
40 blob_sizes.push_back(make_pair(capacity, s));
46 [](
const std::pair<size_t, std::string>& a,
47 const std::pair<size_t, std::string>& b) {
48 return b.first < a.first;
52 LOG(INFO) <<
"---- Workspace blobs: ---- ";
53 LOG(INFO) <<
"name;current shape;capacity bytes;percentage";
54 for (
const auto& sb : blob_sizes) {
55 Blob* b = this->
GetBlob(sb.second);
56 TensorInfoCall shape_fun = GetTensorInfoFunction(b->meta().id());
57 CHECK(shape_fun !=
nullptr);
58 bool _shares_data =
false;
62 auto shape = shape_fun(b->GetRaw(), &_shares_data, &capacity, &_device);
64 ss << sb.second <<
";";
65 for (
const auto d : shape) {
68 LOG(INFO) << ss.str() <<
";" << sb.first <<
";" << std::setprecision(3)
69 << (cumtotal > 0 ? 100.0 * double(sb.first) / cumtotal : 0.0)
72 LOG(INFO) <<
"Total;;" << cumtotal <<
";100%";
77 names.reserve(blob_map_.size());
78 for (
auto& entry : blob_map_) {
79 names.push_back(entry.first);
86 names.reserve(blob_map_.size());
87 for (
auto& entry : blob_map_) {
88 names.push_back(entry.first);
90 for (
const auto& forwarded : forwarded_blobs_) {
91 const auto parent_ws = forwarded.second.first;
92 const auto& parent_name = forwarded.second.second;
93 if (parent_ws->HasBlob(parent_name)) {
94 names.push_back(forwarded.first);
98 const auto& shared_blobs = shared_->
Blobs();
99 names.insert(names.end(), shared_blobs.begin(), shared_blobs.end());
106 VLOG(1) <<
"Blob " << name <<
" already exists. Skipping.";
107 }
else if (forwarded_blobs_.count(name)) {
109 VLOG(1) <<
"Blob " << name <<
" is already forwarded from parent workspace " 110 <<
"(blob " << forwarded_blobs_[name].second <<
"). Skipping.";
112 VLOG(1) <<
"Creating blob " << name;
113 blob_map_[name] = unique_ptr<Blob>(
new Blob());
119 if (blob_map_.count(name)) {
120 VLOG(1) <<
"Blob " << name <<
" already exists. Skipping.";
122 VLOG(1) <<
"Creating blob " << name;
123 blob_map_[name] = unique_ptr<Blob>(
new Blob());
130 auto it = blob_map_.find(old_name);
132 it != blob_map_.end(),
135 " is not in the local blob list");
140 !
HasBlob(new_name),
"Blob ", new_name,
"is already in the workspace");
143 auto value = std::move(it->second);
146 auto* raw_ptr = value.get();
147 blob_map_[new_name] = std::move(value);
152 auto it = blob_map_.find(name);
153 if (it != blob_map_.end()) {
154 VLOG(1) <<
"Removing blob " << name <<
" from this workspace.";
160 VLOG(1) <<
"Blob " << name <<
" not exists. Skipping.";
165 if (blob_map_.count(name)) {
166 return blob_map_.at(name).get();
167 }
else if (forwarded_blobs_.count(name)) {
168 const auto parent_ws = forwarded_blobs_.at(name).first;
169 const auto& parent_name = forwarded_blobs_.at(name).second;
170 return parent_ws->GetBlob(parent_name);
171 }
else if (shared_ && shared_->
HasBlob(name)) {
174 LOG(WARNING) <<
"Blob " << name <<
" not in the workspace.";
185 const std::unordered_map<string, string>& forwarded_blobs,
186 bool skip_defined_blobs) {
187 CAFFE_ENFORCE(parent,
"Parent workspace must be specified");
188 for (
const auto& forwarded : forwarded_blobs) {
190 parent->
HasBlob(forwarded.second),
191 "Invalid parent workspace blob " + forwarded.second);
192 if (forwarded_blobs_.count(forwarded.first)) {
193 const auto& ws_blob = forwarded_blobs_[forwarded.first];
195 ws_blob.first, parent,
"Redefinition of blob " + forwarded.first);
199 "Redefinition of blob " + forwarded.first);
201 if (skip_defined_blobs &&
HasBlob(forwarded.first)) {
205 !
HasBlob(forwarded.first),
"Redefinition of blob " + forwarded.first);
208 forwarded_blobs_[forwarded.first] =
209 std::make_pair(parent, forwarded.second);
219 std::shared_ptr<NetDef> tmp_net_def(
new NetDef(net_def));
220 return CreateNet(tmp_net_def, overwrite);
224 const std::shared_ptr<const NetDef>& net_def,
226 CAFFE_ENFORCE(net_def->has_name(),
"Net definition should have a name.");
227 if (net_map_.count(net_def->name()) > 0) {
230 "I respectfully refuse to overwrite an existing net of the same " 233 "\", unless you explicitly specify overwrite=true.");
235 VLOG(1) <<
"Deleting existing network of the same name.";
240 net_map_.erase(net_def->name());
243 VLOG(1) <<
"Initializing network " << net_def->name();
244 net_map_[net_def->name()] =
246 if (net_map_[net_def->name()].get() ==
nullptr) {
247 LOG(ERROR) <<
"Error when creating the network." 248 <<
"Maybe net type: [" << net_def->type() <<
"] does not exist";
249 net_map_.erase(net_def->name());
252 return net_map_[net_def->name()].get();
256 if (!net_map_.count(name)) {
259 return net_map_[name].get();
264 if (net_map_.count(name)) {
265 net_map_.erase(name);
270 if (!net_map_.count(name)) {
271 LOG(ERROR) <<
"Network " << name <<
" does not exist yet.";
274 return net_map_[name]->Run();
277 bool Workspace::RunOperatorOnce(
const OperatorDef& op_def) {
278 std::unique_ptr<OperatorBase> op(CreateOperator(op_def,
this));
279 if (op.get() ==
nullptr) {
280 LOG(ERROR) <<
"Cannot create operator of type " << op_def.type();
284 LOG(ERROR) <<
"Error when running operator " << op_def.type();
289 bool Workspace::RunNetOnce(
const NetDef& net_def) {
291 if (net ==
nullptr) {
293 "Could not create net: " + net_def.name() +
" of type " +
297 LOG(ERROR) <<
"Error when running network " << net_def.name();
304 return RunPlanOnWorkspace(
this, plan, shouldContinue);
308 std::lock_guard<std::mutex> guard(thread_pool_creation_mutex_);
310 thread_pool_ = ThreadPool::defaultThreadPool();
312 return thread_pool_.get();
Blob is a general container that hosts a typed pointer.
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
void DeleteNet(const string &net_name)
Deletes the instantiated network with the given name.
bool RunPlan(const PlanDef &plan_def, ShouldContinue should_continue=StopOnSignal{})
Runs a plan that has multiple nets and execution steps.
Blob * CreateLocalBlob(const string &name)
Similar to CreateBlob(), but it creates a blob in the local workspace even if another blob with the s...
bool RemoveBlob(const string &name)
Remove the blob of the given name.
vector< string > LocalBlobs() const
Return list of blobs owned by this Workspace, not including blobs shared from parent workspace...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasBlob(const string &name) const
Checks if a blob with the given name is present in the current workspace.
void AddBlobMapping(const Workspace *parent, const std::unordered_map< string, string > &forwarded_blobs, bool skip_defined_blobs=false)
Adds blob mappings from workspace to the blobs from parent workspace.
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Blob * RenameBlob(const string &old_name, const string &new_name)
Renames a local workspace blob.
bool RunNet(const string &net_name)
Finds and runs the instantiated network with the given name.
NetBase * GetNet(const string &net_name)
Gets the pointer to a created net.
vector< string > Blobs() const
Return a list of blob names.
NetBase * CreateNet(const NetDef &net_def, bool overwrite=false)
Creates a network with the given NetDef, and returns the pointer to the network.