1 #ifndef CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_ 2 #define CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 9 template <
class Context>
12 USE_OPERATOR_CONTEXT_FUNCTIONS;
15 dims_(OperatorBase::GetRepeatedArgument<int>(
"dims")) {
16 auto originalSize = dims_.size();
17 CAFFE_ENFORCE(originalSize > 0,
"Parameter `dims` must be provided.");
18 std::sort(dims_.begin(), dims_.end());
19 dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
20 if (dims_.size() < originalSize) {
21 LOG(WARNING) <<
"Parameter `dims` has repeated dimensions.";
23 CAFFE_ENFORCE(dims_.front() >= 0,
"Dimension ids must be non-negative.");
26 bool RunOnDevice()
override {
27 auto& input = Input(0);
28 auto* output = Output(0);
29 output->CopyFrom(input, &context_);
34 auto newDims = input.dims();
36 input.dims().size() + dims_.size(),
38 "Input needs at least ",
39 (1 + dims_.back() - dims_.size()),
40 " dimensions given `dims`.");
41 for (
const auto dim : dims_) {
42 newDims.insert(newDims.begin() + dim, 1);
44 output->Reshape(newDims);
52 template <
class Context>
55 USE_OPERATOR_CONTEXT_FUNCTIONS;
58 dims_(OperatorBase::GetRepeatedArgument<int>(
"dims")) {
59 auto originalSize = dims_.size();
60 CAFFE_ENFORCE(originalSize > 0,
"Parameter `dims` must be provided.");
62 std::sort(dims_.begin(), dims_.end());
63 dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
64 if (dims_.size() < originalSize) {
65 LOG(WARNING) <<
"Parameter `dims` has repeated dimensions.";
67 CAFFE_ENFORCE(dims_.front() >= 0,
"Dimension ids must be non-negative.");
70 bool RunOnDevice()
override {
71 auto& input = Input(0);
72 auto* output = Output(0);
73 output->CopyFrom(input, &context_);
78 "Input needs at least ",
82 std::vector<int> newDims = ComputeDims(input.dims(), dims_);
83 output->Reshape(newDims);
87 static std::vector<int> ComputeDims(
88 std::vector<TIndex> inputDims,
89 std::vector<int> dims) {
91 std::vector<int> newDims;
92 for (
int i = 0; i < inputDims.size(); ++i) {
93 if (j < dims.size() && dims[j] == i) {
99 " of input must be 1",
106 newDims.push_back(inputDims.at(i));
118 #endif // CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...