2 #include "rewrite_net.h" 3 #include "caffe2/core/operator.h" 4 #include "caffe2/utils/proto_utils.h" 5 #include <unordered_map> 6 #include <unordered_set> 9 #include "../android/AndroidGLContext.h" 16 using BlobVersions = std::unordered_map<std::string, size_t>;
17 BlobVersions inVersions;
18 BlobVersions outVersions;
21 std::unordered_map<std::string, std::unordered_map<size_t, std::vector<size_t>>> inUsages;
24 static Analysis analyzeNet(
const NetDef& net) {
25 Analysis::SSA::BlobVersions frontier;
28 auto play = [&](
size_t i,
const OperatorDef& op) {
29 Analysis::SSA::BlobVersions inVersions;
30 for (
const auto& s : op.input()) {
31 inVersions[s] = frontier[s];
32 analysis.inUsages[s][frontier[s]].push_back(i);
34 Analysis::SSA::BlobVersions outVersions;
35 for (
const auto& s : op.output()) {
36 if (frontier.find(s) != frontier.end()) {
39 outVersions[s] = frontier[s];
41 analysis.ssa.push_back(Analysis::SSA{inVersions, outVersions});
44 for (
auto i = 0; i < net.op_size(); ++i) {
50 static void insertCopyToGPUOp(NetDef& predictNet,
const std::string& cpu_blob) {
51 auto* op = predictNet.add_op();
52 op->set_name(
"CopyToOpenGL");
53 op->set_type(
"CopyToOpenGL");
54 op->add_input(cpu_blob);
55 op->add_output(cpu_blob +
"_M");
58 static void insertCopyFromGPUOp(NetDef& predictNet,
const std::string& cpu_blob) {
61 auto* last_op = predictNet.mutable_op(predictNet.op_size() - 1);
62 auto* arg = last_op->add_arg();
63 arg->set_name(
"is_last");
66 auto* op = predictNet.add_op();
67 op->set_name(
"CopyFromOpenGL");
68 op->set_type(
"CopyFromOpenGL");
69 op->add_input(cpu_blob +
"_M");
70 op->add_output(cpu_blob);
73 static NetDef insertInputOutputCopyOps(
const NetDef& def, std::unordered_set<std::string>& glOps) {
79 CAFFE_ENFORCE_GE(def.external_input_size(), 1);
80 CAFFE_ENFORCE_GE(def.external_output_size(), 1);
81 auto analysis = analyzeNet(def);
83 CAFFE_ENFORCE_GE(def.op_size(), 1);
85 const auto& inputBlob = def.external_input(0);
87 CAFFE_ENFORCE(analysis.inUsages[inputBlob][0] == (std::vector<size_t>{0}));
89 const auto& outputBlob = def.external_output(0);
90 CAFFE_ENFORCE(analysis.ssa.back().outVersions.find(outputBlob) !=
91 analysis.ssa.back().outVersions.end());
92 const auto& outputBlobVersion = analysis.ssa.back().outVersions[outputBlob];
94 CAFFE_ENFORCE(analysis.inUsages[outputBlob].find(outputBlobVersion) ==
95 analysis.inUsages[outputBlob].end());
101 std::unordered_map<std::string, std::set<size_t>> cpu_blobs, gpu_blobs;
102 cpu_blobs[def.external_input(0)].insert(0);
104 for (
auto i = 0; i < def.op_size(); i++) {
105 const auto& currentOp = def.op(i);
106 if (glOps.count(currentOp.type()) > 0) {
109 for (
auto j = 0; j < currentOp.input_size(); j++) {
110 auto& input = currentOp.input(j);
111 auto version = analysis.ssa[i].inVersions[input];
112 if (cpu_blobs[input].count(version) > 0) {
113 insertCopyToGPUOp(mdef, input);
114 gpu_blobs[input].insert(version);
115 cpu_blobs[input].erase(version);
120 if (currentOp.type().find(
"OpenGLConv") == 0) {
127 auto* op = mdef.add_op();
128 op->CopyFrom(currentOp);
131 for (
auto j = 0; j < currentOp.input_size(); j++) {
132 auto& input = currentOp.input(j);
133 auto version = analysis.ssa[i].inVersions[input];
134 if (gpu_blobs[input].count(version) > 0) {
135 op->set_input(j, input +
"_M");
140 for (
auto j = 0; j < currentOp.output_size(); j++) {
141 auto& output = currentOp.output(j);
142 auto version = analysis.ssa[i].outVersions[output];
143 op->set_output(j, output +
"_M");
144 gpu_blobs[output].insert(version);
147 if (i == def.op_size() - 1) {
148 insertCopyFromGPUOp(mdef, currentOp.output(0));
153 for (
auto j = 0; j < currentOp.input_size(); j++) {
154 auto& input = currentOp.input(j);
155 auto version = analysis.ssa[i].inVersions[input];
156 if (gpu_blobs[input].count(version) > 0) {
157 insertCopyFromGPUOp(mdef, input);
160 auto* op = mdef.add_op();
161 op->CopyFrom(currentOp);
162 for (
auto j = 0; j < currentOp.output_size(); j++) {
163 auto& output = currentOp.output(j);
164 auto version = analysis.ssa[i].outVersions[output];
165 cpu_blobs[output].insert(version);
172 static bool tryFuseAdjacentOps(
const OperatorDef& currentOp,
173 const OperatorDef& nextOp,
174 OperatorDef* fusedOp,
175 std::unordered_set<std::string>& glOps) {
177 if (currentOp.output_size() != 1 || nextOp.output_size() != 1) {
181 if (currentOp.output(0) != nextOp.input(0) || currentOp.input(0) == nextOp.output(0)) {
185 static const std::map<std::pair<std::string, std::string>, std::string> fusionOpportunities = {
186 {{
"OpenGLInstanceNorm",
"OpenGLPRelu"},
"OpenGLInstanceNormPRelu"},
187 {{
"OpenGLConv",
"OpenGLPRelu"},
"OpenGLConvPRelu"},
188 {{
"OpenGLConv",
"OpenGLRelu"},
"OpenGLConvRelu"},
189 {{
"OpenGLConvTranspose",
"OpenGLPRelu"},
"OpenGLConvTransposePRelu"}};
190 auto it = fusionOpportunities.find({currentOp.type(), nextOp.type()});
191 if (it == fusionOpportunities.end()) {
195 glOps.insert(it->second);
196 fusedOp->CopyFrom(currentOp);
197 fusedOp->set_output(0, nextOp.output(0));
198 fusedOp->set_type(it->second);
199 for (
auto i = 1; i < nextOp.input_size(); i++) {
200 fusedOp->add_input(nextOp.input(i));
205 static NetDef runOpenGLFusion(
const NetDef& def, std::unordered_set<std::string>& glOps) {
206 CHECK_GE(def.op_size(), 1);
212 while (i < def.op_size()) {
213 if (i == def.op_size() - 1) {
214 VLOG(2) <<
"Last operator, skipping";
215 auto* op = mdef.add_op();
216 op->CopyFrom(def.op(i));
221 const auto& currentOp = def.op(i);
222 const auto& nextOp = def.op(i + 1);
224 if (tryFuseAdjacentOps(currentOp, nextOp, &fusedOp, glOps)) {
225 VLOG(2) <<
"Found an adjacent fusion for: " << currentOp.type() <<
", " << nextOp.type();
227 auto* op = mdef.add_op();
228 op->CopyFrom(fusedOp);
232 VLOG(2) <<
"No fusion available for: " << currentOp.type() <<
", " << nextOp.type();
234 auto* op = mdef.add_op();
235 op->CopyFrom(currentOp);
241 void dumpDefForOpenGL(
const NetDef& d) {
242 for (
const auto& op : d.op()) {
243 LOG(INFO) << op.input(0) <<
" -> " << op.type() <<
" -> " << op.output(0);
261 NetDef rewritePredictNetForOpenGL(
const NetDef& predictNet,
bool useTextureInput,
bool useTiling,
bool runFusion) {
262 CAFFE_ENFORCE_GE(predictNet.op_size(), 1);
264 net.CopyFrom(predictNet);
266 std::unordered_map<std::string, std::string> replacements(
267 {{
"OpenGLPackedInt8BGRANHWCToNCHWCStylizerPreprocess",
268 useTextureInput ?
"OpenGLTextureToTextureStylizerPreprocess" 269 :
"OpenGLTensorToTextureStylizerPreprocess"},
270 {
"OpenGLBRGNCHWCToPackedInt8BGRAStylizerDeprocess",
271 useTextureInput ?
"OpenGLTextureToTextureStylizerDeprocess" 272 :
"OpenGLTextureToTensorStylizerDeprocess"}});
274 std::unordered_set<std::string> openGLOps;
275 bool needCopyOps =
false;
277 const auto& opKeyList = CPUOperatorRegistry()->Keys();
278 auto opKeySet = std::set<std::string>(opKeyList.begin(), opKeyList.end());
280 #ifdef CAFFE2_ANDROID 283 if (context->get_platform() == Mali) {
284 opKeySet.erase(
"OpenGLInstanceNorm");
285 opKeySet.erase(
"OpenGLInstanceNormPRelu");
288 for (
auto i = 0; i < net.op_size(); ++i) {
289 auto* op = net.mutable_op(i);
290 string openGLOp = std::string(
"OpenGL") + op->type();
291 if (replacements.count(openGLOp) > 0) {
292 openGLOp = replacements[openGLOp];
295 if (opKeySet.find(openGLOp) != opKeySet.end()) {
296 op->set_type(openGLOp);
297 openGLOps.insert(openGLOp);
300 auto* arg = op->add_arg();
301 arg->set_name(
"tiling");
309 if (useTextureInput && needCopyOps) {
310 CAFFE_THROW(
"OpenGL operator missing");
314 net = runOpenGLFusion(net, openGLOps);
317 if (net.op(0).type() == replacements[
"OpenGLPackedInt8BGRANHWCToNCHWCStylizerPreprocess"]) {
319 if (net.op(net.op_size() - 1).type() !=
320 replacements[
"OpenGLBRGNCHWCToPackedInt8BGRAStylizerDeprocess"]) {
321 auto* last_op = net.mutable_op(net.op_size() - 1);
322 auto output = last_op->output(0) +
"M";
323 last_op->set_output(0, output);
324 auto* copy_op = net.add_op();
325 copy_op->set_name(
"CopyFromOpenGL");
326 copy_op->set_type(
"CopyFromOpenGL");
327 copy_op->add_input(output);
329 copy_op->add_output(net.external_output(0));
332 if (!useTextureInput) {
340 net = insertInputOutputCopyOps(net, openGLOps);
346 bool tryConvertToOpenGL(
const NetDef& initNet,
347 const NetDef& predictNet,
348 NetDef* glPredictNet,
349 bool useTextureInput,
354 *glPredictNet = rewritePredictNetForOpenGL(predictNet, useTextureInput, useTiling, runFusion);
355 dumpDefForOpenGL(*glPredictNet);
358 ws.RunNetOnce(initNet);
359 ws.CreateNet(*glPredictNet);
360 LOG(INFO) <<
"OpenGL is successfully enabled";
362 }
catch (
const std::exception& e) {
363 LOG(ERROR) <<
"Caught exception trying to convert NetDef to OpenGL: " << e.what();
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...