Caffe2 - C++ API
A deep learning, cross platform ML framework
Related Pages
Modules
Data Structures
Files
C++ API
Python API
GitHub
File List
Globals
caffe2
operators
swish_op.cc
1
#include "swish_op.h"
2
#include "caffe2/core/types.h"
3
#include "caffe2/operators/elementwise_op.h"
4
#include "caffe2/utils/math.h"
5
6
namespace
caffe2
{
7
struct
SwishCPUFunctor
{
8
template
<
typename
T>
9
inline
void
10
operator()(
const
int
n,
const
T* x, T* y,
CPUContext
*
/*device_context*/
) {
11
ConstEigenVectorArrayMap<T> xM(x, n);
12
EigenVectorArrayMap<T>(y, n) = xM / (1. + (-xM).exp());
13
}
14
};
15
16
template
<>
17
template
<
typename
T>
18
bool
SwishGradientOp<CPUContext>::DoRunWithType
() {
19
auto
& Xin = Input(X);
20
auto
& Yin = Input(Y);
21
auto
& DYin = Input(DY);
22
auto
* DXout = Output(DX);
23
CAFFE_ENFORCE_EQ(Xin.size(), Yin.size());
24
CAFFE_ENFORCE_EQ(DYin.size(), Yin.size());
25
DXout->ResizeLike(Yin);
26
27
const
float
* Xdata = Xin.template data<float>();
28
const
float
* Ydata = Yin.template data<float>();
29
const
float
* dYdata = DYin.template data<float>();
30
float
* dXdata = DXout->template mutable_data<float>();
31
32
EigenVectorArrayMap<float> dXvec(dXdata, DXout->size());
33
ConstEigenVectorArrayMap<float> Xvec(Xdata, Xin.size());
34
ConstEigenVectorArrayMap<float> Yvec(Ydata, Yin.size());
35
ConstEigenVectorArrayMap<float> dYvec(dYdata, DYin.size());
36
37
// dx = dy * (y + sigmoid(x)*(1-y))
38
dXvec = dYvec * (Yvec + (1. / (1. + (-Xvec).exp())) * (1. - Yvec));
39
return
true
;
40
}
41
42
REGISTER_CPU_OPERATOR(
43
Swish,
44
UnaryElementwiseOp
<
45
TensorTypes<float, double>
,
46
CPUContext
,
47
SwishCPUFunctor
>);
48
REGISTER_CPU_OPERATOR(SwishGradient,
SwishGradientOp<CPUContext>
);
49
50
// Input: X, output: Y
51
OPERATOR_SCHEMA(Swish)
52
.NumInputs(1)
53
.NumOutputs(1)
54
.IdenticalTypeAndShape()
55
.SetDoc(R
"DOC(
56
Swish takes one input data (Tensor<T>) and produces one output data
57
(Tensor<T>) where the swish function, y = x / (1 + exp(-x)), is applied to the
58
tensor elementwise.
59
)DOC")
60
.Input(0,
"X"
,
"1D input tensor"
)
61
.Output(0,
"Y"
,
"1D output tensor"
);
62
// Input: X, Y, dY, output: dX
63
OPERATOR_SCHEMA(SwishGradient)
64
.NumInputs(3)
65
.NumOutputs(1)
66
.AllowInplace({{2, 0}})
67
.SetDoc(R
"DOC(
68
SwishGradient takes X, Y and dY and uses this to update dX according to the
69
chain rule and derivatives of the swish function.
70
)DOC");
71
72
REGISTER_GRADIENT(Swish,
GetSwishGradient
);
73
}
// namespace caffe2
caffe2::SwishGradientOp
Definition:
swish_op.h:8
caffe2::GetSwishGradient
Definition:
swish_op.h:25
caffe2::CPUContext
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition:
context.h:66
caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition:
convert_encoded_to_raw_leveldb.cc:47
caffe2::TensorTypes
Definition:
operator.h:547
caffe2::UnaryElementwiseWithArgsOp
Definition:
elementwise_op.h:36
caffe2::SwishCPUFunctor
Definition:
swish_op.cc:7
Generated on Thu Apr 19 2018 13:03:56 for Caffe2 - C++ API by
1.8.11