Caffe2 - C++ API
A deep learning, cross platform ML framework
cast_op.cc
1 #include "caffe2/operators/cast_op.h"
2 
3 namespace caffe2 {
4 
5 template <>
6 template <typename DstType, typename SrcType>
7 bool CastOp<CPUContext>::DoRunWithType() {
8  auto& input = Input(0);
9  auto* output = Output(0);
10  output->ResizeLike(input);
11  const auto* data = input.template data<SrcType>();
12  auto* out = output->template mutable_data<DstType>();
13  auto N = input.size();
14  for (TIndex i = 0; i < N; ++i) {
15  out[i] = static_cast<DstType>(data[i]);
16  }
17  return true;
18 }
19 
20 template <>
21 void CastOp<CPUContext>::SetBody(TensorProto_DataType to) {
22  switch (to) {
23  case TensorProto_DataType_FLOAT:
24  // body_ = &CastOp::DoRunIncFp16WithDstType<float>;
25  body_ = &CastOp<CPUContext>::DoRunWithDstType<float>;
26  break;
27  case TensorProto_DataType_INT32:
28  body_ = &CastOp<CPUContext>::DoRunWithDstType<int>;
29  break;
30  case TensorProto_DataType_BYTE:
31  LOG(FATAL) << "BYTE is deprecated";
32  break;
33  case TensorProto_DataType_STRING:
34  CAFFE_THROW("Casting to and from strings is not supported yet");
35  // break;
36  case TensorProto_DataType_BOOL:
37  body_ = &CastOp<CPUContext>::DoRunWithDstType<bool>;
38  break;
39  case TensorProto_DataType_UINT8:
40  body_ = &CastOp<CPUContext>::DoRunWithDstType<uint8_t>;
41  break;
42  case TensorProto_DataType_INT8:
43  body_ = &CastOp<CPUContext>::DoRunWithDstType<int8_t>;
44  break;
45  case TensorProto_DataType_UINT16:
46  body_ = &CastOp<CPUContext>::DoRunWithDstType<uint16_t>;
47  break;
48  case TensorProto_DataType_INT16:
49  body_ = &CastOp<CPUContext>::DoRunWithDstType<int16_t>;
50  break;
51  case TensorProto_DataType_INT64:
52  body_ = &CastOp<CPUContext>::DoRunWithDstType<int64_t>;
53  break;
54  case TensorProto_DataType_FLOAT16:
55  CAFFE_THROW("Casting to and from float16 on CPU is not supported yet");
56  // break;
57  case TensorProto_DataType_DOUBLE:
58  //body_ = &CastOp::DoRunIncFp16WithDstType<double>;
59  body_ = &CastOp<CPUContext>::DoRunWithDstType<double>;
60  break;
61  case TensorProto_DataType_UNDEFINED:
62  CAFFE_THROW("Cast op must have 'to' argument of type DataType");
63  // break;
64  default:
65  CAFFE_THROW("Unexpected 'to' argument value: ", to);
66  }
67 }
68 
69 template <>
70 template <typename DstType>
71 bool CastOp<CPUContext>::DoRunWithDstType() {
72  return DispatchHelper<
73  TensorTypes<
74  float,
75  int32_t,
76  bool,
77  uint8_t,
78  int8_t,
79  uint16_t,
80  int16_t,
81  int64_t,
82  double>,
83  DstType>::call(this, Input(0));
84 }
85 
86 REGISTER_CPU_OPERATOR(Cast, CastOp<CPUContext>);
87 
88 OPERATOR_SCHEMA(Cast)
89  .NumInputs(1)
90  .NumOutputs(1)
91  .TensorInferenceFunction(
92  [](const OperatorDef& def, const vector<TensorShape>& in) {
93  ArgumentHelper helper(def);
94  vector<TensorShape> out;
95  out.push_back(in[0]);
96  out[0].set_data_type(cast::GetCastDataType(helper, "to"));
97  return out;
98  })
99  .SetDoc(R"DOC(
100 The operator casts the elements of a given input tensor to a data type
101 specified by the 'to' argument and returns an output tensor of the same size in
102 the converted type. The 'to' argument must be one of the data types specified
103 in the 'DataType' enum field in the TensorProto message. If the 'to' argument
104 is not provided or is not one of the enumerated types in DataType, Caffe2
105 throws an Enforce error.
106 
107 NOTE: Casting to and from strings is not supported yet.
108 )DOC")
109  .Arg(
110  "to",
111  "The data type to which the elements of the input tensor are cast."
112  "Strictly must be one of the types from DataType enum in TensorProto")
113  .Input(0, "input", "Input tensor to be cast.")
114  .Output(
115  0,
116  "output",
117  "Output tensor with the same shape as input with type "
118  "specified by the 'to' argument");
119 
120 // Some Casts are compatible with gradients, but for now we don't support it
121 // GRADIENT_NOT_IMPLEMENTED_YET(Cast);
122 
123 class GetCastGradient : public GradientMakerBase {
124  using GradientMakerBase::GradientMakerBase;
125  vector<OperatorDef> GetGradientDefs() override {
126 
127  vector<OperatorDef> defs = SingleGradientDef("Cast", "", vector<string>{GO(0)}, vector<string>{GI(0)});
128 
129  // now modify the arguments in defs[0]
130  ArgumentHelper argsHelper(def_);
131 
132  auto to_name = cast::GetCastDataType(argsHelper, "to");
133 
134  CAFFE_ENFORCE(
135  argsHelper.HasSingleArgumentOfType<string>("from_type") ||
136  argsHelper.HasSingleArgumentOfType<int>("from_type"),
137  "Argument 'from_type' of type int or string"
138  " is required to get the gradient of CastOp");
139 
140  auto from_name = cast::GetCastDataType(argsHelper, "from_type");
141  Argument *to = defs[0].add_arg();
142  to->set_name("to");
143  to->set_i(from_name);
144 
145  Argument *from = defs[0].add_arg();
146  from->set_name("from_type");
147  from->set_i(to_name);
148 
149  return defs;
150  }
151 
152  bool CopyArguments() const override {
153  return false;
154  }
155 };
156 
157 REGISTER_GRADIENT(Cast, GetCastGradient);
158 
159 
160 
161 
162 } // namespace caffe2
A helper class to index into arguments.
Definition: proto_utils.h:198
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...