1 #include "caffe2/operators/cast_op.h" 6 template <
typename DstType,
typename SrcType>
7 bool CastOp<CPUContext>::DoRunWithType() {
8 auto& input = Input(0);
9 auto* output = Output(0);
10 output->ResizeLike(input);
11 const auto* data = input.template data<SrcType>();
12 auto* out = output->template mutable_data<DstType>();
13 auto N = input.size();
14 for (TIndex i = 0; i < N; ++i) {
15 out[i] =
static_cast<DstType
>(data[i]);
21 void CastOp<CPUContext>::SetBody(TensorProto_DataType to) {
23 case TensorProto_DataType_FLOAT:
25 body_ = &CastOp<CPUContext>::DoRunWithDstType<
float>;
27 case TensorProto_DataType_INT32:
28 body_ = &CastOp<CPUContext>::DoRunWithDstType<
int>;
30 case TensorProto_DataType_BYTE:
31 LOG(FATAL) <<
"BYTE is deprecated";
33 case TensorProto_DataType_STRING:
34 CAFFE_THROW(
"Casting to and from strings is not supported yet");
36 case TensorProto_DataType_BOOL:
37 body_ = &CastOp<CPUContext>::DoRunWithDstType<
bool>;
39 case TensorProto_DataType_UINT8:
40 body_ = &CastOp<CPUContext>::DoRunWithDstType<uint8_t>;
42 case TensorProto_DataType_INT8:
43 body_ = &CastOp<CPUContext>::DoRunWithDstType<int8_t>;
45 case TensorProto_DataType_UINT16:
46 body_ = &CastOp<CPUContext>::DoRunWithDstType<uint16_t>;
48 case TensorProto_DataType_INT16:
49 body_ = &CastOp<CPUContext>::DoRunWithDstType<int16_t>;
51 case TensorProto_DataType_INT64:
52 body_ = &CastOp<CPUContext>::DoRunWithDstType<int64_t>;
54 case TensorProto_DataType_FLOAT16:
55 CAFFE_THROW(
"Casting to and from float16 on CPU is not supported yet");
57 case TensorProto_DataType_DOUBLE:
59 body_ = &CastOp<CPUContext>::DoRunWithDstType<
double>;
61 case TensorProto_DataType_UNDEFINED:
62 CAFFE_THROW(
"Cast op must have 'to' argument of type DataType");
65 CAFFE_THROW(
"Unexpected 'to' argument value: ", to);
70 template <
typename DstType>
71 bool CastOp<CPUContext>::DoRunWithDstType() {
72 return DispatchHelper<
83 DstType>::call(
this, Input(0));
86 REGISTER_CPU_OPERATOR(Cast, CastOp<CPUContext>);
91 .TensorInferenceFunction(
92 [](
const OperatorDef& def,
const vector<TensorShape>& in) {
93 ArgumentHelper helper(def);
94 vector<TensorShape> out;
96 out[0].set_data_type(cast::GetCastDataType(helper,
"to"));
100 The operator casts the elements of a given input tensor to a data type 101 specified by the 'to' argument and returns an output tensor of the same size in 102 the converted type. The 'to' argument must be one of the data types specified 103 in the 'DataType' enum field in the TensorProto message. If the 'to' argument 104 is not provided or is not one of the enumerated types in DataType, Caffe2 105 throws an Enforce error. 107 NOTE: Casting to and from strings is not supported yet. 111 "The data type to which the elements of the input tensor are cast." 112 "Strictly must be one of the types from DataType enum in TensorProto")
113 .Input(0,
"input",
"Input tensor to be cast.")
117 "Output tensor with the same shape as input with type " 118 "specified by the 'to' argument");
124 using GradientMakerBase::GradientMakerBase;
125 vector<OperatorDef> GetGradientDefs()
override {
127 vector<OperatorDef> defs =
SingleGradientDef(
"Cast",
"", vector<string>{GO(0)}, vector<string>{GI(0)});
132 auto to_name = cast::GetCastDataType(argsHelper,
"to");
135 argsHelper.HasSingleArgumentOfType<
string>(
"from_type") ||
136 argsHelper.HasSingleArgumentOfType<
int>(
"from_type"),
137 "Argument 'from_type' of type int or string" 138 " is required to get the gradient of CastOp");
140 auto from_name = cast::GetCastDataType(argsHelper,
"from_type");
141 Argument *to = defs[0].add_arg();
143 to->set_i(from_name);
145 Argument *from = defs[0].add_arg();
146 from->set_name(
"from_type");
147 from->set_i(to_name);
152 bool CopyArguments()
const override {
A helper class to index into arguments.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...