3 #include "caffe2/core/context.h" 4 #include "caffe2/core/logging.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/operators/filler_op.h" 7 #include "caffe2/utils/cast.h" 8 #include "caffe2/utils/math.h" 12 template <
typename T,
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
23 if (!std::is_same<T, float>::value || !helper.HasArgument(
"dtype")) {
26 auto dtype = cast::GetCastDataType(helper,
"dtype");
28 case TensorProto_DataType_FLOAT:
29 ExtractValues<float>();
31 case TensorProto_DataType_DOUBLE:
32 ExtractValues<double>();
34 case TensorProto_DataType_BOOL:
35 ExtractValues<bool>();
37 case TensorProto_DataType_INT32:
40 case TensorProto_DataType_INT64:
41 ExtractValues<int64_t>();
43 case TensorProto_DataType_STRING:
44 ExtractValues<std::string>();
46 case TensorProto_DataType_UNDEFINED:
47 CAFFE_THROW(
"Cannot have undefined 'dtype' argument");
49 CAFFE_THROW(
"Unexpected 'dtype' argument value: ", dtype);
55 return (this->*body_)(output);
59 template <
typename Type>
60 void ExtractValues() {
62 OperatorBase::template GetRepeatedArgument<Type>(
"values");
63 values_.
Resize(source_values.size());
64 Type* values_data = values_.template mutable_data<Type>();
65 for (
int i = 0; i < source_values.size(); i++) {
66 values_data[i] =
static_cast<Type
>(source_values[i]);
68 body_ = &GivenTensorFillOp::FillWithType<Type>;
71 template <
typename Type>
73 DCHECK_EQ(output->
size(), values_.
size())
74 <<
"output size: " << output->
size()
75 <<
" given size: " << values_.
size();
76 auto* data = output->template mutable_data<Type>();
77 const Type* values_data = values_.template data<Type>();
79 context_.template Copy<Type, CPUContext, Context>(
80 output->
size(), values_data, data);
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
TIndex size() const
Returns the size (i.e.
A helper class to index into arguments.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
void Resize(Ts...dim_source)
Resizes a tensor.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...