4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 12 template <
class SIndex,
class Context>
21 bool backward = output ==
nullptr;
23 auto* starts_data = starts.template data<SIndex>();
24 auto* ends_data = ends.template data<SIndex>();
26 CAFFE_ENFORCE_EQ(starts.ndim(), 1);
27 CAFFE_ENFORCE_EQ(ends.ndim(), 1);
28 CAFFE_ENFORCE_GE(data.ndim(), starts.size());
29 CAFFE_ENFORCE_EQ(starts.size(), ends.size());
31 std::vector<SIndex> starts_idx(data.ndim());
32 std::vector<SIndex> ends_idx(data.ndim());
33 std::vector<SIndex> dst_sizes(data.ndim());
35 for (
int i = 0; i < data.ndim(); ++i) {
36 if (i >= starts.size()) {
38 ends_idx[i] = data.dims()[i];
41 if (data.dims()[i] > 0) {
42 auto start = starts_data[i];
43 auto end = ends_data[i];
45 start = data.dims()[i] + 1 + start;
48 end = data.dims()[i] + 1 + end;
50 if (start > data.dims()[i]) {
51 start = data.dims()[i];
53 if (end > data.dims()[i]) {
56 CAFFE_ENFORCE_GE(start, 0);
57 CAFFE_ENFORCE_GE(end, 0);
58 CAFFE_ENFORCE_GE(end, start);
59 starts_idx[i] = start;
61 dst_sizes[i] = end - start;
69 if (data.size() <= 0) {
72 output->Resize(dst_sizes);
73 output->raw_mutable_data(data.meta());
79 for (
int i = 0; i < data.ndim(); ++i) {
80 if (starts_idx[i] > 0 || ends_idx[i] < data.dims()[i]) {
82 dim, -1,
"Currently only possible to slice in 1 dimension.");
88 output->CopyFrom(data, context);
90 gdata->CopyFrom(*go, context);
94 size_t unit = std::accumulate(
95 data.dims().begin() + dim + 1,
98 std::multiplies<SIndex>());
99 size_t num_blocks = std::accumulate(
101 data.dims().begin() + dim,
103 std::multiplies<SIndex>());
105 output->Resize(dst_sizes);
107 gdata->ResizeLike(data);
110 size_t itemsize = data.meta().itemsize();
113 char* src_bytes = (
char*)data.raw_data();
114 char* dst_bytes = (
char*)output->raw_mutable_data(data.meta());
116 size_t src_nbytes = data.nbytes();
117 size_t dst_nbytes = output->nbytes();
119 size_t src_block_size = unit * data.dims()[dim];
120 size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
121 size_t src_offset = unit * starts_idx[dim];
123 if (num_blocks == 0 || dst_block_size == 0) {
127 size_t src_block_size_bytes = itemsize * src_block_size;
128 size_t dst_block_size_bytes = itemsize * dst_block_size;
130 char* src_offset_bytes = src_bytes + itemsize * src_offset;
131 char* dst_offset_bytes = dst_bytes;
132 for (
int i = 0; i < num_blocks; ++i) {
133 char* local_src_offset_bytes =
134 src_offset_bytes + i * src_block_size_bytes;
135 char* local_dst_offset_bytes =
136 dst_offset_bytes + i * dst_block_size_bytes;
138 static_cast<void*>(local_src_offset_bytes + dst_block_size_bytes),
139 static_cast<void*>(src_bytes + src_nbytes));
141 static_cast<void*>(local_dst_offset_bytes + dst_block_size_bytes),
142 static_cast<void*>(dst_bytes + dst_nbytes));
143 context->template CopyItems<Context, Context>(
146 (
void*)local_src_offset_bytes,
147 (
void*)local_dst_offset_bytes);
150 char* src_bytes = (
char*)go->raw_data();
151 char* dst_bytes = (
char*)gdata->raw_mutable_data(go->meta());
153 size_t src_nbytes = go->nbytes();
154 size_t dst_nbytes = gdata->nbytes();
156 size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
157 size_t dst_block_size = unit * data.dims()[dim];
158 size_t dst_offset = unit * starts_idx[dim];
160 if (num_blocks == 0 || dst_block_size == 0) {
164 size_t src_block_size_bytes = itemsize * src_block_size;
165 size_t dst_block_size_bytes = itemsize * dst_block_size;
167 char* src_offset_bytes = src_bytes;
168 char* dst_offset_bytes = dst_bytes + itemsize * dst_offset;
171 math::Set<char, Context>(dst_nbytes, 0, dst_bytes, context);
178 for (
int i = 0; i < num_blocks; ++i) {
179 char* local_src_offset_bytes =
180 src_offset_bytes + i * src_block_size_bytes;
181 char* local_dst_offset_bytes =
182 dst_offset_bytes + i * dst_block_size_bytes;
184 local_src_offset_bytes + src_block_size_bytes,
185 src_bytes + src_nbytes);
187 local_dst_offset_bytes + src_block_size_bytes,
188 dst_bytes + dst_nbytes);
189 context->template CopyItems<Context, Context>(
192 (
void*)local_src_offset_bytes,
193 (
void*)local_dst_offset_bytes);
201 template <class SIndex, class Context>
204 USE_OPERATOR_CONTEXT_FUNCTIONS;
207 starts_(OperatorBase::GetRepeatedArgument<SIndex>(
"starts")),
208 ends_(OperatorBase::GetRepeatedArgument<SIndex>(
"ends")),
209 statically_inited_(
false) {}
211 bool RunOnDevice()
override {
212 auto* output = Output(0);
213 auto& data = Input(0);
215 if (InputSize() > 1) {
216 starts_host_.template CopyFrom<Context>(Input(1));
217 ends_host_.template CopyFrom<Context>(Input(2));
219 if (!statically_inited_) {
220 CAFFE_ENFORCE(HasArgument(
"starts"));
221 CAFFE_ENFORCE(HasArgument(
"ends"));
222 CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
224 starts_host_.Resize(starts_.size());
225 ends_host_.Resize(ends_.size());
228 starts_host_.template mutable_data<SIndex>(),
230 sizeof(SIndex) * starts_.size());
232 ends_host_.template mutable_data<SIndex>(),
234 sizeof(SIndex) * ends_.size());
235 statically_inited_ =
true;
239 return SliceImpl<SIndex, Context>(
240 output, data, starts_host_, ends_host_, &context_);
243 DISABLE_COPY_AND_ASSIGN(
SliceOp);
246 std::vector<SIndex> starts_;
247 std::vector<SIndex> ends_;
248 bool statically_inited_;
253 template <
class SIndex,
class Context>
256 USE_OPERATOR_CONTEXT_FUNCTIONS;
259 starts_(OperatorBase::GetRepeatedArgument<SIndex>(
"starts")),
260 ends_(OperatorBase::GetRepeatedArgument<SIndex>(
"ends")),
261 statically_inited_(
false) {}
263 bool RunOnDevice()
override {
264 auto* gdata = Output(0);
265 auto& data = Input(0);
267 if (InputSize() == 4) {
268 starts_host_.template CopyFrom<Context>(Input(1));
269 ends_host_.template CopyFrom<Context>(Input(2));
273 return SliceImpl<SIndex, Context>(
274 nullptr, data, starts_host_, ends_host_, &context_, gdata, &go);
276 if (!statically_inited_) {
279 CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
281 starts_host_.
Resize(starts_.size());
282 ends_host_.
Resize(ends_.size());
285 starts_host_.template mutable_data<SIndex>(),
287 sizeof(SIndex) * starts_.size());
289 ends_host_.template mutable_data<SIndex>(),
291 sizeof(SIndex) * ends_.size());
293 statically_inited_ =
true;
297 return SliceImpl<SIndex, Context>(
298 nullptr, data, starts_host_, ends_host_, &context_, gdata, &go);
305 std::vector<SIndex> starts_;
306 std::vector<SIndex> ends_;
307 bool statically_inited_;
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
void Resize(Ts...dim_source)
Resizes a tensor.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.