1 #ifndef CAFFE2_OPERATORS_FIND_OP_H_ 2 #define CAFFE2_OPERATORS_FIND_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 8 #include <unordered_map> 12 template <
class Context>
18 OperatorBase::GetSingleArgument<int>(
"missing_value", -1)) {}
19 USE_OPERATOR_CONTEXT_FUNCTIONS;
28 bool DoRunWithType() {
30 auto& needles = Input(1);
31 auto* res_indices = Output(0);
32 res_indices->ResizeLike(needles);
34 const T* idx_data = idx.template data<T>();
35 const T* needles_data = needles.template data<T>();
36 T* res_data = res_indices->template mutable_data<T>();
37 auto idx_size = idx.size();
42 if (needles.size() < 16) {
44 for (
int i = 0; i < needles.size(); i++) {
45 T x = needles_data[i];
46 T res =
static_cast<T
>(missing_value_);
47 for (
int j = idx_size - 1; j >= 0; j--) {
48 if (idx_data[j] == x) {
57 std::unordered_map<T, int> idx_map;
58 for (
int j = 0; j < idx_size; j++) {
59 idx_map[idx_data[j]] = j;
61 for (
int i = 0; i < needles.size(); i++) {
62 T x = needles_data[i];
63 auto it = idx_map.find(x);
64 res_data[i] = (it == idx_map.end() ? missing_value_ : it->second);
77 #endif // CAFFE2_OPERATORS_FIND_OP_H_ Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...