2 #include "caffe2/core/context.h" 3 #include "caffe2/core/operator.h" 9 class CreateMutexOp final :
public Operator<CPUContext> {
11 CreateMutexOp(
const OperatorDef& operator_def, Workspace* ws)
12 : Operator<CPUContext>(operator_def, ws) {}
14 bool RunOnDevice()
override {
15 *OperatorBase::Output<std::unique_ptr<std::mutex>>(0) =
16 std::unique_ptr<std::mutex>(
new std::mutex);
21 class AtomicFetchAddOp final :
public Operator<CPUContext> {
23 AtomicFetchAddOp(
const OperatorDef& operator_def, Workspace* ws)
24 : Operator<CPUContext>(operator_def, ws) {}
26 bool RunOnDevice()
override {
27 auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(0);
32 c->Resize(std::vector<TIndex>());
33 d->Resize(std::vector<TIndex>());
34 auto* aPtr = a.data<int32_t>();
35 auto* bPtr = b.data<int32_t>();
36 auto* cPtr = c->mutable_data<int32_t>();
37 auto* dPtr = d->mutable_data<int32_t>();
38 std::lock_guard<std::mutex> lg(*mutex);
40 *cPtr = *aPtr + *bPtr;
45 class CreateAtomicBoolOp final :
public Operator<CPUContext> {
47 using Operator::Operator;
49 bool RunOnDevice()
override {
50 *OperatorBase::Output<std::unique_ptr<std::atomic<bool>>>(0) =
51 std::unique_ptr<std::atomic<bool>>(
new std::atomic<bool>(
false));
56 class ConditionalSetAtomicBoolOp final :
public Operator<CPUContext> {
58 using Operator::Operator;
60 bool RunOnDevice()
override {
62 OperatorBase::Input<std::unique_ptr<std::atomic<bool>>>(ATOMIC_BOOL);
63 if (Input(CONDITION).data<
bool>()[0]) {
70 INPUT_TAGS(ATOMIC_BOOL, CONDITION);
73 class CheckAtomicBoolOp final :
public Operator<CPUContext> {
75 using Operator::Operator;
77 bool RunOnDevice()
override {
78 auto& ptr = OperatorBase::Input<std::unique_ptr<std::atomic<bool>>>(0);
80 *Output(0)->mutable_data<
bool>() = ptr->load();
85 REGISTER_CPU_OPERATOR(CreateMutex, CreateMutexOp);
86 REGISTER_CPU_OPERATOR(AtomicFetchAdd, AtomicFetchAddOp);
88 REGISTER_CPU_OPERATOR(CreateAtomicBool, CreateAtomicBoolOp);
89 REGISTER_CPU_OPERATOR(ConditionalSetAtomicBool, ConditionalSetAtomicBoolOp);
90 REGISTER_CPU_OPERATOR(CheckAtomicBool, CheckAtomicBoolOp);
92 OPERATOR_SCHEMA(CreateMutex)
95 .SetDoc(
"Creates an unlocked mutex and returns it in a unique_ptr blob.")
96 .Output(0,
"mutex_ptr",
"Blob containing a std::unique_ptr<mutex>.");
98 OPERATOR_SCHEMA(AtomicFetchAdd)
102 Given a mutex and two int32 scalar tensors, performs an atomic fetch add 103 by mutating the first argument and adding it to the second input 104 argument. Returns the updated integer and the value prior to the update. 106 .Input(0, "mutex_ptr",
"Blob containing to a unique_ptr<mutex>")
107 .Input(1,
"mut_value",
"Value to be mutated after the sum.")
108 .Input(2,
"increment",
"Value to add to the first operand.")
109 .Output(0,
"mut_value",
"Mutated value after sum. Usually same as input 1.")
110 .Output(1,
"fetched_value",
"Value of the first operand before sum.")
111 .AllowInplace({{1, 0}});
113 OPERATOR_SCHEMA(CreateAtomicBool)
116 .SetDoc(
"Create an unique_ptr blob to hold an atomic<bool>")
117 .Output(0,
"atomic_bool",
"Blob containing a unique_ptr<atomic<bool>>");
119 OPERATOR_SCHEMA(ConditionalSetAtomicBool)
123 Set an atomic<bool> to true if the given condition bool variable is true 125 .Input(0, "atomic_bool",
"Blob containing a unique_ptr<atomic<bool>>")
126 .Input(1,
"condition",
"Blob containing a bool");
128 OPERATOR_SCHEMA(CheckAtomicBool)
131 .SetDoc(
"Copy the value of an atomic<bool> to a bool")
132 .Input(0,
"atomic_bool",
"Blob containing a unique_ptr<atomic<bool>>")
133 .Output(0,
"value",
"Copy of the value for the atomic<bool>");
135 SHOULD_NOT_DO_GRADIENT(CreateMutex);
136 SHOULD_NOT_DO_GRADIENT(AtomicFetchAdd);
137 SHOULD_NOT_DO_GRADIENT(CreateAtomicBool);
138 SHOULD_NOT_DO_GRADIENT(ConditionalSetAtomicBool);
139 SHOULD_NOT_DO_GRADIENT(CheckAtomicBool);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...