1 #ifndef CAFFE2_OPERATORS_REVERSE_PACKED_SEGS_OP_H_ 2 #define CAFFE2_OPERATORS_REVERSE_PACKED_SEGS_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 9 template <
class Context>
12 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 bool RunOnDevice()
override {
22 bool DoRunWithType() {
23 if (Input(LENGTHS).
template IsType<int>()) {
24 DoRunWithLengthType<T, int>();
26 DoRunWithLengthType<T, long>();
32 INPUT_TAGS(DATA, LENGTHS);
34 template <
typename T,
typename LengthType>
35 void DoRunWithLengthType() {
36 const auto& data = Input(DATA);
37 const auto& lengths = Input(LENGTHS);
41 "DATA should be 3-D tensor <lengths, " 42 "segments, embeddings>");
43 CAFFE_ENFORCE(lengths.ndim() == 1,
"LENGTH should be 1-D");
45 auto* output = Output(0);
46 const auto& shape = data.dims();
47 output->Resize(shape);
49 const auto& max_length = data.dims()[0];
50 const auto& batch_size = data.dims()[1];
51 const auto& block_size = data.dims()[2];
53 lengths.dims()[0] == batch_size,
54 "lenths size should be" 55 " equal to batch size");
57 const T* data_ptr = data.template data<T>();
58 const LengthType* lengths_ptr = lengths.template data<LengthType>();
60 vector<LengthType> lengths_host(batch_size);
61 context_.template Copy<LengthType, Context, CPUContext>(
62 batch_size, lengths_ptr, &lengths_host[0]);
63 context_.FinishDeviceComputation();
65 T* rev_data_ptr = output->template mutable_data<T>();
66 for (TIndex i = 0; i < batch_size; i++) {
67 const auto& seg_length = lengths_host[i];
68 CAFFE_ENFORCE_LE(seg_length, max_length);
70 for (; j < seg_length; j++) {
71 const T* data_block_ptr = data_ptr + (j * batch_size + i) * block_size;
72 T* rev_data_block_ptr =
73 rev_data_ptr + ((seg_length - 1 - j) * batch_size + i) * block_size;
74 context_.template Copy<T, Context, Context>(
75 block_size, data_block_ptr, rev_data_block_ptr);
77 for (; j < max_length; j++) {
78 const T* data_block_ptr = data_ptr + (j * batch_size + i) * block_size;
79 T* rev_data_block_ptr =
80 rev_data_ptr + (j * batch_size + i) * block_size;
81 context_.template Copy<T, Context, Context>(
82 block_size, data_block_ptr, rev_data_block_ptr);
90 #endif // CAFFE2_OPERATORS_REVERSE_PACKED_SEGS_OP_H_ A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...