1 #ifndef CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_ 2 #define CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_ 4 #include "caffe2/core/common.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/context_gpu.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/proto/caffe2.pb.h" 39 template <
class CPUOp,
typename SkipOutputCopy = SkipIndices<>>
45 CAFFE_ENFORCE_EQ(def.device_option().device_type(), CUDA);
46 OperatorDef base_def_(def);
48 base_def_.clear_device_option();
49 base_def_.mutable_device_option()->set_device_type(CPU);
51 for (
const string& name : def.input()) {
52 local_input_blobs_.push_back(local_ws_.
CreateBlob(name));
53 CHECK_NOTNULL(local_input_blobs_.back());
55 base_op_.reset(
new CPUOp(base_def_, &local_ws_));
56 for (
const string& name : def.output()) {
57 local_output_blobs_.push_back(local_ws_.
GetBlob(name));
58 CHECK_NOTNULL(local_output_blobs_.back());
62 bool RunOnDevice()
override {
63 bool need_sync =
false;
64 for (
int i = 0; i < InputSize(); ++i) {
65 if (OperatorBase::InputIsType<TensorCUDA>(i)) {
66 local_input_blobs_[i]->template GetMutable<TensorCPU>()->CopyFrom(
70 VLOG(1) <<
"Input " << i <<
" is not TensorCUDA. Skipping copy.";
74 local_input_blobs_[i]->ShareExternal(
75 const_cast<void*>(OperatorBase::Inputs()[i]->GetRaw()),
76 OperatorBase::Inputs()[i]->meta());
82 context_.FinishDeviceComputation();
85 if (!base_op_->Run()) {
86 LOG(ERROR) <<
"Base op run failed in GPUFallbackOp. Def: " 87 << ProtoDebugString(this->debug_def());
90 for (
int i = 0; i < OutputSize(); ++i) {
91 if (SkipOutputCopy::Contains(i)) {
92 VLOG(1) <<
"Copy output: index " << i <<
" skipped.";
96 local_output_blobs_[i]->
template IsType<TensorCPU>(),
97 "GPU fallback op currently does not support non-TensorCPU " 98 "output type who needs copying.");
100 local_output_blobs_[i]->
template Get<TensorCPU>(), &context_);
107 vector<Blob*> local_input_blobs_;
108 vector<Blob*> local_output_blobs_;
109 std::unique_ptr<CPUOp> base_op_;
114 #endif // CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_ Blob * CreateBlob(const string &name)
Creates a blob of the given name.
void CopyFrom(const Tensor< SrcContext > &src, ContextForCopy *context)
Copies the data from a source tensor, with a contex provided to carry out the underlying memcpy opera...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
A templated class to allow one to wrap a CPU operator as a CUDA operator.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...