Caffe2 - C++ API
A deep learning, cross platform ML framework
net_dag.h
1 #ifndef CAFFE2_CORE_NET_DAG_H_
2 #define CAFFE2_CORE_NET_DAG_H_
3 
4 #include <atomic>
5 #include <climits>
6 #include <cstddef>
7 #include <thread> // NOLINT
8 #include <typeinfo>
9 #include <unordered_map>
10 #include <vector>
11 
12 #include "caffe2/core/blob.h"
13 #include "caffe2/core/common.h"
14 #include "caffe2/core/logging.h"
15 #include "caffe2/core/net_dag_utils.h"
16 #include "caffe2/core/observer.h"
17 #include "caffe2/core/operator_schema.h"
18 #include "caffe2/core/registry.h"
19 #include "caffe2/core/stats.h"
20 #include "caffe2/core/tensor.h"
21 #include "caffe2/core/timer.h"
22 #include "caffe2/core/workspace.h"
23 #include "caffe2/proto/caffe2.pb.h"
24 #include "caffe2/utils/simple_queue.h"
25 
26 namespace caffe2 {
27 
28 class DAGNetBase : public NetBase {
29  public:
30  DAGNetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
31  ~DAGNetBase() override;
32 
33  // WorkerFunction() is a function wrapper to allow us to run worker threads.
34  // It checks out one ready-to-run operator from the job queue, runs it,
35  // notifies all its children, and for any children that is ready, enqueues
36  // it to the job queue.
37  void WorkerFunction();
38  vector<float> TEST_Benchmark(
39  const int warmup_runs,
40  const int main_runs,
41  const bool run_individual) override;
42 
43  const dag_utils::ExecutionChains& TEST_execution_chains() const {
44  return execution_chains_;
45  }
46 
47  vector<OperatorBase*> GetOperators() const override {
48  return operators_;
49  }
50 
51  protected:
52  bool DoRunAsync() override;
53 
54  virtual bool RunAt(int chain_id, const std::vector<int>& chain) = 0;
55  void HandleException(int operator_idx, const std::string& exception_str);
56 
57  vector<dag_utils::OperatorNode> operator_nodes_;
58  vector<OperatorBase*> operators_;
59  dag_utils::ExecutionChains execution_chains_;
60  vector<int> initial_frontier_;
61  std::unique_ptr<SimpleQueue<int>> job_queue_;
62  std::vector<std::thread> workers_;
63  int num_workers_;
64  int remaining_ops_;
65 
66  bool success_;
67  // Use an atomic to guard caught_exception_ so it is written to only once
68  std::atomic<bool> caught_exception_yet_;
69 #ifdef CAFFE2_USE_EXCEPTION_PTR
70  std::exception_ptr caught_exception_;
71 #endif // CAFFE2_USE_EXCEPTION_PTR
72  int iter_;
73  std::mutex remaining_ops_mutex_;
74  std::condition_variable cv_;
75  std::mutex run_in_progress_;
76 
77  struct DAGNetStats {
78  CAFFE_STAT_CTOR(DAGNetStats);
79  CAFFE_AVG_EXPORTED_STAT(task_pool_wait_time_us);
80  CAFFE_AVG_EXPORTED_STAT(task_time_to_scheduled_us);
81  CAFFE_AVG_EXPORTED_STAT(task_time_to_succeeded_ms);
82  CAFFE_AVG_EXPORTED_STAT(task_wait_time_us);
83  };
84  mutable std::vector<DAGNetStats> stats_;
85  std::unordered_map<int, std::unique_ptr<Timer>> task_timers_;
86 
87  DISABLE_COPY_AND_ASSIGN(DAGNetBase);
88 };
89 
90 class DAGNet : public DAGNetBase {
91  public:
92  using DAGNetBase::DAGNetBase;
93 
94  protected:
95  bool RunAt(int chain_id, const std::vector<int>& chain) override;
96  bool SupportsAsync() override {
97  return false;
98  }
99 };
100 
101 } // namespace caffe2
102 
103 #endif // CAFFE2_CORE_NET_DAG_H_
vector< float > TEST_Benchmark(const int warmup_runs, const int main_runs, const bool run_individual) override
Benchmarks a network.
Definition: net_dag.cc:299
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...