Caffe2 - C++ API
A deep learning, cross platform ML framework
pack_segments.h
1 #ifndef CAFFE2_OPERATORS_PACK_SEGMENTS_H_
2 #define CAFFE2_OPERATORS_PACK_SEGMENTS_H_
3 
4 #include <atomic>
5 #include <limits>
6 #include <mutex>
7 #include <unordered_map>
8 #include <vector>
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/tensor.h"
11 #include "caffe2/utils/math.h"
12 
13 namespace caffe2 {
14 
15 template <class Context>
16 class PackSegmentsOp final : public Operator<Context> {
17  public:
18  USE_OPERATOR_CONTEXT_FUNCTIONS;
19  // USE_SIMPLE_CTOR_DTOR(PackSegmentsOp)
20  USE_DISPATCH_HELPER;
21 
22  PackSegmentsOp(const OperatorDef& operator_def, Workspace* ws)
23  : Operator<Context>(operator_def, ws),
24  pad_minf_(OperatorBase::GetSingleArgument<bool>("pad_minf", false)),
25  return_presence_mask_(OperatorBase::GetSingleArgument<bool>(
26  "return_presence_mask",
27  false)) {
28  if (pad_minf_) {
29  padding_ = -1.0 * std::numeric_limits<float>::infinity();
30  } else {
31  padding_ = 0;
32  }
33  }
34 
35  bool RunOnDevice() {
36  return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
37  }
38 
39  template <typename T>
40  bool DoRunWithType();
41 
42  template <typename T, typename Data_T>
43  bool DoRunWithType2();
44 
45  INPUT_TAGS(LENGTHS, DATA);
46 
47  private:
48  bool pad_minf_;
49  float padding_;
50  bool return_presence_mask_;
51 
52  // Scratch space required by the CUDA version
53  Tensor<Context> dev_buffer_;
54  Tensor<Context> dev_lengths_prefix_sum_;
55  Tensor<Context> dev_max_length_;
56  Tensor<CPUContext> host_max_length_;
57 };
58 
59 template <class Context>
60 class UnpackSegmentsOp final : public Operator<Context> {
61  public:
62  USE_OPERATOR_CONTEXT_FUNCTIONS;
63  USE_SIMPLE_CTOR_DTOR(UnpackSegmentsOp)
64  USE_DISPATCH_HELPER;
65 
66  bool RunOnDevice() override {
67  return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
68  }
69 
70  template <typename T>
71  bool DoRunWithType();
72 
73  template <typename T, typename Data_T>
74  bool DoRunWithType2();
75 
76  INPUT_TAGS(LENGTHS, DATA);
77 
78  private:
79  Tensor<Context> dev_buffer_;
80  Tensor<Context> dev_lengths_prefix_sum_;
81  Tensor<Context> dev_max_length_;
82  Tensor<Context> dev_num_cell_;
83  Tensor<CPUContext> host_max_length_;
84  Tensor<CPUContext> host_num_cell_;
85 };
86 
87 } // namespace caffe2
88 #endif // CAFFE2_OPERATORS_PACK_SEGMENTS_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...