Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_base.h
1 #ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_
2 #define CAFFE2_CORE_NET_ASYNC_BASE_H_
3 
4 #include "caffe2/core/common.h"
5 #include "caffe2/core/net.h"
6 #include "caffe2/core/net_dag_utils.h"
7 #include "caffe2/core/registry.h"
8 #include "caffe2/core/stats.h"
9 #include "caffe2/core/timer.h"
10 #include "caffe2/core/workspace.h"
11 #include "caffe2/proto/caffe2.pb.h"
12 #include "caffe2/utils/proto_utils.h"
13 #include "caffe2/utils/thread_pool.h"
14 
15 namespace caffe2 {
16 
17 class AsyncNetBase : public NetBase {
18  public:
19  AsyncNetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
20  ~AsyncNetBase() override;
21 
22  bool SupportsAsync() override {
23  return true;
24  }
25 
26  vector<OperatorBase*> GetOperators() const override {
27  return operators_;
28  }
29 
30  protected:
31  bool canSchedule(
32  int chain_id,
33  const std::vector<EventStatus>* status = nullptr);
34 
35  int tasksNum() const;
36  Event& event(int task_id) const;
37  EventStatus query(int task_id) const;
38  const std::vector<int>& children(int task_id) const;
39  const std::vector<int>& parents(int task_id) const;
40  void asyncWait(
41  int task_id,
42  int stream_id,
43  const std::vector<int>& wait_task_ids) const;
44  void run(int task_id, int stream_id);
45  int stream(int task_id);
46  std::shared_ptr<TaskThreadPool> pool(const DeviceOption& device_option);
47 
48  void finishTasks(const std::unordered_set<int>& task_ids);
49  void finalizeEvents();
50 
51  bool isStreamFree(int task_id, int stream_id) const;
52 
53  // Operator/task graph
54  std::vector<OperatorBase*> operators_;
55  std::vector<dag_utils::OperatorNode> operator_nodes_;
56  std::vector<std::vector<int>> chains_;
57  std::vector<dag_utils::OpGraphNode> chain_nodes_; // chains' parents/children
58 
59  // Pools and streams
60  std::mutex pools_mutex_;
61  std::shared_ptr<TaskThreadPool> cpu_pool_;
62  std::vector<std::shared_ptr<TaskThreadPool>> cpu_pools_;
63  std::vector<std::shared_ptr<TaskThreadPool>> gpu_pools_;
64  static thread_local std::vector<int> stream_counters_;
65 
66  DISABLE_COPY_AND_ASSIGN(AsyncNetBase);
67 
68  private:
69  std::shared_ptr<TaskThreadPool> pool_getter(
70  std::vector<std::shared_ptr<TaskThreadPool>>& pools,
71  int pool_idx,
72  const DeviceOption& device_option);
73 };
74 
75 CAFFE_DECLARE_SHARED_REGISTRY(
76  ThreadPoolRegistry,
78  const DeviceOption&);
79 
80 std::shared_ptr<TaskThreadPool> GetAsyncNetCPUThreadPool(int numa_node_id);
81 
82 } // namespace caffe2
83 
84 #endif // CAFFE2_CORE_NET_ASYNC_POLLING_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...