Caffe2 - C++ API
A deep learning, cross platform ML framework
boolean_mask_ops.h
1 #ifndef BOOLEAN_MASK_OPS_H
2 #define BOOLEAN_MASK_OPS_H
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/core/tensor.h"
7 #include "caffe2/utils/conversions.h"
8 
9 namespace caffe2 {
10 
11 template <class Context>
12 class BooleanMaskOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  BooleanMaskOp(const OperatorDef& operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws) {}
17 
18  bool RunOnDevice() override;
19 };
20 
21 template <class Context>
22 class SequenceMaskOp final : public Operator<Context> {
23  public:
24  USE_OPERATOR_CONTEXT_FUNCTIONS;
25  explicit SequenceMaskOp(const OperatorDef& operator_def, Workspace* ws)
26  : Operator<Context>(operator_def, ws),
27  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)),
28  radius_(OperatorBase::GetSingleArgument<int>("radius", 10)),
29  grad_(OperatorBase::GetSingleArgument<bool>("grad", false)),
30  fill_val_(OperatorBase::GetSingleArgument<float>(
31  "fill_val",
32  -1.0f * std::numeric_limits<float>::infinity())) {
33  // Mode argument is required
34  mode_ = GetArgument(operator_def, "mode").s();
35  // batch argument is optional, but if not given, we don't want a default val
36  if (HasArgument("batch")) {
37  batch_ = GetArgument(operator_def, "batch").i();
38  }
39 
40  if (HasArgument("repeat_from_axis")) {
41  CAFFE_ENFORCE(
42  mode_ == "sequence",
43  "repeat_from_axis currently only supported in sequence mode.");
44  CAFFE_ENFORCE(
45  !HasArgument("batch"),
46  "repeat_from_axis and batch not currently supported together.");
47  repeat_from_ =
48  OperatorBase::GetSingleArgument<int>("repeat_from_axis", -1);
49  }
50  }
51 
52  bool RunOnDevice() override;
53 
54  template <typename T>
55  bool DoRunWithType();
56 
57  private:
58  int axis_;
59  int radius_;
60  std::string mode_;
61  bool grad_;
62  float fill_val_;
63  int batch_;
64  int repeat_from_;
65 };
66 
67 } // namespace caffe2
68 
69 #endif
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37