Caffe2 - C++ API
A deep learning, cross platform ML framework
half_float_ops.h
1 #ifndef CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_
2 #define CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <class Context>
10 class FloatToHalfOp : public Operator<Context> {
11  public:
12  USE_OPERATOR_CONTEXT_FUNCTIONS;
13  USE_SIMPLE_CTOR_DTOR(FloatToHalfOp);
14 
15  bool RunOnDevice() override;
16 };
17 
18 template <class Context>
19 class HalfToFloatOp : public Operator<Context> {
20  public:
21  USE_OPERATOR_CONTEXT_FUNCTIONS;
22  USE_SIMPLE_CTOR_DTOR(HalfToFloatOp);
23 
24  bool RunOnDevice() override;
25 };
26 
27 class Float16ConstantFillOp : public Operator<CPUContext> {
28  public:
29  Float16ConstantFillOp(const OperatorDef& operator_def, Workspace* ws)
30  : Operator<CPUContext>(operator_def, ws),
31  shape_(
32  ToVectorTIndex(OperatorBase::GetRepeatedArgument<int>("shape"))) {}
33 
34  USE_OPERATOR_FUNCTIONS(CPUContext);
35  virtual ~Float16ConstantFillOp() {}
36 
37  bool RunOnDevice() override;
38 
39  private:
40  vector<TIndex> shape_;
41 };
42 
43 inline std::vector<TensorShape> Float16FillerTensorInference(
44  const OperatorDef& def,
45  const vector<TensorShape>& in) {
46  vector<TensorShape> out(1);
47  ArgumentHelper helper(def);
48  out[0].set_data_type(static_cast<TensorProto_DataType>(
49  helper.GetSingleArgument<int>("dtype", TensorProto_DataType_FLOAT)));
50  auto shape = helper.GetRepeatedArgument<int>("shape");
51  for (int d : shape) {
52  out[0].add_dims(d);
53  }
54  return out;
55 }
56 
57 } // namespace caffe2
58 
59 #endif // CAFFE2_OPERATORS_HALF_FLOAT_OPS_H_
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:66
A helper class to index into arguments.
Definition: proto_utils.h:198
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
vector< TIndex > ToVectorTIndex(const std::vector< int > &src)
A utility function to convert vector<int> to vector<TIndex>.
Definition: tensor.h:33