1 #ifndef CAFFE2_OPERATORS_SEQUENCE_OPS_H_ 2 #define CAFFE2_OPERATORS_SEQUENCE_OPS_H_ 4 #include "caffe2/core/operator.h" 5 #include "caffe2/core/tensor.h" 6 #include "caffe2/utils/math.h" 10 template <
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
17 OperatorBase::GetSingleArgument<int>(
"padding_width", 1)),
19 OperatorBase::GetSingleArgument<int>(
"end_padding_width", -1)) {
20 CAFFE_ENFORCE_GE(startPaddingWidth_, 0);
21 if (endPaddingWidth_ < 0) {
22 endPaddingWidth_ = startPaddingWidth_;
26 bool RunOnDevice()
override {
27 if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
28 Output(0)->Resize(std::vector<TIndex>(0));
29 Output(0)->template mutable_data<TIndex>();
30 if (OutputSize() == 2) {
31 Output(1)->Resize(std::vector<TIndex>(0));
32 Output(1)->template mutable_data<TIndex>();
41 bool DoRunWithType() {
42 const auto& in = Input(0);
43 CAFFE_ENFORCE_GE(in.ndim(), 1);
44 const int32_t outer_size = in.dims()[0];
45 const auto block_size = in.size_from_dim(1);
46 const auto pad_width = startPaddingWidth_ + endPaddingWidth_;
49 const int32_t* lengths_ptr = &outer_size;
50 int64_t lengths_size = 1;
51 if (InputSize() > 1) {
52 const auto& lengths = Input(1);
53 lengths_ptr = lengths.template data<int32_t>();
54 lengths_size = lengths.size();
56 std::vector<TIndex> padShape(in.dims().begin() + 1, in.dims().end());
58 Output(0)->Resize(padShape);
59 T* padding_start_ptr = Output(0)->template mutable_data<T>();
60 math::Set<T, Context>(block_size, 0.0, padding_start_ptr, &context_);
63 T* padding_end_ptr = padding_start_ptr;
64 if (OutputSize() == 2) {
65 Output(1)->Resize(padShape);
66 padding_end_ptr = Output(1)->template mutable_data<T>();
67 math::Set<T, Context>(block_size, 0.0, padding_end_ptr, &context_);
74 in.template data<T>(),
85 const int lengths_size,
89 const int* lengths_ptr,
93 int startPaddingWidth_;
100 template <
class Context>
103 USE_OPERATOR_CONTEXT_FUNCTIONS;
107 OperatorBase::GetSingleArgument<int>(
"padding_width", 1)),
109 OperatorBase::GetSingleArgument<int>(
"end_padding_width", -1)) {
110 CAFFE_ENFORCE_GE(startPaddingWidth_, 0);
111 if (endPaddingWidth_ < 0) {
112 endPaddingWidth_ = startPaddingWidth_;
116 bool RunOnDevice()
override {
117 if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
118 Output(0)->CopyFrom(Input(0), &context_);
119 if (OutputSize() == 2) {
120 Output(1)->CopyFrom(Input(1), &context_);
128 template <
typename T>
129 bool DoRunWithType();
132 int startPaddingWidth_;
133 int endPaddingWidth_;
140 template <
class Context>
143 USE_OPERATOR_CONTEXT_FUNCTIONS;
147 OperatorBase::GetSingleArgument<int>(
"padding_width", 1)),
149 OperatorBase::GetSingleArgument<int>(
"end_padding_width", -1)) {
150 CAFFE_ENFORCE_GE(startPaddingWidth_, 0);
151 if (endPaddingWidth_ < 0) {
152 endPaddingWidth_ = startPaddingWidth_;
156 bool RunOnDevice()
override {
157 if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
158 Output(0)->CopyFrom(Input(0), &context_);
159 if (OutputSize() == 2) {
160 Output(1)->CopyFrom(Input(1), &context_);
168 template <
typename T>
169 bool DoRunWithType() {
170 const auto& in = Input(0);
171 CAFFE_ENFORCE_GE(in.ndim(), 1);
172 const int32_t outer_size = in.dims()[0];
173 const auto block_size = in.size_from_dim(1);
176 const int32_t* lengths_ptr =
nullptr;
177 int32_t lengths_size = 1;
178 if (InputSize() > 1) {
179 const auto& lengths = Input(1);
180 lengths_ptr = lengths.template data<int32_t>();
181 lengths_size = lengths.size();
188 const T* padding_start_ptr =
nullptr;
189 const T* padding_end_ptr =
nullptr;
190 if (InputSize() >= 3) {
191 auto& padding_start = Input(2);
192 CAFFE_ENFORCE_EQ(block_size, padding_start.size());
193 padding_start_ptr = padding_start.template data<T>();
195 if (InputSize() == 4) {
196 auto& padding_end = Input(3);
197 CAFFE_ENFORCE_EQ(block_size, padding_end.size());
198 padding_end_ptr = padding_end.template data<T>();
200 padding_end_ptr = padding_start_ptr;
203 auto* out = Output(0);
205 auto out_dims = in.dims();
206 out_dims[0] += (startPaddingWidth_ + endPaddingWidth_) * lengths_size;
207 out->Resize(std::move(out_dims));
209 const auto* in_ptr = in.template data<T>();
210 auto* out_ptr = out->template mutable_data<T>();
212 return MakePadding<T>(
224 template <
typename T>
228 const int32_t* lengths_ptr,
229 int32_t lengths_size,
231 const T* padding_start_ptr,
232 const T* padding_end_ptr,
235 int startPaddingWidth_;
236 int endPaddingWidth_;
243 template <
class Context>
246 USE_OPERATOR_CONTEXT_FUNCTIONS;
250 bool RunOnDevice()
override;
255 #endif // CAFFE2_OPERATORS_SEQUENCE_OPS_H_ Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...