1 #ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_ 2 #define CAFFE2_CORE_NET_ASYNC_BASE_H_ 4 #include "caffe2/core/common.h" 5 #include "caffe2/core/net.h" 6 #include "caffe2/core/net_dag_utils.h" 7 #include "caffe2/core/registry.h" 8 #include "caffe2/core/stats.h" 9 #include "caffe2/core/timer.h" 10 #include "caffe2/core/workspace.h" 11 #include "caffe2/proto/caffe2.pb.h" 12 #include "caffe2/utils/proto_utils.h" 13 #include "caffe2/utils/thread_pool.h" 22 bool SupportsAsync()
override {
26 vector<OperatorBase*> GetOperators()
const override {
33 const std::vector<EventStatus>* status =
nullptr);
36 Event& event(
int task_id)
const;
37 EventStatus query(
int task_id)
const;
38 const std::vector<int>& children(
int task_id)
const;
39 const std::vector<int>& parents(
int task_id)
const;
43 const std::vector<int>& wait_task_ids)
const;
44 void run(
int task_id,
int stream_id);
45 int stream(
int task_id);
46 std::shared_ptr<TaskThreadPool> pool(
const DeviceOption& device_option);
48 void finishTasks(
const std::unordered_set<int>& task_ids);
49 void finalizeEvents();
51 bool isStreamFree(
int task_id,
int stream_id)
const;
54 std::vector<OperatorBase*> operators_;
55 std::vector<dag_utils::OperatorNode> operator_nodes_;
56 std::vector<std::vector<int>> chains_;
57 std::vector<dag_utils::OpGraphNode> chain_nodes_;
60 std::mutex pools_mutex_;
61 std::shared_ptr<TaskThreadPool> cpu_pool_;
62 std::vector<std::shared_ptr<TaskThreadPool>> cpu_pools_;
63 std::vector<std::shared_ptr<TaskThreadPool>> gpu_pools_;
64 static thread_local std::vector<int> stream_counters_;
69 std::shared_ptr<TaskThreadPool> pool_getter(
70 std::vector<std::shared_ptr<TaskThreadPool>>& pools,
72 const DeviceOption& device_option);
75 CAFFE_DECLARE_SHARED_REGISTRY(
80 std::shared_ptr<TaskThreadPool> GetAsyncNetCPUThreadPool(
int numa_node_id);
84 #endif // CAFFE2_CORE_NET_ASYNC_POLLING_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...