1 #ifndef CAFFE_OPERATORS_ONE_HOT_OPS_H_ 2 #define CAFFE_OPERATORS_ONE_HOT_OPS_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
class Context>
14 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 bool RunOnDevice()
override {
20 auto& indices = Input(0);
24 "indices input must be 1D tensor of data type TIndex");
27 auto& index_size_tensor = OperatorBase::Input<Tensor<CPUContext>>(1);
29 index_size_tensor.size(),
31 "index_size_tensor input must be scalar of data type TIndex");
33 auto batch_size = indices.size();
34 auto index_size = *index_size_tensor.template data<TIndex>();
35 auto one_hots = Output(0);
36 one_hots->Resize(batch_size, index_size);
37 auto output_size = one_hots->size();
38 if (output_size == 0) {
42 DoOneHotOp(batch_size, index_size, indices, one_hots);
54 template <
class Context>
57 USE_OPERATOR_CONTEXT_FUNCTIONS;
61 bool RunOnDevice()
override {
69 INPUT_TAGS(X, LENS, VALS);
74 std::vector<TIndex> valsOffsets_;
77 template <
class Context>
80 USE_OPERATOR_CONTEXT_FUNCTIONS;
84 bool RunOnDevice()
override;
87 INPUT_TAGS(X, LENS, BOUNDARIES);
93 #endif // CAFFE_OPERATORS_ONE_HOT_OPS_H_ Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...