1 #include "caffe2/operators/rnn/recurrent_network_executor.h" 3 #include "caffe2/core/timer.h" 14 const NetDef& step_net_def,
15 std::map<string, string>& recurrent_input_map,
16 std::string timestep_blob,
19 step_net_def, recurrent_input_map, timestep_blob);
21 rnn_args.GetSingleArgument<
int>(
"rnn_executor.num_threads", 0);
22 if (num_threads > 0) {
23 exec->setNumThreads(num_threads);
24 LOG(INFO) <<
"Set num threads: " << num_threads;
26 exec->debug_ = rnn_args.GetSingleArgument<
int>(
"rnn_executor_debug", 0);
27 return std::unique_ptr<RecurrentNetworkExecutorBase>(exec);
34 CAFFE_ENFORCE(timestep_ops_.size() >= T);
35 countdown_ = T * timestep_ops_[0].size();
36 finished_timesteps_ = 0;
38 CHECK(task_queue_.size() == 0);
40 for (
auto& rnn_op : timestep_ops_[0]) {
42 if (rnn_op.frontier) {
43 task_queue_.Push(
OpTask(0, rnn_op.order, T, 1));
55 CAFFE_ENFORCE(timestep_ops_.size() >= T);
56 countdown_ = T * timestep_ops_[0].size();
57 finished_timesteps_ = 0;
60 CHECK(task_queue_.size() == 0);
62 for (
auto& rnn_op : timestep_ops_[T - 1]) {
63 if (rnn_op.frontier) {
64 task_queue_.Push(
OpTask(T - 1, rnn_op.order, T, -1));
76 void ThreadedRecurrentNetworkExecutor::RunOp(
OpTask job,
int thread_id) {
78 ((job.forward() && job.timestep == 0) ||
79 (job.backward() && job.timestep == job.T - 1));
81 ((job.backward() && job.timestep == 0) ||
82 (job.forward() && job.timestep == job.T - 1));
83 auto& rnn_op = timestep_ops_[job.timestep][job.op_idx];
84 if (rnn_op.num_dynamic_inputs > 0 && !rnn_op.frontier) {
87 rnn_op.num_dynamic_inputs -
88 first_timestep * rnn_op.num_recurrent_inputs,
100 rnn_op.proc_inputs = 0;
107 for (
int depidx : rnn_op.dependencies) {
108 int t = job.timestep;
109 bool for_next_timestep = depidx <= rnn_op.order;
110 if (!last_timestep && for_next_timestep) {
112 }
else if (for_next_timestep) {
116 auto& dep_op = timestep_ops_[t][depidx];
117 int proc_inputs = dep_op.proc_inputs.fetch_add(1) + 1;
121 int num_req_inputs = dep_op.num_dynamic_inputs;
122 if (first_timestep && !for_next_timestep) {
123 num_req_inputs -= dep_op.num_recurrent_inputs;
126 if (proc_inputs == num_req_inputs || num_req_inputs == 0) {
127 task_queue_.Push(
OpTask(t, depidx, job.T, job.direction));
133 if (countdown_.fetch_sub(1) == 1) {
134 CAFFE_ENFORCE_EQ(0, task_queue_.size());
135 std::unique_lock<std::mutex> lk(countdown_mtx_);
144 void ThreadedRecurrentNetworkExecutor::WorkerFunction() {
146 static std::atomic<int> seq(0);
147 int id = seq.fetch_add(1);
151 if (!task_queue_.Pop(&job)) {
157 if (max_parallel_timesteps_ > 0) {
158 int t = (job.direction == 1 ? job.timestep : job.T - job.timestep + 1);
159 if (t - finished_timesteps_ >= max_parallel_timesteps_) {
161 task_queue_.Push(job);
168 if (job.op_idx == timestep_ops_template_.size() - 1) {
169 finished_timesteps_.fetch_add(1);
173 std::unique_lock<std::mutex> lk(countdown_mtx_);
174 LOG(ERROR) <<
"Crash at thread " <<
id <<
" timestep " << job.timestep
175 <<
" op:" << ProtoDebugString(step_net_def_.op(job.op_idx))
177 task_queue_.NoMoreJobs();
183 VLOG(1) <<
"Worker exiting, did run: " << num_jobs <<
" jobs";
190 void ThreadedRecurrentNetworkExecutor::_Exec() {
192 false, failed_,
"Tried to execute a previously failed RNN executor");
195 std::unique_lock<std::mutex> lk(countdown_mtx_);
196 while (workers_.size() < num_threads_) {
197 VLOG(1) <<
"Start RNN worker " << workers_.size() <<
" / " << num_threads_;
199 std::thread(&ThreadedRecurrentNetworkExecutor::WorkerFunction,
this));
204 while (!failed_ && countdown_ > 0) {
205 cv_.wait_for(lk, std::chrono::seconds(30), [&] {
209 LOG(INFO) <<
"RNN Executor still running, remaining ops: " 212 return failed_ || countdown_ == 0;
219 "RNN executor encountered failure. See prior error logs for details.");
Data structure for a scheduled task in the task queue.
std::unique_ptr< RecurrentNetworkExecutorBase > createRNNExecutor< CPUContext >(const NetDef &step_net_def, std::map< string, string > &recurrent_input_map, std::string timestep_blob, ArgumentHelper rnn_args)
Implementation of RecurrentNetworkExecutor that uses thread pool for multithreaded execution of RNNs...
A helper class to index into arguments.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
float Seconds()
Returns the elapsed time in seconds.
bool RunBackwards(int T) override
Run backward pass with T timesteps.
A simple timer object for measuring time.
bool Run(int T) override
Run forwardpass with T timesteps.