Caffe2 - C++ API
A deep learning, cross platform ML framework
variable_length_sequence_padding.h
1 #pragma once
2 
3 #include "caffe2/core/context.h"
4 #include "caffe2/core/operator.h"
5 #include "caffe2/utils/math.h"
6 
7 namespace caffe2 {
8 namespace detail {
9 
10 template <typename T, typename Context>
11 void VariableLengthSequencePadding(
12  int N,
13  int B,
14  int M,
15  T* X,
16  const int32_t* seqLengths,
17  const T padValue,
18  Context* /*context*/) {
19  for (int j = 0; j < B; j++) {
20  for (int i = seqLengths[j]; i < N; i++) {
21  EigenVectorArrayMap<T>(X + B * M * i + M * j, M).setConstant(padValue);
22  }
23  }
24 }
25 
26 } // namespace detail
27 
28 template <typename T, typename Context>
29 class VariableLengthSequencePaddingOp : public Operator<Context> {
30  public:
32  const OperatorDef& operator_def,
33  Workspace* ws)
34  : Operator<Context>(operator_def, ws) {}
35  USE_OPERATOR_CONTEXT_FUNCTIONS;
36 
37  bool RunOnDevice() override {
38  const auto N = Input(INPUT).dim(0);
39  const auto B = Input(INPUT).dim(1);
40  const auto M = Input(INPUT).dim(2);
41 
42  auto X = Output(OUTPUT)->template mutable_data<T>();
43 
44  auto seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
45 
46  detail::VariableLengthSequencePadding<T, Context>(
47  N, B, M, X, seqLengths, 0, &context_);
48  return true;
49  }
50 
51  protected:
52  INPUT_TAGS(INPUT, SEQ_LENGTHS);
53  OUTPUT_TAGS(OUTPUT);
54 };
55 
56 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...