Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_tile_op.h
1 #ifndef CAFFE2_OPERATORS_LENGTHS_TILE_OP_H_
2 #define CAFFE2_OPERATORS_LENGTHS_TILE_OP_H_
3 
4 #include "caffe2/core/operator.h"
5 #include "caffe2/utils/math.h"
6 
7 namespace caffe2 {
8 
9 template <class Context>
10 class LengthsTileOp : public Operator<Context> {
11  public:
12  USE_OPERATOR_CONTEXT_FUNCTIONS;
13  USE_SIMPLE_CTOR_DTOR(LengthsTileOp);
14 
15  bool RunOnDevice() override {
16  auto& data = Input(DATA);
17  auto& lengths = Input(LENGTHS);
18  auto* output = Output(0);
19 
20  CAFFE_ENFORCE_EQ(lengths.ndim(), 1, "LENGTHS must be 1-D");
21  CAFFE_ENFORCE_GE(data.ndim(), 1, "DATA should be at least 1-D");
22  CAFFE_ENFORCE_EQ(lengths.size(), data.dim(0));
23 
24  // Context::CopyFrom and math::Sum need the same context to avoid race
25  // conditions
26  CPUContext cpuContext;
27  lengths_host_.CopyFrom(lengths, &cpuContext);
28  auto lengths_size = lengths_host_.size();
29  auto* lengths_data = lengths_host_.data<int32_t>();
30 
31  int32_t total_length = 0;
32  math::Sum<int32_t, CPUContext>(
33  lengths_size, lengths_data, &total_length, &cpuContext);
34 
35  auto shape = data.dims();
36  shape[0] = total_length;
37  output->Resize(shape);
38 
39  auto block_bytesize = data.size_from_dim(1) * data.meta().itemsize();
40  auto src = static_cast<const char*>(data.raw_data());
41  auto out = static_cast<char*>(output->raw_mutable_data(data.meta()));
42 
43  for (TIndex i = 0; i < lengths_size; ++i) {
44  auto length = lengths_data[i];
45  CAFFE_ENFORCE_GE(length, 0);
46  for (int32_t j = 0; j < length; ++j) {
47  context_.template CopyBytes<Context, Context>(block_bytesize, src, out);
48  out += block_bytesize;
49  }
50  src += block_bytesize;
51  }
52  return true;
53  }
54 
55  INPUT_TAGS(DATA, LENGTHS);
56 
57  private:
58  TensorCPU lengths_host_;
59 };
60 
61 } // namespace caffe2
62 
63 #endif // CAFFE2_OPERATORS_LENGTHS_TILE_OP_H_
const T * data() const
Returns a typed pointer of the underlying storage.
Definition: tensor.h:484
void CopyFrom(const Tensor< SrcContext > &src, ContextForCopy *context)
Copies the data from a source tensor, with a contex provided to carry out the underlying memcpy opera...
Definition: tensor.h:166
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:66
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:593
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...