1 #ifndef CAFFE2_OPERATORS_FILLER_OP_H_ 2 #define CAFFE2_OPERATORS_FILLER_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 20 template <
class Context>
25 shape_(
ToVectorTIndex(OperatorBase::GetRepeatedArgument<int>(
"shape"))),
27 OperatorBase::GetRepeatedArgument<int>(
"extra_shape"))),
29 OperatorBase::GetSingleArgument<bool>(
"input_as_shape",
false)) {
31 if (shape_.size() != 0) {
33 "Cannot set the shape argument and pass in an input at " 37 if (!extra_shape_.empty()) {
38 CAFFE_THROW(
"Cannot set extra_shape when there is no input");
40 if (input_as_shape_) {
41 CAFFE_THROW(
"An input must be given if input_as_shape is true");
43 if (shape_.size() == 0 &&
44 OperatorBase::HasSingleArgumentOfType<int>(
"shape")) {
45 CAFFE_THROW(
"Fill 'shape' argument was a scalar, list expected");
51 USE_OPERATOR_CONTEXT_FUNCTIONS;
53 bool RunOnDevice()
override {
56 auto shape = vector<TIndex>{};
57 if (input_as_shape_) {
59 auto& input = OperatorBase::Input<Tensor<CPUContext>>(0);
63 "When input_as_shape is true, the input must be a 1D tensor of " 65 auto* shape_data = input.template data<TIndex>();
66 shape.insert(shape.end(), shape_data, shape_data + input.dim32(0));
68 auto& input = Input(0);
69 shape.insert(shape.end(), input.dims().begin(), input.dims().end());
71 shape.insert(shape.end(), extra_shape_.begin(), extra_shape_.end());
72 output->Resize(shape);
74 output->Resize(shape_);
82 vector<TIndex> shape_;
83 vector<TIndex> extra_shape_;
87 template <
typename T,
class Context>
90 USE_OPERATOR_CONTEXT_FUNCTIONS;
93 min_(OperatorBase::template GetSingleArgument<T>(
"min", 0)),
94 max_(OperatorBase::template GetSingleArgument<T>(
"max", 1)) {
95 if (InputSize() == 3) {
97 !OperatorBase::HasSingleArgumentOfType<T>(
"min"),
98 "Cannot set both min arg and min input blob");
100 !OperatorBase::HasSingleArgumentOfType<T>(
"max"),
101 "Cannot set both max arg and max input blob");
104 min_, max_,
"Max value should be bigger than min value.");
111 if (InputSize() == 3) {
112 CAFFE_ENFORCE_EQ(1, Input(1).size(),
"min blob must be scalar");
113 CAFFE_ENFORCE_EQ(1, Input(2).size(),
"max blob must be scalar");
114 min = *Input(1).template data<T>();
115 max = *Input(2).template data<T>();
117 auto shape = output->
dims();
120 output->template mutable_data<T>();
124 math::RandUniform<T, Context>(
128 output->template mutable_data<T>(),
138 template <
class Context>
141 USE_OPERATOR_CONTEXT_FUNCTIONS;
144 TensorProto_DataType dtype =
145 static_cast<TensorProto_DataType
>(OperatorBase::GetSingleArgument<int>(
146 "dtype", TensorProto_DataType_INT32));
149 case TensorProto_DataType_INT32:
151 body_ = &UniqueUniformFillOp::FillWithType<int>;
153 case TensorProto_DataType_INT64:
154 CheckRange<int64_t>();
155 body_ = &UniqueUniformFillOp::FillWithType<int64_t>;
157 case TensorProto_DataType_UNDEFINED:
159 "UniqueUniformFill op cannot have undefined 'dtype' argument");
162 CAFFE_THROW(
"Unexpected 'dtype' argument value: ", dtype);
167 return (this->*body_)(output);
171 template <
typename T>
173 CAFFE_ENFORCE(OperatorBase::HasSingleArgumentOfType<T>(
"min"));
174 CAFFE_ENFORCE(OperatorBase::HasSingleArgumentOfType<T>(
"max"));
176 OperatorBase::GetSingleArgument<T>(
"min", 0),
177 OperatorBase::GetSingleArgument<T>(
"max", 0),
178 "Max value should be bigger than min value.");
181 template <
typename T>
183 T min = OperatorBase::GetSingleArgument<T>(
"min", 0);
184 T max = OperatorBase::GetSingleArgument<T>(
"max", 0);
186 const T* avoid_data =
nullptr;
187 size_t avoid_size = 0;
188 if (InputSize() >= 2) {
189 auto& avoid = Input(1);
190 avoid_data = avoid.template data<T>();
191 avoid_size = avoid.size();
193 math::RandUniformUnique<T, Context>(
197 output->template mutable_data<T>(),
207 template <
class Context>
210 USE_OPERATOR_CONTEXT_FUNCTIONS;
213 TensorProto_DataType dtype =
214 static_cast<TensorProto_DataType
>(OperatorBase::GetSingleArgument<int>(
215 "dtype", TensorProto_DataType_FLOAT));
221 if (OperatorBase::HasSingleArgumentOfType<float>(
"value")) {
222 dtype = TensorProto_DataType_FLOAT;
223 }
else if (OperatorBase::HasSingleArgumentOfType<int64_t>(
"value")) {
224 dtype = TensorProto_DataType_INT64;
226 CAFFE_THROW(
"Argument 'value' is of unexpected type");
228 VLOG(1) <<
"Argument 'dtype' is not provided. Assume the data type is " 229 <<
"the same as that of argument 'value': " << dtype;
233 case TensorProto_DataType_FLOAT:
234 body_ = &ConstantFillOp::FillWithType<float>;
236 case TensorProto_DataType_DOUBLE:
237 body_ = &ConstantFillOp::FillWithType<double>;
239 case TensorProto_DataType_BOOL:
240 body_ = &ConstantFillOp::FillWithType<bool>;
242 case TensorProto_DataType_INT8:
243 body_ = &ConstantFillOp::FillWithType<int8_t>;
245 case TensorProto_DataType_INT16:
246 body_ = &ConstantFillOp::FillWithType<int16_t>;
248 case TensorProto_DataType_INT32:
249 body_ = &ConstantFillOp::FillWithType<int>;
251 case TensorProto_DataType_INT64:
252 body_ = &ConstantFillOp::FillWithType<int64_t>;
254 case TensorProto_DataType_UINT8:
255 body_ = &ConstantFillOp::FillWithType<uint8_t>;
257 case TensorProto_DataType_UINT16:
258 body_ = &ConstantFillOp::FillWithType<uint16_t>;
260 case TensorProto_DataType_STRING:
261 body_ = &ConstantFillOp::FillWithString;
263 case TensorProto_DataType_UNDEFINED:
264 CAFFE_THROW(
"ConstantFill op cannot have undefined 'dtype' argument");
267 CAFFE_THROW(
"Unexpected 'dtype' argument value: ", dtype);
272 return (this->*body_)(output);
275 template <
typename T>
277 T value = OperatorBase::GetSingleArgument<T>(
"value", 0);
278 auto* data = output->template mutable_data<T>();
279 if (output->
size()) {
280 math::Set<T, Context>(output->
size(), value, data, &context_);
286 auto value = OperatorBase::GetSingleArgument<std::string>(
"value",
"");
287 auto* data = output->template mutable_data<std::string>();
288 for (
int i = 0; i < output->
size(); ++i) {
298 template <
class Context>
301 USE_OPERATOR_CONTEXT_FUNCTIONS;
304 TensorProto_DataType dtype =
305 static_cast<TensorProto_DataType
>(OperatorBase::GetSingleArgument<int>(
306 "dtype", TensorProto_DataType_FLOAT));
312 if (OperatorBase::HasSingleArgumentOfType<float>(
"value")) {
313 dtype = TensorProto_DataType_FLOAT;
314 }
else if (OperatorBase::HasSingleArgumentOfType<int64_t>(
"value")) {
315 dtype = TensorProto_DataType_INT64;
317 CAFFE_THROW(
"Argument 'value' is of unexpected type");
319 VLOG(1) <<
"Argument 'dtype' is not provided. Assume the data type is " 320 <<
"the same as that of argument 'value': " << dtype;
324 case TensorProto_DataType_FLOAT:
325 body_ = &DiagonalFillOp::FillWithType<float>;
327 case TensorProto_DataType_DOUBLE:
328 body_ = &DiagonalFillOp::FillWithType<double>;
330 case TensorProto_DataType_BOOL:
331 body_ = &DiagonalFillOp::FillWithType<bool>;
333 case TensorProto_DataType_INT8:
334 body_ = &DiagonalFillOp::FillWithType<int8_t>;
336 case TensorProto_DataType_INT16:
337 body_ = &DiagonalFillOp::FillWithType<int16_t>;
339 case TensorProto_DataType_INT32:
340 body_ = &DiagonalFillOp::FillWithType<int>;
342 case TensorProto_DataType_INT64:
343 body_ = &DiagonalFillOp::FillWithType<int64_t>;
345 case TensorProto_DataType_UINT8:
346 body_ = &DiagonalFillOp::FillWithType<uint8_t>;
348 case TensorProto_DataType_UINT16:
349 body_ = &DiagonalFillOp::FillWithType<uint16_t>;
351 case TensorProto_DataType_UNDEFINED:
352 CAFFE_THROW(
"Cannot have undefined 'dtype' argument");
354 CAFFE_THROW(
"Unexpected 'dtype' argument value: ", dtype);
359 return (this->*body_)(output);
362 template <
typename T>
367 CAFFE_ENFORCE(output->
ndim() >= 2,
"Input shape must be >= 2D");
372 if (output->
ndim() == 2) {
373 step = output->
dim(1) + 1;
375 TIndex prev_i = output->
dim(0);
376 for (
auto i : output->
dims()) {
378 CAFFE_THROW(
"All dimensions of input must be of equal length");
381 vector<TIndex> cumprod(output->
ndim());
382 auto dims = output->
dims();
387 std::multiplies<TIndex>());
390 cumprod.begin(), cumprod.end(),
static_cast<TIndex
>(0));
399 template <
typename T,
class Context>
402 USE_OPERATOR_CONTEXT_FUNCTIONS;
405 mean_(OperatorBase::template GetSingleArgument<float>(
"mean", 0)),
406 std_(OperatorBase::template GetSingleArgument<float>(
"std", 1)) {
407 DCHECK_GT(std_, 0) <<
"Standard deviation should be nonnegative.";
411 math::RandGaussian<T, Context>(
415 output->template mutable_data<T>(),
425 template <
typename T,
class Context>
428 USE_OPERATOR_CONTEXT_FUNCTIONS;
433 const int fan_in = output->
size() / output->
dim32(0);
434 T scale = std::sqrt(T(3) / fan_in);
435 math::RandUniform<T, Context>(
439 output->template mutable_data<T>(),
445 template <
typename T,
class Context>
448 USE_OPERATOR_CONTEXT_FUNCTIONS;
453 const int fan_out = output->
size() / output->
dim32(1);
454 T scale = std::sqrt(T(2) / fan_out);
455 math::RandGaussian<T, Context>(
459 output->template mutable_data<T>(),
468 template <
typename T,
class Context>
471 USE_OPERATOR_CONTEXT_FUNCTIONS;
478 template <
class Context>
481 USE_OPERATOR_CONTEXT_FUNCTIONS;
484 bool RunOnDevice()
override {
485 auto& input = Input(0);
486 auto* output = Output(0);
487 auto* input_data = input.template data<int32_t>();
489 CAFFE_ENFORCE_EQ(input.ndim(), 1,
"Input must be a vector.");
491 auto len_sum = std::accumulate(input_data, input_data + input.size(), 0);
493 output->Resize(len_sum);
494 auto* output_data = output->template mutable_data<int32_t>();
497 for (
int i = 0; i < input.size(); ++i) {
498 auto len = input_data[i];
499 auto start = output_data + offset;
510 template <
int VALUE_TYPE = TensorProto_DataType_FLOAT>
511 inline std::vector<TensorShape> FillerTensorInference(
512 const OperatorDef& def,
513 const vector<TensorShape>& in) {
514 vector<TensorShape> out(1);
516 out[0].set_data_type(static_cast<TensorProto_DataType>(
517 helper.GetSingleArgument<
int>(
"dtype", VALUE_TYPE)));
521 bool input_as_shape =
522 helper.GetSingleArgument<
bool>(
"input_as_shape",
false);
523 if (input_as_shape) {
524 out[0].set_unknown_shape(
true);
527 for (
int d : in[0].dims()) {
531 auto shape = helper.GetRepeatedArgument<
int>(
"shape");
532 for (
int d : shape) {
541 #endif // CAFFE2_OPERATORS_FILLER_OP_H_
TIndex dim(const int i) const
Returns the i-th dimension of the tensor.
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
int dim32(const int i) const
Returns the i-th dimension of the tensor in int.
TIndex size() const
Returns the size (i.e.
A helper class to index into arguments.
const vector< TIndex > & dims() const
Returns the dimensions of the tensor as a vector.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
void Resize(Ts...dim_source)
Resizes a tensor.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
vector< TIndex > ToVectorTIndex(const std::vector< int > &src)
A utility function to convert vector<int> to vector<TIndex>.
int ndim() const
Returns the number of dimensions of the data.