5 #include "caffe2/core/context.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 10 template <
typename T,
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 OperatorBase::GetSingleArgument<int>(
"categorical_limit", 0)) {
19 CAFFE_ENFORCE_GT(categorical_limit_, 0);
22 bool RunOnDevice()
override {
23 auto& keys = Input(0);
25 const T* keys_data = keys.template data<T>();
26 std::vector<int> counts(categorical_limit_);
27 std::vector<int*> eids(categorical_limit_);
28 for (
int k = 0; k < categorical_limit_; k++) {
31 for (
int i = 0; i < N; i++) {
33 CAFFE_ENFORCE_GT(categorical_limit_, k);
34 CAFFE_ENFORCE_GE(k, 0);
37 for (
int k = 0; k < categorical_limit_; k++) {
38 auto* eid = Output(k);
39 eid->Resize(counts[k]);
40 eids[k] = eid->template mutable_data<int>();
43 for (
int i = 0; i < N; i++) {
45 eids[k][counts[k]++] = i;
51 int categorical_limit_;
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...