1 #ifndef CAFFE2_CORE_OPERATOR_SCHEMA_H_ 2 #define CAFFE2_CORE_OPERATOR_SCHEMA_H_ 6 #include <initializer_list> 10 #include <unordered_map> 12 #include "caffe2/core/common.h" 13 #include "caffe2/core/logging.h" 14 #include "caffe2/core/registry.h" 15 #include "caffe2/proto/caffe2.pb.h" 21 constexpr
int kCannotComputeNumOutputs = -1;
40 OpSchema() : file_(
"unknown"), line_(0) {}
46 inline const string&
file()
const {
60 inline const char*
doc()
const {
61 return doc_.empty() ?
nullptr : doc_.c_str();
68 bool Verify(
const OperatorDef& def)
const;
128 OpSchema& AllowInplace(std::function<
bool(
int,
int)> inplace);
129 OpSchema& AllowInplace(
set<std::pair<int, int>> inplace);
132 OpSchema& EnforceInplace(std::function<
bool(
int,
int)> inplace);
133 OpSchema& EnforceInplace(
set<std::pair<int, int>> inplace);
140 typedef std::function<
141 vector<TensorShape>(
const OperatorDef&,
const vector<TensorShape>&)>
142 TensorInferenceFunctionType;
160 OpSchema& IdenticalTypeAndShapeOfInput(
int idx);
161 OpSchema& IdenticalTypeAndShapeOfInputDim(
int idx,
int dim);
162 OpSchema& ScalarType(::caffe2::TensorProto_DataType dt);
169 const OperatorDef& def,
170 const vector<TensorShape>& input_type_shape)
const {
171 return tensor_inference_function_(def, input_type_shape);
180 uint64_t bytes_moved;
181 uint64_t params_bytes;
188 typedef std::function<
189 struct Cost(const OperatorDef&,
const vector<TensorShape>&)>
197 #if 0 // def _MSC_VER 201 template <
typename T,
202 typename = std::enable_if<
203 std::is_same<CostInferenceFunctionType&&, T>:value
213 bool HasCostInferenceFunction()
const {
214 return !!cost_inference_function_;
217 inline struct Cost InferCost(
218 const OperatorDef& def,
219 const vector<TensorShape>& input_tensor_shape)
const {
221 cost_inference_function_,
"Cost inference function not defined.");
222 return (*cost_inference_function_)(def, input_tensor_shape);
229 Argument(
const char* name,
const char* description,
bool required)
230 : name_{name}, description_{description}, required_{required} {}
232 const char* name()
const {
236 const char* description()
const {
240 bool is_required()
const {
246 const char* description_;
247 const bool required_;
251 Arg(
const char* name,
const char* description,
bool required =
false);
253 #define DECLARE_STANDARD_ARG(name, str) \ 254 CAFFE2_API static const char* Arg_##name; \ 255 CAFFE2_API OpSchema& Arg##name(const char* description); 257 DECLARE_STANDARD_ARG(IsTest, is_test)
259 #undef DECLARE_STANDARD_ARG 261 OpSchema& Input(
const int n,
const char* name,
const char* description);
262 OpSchema& Output(
const int n,
const char* name,
const char* description);
279 const std::string& onnx_schema()
const {
283 int min_input()
const {
287 int max_input()
const {
291 int min_output()
const {
295 int max_output()
const {
299 bool num_inputs_allowed(
int x)
const {
300 return num_inputs_allowed_(x);
303 bool num_outputs_allowed(
int x)
const {
304 return num_outputs_allowed_(x);
307 bool num_inputs_outputs_allowed(
int x,
int y)
const {
308 return num_inputs_outputs_allowed_(x, y);
312 return std::numeric_limits<int>::max();
315 friend std::ostream& operator<<(std::ostream& out,
const OpSchema& schema);
317 const std::vector<Argument>& args()
const {
321 const std::vector<std::pair<const char*, const char*>>& input_desc()
const {
324 const std::vector<std::pair<const char*, const char*>>& output_desc()
const {
330 bool inputs_can_cross_devices()
const {
331 return inputs_can_cross_devices_;
338 std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>>(
339 const OperatorDef& def)>;
346 inline std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>>
348 return device_inference_function_(def);
355 std::vector<Argument> args_{};
356 std::vector<std::pair<const char*, const char*>> input_desc_{};
357 std::vector<std::pair<const char*, const char*>> output_desc_{};
360 int max_input_ = std::numeric_limits<int>::max();
362 int max_output_ = std::numeric_limits<int>::max();
363 bool private_ =
false;
364 bool inputs_can_cross_devices_ =
false;
365 std::function<bool(int)> num_inputs_allowed_ = [](int) {
return true; };
366 std::function<bool(int)> num_outputs_allowed_ = [](int) {
return true; };
367 std::function<bool(int, int)> num_inputs_outputs_allowed_ = [](int, int) {
370 std::function<int(int)> calculate_output_;
372 std::function<bool(int, int)> inplace_allowed_ = [](int, int) {
375 std::function<bool(int, int)> inplace_enforced_ = [](int, int) {
378 TensorInferenceFunctionType tensor_inference_function_ =
379 [](
const OperatorDef& def,
const vector<TensorShape>&) {
380 vector<TensorShape> out;
381 for (
int i = 0; i < def.output_size(); i++) {
383 ts.set_unknown_shape(
true);
388 std::unique_ptr<CostInferenceFunctionType> cost_inference_function_ =
nullptr;
390 [](
const OperatorDef& def) {
392 def.has_device_option() ? def.device_option() : DeviceOption();
393 vector<DeviceOption> in_dev(def.input_size(), op_device);
394 vector<DeviceOption> out_dev(def.output_size(), op_device);
395 return std::make_pair(in_dev, out_dev);
405 NewSchema(
const string& key,
const string& file,
const int line) {
407 auto it = m.find(key);
409 const auto& schema = it->second;
410 std::ios_base::Init init;
411 std::cerr <<
"Trying to register schema with name " << key
412 <<
" from file " << file <<
" line " << line
413 <<
", but it is already registered from file " << schema.file()
414 <<
" line " << schema.line();
417 m.emplace(std::make_pair(key,
OpSchema(file, line)));
421 static const OpSchema* Schema(
const string& key) {
423 auto it = m.find(key);
445 static CaffeMap<string, OpSchema>& map();
449 template <
typename T_I =
int>
450 inline TensorShape CreateTensorShape(
452 ::caffe2::TensorProto_DataType dt) {
457 ts.set_data_type(dt);
462 inline vector<TIndex> GetDimsVector(
const TensorShape& shape) {
464 for (
auto d : shape.dims()) {
471 inline std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>>
472 InferOpInputOutputDevice(
const OperatorDef& op) {
473 auto op_schema = OpSchemaRegistry::Schema(op.type());
475 op_schema,
"Device inference failed. No schema for: ", op.type());
477 return op_schema->InferDevice(op);
480 template <u
int64_t OpsPerPo
int>
483 const vector<TensorShape>& inputs) {
485 const TensorShape X = inputs[0];
488 for (
auto i = 0; i < X.dims().size(); ++i) {
492 c.flops = size * OpsPerPoint;
493 c.bytes_moved = size *
sizeof(X.data_type());
499 #ifndef CAFFE2_NO_OPERATOR_SCHEMA 501 #define OPERATOR_SCHEMA(name) \ 502 void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \ 503 static OpSchema* CAFFE_ANONYMOUS_VARIABLE(name) = \ 504 &OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) 505 #define OPERATOR_SCHEMA_STR(name) \ 506 static OpSchema* CAFFE_ANONYMOUS_VARIABLE(schema_registration) = \ 507 &OpSchemaRegistry::NewSchema(name, __FILE__, __LINE__) 509 #else // CAFFE2_NO_OPERATOR_SCHEMA 511 #define OPERATOR_SCHEMA(name) \ 512 void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \ 513 static OpSchema* CAFFE_ANONYMOUS_VARIABLE(name) = \ 514 1 ? nullptr : &OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) 515 #define OPERATOR_SCHEMA_STR(name) \ 516 static OpSchema* CAFFE_ANONYMOUS_VARIABLE(schema_registration) = \ 517 1 ? nullptr : &OpSchemaRegistry::NewSchema(name, __FILE__, __LINE__) 519 #endif // CAFFE2_NO_OPERATOR_SCHEMA 521 #endif // CAFFE2_CORE_OPERATOR_SCHEMA_H_ std::function< std::pair< std::vector< DeviceOption >, std::vector< DeviceOption >>(const OperatorDef &def)> DeviceInferenceFunctionType
Returns the required device location of inputs and outputs.
OpSchema & NumInputs(int n)
A single input.
A class to record the schema of an op.
vector< TensorShape > InferTensor(const OperatorDef &def, const vector< TensorShape > &input_type_shape) const
A function to allow one to infer the type and shape from the op schema.
bool Verify(const OperatorDef &def) const
Verifies if an operator definition protobuf matches the pattern specified in the schema.
A registry to hold all the operator schemas.
int line() const
Returns the line in file that the op schema is registered from.
const char * doc() const
Returns the docstring of the op schema.
OpSchema & OutputCalculator(std::function< int(int)> calc)
Set the output calculator to a user-defined function.
OpSchema & IdenticalTypeAndShape()
Sets the tensor inference function to produce the same output as the input.
OpSchema & SameNumberOfOutput()
Set the number of outputs to be the same as the number of inputs.
std::pair< std::vector< DeviceOption >, std::vector< DeviceOption > > InferDevice(const OperatorDef &def) const
Infer required device location of an op's inputs and outputs.
OpSchema & InheritOnnxSchema(const std::string &onnx_schema_name)
Sets the corresponding onnx schema name.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
OpSchema & CostInferenceFunction(CostInferenceFunctionType function)
Register the Cost inference function.
OpSchema & NumInputsOutputs(std::function< bool(int, int)> func)
Relationship between inputs and outputs is checked with a specified function.
const string & file() const
Returns the file that the op schema is registered from.
OpSchema & TensorInferenceFunction(TensorInferenceFunctionType function)
Sets the tensor inference function, which is a std::function object defined in operator_schema.h.
int CalculateOutput(int num_input) const
A function to allow one to get the number of outputs based on the number of inputs, if this schema supports it.
OpSchema & NumOutputs(int n)
A single output.
std::function< struct Cost(const OperatorDef &, const vector< TensorShape > &)> CostInferenceFunctionType
Registers a function that takes in an OperatorDef and a series of input shapes and returns the total ...