1 #include "caffe2/core/net_async_gpu_thread_pool.h" 3 #include "caffe2/core/context_gpu.h" 5 CAFFE2_DEFINE_int(caffe2_threads_per_gpu, 1,
"Number of CPU threads per GPU");
10 std::shared_ptr<TaskThreadPool> AsyncNetGPUThreadPoolCreator(
11 const DeviceOption& device_option) {
13 device_option.device_type(),
15 "Unexpected device type for CUDA thread pool");
16 return GetAsyncNetGPUThreadPool(device_option.cuda_gpu_id());
20 CAFFE_REGISTER_CREATOR(ThreadPoolRegistry, CUDA, AsyncNetGPUThreadPoolCreator);
22 std::shared_ptr<TaskThreadPool> GetAsyncNetGPUThreadPool(
int gpu_id) {
23 static std::unordered_map<int, std::weak_ptr<TaskThreadPool>> pools;
24 static std::mutex pool_mutex;
25 std::lock_guard<std::mutex> lock(pool_mutex);
27 std::shared_ptr<TaskThreadPool> shared_pool =
nullptr;
28 if (pools.count(gpu_id)) {
29 shared_pool = pools.at(gpu_id).lock();
33 std::make_shared<TaskThreadPool>(FLAGS_caffe2_threads_per_gpu);
34 pools[gpu_id] = shared_pool;
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...