3 #include "caffe2/core/common.h" 4 #include "caffe2/core/logging.h" 7 #include <condition_variable> 27 constexpr
size_t kGEMMLOWPCacheLineSize = 64;
32 template <
typename... Args>
33 static T* alloc(Args&&... args) {
36 #if defined(__ANDROID__) 37 p = memalign(kGEMMLOWPCacheLineSize,
sizeof(T));
38 #elif defined(_MSC_VER) 39 p = _aligned_malloc(
sizeof(T), kGEMMLOWPCacheLineSize);
41 posix_memalign((
void**)&p, kGEMMLOWPCacheLineSize,
sizeof(T));
45 return new (p) T(std::forward<Args>(args)...);
52 static void release(T* p) {
69 template <
typename... Args>
70 static std::unique_ptr<T, AlignedDeleter<T>> make(Args&&... args) {
71 return std::unique_ptr<T, AlignedDeleter<T>>(
76 const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
79 #define GEMMLOWP_NOP __nop(); 81 #define GEMMLOWP_NOP "nop\n" 84 #define GEMMLOWP_STRING_CONCAT_4(X) X X X X 85 #define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP) 86 #define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4) 87 #define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16) 89 inline int Do256NOPs() {
93 asm volatile(GEMMLOWP_NOP64);
98 #undef GEMMLOWP_STRING_CONCAT_4 99 #undef GEMMLOWP_NOP256 100 #undef GEMMLOWP_NOP64 101 #undef GEMMLOWP_NOP16 127 template <
typename T>
128 T WaitForVariableChange(std::atomic<T>* var,
130 std::condition_variable* cond,
136 T new_value = var->load(std::memory_order_relaxed);
137 if (new_value != initial_value) {
138 std::atomic_thread_fence(std::memory_order_acquire);
142 while (nops < kMaxBusyWaitNOPs) {
144 new_value = var->load(std::memory_order_relaxed);
145 if (new_value != initial_value) {
146 std::atomic_thread_fence(std::memory_order_acquire);
154 std::unique_lock<std::mutex> g(*mutex);
155 T new_value = var->load(std::memory_order_relaxed);
157 cond->wait(g, [&]() {
158 new_value = var->load(std::memory_order_relaxed);
159 return new_value != initial_value;
161 DCHECK_NE(static_cast<size_t>(new_value), static_cast<size_t>(initial_value));
173 void Reset(std::size_t initial_count) {
174 std::lock_guard<std::mutex> g(mutex_);
175 DCHECK_EQ(count_, 0);
176 count_ = initial_count;
183 bool DecrementCount() {
184 const auto count_value = count_.fetch_sub(1, std::memory_order_relaxed) - 1;
185 DCHECK_GE(count_value, 0);
186 if (count_value == 0) {
187 std::lock_guard<std::mutex> g(mutex_);
190 bool retval = count_value == 0;
197 while (
size_t count_value = count_.load(std::memory_order_relaxed)) {
198 WaitForVariableChange(&count_, count_value, &cond_, &mutex_);
203 std::condition_variable cond_;
205 std::atomic<std::size_t> count_{0};
212 virtual void Run() = 0;
216 class alignas(kGEMMLOWPCacheLineSize)
Worker {
218 enum class State : uint8_t {
227 state_(State::ThreadStartup),
228 counter_to_decrement_when_ready_(counter_to_decrement_when_ready) {
229 thread_ = caffe2::make_unique<std::thread>([
this]() { this->ThreadFunc(); });
233 ChangeState(State::ExitAsSoonAsPossible);
240 void ChangeState(State new_state) {
241 std::lock_guard<std::mutex> g(state_mutex_);
242 DCHECK(new_state != state_.load(std::memory_order_relaxed));
243 switch (state_.load(std::memory_order_relaxed)) {
244 case State::ThreadStartup:
245 DCHECK(new_state == State::Ready);
248 DCHECK(new_state == State::HasWork || new_state == State::ExitAsSoonAsPossible);
251 DCHECK(new_state == State::Ready || new_state == State::ExitAsSoonAsPossible);
256 state_.store(new_state, std::memory_order_relaxed);
257 state_cond_.notify_one();
258 if (new_state == State::Ready) {
259 counter_to_decrement_when_ready_->DecrementCount();
265 ChangeState(State::Ready);
272 State state_to_act_upon =
273 WaitForVariableChange(&state_, State::Ready, &state_cond_, &state_mutex_);
276 switch (state_to_act_upon) {
282 ChangeState(State::Ready);
284 case State::ExitAsSoonAsPossible:
292 static void* ThreadFunc(
void* arg) {
293 static_cast<Worker*
>(arg)->ThreadFunc();
299 void StartWork(
Task* task) {
302 DCHECK(state_.load(std::memory_order_acquire) == State::Ready);
303 ChangeState(State::HasWork);
308 std::unique_ptr<std::thread> thread_;
315 std::condition_variable state_cond_;
316 std::mutex state_mutex_;
319 std::atomic<State> state_;
330 void Execute(
const std::vector<std::shared_ptr<Task>>& tasks) {
331 CAFFE_ENFORCE_GE(tasks.size(), 1);
333 int workers_count = tasks.size() - 1;
334 CreateWorkers(workers_count);
335 DCHECK_LE(workers_count, workers_.size());
336 counter_to_decrement_when_ready_.Reset(workers_count);
337 for (
auto task = 1; task < tasks.size(); ++task) {
338 workers_[task - 1]->StartWork(tasks[task].
get());
341 auto& task = tasks.front();
344 counter_to_decrement_when_ready_.Wait();
351 void CreateWorkers(std::size_t workers_count) {
352 if (workers_.size() >= workers_count) {
355 counter_to_decrement_when_ready_.Reset(workers_count - workers_.size());
356 while (workers_.size() < workers_count) {
359 counter_to_decrement_when_ready_.Wait();
363 std::vector<std::unique_ptr<Worker, AlignedDeleter<Worker>>> workers_;
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...