Caffe2 - C++ API
A deep learning, cross platform ML framework
leaky_relu_op.h
1 #pragma once
2 
3 #include "caffe2/core/context.h"
4 #include "caffe2/core/logging.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <typename T, class Context>
10 class LeakyReluOp : public Operator<Context> {
11  public:
12  LeakyReluOp(const OperatorDef& operator_def, Workspace* ws)
13  : Operator<Context>(operator_def, ws), alpha_(0.01) {
14  if (HasArgument("alpha")) {
15  alpha_ =
16  static_cast<T>(OperatorBase::GetSingleArgument<float>("alpha", 0.01));
17  }
18  }
19 
20  USE_OPERATOR_CONTEXT_FUNCTIONS;
21 
22  bool RunOnDevice() override;
23 
24  protected:
25  T alpha_;
26 };
27 
28 template <typename T, class Context>
29 class LeakyReluGradientOp final : public Operator<Context> {
30  public:
31  LeakyReluGradientOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws), alpha_(0.01) {
33  if (HasArgument("alpha")) {
34  alpha_ =
35  static_cast<T>(OperatorBase::GetSingleArgument<float>("alpha", 0.01));
36  }
37  }
38 
39  USE_OPERATOR_CONTEXT_FUNCTIONS;
40 
41  bool RunOnDevice() override;
42 
43  protected:
44  T alpha_;
45 };
46 
47 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37