// Copyright 2017 The TensorFlow Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Revision History // Version 0: Initial version. // Version 1: Add subgraphs to schema. // Version 2: Rename operators to conform to NN API. // Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. // Version 3a: Add new builtin op code field. Has backward compatibility with // version 3. namespace tflite; // This corresponds to the version. file_identifier "TFL3"; // File extension of any written files. file_extension "tflite"; // IMPORTANT: All new members of tables, enums and unions must be added at the // end to ensure backwards compatibility. // The type of data stored in a tensor. enum TensorType : byte { FLOAT32 = 0, FLOAT16 = 1, INT32 = 2, UINT8 = 3, INT64 = 4, STRING = 5, BOOL = 6, INT16 = 7, COMPLEX64 = 8, INT8 = 9, FLOAT64 = 10, COMPLEX128 = 11, } // Custom quantization parameters for experimenting with new quantization // techniques. table CustomQuantization { custom:[ubyte] (force_align: 16); } // Represents a specific quantization technique's parameters. union QuantizationDetails { CustomQuantization, } // Parameters for converting a quantized tensor back to float. table QuantizationParameters { // These four parameters are the asymmetric linear quantization parameters. // Given a quantized value q, the corresponding float value f should be: // f = scale * (q - zero_point) // For other quantization types, the QuantizationDetails below is used. min:[float]; // For importing back into tensorflow. max:[float]; // For importing back into tensorflow. scale:[float]; // For dequantizing the tensor's values. zero_point:[long]; // If this is not none, the other quantization parameters (i.e. min, max, // scale, zero_point fields above) are ignored and the value of the // QuantizationDetails union should be used. details:QuantizationDetails; // Specifies the dimension of the Tensor's shape that the scales and // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] // with quantization params: // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1 // will be quantized across the second dimension of t. // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 quantized_dimension:int; } // Sparse tensors. // We use a modification of the TACO format. // Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf // // To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1), // potentially with a k-dimensional block (0 <= k <= n) with dims // (dn, ..., dn+k-1), the format needs to specify: // 1. In what order to traverse these dimensions. For example, to store a 2-D // matrix in row major order, the traversal order would be (d0, d1), // whereas to store it in column major order, the traversal order would be // (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order // could be (d0, d1, d2, d3). // 2. How each block dimension in (dn, ..., dn+k-1) maps to the original // tensor dimension in (d0, ..., dn-1). // 3. In the traversal order defined above, the format (dense vs. sparse) and // index metadata for each dimension. For a dense dimension, this is just // the size of that dimension. For a sparse dimension, it's the same as // the compressed index defined in the Compressed Sparse Row (CSR) format. // (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html) // The storage type for a dimension. Currently we support: // 1. DENSE: each coordinate in this dimension is stored implicitly. // 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The // compression technique is the same what CSR uses. // More types like a sparse dimension with a different compression technique // could be added to the list in the future. enum DimensionType : byte { DENSE = 0, SPARSE_CSR = 1, } table Int32Vector { values:[int]; } table Uint16Vector { values:[ushort] (force_align: 4); } table Uint8Vector { values:[ubyte] (force_align: 4); } // Variable-typed buffer to store the index metadata for a sparse dimension. // The widest type is Int32 instead of UInt32 because tensor's shape is a int32 // vector. We don't want the per-dimensional index to overflow that range. union SparseIndexVector { Int32Vector, Uint16Vector, Uint8Vector } table DimensionMetadata { // Whether a dimension is dense or sparse. format:DimensionType; // Index metadata used for a dimension. // - If format is DimensionType.DENSE then we use the dense_size field to // store the size of that dimension. Each index in that dimension is // stored implicitly. // - If format is DimensionType.SPARSE_CSR then we use array_segments and // array_indices to encode that dimension. array_segments represents how // to segment the indices array, each segment corresponds to one element // in the previous dimension. array_indices represents the index of the // non-zero elements within this dimension (as those in the CSR matrix // format, where the first array is row pointers and the second array is // column indices). dense_size:int; array_segments:SparseIndexVector; array_indices:SparseIndexVector; } // Parameters to encode a sparse TfLite tensor. table SparsityParameters { // The traversal order of the dimensions defined in the `shape` field of the // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1, // ..., dn-1), // - if not block sparse, the traversal_order is just a permutation of (d0, // ..., dn-1). For example, a 2-D matrix stored in row-major order would // have traversal_order = (d0, d1). // - if block sparse with a k-dimensional block (0 <= k <= n), the // traversal_order has n + k elements. The first n elements are still a // permutation of (d0, ..., dn-1). The lask k elements are a permutation // of (dn, ..., dn+k-1), defining how to traverse a block internally. For // example, a 2-D matrix with 2-D blocks, both stored in row-major order // would have traversal_order = (d0, d1, d2, d3). traversal_order:[int]; // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), // stores how a block dimension in (dn, ..., dn+k-1) maps to the original // tensor dimension in (d0, ..., dn). // It's stored in the order of (dn, ..., dn+k-1). // If not block-sparse, this field is NULL. block_map:[int]; // In the traversal order defined above, the metadata needed for // each dimension to locate the non-zero values in the original dense tensor. // The size of the dim_metadata array = the size of the traversal_order array // = n + k. dim_metadata:[DimensionMetadata]; } table Tensor { // The tensor shape. The meaning of each entry is operator-specific but // builtin ops use: [batch size, height, width, number of channels] (That's // Tensorflow's NHWC). shape:[int]; type:TensorType; // An index that refers to the buffers table at the root of the model. Or, // if there is no data buffer associated (i.e. intermediate results), then // this is 0 (which refers to an always existent empty buffer). // // The data_buffer itself is an opaque container, with the assumption that the // target device is little-endian. In addition, all builtin operators assume // the memory is ordered such that if `shape` is [4, 3, 2], then index // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. buffer:uint; name:string; // For debugging and importing back into tensorflow. quantization:QuantizationParameters; // Optional. is_variable:bool = false; // Parameters to encode a sparse tensor. See the example in // tensorflow/lite/testdata/sparse_tensor.json. sparsity:SparsityParameters; // Optional. // Encodes `shape` with unknown dimensions. Unknown dimensions are // represented with -1. shape_signature:[int]; // Optional. } // A list of builtin operators. Builtin operators are slightly faster than custom // ones, but not by much. Moreover, while custom operators accept an opaque // object containing configuration parameters, builtins have a predetermined // set of acceptable options. enum BuiltinOperator : int32 { ADD = 0, AVERAGE_POOL_2D = 1, CONCATENATION = 2, CONV_2D = 3, DEPTHWISE_CONV_2D = 4, DEPTH_TO_SPACE = 5, DEQUANTIZE = 6, EMBEDDING_LOOKUP = 7, FLOOR = 8, FULLY_CONNECTED = 9, HASHTABLE_LOOKUP = 10, L2_NORMALIZATION = 11, L2_POOL_2D = 12, LOCAL_RESPONSE_NORMALIZATION = 13, LOGISTIC = 14, LSH_PROJECTION = 15, LSTM = 16, MAX_POOL_2D = 17, MUL = 18, RELU = 19, // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed // since different model developers use RELU1 in different ways. Never // create another op called RELU1. RELU_N1_TO_1 = 20, RELU6 = 21, RESHAPE = 22, RESIZE_BILINEAR = 23, RNN = 24, SOFTMAX = 25, SPACE_TO_DEPTH = 26, SVDF = 27, TANH = 28, CONCAT_EMBEDDINGS = 29, SKIP_GRAM = 30, CALL = 31, CUSTOM = 32, EMBEDDING_LOOKUP_SPARSE = 33, PAD = 34, UNIDIRECTIONAL_SEQUENCE_RNN = 35, GATHER = 36, BATCH_TO_SPACE_ND = 37, SPACE_TO_BATCH_ND = 38, TRANSPOSE = 39, MEAN = 40, SUB = 41, DIV = 42, SQUEEZE = 43, UNIDIRECTIONAL_SEQUENCE_LSTM = 44, STRIDED_SLICE = 45, BIDIRECTIONAL_SEQUENCE_RNN = 46, EXP = 47, TOPK_V2 = 48, SPLIT = 49, LOG_SOFTMAX = 50, // DELEGATE is a special op type for the operations which are delegated to // other backends. // WARNING: Experimental interface, subject to change DELEGATE = 51, BIDIRECTIONAL_SEQUENCE_LSTM = 52, CAST = 53, PRELU = 54, MAXIMUM = 55, ARG_MAX = 56, MINIMUM = 57, LESS = 58, NEG = 59, PADV2 = 60, GREATER = 61, GREATER_EQUAL = 62, LESS_EQUAL = 63, SELECT = 64, SLICE = 65, SIN = 66, TRANSPOSE_CONV = 67, SPARSE_TO_DENSE = 68, TILE = 69, EXPAND_DIMS = 70, EQUAL = 71, NOT_EQUAL = 72, LOG = 73, SUM = 74, SQRT = 75, RSQRT = 76, SHAPE = 77, POW = 78, ARG_MIN = 79, FAKE_QUANT = 80, REDUCE_PROD = 81, REDUCE_MAX = 82, PACK = 83, LOGICAL_OR = 84, ONE_HOT = 85, LOGICAL_AND = 86, LOGICAL_NOT = 87, UNPACK = 88, REDUCE_MIN = 89, FLOOR_DIV = 90, REDUCE_ANY = 91, SQUARE = 92, ZEROS_LIKE = 93, FILL = 94, FLOOR_MOD = 95, RANGE = 96, RESIZE_NEAREST_NEIGHBOR = 97, LEAKY_RELU = 98, SQUARED_DIFFERENCE = 99, MIRROR_PAD = 100, ABS = 101, SPLIT_V = 102, UNIQUE = 103, CEIL = 104, REVERSE_V2 = 105, ADD_N = 106, GATHER_ND = 107, COS = 108, WHERE = 109, RANK = 110, ELU = 111, REVERSE_SEQUENCE = 112, MATRIX_DIAG = 113, QUANTIZE = 114, MATRIX_SET_DIAG = 115, ROUND = 116, HARD_SWISH = 117, IF = 118, WHILE = 119, NON_MAX_SUPPRESSION_V4 = 120, NON_MAX_SUPPRESSION_V5 = 121, SCATTER_ND = 122, SELECT_V2 = 123, DENSIFY = 124, SEGMENT_SUM = 125, BATCH_MATMUL = 126, PLACEHOLDER_FOR_GREATER_OP_CODES = 127 } // Options for the builtin operators. union BuiltinOptions { Conv2DOptions, DepthwiseConv2DOptions, ConcatEmbeddingsOptions, LSHProjectionOptions, Pool2DOptions, SVDFOptions, RNNOptions, FullyConnectedOptions, SoftmaxOptions, ConcatenationOptions, AddOptions, L2NormOptions, LocalResponseNormalizationOptions, LSTMOptions, ResizeBilinearOptions, CallOptions, ReshapeOptions, SkipGramOptions, SpaceToDepthOptions, EmbeddingLookupSparseOptions, MulOptions, PadOptions, GatherOptions, BatchToSpaceNDOptions, SpaceToBatchNDOptions, TransposeOptions, ReducerOptions, SubOptions, DivOptions, SqueezeOptions, SequenceRNNOptions, StridedSliceOptions, ExpOptions, TopKV2Options, SplitOptions, LogSoftmaxOptions, CastOptions, DequantizeOptions, MaximumMinimumOptions, ArgMaxOptions, LessOptions, NegOptions, PadV2Options, GreaterOptions, GreaterEqualOptions, LessEqualOptions, SelectOptions, SliceOptions, TransposeConvOptions, SparseToDenseOptions, TileOptions, ExpandDimsOptions, EqualOptions, NotEqualOptions, ShapeOptions, PowOptions, ArgMinOptions, FakeQuantOptions, PackOptions, LogicalOrOptions, OneHotOptions, LogicalAndOptions, LogicalNotOptions, UnpackOptions, FloorDivOptions, SquareOptions, ZerosLikeOptions, FillOptions, BidirectionalSequenceLSTMOptions, BidirectionalSequenceRNNOptions, UnidirectionalSequenceLSTMOptions, FloorModOptions, RangeOptions, ResizeNearestNeighborOptions, LeakyReluOptions, SquaredDifferenceOptions, MirrorPadOptions, AbsOptions, SplitVOptions, UniqueOptions, ReverseV2Options, AddNOptions, GatherNdOptions, CosOptions, WhereOptions, RankOptions, ReverseSequenceOptions, MatrixDiagOptions, QuantizeOptions, MatrixSetDiagOptions, HardSwishOptions, IfOptions, WhileOptions, DepthToSpaceOptions, NonMaxSuppressionV4Options, NonMaxSuppressionV5Options, ScatterNdOptions, SelectV2Options, DensifyOptions, SegmentSumOptions, BatchMatMulOptions } enum Padding : byte { SAME, VALID } enum ActivationFunctionType : byte { NONE = 0, RELU = 1, RELU_N1_TO_1 = 2, RELU6 = 3, TANH = 4, SIGN_BIT = 5, } table Conv2DOptions { padding:Padding; stride_w:int; stride_h:int; fused_activation_function:ActivationFunctionType; dilation_w_factor:int = 1; dilation_h_factor:int = 1; } table Pool2DOptions { padding:Padding; stride_w:int; stride_h:int; filter_width:int; filter_height:int; fused_activation_function:ActivationFunctionType; } table DepthwiseConv2DOptions { // Parameters for DepthwiseConv version 1 or above. padding:Padding; stride_w:int; stride_h:int; // `depth_multiplier` is redundant. It's used by CPU kernels in // TensorFlow 2.0 or below, but ignored in versions above. // See comments in lite/c/builtin_op_data.h for more details. depth_multiplier:int; fused_activation_function:ActivationFunctionType; // Parameters for DepthwiseConv version 2 or above. dilation_w_factor:int = 1; dilation_h_factor:int = 1; } table ConcatEmbeddingsOptions { num_channels:int; num_columns_per_channel:[int]; embedding_dim_per_channel:[int]; // This could be inferred from parameters. } enum LSHProjectionType: byte { UNKNOWN = 0, SPARSE = 1, DENSE = 2, } table LSHProjectionOptions { type: LSHProjectionType; } table SVDFOptions { rank:int; fused_activation_function:ActivationFunctionType; // For weights-only quantization, use asymmetric quantization for non // constant inputs at evaluation time. asymmetric_quantize_inputs:bool; } // An implementation of TensorFlow RNNCell. table RNNOptions { fused_activation_function:ActivationFunctionType; asymmetric_quantize_inputs:bool; } // An implementation of TensorFlow dynamic_rnn with RNNCell. table SequenceRNNOptions { time_major:bool; fused_activation_function:ActivationFunctionType; asymmetric_quantize_inputs:bool; } // An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. table BidirectionalSequenceRNNOptions { time_major:bool; fused_activation_function:ActivationFunctionType; merge_outputs: bool; asymmetric_quantize_inputs:bool; } enum FullyConnectedOptionsWeightsFormat: byte { DEFAULT = 0, SHUFFLED4x16INT8 = 1, } // An implementation of TensorFlow fully_connected (a.k.a Dense) layer. table FullyConnectedOptions { // Parameters for FullyConnected version 1 or above. fused_activation_function:ActivationFunctionType; // Parameters for FullyConnected version 2 or above. weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; // Parameters for FullyConnected version 5 or above. // If set to true, then the number of dimension is preserved. Furthermore, // all but the last dimension of the input and output shapes will be equal. keep_num_dims: bool; // Parameters for FullyConnected version 7 or above. // If set to true, then weights-only op will use asymmetric quantization for // inputs. asymmetric_quantize_inputs: bool; } table SoftmaxOptions { beta: float; } // An implementation of TensorFlow concat. table ConcatenationOptions { axis:int; fused_activation_function:ActivationFunctionType; } table AddOptions { fused_activation_function:ActivationFunctionType; // Parameters supported by version 4. pot_scale_int16:bool = true; } table MulOptions { fused_activation_function:ActivationFunctionType; } table L2NormOptions { fused_activation_function:ActivationFunctionType; } table LocalResponseNormalizationOptions { radius:int; bias:float; alpha:float; beta:float; } enum LSTMKernelType : byte { // Full LSTM kernel which supports peephole and projection. FULL = 0, // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. BASIC = 1, } // An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell table LSTMOptions { // Parameters for LSTM version 1 or above. fused_activation_function:ActivationFunctionType; cell_clip: float; // Optional, 0.0 means no clipping proj_clip: float; // Optional, 0.0 means no clipping // Parameters for LSTM version 2 or above. // Basic kernel is only supported in version 2 or above. kernel_type: LSTMKernelType = FULL; // Parameters for LSTM version 4 or above. asymmetric_quantize_inputs: bool; } // An implementation of TensorFlow dynamic_rnn with LSTMCell. table UnidirectionalSequenceLSTMOptions { fused_activation_function:ActivationFunctionType; cell_clip: float; // Optional, 0.0 means no clipping proj_clip: float; // Optional, 0.0 means no clipping // If true then first dimension is sequence, otherwise batch. time_major:bool; // Parameter for Unidirectional Sequence LSTM version 4. asymmetric_quantize_inputs:bool; } table BidirectionalSequenceLSTMOptions { // Parameters supported by version 1: fused_activation_function:ActivationFunctionType; cell_clip: float; // Optional, 0.0 means no clipping proj_clip: float; // Optional, 0.0 means no clipping // If true, store the outputs of both directions into the first output. merge_outputs: bool; // Parameters supported by version 2: // If true then first dimension is sequence, otherwise batch. // Version 1 implementations assumed time_major to be true, so this default // value should never change. time_major: bool = true; // Parameters for version 3 or above. asymmetric_quantize_inputs:bool; } table ResizeBilinearOptions { new_height: int (deprecated); new_width: int (deprecated); align_corners: bool; half_pixel_centers: bool; } table ResizeNearestNeighborOptions { align_corners: bool; half_pixel_centers: bool; } // A call operation options table CallOptions { // The subgraph index that needs to be called. subgraph:uint; } table PadOptions { } table PadV2Options { } table ReshapeOptions { new_shape:[int]; } table SpaceToBatchNDOptions { } table BatchToSpaceNDOptions { } table SkipGramOptions { ngram_size: int; max_skip_size: int; include_all_ngrams: bool; } table SpaceToDepthOptions { block_size: int; } table DepthToSpaceOptions { block_size: int; } table SubOptions { fused_activation_function:ActivationFunctionType; // Parameters supported by version 5 pot_scale_int16:bool = true; } table DivOptions { fused_activation_function:ActivationFunctionType; } table TopKV2Options { } enum CombinerType : byte { SUM = 0, MEAN = 1, SQRTN = 2, } table EmbeddingLookupSparseOptions { combiner:CombinerType; } table GatherOptions { axis: int; } table TransposeOptions { } table ExpOptions { } table CosOptions { } table ReducerOptions { keep_dims: bool; } table SqueezeOptions { squeeze_dims:[int]; } table SplitOptions { num_splits: int; } table SplitVOptions { num_splits: int; } table StridedSliceOptions { begin_mask: int; end_mask: int; ellipsis_mask: int; new_axis_mask: int; shrink_axis_mask: int; } table LogSoftmaxOptions { } table CastOptions { in_data_type: TensorType; out_data_type: TensorType; } table DequantizeOptions { } table MaximumMinimumOptions { } table TileOptions { } table ArgMaxOptions { output_type : TensorType; } table ArgMinOptions { output_type : TensorType; } table GreaterOptions { } table GreaterEqualOptions { } table LessOptions { } table LessEqualOptions { } table NegOptions { } table SelectOptions { } table SliceOptions { } table TransposeConvOptions { padding:Padding; stride_w:int; stride_h:int; } table ExpandDimsOptions { } table SparseToDenseOptions { validate_indices:bool; } table EqualOptions { } table NotEqualOptions { } table ShapeOptions { // Optional output type of the operation (int32 or int64). Defaults to int32. out_type : TensorType; } table RankOptions { } table PowOptions { } table FakeQuantOptions { // Parameters supported by version 1: min:float; max:float; num_bits:int; // Parameters supported by version 2: narrow_range:bool; } table PackOptions { values_count:int; axis:int; } table LogicalOrOptions { } table OneHotOptions { axis:int; } table AbsOptions { } table HardSwishOptions { } table LogicalAndOptions { } table LogicalNotOptions { } table UnpackOptions { num:int; axis:int; } table FloorDivOptions { } table SquareOptions { } table ZerosLikeOptions { } table FillOptions { } table FloorModOptions { } table RangeOptions { } table LeakyReluOptions { alpha:float; } table SquaredDifferenceOptions { } enum MirrorPadMode : byte { // Doesn't include borders. REFLECT = 0, // Includes borders. SYMMETRIC = 1, } table MirrorPadOptions { mode:MirrorPadMode; } table UniqueOptions { idx_out_type:TensorType = INT32; } table ReverseV2Options { } table AddNOptions { } table GatherNdOptions { } table WhereOptions { } table ReverseSequenceOptions { seq_dim:int; batch_dim:int = 0; } table MatrixDiagOptions { } table QuantizeOptions { } table MatrixSetDiagOptions { } table IfOptions { then_subgraph_index:int; else_subgraph_index:int; } table WhileOptions { cond_subgraph_index:int; body_subgraph_index:int; } table NonMaxSuppressionV4Options { } table NonMaxSuppressionV5Options { } table ScatterNdOptions { } table SelectV2Options { } table DensifyOptions { } table SegmentSumOptions { } table BatchMatMulOptions { adj_x:bool; adj_y:bool; } // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { // This field is for backward compatibility. This field will be used when // the value of the extended builtin_code field has less than // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. deprecated_builtin_code:byte; custom_code:string; // The version of the operator. The version need to be bumped whenever new // parameters are introduced into an op. version:int = 1; // This field is introduced for resolving op builtin code shortage problem // (the original BuiltinOperator enum field was represented as a byte). // This field will be used when the value of the extended builtin_code field // has greater than BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. builtin_code:BuiltinOperator; } enum CustomOptionsFormat : byte { FLEXBUFFERS = 0, } // An operator takes tensors as inputs and outputs. The type of operation being // performed is determined by an index into the list of valid OperatorCodes, // while the specifics of each operations is configured using builtin_options // or custom_options. table Operator { // Index into the operator_codes array. Using an integer here avoids // complicate map lookups. opcode_index:uint; // Optional input are indicated by -1. inputs:[int]; outputs:[int]; builtin_options:BuiltinOptions; custom_options:[ubyte]; custom_options_format:CustomOptionsFormat; // A list of booleans indicating the input tensors which are being mutated by // this operator.(e.g. used by RNN and LSTM). // For example, if the "inputs" array refers to 5 tensors and the second and // fifth are mutable variables, then this list will contain // [false, true, false, false, true]. // // If the list is empty, no variable is mutated in this operator. // The list either has the same length as `inputs`, or is empty. mutating_variable_inputs:[bool]; // A list of indices to the subgraph's "tensors" that are internal to an Op. // Internal tensors are those that do not flow in or out of the operation, // but instead are part of internal computation. As such, the operation's // implementation may manage its memory more efficiently. They are needed // however (i.e. not just an implementation detail) since they are part of the // computation, which may require relevant metadata such as quantization // parameters. intermediates:[int]; } // The root type, defining a subgraph, which typically represents an entire // model. table SubGraph { // A list of all tensors used in this subgraph. tensors:[Tensor]; // Indices of the tensors that are inputs into this subgraph. Note this is // the list of non-static tensors that feed into the subgraph for inference. inputs:[int]; // Indices of the tensors that are outputs out of this subgraph. Note this is // the list of output tensors that are considered the product of the // subgraph's inference. outputs:[int]; // All operators, in execution order. operators:[Operator]; // Name of this subgraph (used for debugging). name:string; } // Table of raw data buffers (used for constant tensors). Referenced by tensors // by index. The generous alignment accommodates mmap-friendly data structures. table Buffer { data:[ubyte] (force_align: 16); } table Metadata { // A human readable string to uniquely identify a Metadata. name:string; // An index to the buffers table. buffer:uint; } table Model { // Version of the schema. version:uint; // A list of all operator codes used in this model. This is // kept in order because operators carry an index into this // vector. operator_codes:[OperatorCode]; // All the subgraphs of the model. The 0th is assumed to be the main // model. subgraphs:[SubGraph]; // A description of the model. description:string; // Buffers of the model. // Note the 0th entry of this array must be an empty buffer (sentinel). // This is a convention so that tensors without a buffer can provide 0 as // their buffer. buffers:[Buffer]; // Metadata about the model. Indirects into the existings buffers list. // Deprecated, prefer to use metadata field. metadata_buffer:[int]; // Metadata about the model. metadata:[Metadata]; } root_type Model;