Caffe2 - C++ API
A deep learning, cross platform ML framework
operator_fallback_gpu.h
1 #ifndef CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
2 #define CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
3 
4 #include "caffe2/core/common.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/context_gpu.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/proto/caffe2.pb.h"
9 
10 namespace caffe2 {
11 
39 template <class CPUOp, typename SkipOutputCopy = SkipIndices<>>
40 class GPUFallbackOp final : public Operator<CUDAContext> {
41  public:
42  USE_OPERATOR_FUNCTIONS(CUDAContext);
43  GPUFallbackOp(const OperatorDef& def, Workspace* ws)
44  : Operator<CUDAContext>(def, ws) {
45  CAFFE_ENFORCE_EQ(def.device_option().device_type(), CUDA);
46  OperatorDef base_def_(def);
47  // base_def_ runs on CPU, so we will set its device option to CPU.
48  base_def_.clear_device_option();
49  base_def_.mutable_device_option()->set_device_type(CPU);
50  // Set up the symbols for the local workspace.
51  for (const string& name : def.input()) {
52  local_input_blobs_.push_back(local_ws_.CreateBlob(name));
53  CHECK_NOTNULL(local_input_blobs_.back());
54  }
55  base_op_.reset(new CPUOp(base_def_, &local_ws_));
56  for (const string& name : def.output()) {
57  local_output_blobs_.push_back(local_ws_.GetBlob(name));
58  CHECK_NOTNULL(local_output_blobs_.back());
59  }
60  }
61 
62  bool RunOnDevice() override {
63  bool need_sync = false;
64  for (int i = 0; i < InputSize(); ++i) {
65  if (OperatorBase::InputIsType<TensorCUDA>(i)) {
66  local_input_blobs_[i]->template GetMutable<TensorCPU>()->CopyFrom(
67  Input(i), &context_);
68  need_sync = true;
69  } else {
70  VLOG(1) << "Input " << i << " is not TensorCUDA. Skipping copy.";
71  // Note(jiayq): This removes a const but conceptually
72  // local_input_blobs will only be used as const blob input for the
73  // base op so we are still fine.
74  local_input_blobs_[i]->ShareExternal(
75  const_cast<void*>(OperatorBase::Inputs()[i]->GetRaw()),
76  OperatorBase::Inputs()[i]->meta());
77  }
78  }
79 
80  // Sync to make sure copies are done.
81  if (need_sync) {
82  context_.FinishDeviceComputation();
83  }
84 
85  if (!base_op_->Run()) {
86  LOG(ERROR) << "Base op run failed in GPUFallbackOp. Def: "
87  << ProtoDebugString(this->debug_def());
88  return false;
89  }
90  for (int i = 0; i < OutputSize(); ++i) {
91  if (SkipOutputCopy::Contains(i)) {
92  VLOG(1) << "Copy output: index " << i << " skipped.";
93  continue;
94  }
95  CAFFE_ENFORCE(
96  local_output_blobs_[i]->template IsType<TensorCPU>(),
97  "GPU fallback op currently does not support non-TensorCPU "
98  "output type who needs copying.");
99  Output(i)->CopyFrom(
100  local_output_blobs_[i]->template Get<TensorCPU>(), &context_);
101  }
102  return true;
103  }
104 
105  protected:
106  Workspace local_ws_;
107  vector<Blob*> local_input_blobs_;
108  vector<Blob*> local_output_blobs_;
109  std::unique_ptr<CPUOp> base_op_;
110 };
111 
112 } // namespace caffe2
113 
114 #endif // CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:104
void CopyFrom(const Tensor< SrcContext > &src, ContextForCopy *context)
Copies the data from a source tensor, with a contex provided to carry out the underlying memcpy opera...
Definition: tensor.h:166
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:164
A templated class to allow one to wrap a CPU operator as a CUDA operator.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...