2 #include "rewrite_net.h" 3 #include "caffe2/core/operator.h" 4 #include "caffe2/utils/proto_utils.h" 5 #include <unordered_map> 11 using BlobVersions = std::unordered_map<std::string, size_t>;
12 BlobVersions inVersions;
13 BlobVersions outVersions;
16 std::unordered_map<std::string, std::unordered_map<size_t, std::vector<size_t>>> inUsages;
19 static Analysis analyzeNet(
const NetDef& net) {
20 Analysis::SSA::BlobVersions frontier;
23 auto play = [&](
size_t i,
const OperatorDef& op) {
24 Analysis::SSA::BlobVersions inVersions;
25 for (
const auto& s : op.input()) {
26 inVersions[s] = frontier[s];
27 analysis.inUsages[s][frontier[s]].push_back(i);
29 Analysis::SSA::BlobVersions outVersions;
30 for (
const auto& s : op.output()) {
31 if (frontier.find(s) != frontier.end()) {
34 outVersions[s] = frontier[s];
36 analysis.ssa.push_back(
Analysis::SSA{inVersions, outVersions});
39 for (
auto i = 0; i < net.op_size(); ++i) {
45 static void insertCopyFromGLOp(NetDef& predictNet,
const std::string& cpu_blob) {
46 auto* op = predictNet.add_op();
47 op->set_name(
"CopyFromGL");
48 op->set_type(
"CopyFromGL");
49 op->add_input(cpu_blob +
"_M");
50 op->add_output(cpu_blob);
53 static NetDef insertInputOutputCopyOps(
const NetDef& def, std::unordered_set<std::string>& cpuOp) {
59 CAFFE_ENFORCE_GE(def.external_input_size(), 1);
60 CAFFE_ENFORCE_GE(def.external_output_size(), 1);
61 auto analysis = analyzeNet(def);
63 CAFFE_ENFORCE_GE(def.op_size(), 1);
65 const auto& inputBlob = def.external_input(0);
67 CAFFE_ENFORCE(analysis.inUsages[inputBlob][0] == (std::vector<size_t>{0}));
69 const auto& outputBlob = def.external_output(0);
70 CAFFE_ENFORCE(analysis.ssa.back().outVersions.find(outputBlob) !=
71 analysis.ssa.back().outVersions.end());
72 const auto& outputBlobVersion = analysis.ssa.back().outVersions[outputBlob];
74 CAFFE_ENFORCE(analysis.inUsages[outputBlob].find(outputBlobVersion) ==
75 analysis.inUsages[outputBlob].end());
81 std::unordered_map<std::string, std::set<size_t>> cpu_blobs, gpu_blobs;
82 cpu_blobs[def.external_input(0)].insert(0);
84 for (
auto i = 0; i < def.op_size(); i++) {
85 const auto& currentOp = def.op(i);
86 if (cpuOp.count(currentOp.type()) > 0) {
89 for (
auto j = 0; j < currentOp.input_size(); j++) {
90 auto& input = currentOp.input(j);
91 auto version = analysis.ssa[i].inVersions[input];
92 if (gpu_blobs[input].count(version) > 0) {
93 insertCopyFromGLOp(mdef, input);
96 auto* op = mdef.add_op();
97 op->CopyFrom(currentOp);
98 for (
auto j = 0; j < currentOp.output_size(); j++) {
99 auto& output = currentOp.output(j);
100 auto version = analysis.ssa[i].outVersions[output];
101 cpu_blobs[output].insert(version);
105 auto* op = mdef.add_op();
106 op->CopyFrom(currentOp);
108 for (
auto j = 0; j < op->input_size(); j++) {
109 auto* input = op->mutable_input(j);
110 auto version = analysis.ssa[i].inVersions[*input];
111 if (gpu_blobs[*input].count(version) > 0) {
112 *input = *input +
"_M";
116 for (
auto j = 0; j < currentOp.output_size(); j++) {
117 auto& output = currentOp.output(j);
118 auto version = analysis.ssa[i].outVersions[output];
119 gpu_blobs[output].insert(version);
121 auto* output_ = op->mutable_output(j);
123 for(
auto k = 0; k < def.external_output_size(); k++) {
124 if (*output_ == def.external_output(k)) {
129 *output_ = *output_ +
"_M";
137 static bool tryFuseAdjacentOps(
const OperatorDef& currentOp,
138 const OperatorDef& nextOp,
139 OperatorDef* fusedOp,
140 std::unordered_set<std::string>& glOps) {
142 if (currentOp.output_size() != 1 || nextOp.output_size() != 1) {
146 if (currentOp.output(0) != nextOp.input(0) || currentOp.input(0) == nextOp.output(0)) {
150 static const std::map<std::pair<std::string, std::string>, std::string> fusionOpportunities = {
151 {{
"OpenGLInstanceNorm",
"OpenGLPRelu"},
"OpenGLInstanceNormPRelu"},
152 {{
"OpenGLConv",
"OpenGLPRelu"},
"OpenGLConvPRelu"},
153 {{
"OpenGLConv",
"OpenGLRelu"},
"OpenGLConvRelu"},
154 {{
"OpenGLConvTranspose",
"OpenGLPRelu"},
"OpenGLConvTransposePRelu"}};
155 auto it = fusionOpportunities.find({currentOp.type(), nextOp.type()});
156 if (it == fusionOpportunities.end()) {
160 glOps.insert(it->second);
161 fusedOp->CopyFrom(currentOp);
162 fusedOp->set_output(0, nextOp.output(0));
163 fusedOp->set_type(it->second);
164 for (
auto i = 1; i < nextOp.input_size(); i++) {
165 fusedOp->add_input(nextOp.input(i));
170 static NetDef runOpenGLFusion(
const NetDef& def, std::unordered_set<std::string>& glOps) {
171 CHECK_GE(def.op_size(), 1);
177 while (i < def.op_size()) {
178 if (i == def.op_size() - 1) {
179 VLOG(2) <<
"Last operator, skipping";
180 auto* op = mdef.add_op();
181 op->CopyFrom(def.op(i));
186 const auto& currentOp = def.op(i);
187 const auto& nextOp = def.op(i + 1);
189 if (tryFuseAdjacentOps(currentOp, nextOp, &fusedOp, glOps)) {
190 VLOG(2) <<
"Found an adjacent fusion for: " << currentOp.type() <<
", " << nextOp.type();
192 auto* op = mdef.add_op();
193 op->CopyFrom(fusedOp);
197 VLOG(2) <<
"No fusion available for: " << currentOp.type() <<
", " << nextOp.type();
199 auto* op = mdef.add_op();
200 op->CopyFrom(currentOp);
206 void dumpDefForOpenGL(
const NetDef& d) {
207 for (
const auto& op : d.op()) {
208 LOG(INFO) << op.input(0) <<
" -> " << op.type() <<
" -> " << op.output(0);
226 NetDef rewritePredictNetForOpenGL(
const NetDef& predictNet,
bool runFusion, std::unordered_set<std::string> cpuOps) {
227 CAFFE_ENFORCE_GE(predictNet.op_size(), 1);
229 net.CopyFrom(predictNet);
235 net = insertInputOutputCopyOps(net, cpuOps);
236 net.set_type(
"opengl");
238 for (
auto i = 0; i < net.op().size(); ++i) {
239 auto op = net.mutable_op(i);
240 if (std::find(cpuOps.begin(), cpuOps.end(), op->type()) == cpuOps.end()) {
241 op->mutable_device_option()->set_device_type(OPENGL);
248 bool tryConvertToOpenGL(
const NetDef& predictNet,
249 NetDef* glPredictNet,
251 std::unordered_set<std::string> cpuOps) {
254 *glPredictNet = rewritePredictNetForOpenGL(predictNet, runFusion, cpuOps);
255 dumpDefForOpenGL(*glPredictNet);
257 LOG(INFO) <<
"OpenGL is successfully enabled";
259 }
catch (
const std::exception& e) {
260 LOG(ERROR) <<
"Caught exception trying to convert NetDef to OpenGL: " << e.what();
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...