Caffe2 - C++ API
A deep learning, cross platform ML framework
operator_schema.cc
1 #include "caffe2/core/operator_schema.h"
2 #include "caffe2/core/logging.h"
3 
4 namespace caffe2 {
5 
6 bool OpSchema::Verify(const OperatorDef& def) const {
7  // Check the number of inputs.
8  if (def.input_size() < min_input_ || def.input_size() > max_input_) {
9  LOG(ERROR) << "Input size " << def.input_size()
10  << " not in range [min=" << min_input_ << ", max="
11  << max_input_ << "].";
12  return false;
13  }
14  if (!num_inputs_allowed_(def.input_size())) {
15  LOG(ERROR) << "Input size " << def.input_size()
16  << " not in allowed input sizes.";
17  return false;
18  }
19  // Check the number of outputs.
20  if (def.output_size() < min_output_ || def.output_size() > max_output_) {
21  LOG(ERROR) << "Output size " << def.output_size()
22  << " not in range [min=" << min_output_ << ", max="
23  << max_output_ << "].";
24  return false;
25  }
26  if (!num_outputs_allowed_(def.output_size())) {
27  LOG(ERROR) << "Output size " << def.output_size()
28  << " not in allowed output sizes.";
29  return false;
30  }
31  if (!num_inputs_outputs_allowed_(def.input_size(), def.output_size())) {
32  LOG(ERROR) << "Combination of input size " << def.input_size()
33  << "and output size " << def.output_size() << " not in allowed.";
34  return false;
35  }
36  // If the number of outputs can be calculated, check if the number matches.
37  if (calculate_output_) {
38  int expected_nout = calculate_output_(def.input_size());
39  if (expected_nout != kCannotComputeNumOutputs &&
40  def.output_size() != expected_nout) {
41  LOG(ERROR) << "Output size " << def.output_size()
42  << " not matching expected output size, which is "
43  << expected_nout;
44  return false;
45  }
46  }
47 
48  // Check in-place settings.
49  for (int in_idx = 0; in_idx < def.input_size(); ++in_idx) {
50  for (int out_idx = 0; out_idx < def.output_size(); ++out_idx) {
51  // If an input is the same as an output but in-place is not opt-in
52  // either as allowed or enforced, we will fail the verification.
53  if (def.input(in_idx) == def.output(out_idx) &&
54  (!inplace_allowed_(in_idx, out_idx)
55  && !inplace_enforced_(in_idx, out_idx))) {
56  LOG(ERROR) << "Input index " << in_idx << " and output idx " << out_idx
57  << " (" << def.input(in_idx) << ")"
58  << " are set to be in-place but this is actually not "
59  << "supported by op " << def.type();
60  return false;
61  }
62  if (def.input(in_idx) != def.output(out_idx) &&
63  inplace_enforced_(in_idx, out_idx)) {
64  LOG(ERROR) << "Input index " << in_idx << " (" << def.input(in_idx) << ")"
65  << " and output idx " << out_idx
66  << " (" << def.output(in_idx) << ")"
67  << " are not in-place but should be as required by op "
68  << def.type();
69  return false;
70  }
71  }
72  }
73 
74  std::set<std::string> present_args{};
75  for (const auto& arg : def.arg()) {
76  present_args.insert(arg.name());
77  }
78 
79  for (const auto& arg : args()) {
80  if (arg.is_required() &&
81  present_args.find(arg.name()) == present_args.end()) {
82  LOG(ERROR) << "Argument '" << arg.name() << "' is required for Operator '"
83  << def.type() << "'.";
84  return false;
85  }
86  }
87 
88  // Phew. All verifications passed.
89  return true;
90 }
91 
92 OpSchema& OpSchema::NumInputs(int min, int max) {
93  min_input_ = min;
94  max_input_ = max;
95  return *this;
96 }
97 
99  return NumInputs(n, n);
100 }
101 
102 OpSchema& OpSchema::NumInputs(std::function<bool(int)> func) {
103  num_inputs_allowed_ = func;
104  return *this;
105 }
106 
107 OpSchema& OpSchema::NumInputs(set<int> allowed_input_nums) {
108  return NumInputs(
109  [allowed_input_nums](int n)->bool {
110  return allowed_input_nums.count(n);
111  });
112 }
113 
114 OpSchema& OpSchema::NumOutputs(int min, int max) {
115  min_output_ = min;
116  max_output_ = max;
117  return *this;
118 }
119 
121  return NumOutputs(n, n);
122 }
123 
124 OpSchema& OpSchema::NumOutputs(std::function<bool(int)> func) {
125  num_outputs_allowed_ = func;
126  return *this;
127 }
128 
129 OpSchema& OpSchema::NumOutputs(set<int> allowed_output_nums) {
130  return NumOutputs(
131  [allowed_output_nums](int n)->bool {
132  return allowed_output_nums.count(n);
133  });
134 }
135 
136 OpSchema& OpSchema::NumInputsOutputs(std::function<bool(int, int)> func) {
137  num_inputs_outputs_allowed_ = func;
138  return *this;
139 }
140 
141 OpSchema& OpSchema::OutputCalculator(std::function<int(int)> calc) {
142  calculate_output_ = calc;
143  return *this;
144 }
145 
147  return OutputCalculator([](int n)->int { return n; } );
148 }
149 
150 OpSchema& OpSchema::AllowInplace(std::function<bool(int, int)> inplace) {
151  inplace_allowed_ = inplace;
152  return *this;
153 }
154 
155 OpSchema& OpSchema::AllowInplace(set<std::pair<int, int>> inplace) {
156  return AllowInplace(
157  [inplace](int in, int out)->bool {
158  return inplace.count(std::make_pair(in, out));
159  });
160 }
161 
162 OpSchema& OpSchema::AllowOneToOneInplace() {
163  return AllowInplace([](int in, int out) { return in == out; });
164 }
165 
166 OpSchema& OpSchema::EnforceInplace(std::function<bool(int, int)> inplace) {
167  inplace_enforced_ = inplace;
168  return *this;
169 }
170 
171 OpSchema& OpSchema::EnforceInplace(set<std::pair<int, int>> inplace) {
172  return EnforceInplace(
173  [inplace](int in, int out)->bool {
174  return inplace.count(std::make_pair(in, out));
175  });
176 }
177 
178 OpSchema& OpSchema::EnforceOneToOneInplace() {
179  return EnforceInplace([](int in, int out) { return in == out; });
180 }
181 
182 OpSchema& OpSchema::Private() {
183  private_ = true;
184  return *this;
185 }
186 
187 OpSchema& OpSchema::InputsCanCrossDevices() {
188  inputs_can_cross_devices_ = true;
189  return *this;
190 }
191 
193  TensorInferenceFunctionType function) {
194  tensor_inference_function_ = function;
195  return *this;
196 }
197 
198 OpSchema& OpSchema::InheritOnnxSchema(const std::string& onnx_schema_name) {
199  onnx_schema_ = onnx_schema_name;
200  return *this;
201 }
202 
205  [](const OperatorDef&, const vector<TensorShape>& input_types) {
206  return vector<TensorShape>(input_types);
207  });
208 }
209 
210 OpSchema& OpSchema::IdenticalTypeAndShapeOfInput(int idx) {
212  [idx](const OperatorDef&, const vector<TensorShape>& input_types) {
213  vector<TensorShape> out(1);
214  out[0] = input_types[idx];
215  return out;
216  });
217 }
218 
219 OpSchema& OpSchema::IdenticalTypeAndShapeOfInputDim(int idx, int dim) {
221  [idx, dim](const OperatorDef&, const vector<TensorShape>& input_types) {
222  vector<TensorShape> out(1);
223  out[0].add_dims(input_types[idx].dims(dim));
224  out[0].set_data_type(input_types[idx].data_type());
225  return out;
226  });
227 }
228 
229 OpSchema& OpSchema::ScalarType(::caffe2::TensorProto_DataType dt) {
231  [dt](const OperatorDef&, const vector<TensorShape>& /*input_types*/) {
232  vector<TensorShape> out(1);
233  out[0].set_data_type(dt);
234  return out;
235  });
236 }
237 
239  cost_inference_function_ =
240  caffe2::make_unique<CostInferenceFunctionType>(function);
241  return *this;
242 }
243 
244 OpSchema& OpSchema::DeviceInferenceFunction(
245  DeviceInferenceFunctionType function) {
246  device_inference_function_ = function;
247  return *this;
248 }
249 
250 OpSchema& OpSchema::SetDoc(const string& doc) {
251  doc_ = doc;
252  return *this;
253 }
254 
255 OpSchema&
256 OpSchema::Arg(const char* name, const char* description, bool required) {
257  args_.push_back(Argument(name, description, required));
258  return *this;
259 }
260 
261 #define DEFINE_STANDARG_ARG(name, str) \
262  CAFFE2_API const char* OpSchema::Arg_##name = #str; \
263  CAFFE2_API OpSchema& OpSchema::Arg##name(const char* description) { \
264  return Arg(#str, description, true); \
265  }
266 
267 DEFINE_STANDARG_ARG(IsTest, is_test)
268 
269 #undef DEFINE_STANDARG_ARG
270 
271 OpSchema& OpSchema::Input(const int n, const char* name, const char* description) {
272  if (input_desc_.size() <= n) {
273  input_desc_.resize(n + 1);
274  }
275  input_desc_[n] = std::make_pair(name, description);
276  return *this;
277 }
278 
279 OpSchema& OpSchema::Output(const int n, const char* name, const char* description) {
280  if (output_desc_.size() <= n) {
281  output_desc_.resize(n + 1);
282  }
283  output_desc_[n] = std::make_pair(name, description);
284  return *this;
285 }
286 
287 OpSchema& OpSchema::FillUsing(std::function<void(OpSchema&)> populator) {
288  if (populator) {
289  populator(*this);
290  }
291  return *this;
292 }
293 
294 int OpSchema::CalculateOutput(int num_input) const {
295  if (min_output_ == max_output_) {
296  return min_output_;
297  } else if (calculate_output_) {
298  return calculate_output_(num_input);
299  } else {
300  return kCannotComputeNumOutputs;
301  }
302 }
303 
304 std::ostream& operator<<(std::ostream& out, const OpSchema& schema) {
305  if (!schema.args().empty()) {
306  out << "Arguments:" << std::endl;
307  for (const auto& arg : schema.args()) {
308  out << " " << arg.name() << " : " << arg.description() << std::endl;
309  }
310  }
311  if (schema.max_input_ > 0) {
312  out << "Inputs:" << std::endl;
313  if (!schema.input_desc_.empty()) {
314  for (int i = 0; i < schema.input_desc_.size(); ++i) {
315  const auto& p = schema.input_desc_[i];
316  out << " " << i << ", " << (p.first ? p.first : "(unnamed)") << " : "
317  << (p.second ? p.second : "(no doc)") << std::endl;
318  }
319  } else {
320  out << " (no explicit description available)" << std::endl;
321  }
322  }
323  if (schema.max_output_ > 0) {
324  out << "Outputs:" << std::endl;
325  if (!schema.output_desc_.empty()) {
326  for (int i = 0; i < schema.output_desc_.size(); ++i) {
327  const auto& p = schema.output_desc_[i];
328  out << " " << i << ", " << (p.first ? p.first : "(unnamed)") << " : "
329  << (p.second ? p.second : "(no doc)") << std::endl;
330  }
331  } else {
332  out << " (no explicit description available)" << std::endl;
333  }
334  }
335  out << std::endl;
336  if (schema.doc()) {
337  out << schema.doc();
338  } else {
339  out << "(no documentation yet)" << std::endl;
340  }
341  out << std::endl;
342  if (schema.line_) {
343  out << "Defined at " << schema.file_ << ":" << schema.line_ << std::endl;
344  }
345  return out;
346 }
347 
348 CaffeMap<string, OpSchema>& OpSchemaRegistry::map() {
349  static CaffeMap<string, OpSchema> map;
350  return map;
351 }
352 
353 } // namespace caffe2
std::function< std::pair< std::vector< DeviceOption >, std::vector< DeviceOption >>(const OperatorDef &def)> DeviceInferenceFunctionType
Returns the required device location of inputs and outputs.
OpSchema & NumInputs(int n)
A single input.
A class to record the schema of an op.
bool Verify(const OperatorDef &def) const
Verifies if an operator definition protobuf matches the pattern specified in the schema.
const char * doc() const
Returns the docstring of the op schema.
OpSchema & OutputCalculator(std::function< int(int)> calc)
Set the output calculator to a user-defined function.
OpSchema & IdenticalTypeAndShape()
Sets the tensor inference function to produce the same output as the input.
OpSchema & SameNumberOfOutput()
Set the number of outputs to be the same as the number of inputs.
OpSchema & InheritOnnxSchema(const std::string &onnx_schema_name)
Sets the corresponding onnx schema name.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
OpSchema & CostInferenceFunction(CostInferenceFunctionType function)
Register the Cost inference function.
OpSchema & NumInputsOutputs(std::function< bool(int, int)> func)
Relationship between inputs and outputs is checked with a specified function.
OpSchema & TensorInferenceFunction(TensorInferenceFunctionType function)
Sets the tensor inference function, which is a std::function object defined in operator_schema.h.
int CalculateOutput(int num_input) const
A function to allow one to get the number of outputs based on the number of inputs, if this schema supports it.
OpSchema & NumOutputs(int n)
A single output.
std::function< struct Cost(const OperatorDef &, const vector< TensorShape > &)> CostInferenceFunctionType
Registers a function that takes in an OperatorDef and a series of input shapes and returns the total ...