Caffe2 - C++ API
A deep learning, cross platform ML framework
Related Pages
Modules
Data Structures
Files
C++ API
Python API
GitHub
File List
Globals
caffe2
operators
math_ops.cc
1
#include "caffe2/operators/math_ops.h"
2
#include "caffe2/utils/math.h"
3
4
namespace
caffe2
{
5
6
struct
SqrCPUFunctor
{
7
template
<
typename
T>
8
inline
void
9
operator()(
const
int
n,
const
T* x, T* y,
CPUContext
* device_context) {
10
math::Sqr<T, CPUContext>(n, x, y, device_context);
11
}
12
};
13
14
REGISTER_CPU_OPERATOR(
15
Sqr,
16
UnaryElementwiseOp
<
TensorTypes<float>
,
CPUContext
,
SqrCPUFunctor
>);
17
18
OPERATOR_SCHEMA(Sqr)
19
.NumInputs(1)
20
.NumOutputs(1)
21
.AllowInplace({{0, 0}})
22
.IdenticalTypeAndShape()
23
.SetDoc(
"Square (x^2) the elements of the input"
)
24
.Input(0,
"input"
,
"Input tensor"
)
25
.Output(0,
"output"
,
"Squared elements of the input"
);
26
27
class
GetSqrGradient
:
public
GradientMakerBase
{
28
using
GradientMakerBase::GradientMakerBase;
29
vector<OperatorDef> GetGradientDefs()
override
{
30
Argument scale_arg;
31
scale_arg.set_name(
"scale"
);
32
scale_arg.set_f(2.0);
33
return
vector<OperatorDef>{CreateOperatorDef(
34
"Scale"
,
35
""
,
36
std::vector<string>{GO(0)},
37
std::vector<string>{GO(0)},
38
std::vector<Argument>{scale_arg}),
39
CreateOperatorDef(
40
"Mul"
,
41
""
,
42
std::vector<string>{GO(0), I(0)},
43
std::vector<string>{GI(0)})};
44
}
45
};
46
REGISTER_GRADIENT(Sqr,
GetSqrGradient
);
47
48
struct
SignCPUFunctor
{
49
template
<
typename
T>
50
inline
void
51
operator()(
const
int
n,
const
T* x, T* y,
CPUContext
* device_context) {
52
for
(
int
i = 0; i < n; ++i) {
53
y[i] = (-T(1) * (x[i] < 0)) + (x[i] > 0);
54
}
55
}
56
};
57
58
REGISTER_CPU_OPERATOR(
59
Sign,
60
UnaryElementwiseOp
<
TensorTypes<float>
,
CPUContext
,
SignCPUFunctor
>);
61
62
OPERATOR_SCHEMA(Sign)
63
.NumInputs(1)
64
.NumOutputs(1)
65
.SetDoc(
"Computes sign for each element of the input: -1, 0 or 1."
)
66
.IdenticalTypeAndShape();
67
SHOULD_NOT_DO_GRADIENT(Sign);
68
69
}
// namespace caffe2
caffe2::GradientMakerBase
Definition:
operator_gradient.h:47
caffe2::CPUContext
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition:
context.h:66
caffe2::SignCPUFunctor
Definition:
math_ops.cc:48
caffe2::SqrCPUFunctor
Definition:
math_ops.cc:6
caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition:
convert_encoded_to_raw_leveldb.cc:47
caffe2::TensorTypes
Definition:
operator.h:547
caffe2::UnaryElementwiseWithArgsOp
Definition:
elementwise_op.h:36
caffe2::GetSqrGradient
Definition:
math_ops.cc:27
Generated on Thu Apr 19 2018 13:03:55 for Caffe2 - C++ API by
1.8.11