Caffe2 - C++ API
A deep learning, cross platform ML framework
swish_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 #include "caffe2/utils/math.h"
5 
6 namespace caffe2 {
7 template <class Context>
8 class SwishGradientOp final : public Operator<Context> {
9  public:
10  USE_SIMPLE_CTOR_DTOR(SwishGradientOp)
11  USE_OPERATOR_CONTEXT_FUNCTIONS;
12 
13  template <typename T>
14  bool DoRunWithType();
15 
16  bool RunOnDevice() override {
17  return DispatchHelper<TensorTypes<float, double>>::call(this, Input(X));
18  }
19 
20  protected:
21  INPUT_TAGS(X, Y, DY);
22  OUTPUT_TAGS(DX);
23 };
24 
26  using GradientMakerBase::GradientMakerBase;
27  vector<OperatorDef> GetGradientDefs() override {
28  return SingleGradientDef(
29  "SwishGradient",
30  "",
31  vector<string>{I(0), O(0), GO(0)},
32  vector<string>{GI(0)});
33  }
34 };
35 
36 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...