1 #include "caffe2/core/operator_schema.h" 2 #include "caffe2/core/logging.h" 8 if (def.input_size() < min_input_ || def.input_size() > max_input_) {
9 LOG(ERROR) <<
"Input size " << def.input_size()
10 <<
" not in range [min=" << min_input_ <<
", max=" 11 << max_input_ <<
"].";
14 if (!num_inputs_allowed_(def.input_size())) {
15 LOG(ERROR) <<
"Input size " << def.input_size()
16 <<
" not in allowed input sizes.";
20 if (def.output_size() < min_output_ || def.output_size() > max_output_) {
21 LOG(ERROR) <<
"Output size " << def.output_size()
22 <<
" not in range [min=" << min_output_ <<
", max=" 23 << max_output_ <<
"].";
26 if (!num_outputs_allowed_(def.output_size())) {
27 LOG(ERROR) <<
"Output size " << def.output_size()
28 <<
" not in allowed output sizes.";
31 if (!num_inputs_outputs_allowed_(def.input_size(), def.output_size())) {
32 LOG(ERROR) <<
"Combination of input size " << def.input_size()
33 <<
"and output size " << def.output_size() <<
" not in allowed.";
37 if (calculate_output_) {
38 int expected_nout = calculate_output_(def.input_size());
39 if (expected_nout != kCannotComputeNumOutputs &&
40 def.output_size() != expected_nout) {
41 LOG(ERROR) <<
"Output size " << def.output_size()
42 <<
" not matching expected output size, which is " 49 for (
int in_idx = 0; in_idx < def.input_size(); ++in_idx) {
50 for (
int out_idx = 0; out_idx < def.output_size(); ++out_idx) {
53 if (def.input(in_idx) == def.output(out_idx) &&
54 (!inplace_allowed_(in_idx, out_idx)
55 && !inplace_enforced_(in_idx, out_idx))) {
56 LOG(ERROR) <<
"Input index " << in_idx <<
" and output idx " << out_idx
57 <<
" (" << def.input(in_idx) <<
")" 58 <<
" are set to be in-place but this is actually not " 59 <<
"supported by op " << def.type();
62 if (def.input(in_idx) != def.output(out_idx) &&
63 inplace_enforced_(in_idx, out_idx)) {
64 LOG(ERROR) <<
"Input index " << in_idx <<
" (" << def.input(in_idx) <<
")" 65 <<
" and output idx " << out_idx
66 <<
" (" << def.output(in_idx) <<
")" 67 <<
" are not in-place but should be as required by op " 74 std::set<std::string> present_args{};
75 for (
const auto& arg : def.arg()) {
76 present_args.insert(arg.name());
79 for (
const auto& arg : args()) {
80 if (arg.is_required() &&
81 present_args.find(arg.name()) == present_args.end()) {
82 LOG(ERROR) <<
"Argument '" << arg.name() <<
"' is required for Operator '" 83 << def.type() <<
"'.";
103 num_inputs_allowed_ = func;
109 [allowed_input_nums](
int n)->
bool {
110 return allowed_input_nums.count(n);
125 num_outputs_allowed_ = func;
131 [allowed_output_nums](
int n)->
bool {
132 return allowed_output_nums.count(n);
137 num_inputs_outputs_allowed_ = func;
142 calculate_output_ = calc;
150 OpSchema& OpSchema::AllowInplace(std::function<
bool(
int,
int)> inplace) {
151 inplace_allowed_ = inplace;
155 OpSchema& OpSchema::AllowInplace(
set<std::pair<int, int>> inplace) {
157 [inplace](
int in,
int out)->
bool {
158 return inplace.count(std::make_pair(in, out));
162 OpSchema& OpSchema::AllowOneToOneInplace() {
163 return AllowInplace([](
int in,
int out) {
return in == out; });
166 OpSchema& OpSchema::EnforceInplace(std::function<
bool(
int,
int)> inplace) {
167 inplace_enforced_ = inplace;
171 OpSchema& OpSchema::EnforceInplace(
set<std::pair<int, int>> inplace) {
172 return EnforceInplace(
173 [inplace](
int in,
int out)->
bool {
174 return inplace.count(std::make_pair(in, out));
178 OpSchema& OpSchema::EnforceOneToOneInplace() {
179 return EnforceInplace([](
int in,
int out) {
return in == out; });
187 OpSchema& OpSchema::InputsCanCrossDevices() {
188 inputs_can_cross_devices_ =
true;
193 TensorInferenceFunctionType
function) {
194 tensor_inference_function_ =
function;
199 onnx_schema_ = onnx_schema_name;
205 [](
const OperatorDef&,
const vector<TensorShape>& input_types) {
206 return vector<TensorShape>(input_types);
210 OpSchema& OpSchema::IdenticalTypeAndShapeOfInput(
int idx) {
212 [idx](
const OperatorDef&,
const vector<TensorShape>& input_types) {
213 vector<TensorShape> out(1);
214 out[0] = input_types[idx];
219 OpSchema& OpSchema::IdenticalTypeAndShapeOfInputDim(
int idx,
int dim) {
221 [idx, dim](
const OperatorDef&,
const vector<TensorShape>& input_types) {
222 vector<TensorShape> out(1);
223 out[0].add_dims(input_types[idx].dims(dim));
224 out[0].set_data_type(input_types[idx].data_type());
229 OpSchema& OpSchema::ScalarType(::caffe2::TensorProto_DataType dt) {
231 [dt](
const OperatorDef&,
const vector<TensorShape>& ) {
232 vector<TensorShape> out(1);
233 out[0].set_data_type(dt);
239 cost_inference_function_ =
240 caffe2::make_unique<CostInferenceFunctionType>(
function);
244 OpSchema& OpSchema::DeviceInferenceFunction(
246 device_inference_function_ =
function;
256 OpSchema::Arg(
const char* name,
const char* description,
bool required) {
257 args_.push_back(
Argument(name, description, required));
261 #define DEFINE_STANDARG_ARG(name, str) \ 262 CAFFE2_API const char* OpSchema::Arg_##name = #str; \ 263 CAFFE2_API OpSchema& OpSchema::Arg##name(const char* description) { \ 264 return Arg(#str, description, true); \ 267 DEFINE_STANDARG_ARG(IsTest, is_test)
269 #undef DEFINE_STANDARG_ARG 271 OpSchema& OpSchema::Input(
const int n,
const char* name,
const char* description) {
272 if (input_desc_.size() <= n) {
273 input_desc_.resize(n + 1);
275 input_desc_[n] = std::make_pair(name, description);
279 OpSchema& OpSchema::Output(
const int n,
const char* name,
const char* description) {
280 if (output_desc_.size() <= n) {
281 output_desc_.resize(n + 1);
283 output_desc_[n] = std::make_pair(name, description);
295 if (min_output_ == max_output_) {
297 }
else if (calculate_output_) {
298 return calculate_output_(num_input);
300 return kCannotComputeNumOutputs;
304 std::ostream& operator<<(std::ostream& out,
const OpSchema& schema) {
305 if (!schema.args().empty()) {
306 out <<
"Arguments:" << std::endl;
307 for (
const auto& arg : schema.args()) {
308 out <<
" " << arg.name() <<
" : " << arg.description() << std::endl;
311 if (schema.max_input_ > 0) {
312 out <<
"Inputs:" << std::endl;
313 if (!schema.input_desc_.empty()) {
314 for (
int i = 0; i < schema.input_desc_.size(); ++i) {
315 const auto& p = schema.input_desc_[i];
316 out <<
" " << i <<
", " << (p.first ? p.first :
"(unnamed)") <<
" : " 317 << (p.second ? p.second :
"(no doc)") << std::endl;
320 out <<
" (no explicit description available)" << std::endl;
323 if (schema.max_output_ > 0) {
324 out <<
"Outputs:" << std::endl;
325 if (!schema.output_desc_.empty()) {
326 for (
int i = 0; i < schema.output_desc_.size(); ++i) {
327 const auto& p = schema.output_desc_[i];
328 out <<
" " << i <<
", " << (p.first ? p.first :
"(unnamed)") <<
" : " 329 << (p.second ? p.second :
"(no doc)") << std::endl;
332 out <<
" (no explicit description available)" << std::endl;
339 out <<
"(no documentation yet)" << std::endl;
343 out <<
"Defined at " << schema.file_ <<
":" << schema.line_ << std::endl;
348 CaffeMap<string, OpSchema>& OpSchemaRegistry::map() {
349 static CaffeMap<string, OpSchema> map;
std::function< std::pair< std::vector< DeviceOption >, std::vector< DeviceOption >>(const OperatorDef &def)> DeviceInferenceFunctionType
Returns the required device location of inputs and outputs.
OpSchema & NumInputs(int n)
A single input.
A class to record the schema of an op.
bool Verify(const OperatorDef &def) const
Verifies if an operator definition protobuf matches the pattern specified in the schema.
const char * doc() const
Returns the docstring of the op schema.
OpSchema & OutputCalculator(std::function< int(int)> calc)
Set the output calculator to a user-defined function.
OpSchema & IdenticalTypeAndShape()
Sets the tensor inference function to produce the same output as the input.
OpSchema & SameNumberOfOutput()
Set the number of outputs to be the same as the number of inputs.
OpSchema & InheritOnnxSchema(const std::string &onnx_schema_name)
Sets the corresponding onnx schema name.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
OpSchema & CostInferenceFunction(CostInferenceFunctionType function)
Register the Cost inference function.
OpSchema & NumInputsOutputs(std::function< bool(int, int)> func)
Relationship between inputs and outputs is checked with a specified function.
OpSchema & TensorInferenceFunction(TensorInferenceFunctionType function)
Sets the tensor inference function, which is a std::function object defined in operator_schema.h.
int CalculateOutput(int num_input) const
A function to allow one to get the number of outputs based on the number of inputs, if this schema supports it.
OpSchema & NumOutputs(int n)
A single output.
std::function< struct Cost(const OperatorDef &, const vector< TensorShape > &)> CostInferenceFunctionType
Registers a function that takes in an OperatorDef and a series of input shapes and returns the total ...