1 #include "caffe2/mobile/contrib/ios/ios_caffe_predictor.h" 3 #include "caffe2/core/tensor.h" 5 #if defined(CAFFE2_USE_MPSCNN) && CAFFE2_MOBILE 6 #include "caffe2/mobile/contrib/ios/mpscnn/mpscnn.h" 9 CAFFE2_DECLARE_bool(caffe2_force_shared_col_buffer);
12 const caffe2::NetDef& predict_net,
13 bool disableMultithreadProcessing,
14 bool allowMetalOperators) {
15 caffe2::NetDef metal_predict_net;
16 bool usingMetalOperators =
false;
17 #if defined(CAFFE2_USE_MPSCNN) && CAFFE2_MOBILE 18 if (allowMetalOperators) {
19 caffe2::dumpDef(predict_net);
20 if (caffe2::tryConvertToMPSCNN(init_net, predict_net, &metal_predict_net)) {
21 LOG(INFO) <<
"Successfully converted to MPSCNN";
22 caffe2::dumpDef(metal_predict_net);
23 usingMetalOperators =
true;
25 LOG(ERROR) <<
"Failed converting model to MPSCNN";
31 usingMetalOperators ? metal_predict_net : predict_net,
32 disableMultithreadProcessing,
36 Caffe2IOSPredictor::Caffe2IOSPredictor(
const caffe2::NetDef& init_net,
37 const caffe2::NetDef& predict_net,
38 bool disableMultithreadProcessing,
39 bool usingMetalOperators)
40 : usingMetalOperators(usingMetalOperators), predictor_(init_net, predict_net) {
42 if (disableMultithreadProcessing) {
44 if (threadpool !=
nullptr) {
45 threadpool->setMinWorkSize(std::numeric_limits<size_t>::max());
51 void Caffe2IOSPredictor::run(
const Tensor& inData,
Tensor& outData, std::string& errorMessage) {
52 caffe2::FLAGS_caffe2_force_shared_col_buffer =
true;
56 caffe2::Predictor::TensorVector input_vec{&input};
57 caffe2::Predictor::TensorVector output_vec;
59 predictor_.run(input_vec, &output_vec);
61 std::string error = e.msg();
62 errorMessage.swap(error);
64 }
catch (
const std::exception& e) {
65 std::string error = e.what();
66 errorMessage.swap(error);
71 outData.dims = output->
dims();
void ShareExternalPointer(T *src, size_t capacity=0, Deleter d=nullptr)
Shares the data with an externally managed pointer.
T * mutable_data()
Returns a typed pointer of the underlying storage.
const vector< TIndex > & dims() const
Returns the dimensions of the tensor as a vector.
void Resize(Ts...dim_source)
Resizes a tensor.
Commandline flags support for Caffe2.
static Caffe2IOSPredictor * NewCaffe2IOSPredictor(const caffe2::NetDef &init_net, const caffe2::NetDef &predict_net, bool disableMultithreadProcessing, bool allowMetalOperators)
Allow converting eligible operators to Metal GPU framework accelerated operators. ...