1 #ifndef CAFFE2_OPERATORS_RECURRENT_NETWORK_EXECUTOR_H_ 2 #define CAFFE2_OPERATORS_RECURRENT_NETWORK_EXECUTOR_H_ 5 #include <unordered_set> 8 #include "caffe2/core/context.h" 9 #include "caffe2/core/logging.h" 10 #include "caffe2/core/operator.h" 11 #include "caffe2/core/timer.h" 12 #include "caffe2/operators/rnn/recurrent_network_executor_incl.h" 13 #include "caffe2/operators/rnn/rnn_capable_operator_observer.h" 35 const NetDef& step_net_def,
36 std::map<string, string>& recurrent_input_map,
37 std::string timestep_blob)
38 : step_net_def_(step_net_def),
39 recurrent_input_map_(recurrent_input_map),
40 timestep_blob_(timestep_blob) {
41 for (
int i = 0; i < step_net_def_.op_size(); i++) {
42 op_deps_.push_back(op_deps(i));
49 if (timestep_ops_.size() > 0) {
55 virtual bool Run(
int T) = 0;
57 virtual bool RunBackwards(
int T) = 0;
71 if (timestep_ops_template_.size() == 0) {
73 CalculateInternalDependencies();
77 for (
auto& rnn_op : timestep_ops_template_) {
78 rnn_op.has_timestep_blob =
false;
79 const OperatorDef& op = step_net_def_.op(rnn_op.order);
80 for (
int i = 0; i < op.input_size(); i++) {
81 if (op.input(i) == timestep_blob_) {
82 rnn_op.has_timestep_blob =
true;
87 !HasOutput(op, timestep_blob_),
88 "Timestep cannot be output of an op: ",
90 " op=" + ProtoDebugString(op));
95 if (timestep_ops_.size() <= t ||
96 (timestep_ops_.size() > t && timestep_ops_[t].size() == 0)) {
99 for (
int j = timestep_ops_.size(); j < t + 1; j++) {
100 timestep_ops_.push_back(std::vector<RNNNetOperator>());
101 timestep_ops_.back().reserve(timestep_ops_template_.size());
105 if (workspaces_.size() < t + 1) {
106 workspaces_.resize(t + 1);
113 std::string this_timestep_blob =
114 timestep_blob_ +
"_rnnexec_t" + caffe2::to_string(t);
116 auto b = ws->
GetBlob(this_timestep_blob);
118 b->GetMutable<
TensorCPU>()->mutable_data<int32_t>()[0] = t;
121 for (
auto& template_rnn_op : timestep_ops_template_) {
122 auto& rnn_op = template_rnn_op;
128 if (rnn_op.has_timestep_blob) {
129 OperatorDef op_copy = step_net_def_.op(rnn_op.order);
131 for (
int i = 0; i < op_copy.input_size(); i++) {
132 if (op_copy.input(i) == timestep_blob_) {
133 op_copy.set_input(i, this_timestep_blob);
137 rnn_op.op = CreateOperator(op_copy, ws);
138 for (
const auto& observer : observers_list) {
140 dynamic_cast_if_rtti<const RNNCapableOperatorObserver*>(
143 std::unique_ptr<ObserverBase<OperatorBase>> rnn_observer_copy =
144 rnn_observer->rnnCopy(rnn_op.op.get(), rnn_op.order);
147 "Observers without rnnCopy() implemented cannot be attached " 148 "to RNN using RNNExecutor.");
149 rnn_op.op->AttachObserver(std::move(rnn_observer_copy));
155 if (t > max_parallel_timesteps_ && max_parallel_timesteps_ > 0 &&
156 workspaces_[t - max_parallel_timesteps_] == ws) {
158 timestep_ops_[t - max_parallel_timesteps_][rnn_op.order].op;
162 rnn_op.op = CreateOperator(step_net_def_.op(rnn_op.order), ws);
163 for (
const auto& observer : observers_list) {
165 dynamic_cast_if_rtti<const RNNCapableOperatorObserver*>(
168 std::unique_ptr<ObserverBase<OperatorBase>> rnn_observer_copy =
169 rnn_observer->rnnCopy(rnn_op.op.get(), rnn_op.order);
172 "Observers without rnnCopy() implemented cannot be attached " 173 "to RNN using RNNExecutor.");
174 rnn_op.op->AttachObserver(std::move(rnn_observer_copy));
179 rnn_op.op->DisableEvent();
181 timestep_ops_[t].emplace_back(rnn_op);
192 max_parallel_timesteps_ = p;
195 size_t NumObserversStepNet() {
197 for (
auto& ops_at_timestep_t : timestep_ops_) {
198 for (
auto& rnn_op : ops_at_timestep_t) {
199 num += rnn_op.op->NumObservers();
208 bool has_input(std::string x,
int opidx) {
209 for (
auto& inp : step_net_def_.op(opidx).input()) {
214 for (
auto& inp : step_net_def_.op(opidx).control_input()) {
224 std::vector<string> op_deps(
int i) {
225 std::vector<string> outs;
226 auto& opdef = step_net_def_.op(i);
227 for (
string o : opdef.output()) {
230 for (
auto& arg : opdef.arg()) {
231 if (arg.name().find(
"rnn_dependency") == 0) {
232 outs.push_back(arg.s());
242 void infer_dependencies(
244 std::unordered_set<string> outputs,
245 std::vector<RNNNetOperator>& rnn_ops,
246 std::unordered_set<int>* dep_ops) {
247 std::unordered_set<int> already_accounted_deps;
248 int num_ops = step_net_def_.op_size();
249 bool ignore_links = this->ignoreLinkDependencies();
250 for (
int j = 0; j < num_ops - 1 && !outputs.empty(); j++) {
251 int i = (start_i + j) % num_ops;
252 if (ignore_links && rnn_ops[i].link_op) {
255 for (
auto& outp : outputs) {
256 if (has_input(outp, i)) {
257 if (already_accounted_deps.find(i) == already_accounted_deps.end()) {
263 for (
int odep : rnn_ops[i].dependencies) {
264 already_accounted_deps.insert(odep);
266 for (
string& dep_out : op_deps_[i]) {
267 auto oit = outputs.find(dep_out);
268 if (oit != outputs.end()) {
287 void add_race_conflict_dependencies(
289 std::vector<RNNNetOperator>& rnn_ops,
290 std::unordered_set<int>* dep_ops) {
291 for (
int i = 0; i < rnn_ops.size(); i++) {
295 if (rnn_ops[i].link_op && this->ignoreLinkDependencies()) {
298 for (
auto& dep_blob : op_deps_[i]) {
299 for (
auto& inp : step_net_def_.op(opidx).input()) {
300 if (inp == dep_blob) {
306 for (
auto& outp : step_net_def_.op(opidx).output()) {
307 if (outp == dep_blob) {
322 void CalculateInternalDependencies() {
323 for (
int i = 0; i < step_net_def_.op_size(); i++) {
324 timestep_ops_template_.push_back(
RNNNetOperator(step_net_def_.op(i), i));
328 for (
auto& rnn_op : timestep_ops_template_) {
329 std::unordered_set<string> dep_outputs;
330 for (
auto& outp : op_deps_[rnn_op.order]) {
331 dep_outputs.insert(outp);
335 for (
auto& outp : dep_outputs) {
336 auto rit = recurrent_input_map_.find(outp);
337 if (rit != recurrent_input_map_.end()) {
338 dep_outputs.insert(rit->second);
340 dep_outputs.insert(outp);
345 if (!rnn_op.link_op || !this->ignoreLinkDependencies()) {
346 std::unordered_set<int> dependent_ops;
350 timestep_ops_template_,
355 if (!this->ignoreLinkDependencies()) {
356 add_race_conflict_dependencies(
357 rnn_op.order, timestep_ops_template_, &dependent_ops);
360 for (
int i : dependent_ops) {
361 rnn_op.dependencies.push_back(i);
368 rnn_op.dependencies.begin(),
369 rnn_op.dependencies.end(),
370 [&](
const int& a,
const int& b) {
371 if (a < rnn_op.order && b < rnn_op.order) {
374 if (a >= rnn_op.order && b >= rnn_op.order) {
377 if (a >= rnn_op.order && b < rnn_op.order) {
386 for (
auto& rnn_op : timestep_ops_template_) {
387 for (
int i : rnn_op.dependencies) {
388 timestep_ops_template_[i].num_dynamic_inputs++;
390 if (i > rnn_op.order) {
391 timestep_ops_template_[i].frontier =
false;
393 timestep_ops_template_[i].num_recurrent_inputs++;
401 for (
auto& rnn_op : timestep_ops_template_) {
402 if (rnn_op.num_dynamic_inputs == 0 && rnn_op.num_recurrent_inputs == 0) {
403 if (rnn_op.link_op && this->ignoreLinkDependencies()) {
406 timestep_ops_template_.back().dependencies.push_back(rnn_op.order);
411 for (
auto& rnn_op : timestep_ops_template_) {
412 for (
int dep : rnn_op.dependencies) {
413 timestep_ops_template_[dep].parents.push_back(rnn_op.order);
425 auto& rnn_ops = timestep_ops_[t];
427 LOG(INFO) <<
"Timestep: " << t;
428 for (
auto& rnn_op : rnn_ops) {
429 auto& op = rnn_op.op;
430 LOG(INFO) <<
"Operator " << rnn_op.order <<
": " << op->type()
431 <<
" dep inputs:" << rnn_op.num_dynamic_inputs
432 <<
" rec inputs:" << rnn_op.num_recurrent_inputs
433 <<
" frontier: " << rnn_op.frontier;
434 for (
auto& inp : rnn_op.op->debug_def().input()) {
435 LOG(INFO) <<
" ---- input: " << inp;
437 for (
auto& outp : rnn_op.op->debug_def().output()) {
438 LOG(INFO) <<
" ---- output: " << outp;
440 for (
auto j : rnn_op.dependencies) {
441 LOG(INFO) <<
" dep: " << j <<
": " << rnn_ops[j].op->type();
443 for (
auto j : rnn_op.parents) {
444 LOG(INFO) <<
" parent: " << j <<
": " << rnn_ops[j].op->type();
448 LOG(INFO) <<
"recurrent_inputs:" << recurrent_input_map_;
450 for (
auto& rnn_op : rnn_ops) {
451 LOG(INFO) <<
"Operator " << rnn_op.order;
452 LOG(INFO) << ProtoDebugString(rnn_op.op->debug_def());
456 virtual void AnalyzeOps() {}
458 virtual bool ignoreLinkDependencies() = 0;
460 std::vector<std::vector<RNNNetOperator>> timestep_ops_;
461 std::vector<OperatorBase*> op_ptrs_;
463 std::vector<RNNNetOperator> timestep_ops_template_;
465 NetDef step_net_def_;
466 std::vector<std::vector<string>> op_deps_;
467 std::vector<Workspace*> workspaces_;
468 std::map<string, string> recurrent_input_map_;
469 std::string timestep_blob_;
471 int max_parallel_timesteps_ = -1;
477 template <
class Context>
478 std::unique_ptr<RecurrentNetworkExecutorBase> createRNNExecutor(
479 const NetDef& step_net_def,
480 std::map<string, string>& recurrent_input_map,
481 std::string timestep_blob,
487 const NetDef& step_net_def,
488 std::map<string, string>& recurrent_input_map,
489 std::string timestep_blob)
494 task_queue_.NoMoreJobs();
495 VLOG(1) <<
"Joining workers.";
496 for (
auto& worker : workers_) {
501 bool Run(
int T)
override;
503 bool RunBackwards(
int T)
override;
505 bool ignoreLinkDependencies()
override {
509 void setNumThreads(
int n) {
514 void _ExecRange(
int from,
int to);
518 void WorkerFunction();
520 void RunOp(
OpTask job,
int thread_id);
523 std::atomic<int> countdown_;
524 std::atomic<bool> failed_;
525 std::atomic<int> finished_timesteps_;
527 std::mutex countdown_mtx_;
528 std::condition_variable cv_;
529 std::vector<std::thread> workers_;
530 int num_threads_ = 4;
535 #endif // CAFFE2_OPERATORS_RECURRENT_NETWORK_EXECUTOR_H_ Blob * CreateBlob(const string &name)
Creates a blob of the given name.
RecurrentNetworkExecutor is a specialized runtime for recurrent neural networks (RNNs).
Struct for operator in a timestep and its dependenceis.
Data structure for a scheduled task in the task queue.
void EnsureTimestepInitialized(int t, Workspace *ws, const std::vector< std::unique_ptr< ObserverBase< OperatorBase >>> &observers_list)
Callers must call EnsureTimestepInitialized before starting execution for each of the relevant timest...
void PrintInfo(int t)
For debug purposes, print the dependency structure.
A helper class to index into arguments.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
void SetMaxParallelTimesteps(int p)
Set limit for the number of timesteps that run in parallel.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
T * GetMutable(bool *is_new_object=nullptr)
Gets a mutable pointer to the stored object.