Caffe2 - C++ API
A deep learning, cross platform ML framework
pack_rnn_sequence_op.h
1 #ifndef CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
2 #define CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
3 
4 #include <algorithm>
5 #include <vector>
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <class Context, bool Forward>
13 class PackRNNSequenceOpBase : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  PackRNNSequenceOpBase(const OperatorDef& operator_def, Workspace* ws)
17  : Operator<Context>(operator_def, ws) {}
18 
19  bool RunOnDevice() override {
21  this, Input(0));
22  }
23 
24  template <typename ValT>
25  bool DoRunWithType() {
26  // The value is copied from the sequence to the pack
27  // if Forward is true, and vice versa
28  int dim_offset = Forward ? 1 : 2;
29  auto& values = Input(0);
30  CAFFE_ENFORCE_GT(values.ndim(), dim_offset);
31 
32  // block_size is the size for each individual feature
33  TIndex block_size = values.size_from_dim(dim_offset);
34  auto values_vec = values.template data<ValT>();
35 
36  auto& lengths = Input(LENGTHS);
37  CAFFE_ENFORCE_EQ(lengths.ndim(), 1);
38  const auto cols = lengths.size();
39  const int32_t* lengths_vec = lengths.template data<int32_t>();
40  // the total number of rows is defined as the max number from lengths
41  // if when the lengths is empty, we set rows = 0 to support zero lengths
42  const auto rows =
43  cols ? *std::max_element(lengths_vec, lengths_vec + cols) : 0;
44  CAFFE_ENFORCE_GE(rows, 0);
45  int length_sum = 0;
46  if (cols > 0) {
47  math::Sum<int, Context>(cols, lengths_vec, &length_sum, &context_);
48  }
49 
50  vector<TIndex> shape;
51  // the output shape is rows * cols for the pack,
52  // or length_sum for the sequence
53  if (Forward) {
54  shape.push_back(rows);
55  shape.push_back(cols);
56  } else {
57  shape.push_back(length_sum);
58  }
59  // insert the dim for the feature
60  shape.insert(
61  shape.end(), values.dims().begin() + dim_offset, values.dims().end());
62 
63  auto* output = Output(OUTPUTVALUE);
64  output->Resize(shape);
65 
66  auto output_data = output->template mutable_data<ValT>();
67  // initialize output_data with zero, as it is the default value for padding
68  // when certain length is smaller than rows
69  math::Set<ValT, Context>(output->size(), 0, output_data, &context_);
70 
71  int32_t offset = 0;
72  for (int c = 0; c < cols; c++) {
73  for (int r = 0; r < lengths_vec[c]; r++) {
74  auto input_offset = Forward ? (offset + r) : (r * cols + c);
75  auto output_offset = Forward ? (r * cols + c) : (offset + r);
76  context_.template CopyItems<Context, Context>(
77  values.meta(),
78  block_size,
79  values_vec + input_offset * block_size,
80  output_data + output_offset * block_size);
81  }
82  offset += lengths_vec[c];
83  }
84  return true;
85  }
86 
87  private:
88  INPUT_TAGS(INPUTVALUE, LENGTHS);
89  OUTPUT_TAGS(OUTPUTVALUE);
90 };
91 } // namespace caffe2
92 
93 #endif // CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...