Caffe2 - C++ API
A deep learning, cross platform ML framework
flatten_op.h
1 #ifndef CAFFE2_OPERATORS_FLATTEN_OP_H_
2 #define CAFFE2_OPERATORS_FLATTEN_OP_H_
3 
4 #include "caffe2/core/operator.h"
5 
6 namespace caffe2 {
7 
8 template <class Context>
9 class FlattenOp : public Operator<Context> {
10  public:
11  USE_OPERATOR_CONTEXT_FUNCTIONS;
12 
13  FlattenOp(const OperatorDef& operator_def, Workspace* ws)
14  : Operator<Context>(operator_def, ws),
15  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {}
16 
17  bool RunOnDevice() override {
18  auto& input = Input(0);
19  auto* output = Output(0);
20  CAFFE_ENFORCE_GE(
21  input.dims().size(), axis_, "The rank of the tensor must be >= axis.");
22  output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_));
23  context_.template CopyItems<Context, Context>(
24  input.meta(),
25  input.size(),
26  input.raw_data(),
27  output->raw_mutable_data(input.meta()));
28  return true;
29  }
30 
31  private:
32  int axis_;
33 };
34 
35 } // namespace caffe2
36 
37 #endif // CAFFE2_OPERATORS_FLATTEN_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...