name: Trainer description: |- A TFX component to train a TensorFlow model. The Trainer component is used to train and eval a model using given inputs and a user-supplied estimator. This component includes a custom driver to optionally grab previous model to warm start from. ## Providing an estimator The TFX executor will use the estimator provided in the `module_file` file to train the model. The Trainer executor will look specifically for the `trainer_fn()` function within that file. Before training, the executor will call that function expecting the following returned as a dictionary: - estimator: The [estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) to be used by TensorFlow to train the model. - train_spec: The [configuration](https://www.tensorflow.org/api_docs/python/tf/estimator/TrainSpec) to be used by the "train" part of the TensorFlow `train_and_evaluate()` call. - eval_spec: The [configuration](https://www.tensorflow.org/api_docs/python/tf/estimator/EvalSpec) to be used by the "eval" part of the TensorFlow `train_and_evaluate()` call. - eval_input_receiver_fn: The [configuration](https://www.tensorflow.org/tfx/model_analysis/get_started#modify_an_existing_model) to be used by the [ModelValidator](https://www.tensorflow.org/tfx/guide/modelval) component when validating the model. An example of `trainer_fn()` can be found in the [user-supplied code]((https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_utils.py)) of the TFX Chicago Taxi pipeline example. Args: examples: A Channel of 'Examples' type, serving as the source of examples that are used in training (required). May be raw or transformed. transform_graph: An optional Channel of 'TransformGraph' type, serving as the input transform graph if present. schema: A Channel of 'SchemaPath' type, serving as the schema of training and eval data. module_file: A path to python module file containing UDF model definition. The module_file must implement a function named `trainer_fn` at its top level. The function must have the following signature. def trainer_fn(tf.contrib.training.HParams, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma.export.EvalInputReceiver Exactly one of 'module_file' or 'trainer_fn' must be supplied. trainer_fn: A python path to UDF model definition function. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied. train_args: A trainer_pb2.TrainArgs instance, containing args used for training. Current only num_steps is available. eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval. Current only num_steps is available. custom_config: A dict which contains the training job parameters to be passed to Google Cloud ML Engine. For the full set of parameters supported by Google Cloud ML Engine, refer to https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job Returns: model: Optional 'Model' channel for result of exported models. Raises: ValueError: - When both or neither of 'module_file' and 'trainer_fn' is supplied. inputs: - {name: examples, type: Examples} - {name: schema, type: Schema} - name: train_args type: JsonObject: {data_type: 'proto:tfx.components.trainer.TrainArgs'} - name: eval_args type: JsonObject: {data_type: 'proto:tfx.components.trainer.EvalArgs'} - {name: module_file, type: String, optional: true} - {name: trainer_fn, type: String, optional: true} - {name: custom_config, type: JsonObject, optional: true} - {name: transform_graph, type: TransformGraph, optional: true} - {name: base_model, type: Model, optional: true} - {name: hyperparameters, type: HyperParameters, optional: true} outputs: - {name: model, type: Model} implementation: container: image: tensorflow/tfx:0.21.4 command: - python3 - -u - -c - | def _make_parent_dirs_and_return_path(file_path: str): import os os.makedirs(os.path.dirname(file_path), exist_ok=True) return file_path def Trainer( examples_path , schema_path , model_path , train_args , eval_args , module_file = None, trainer_fn = None, custom_config = None, transform_graph_path = None, base_model_path = None, hyperparameters_path = None, ): """ A TFX component to train a TensorFlow model. The Trainer component is used to train and eval a model using given inputs and a user-supplied estimator. This component includes a custom driver to optionally grab previous model to warm start from. ## Providing an estimator The TFX executor will use the estimator provided in the `module_file` file to train the model. The Trainer executor will look specifically for the `trainer_fn()` function within that file. Before training, the executor will call that function expecting the following returned as a dictionary: - estimator: The [estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) to be used by TensorFlow to train the model. - train_spec: The [configuration](https://www.tensorflow.org/api_docs/python/tf/estimator/TrainSpec) to be used by the "train" part of the TensorFlow `train_and_evaluate()` call. - eval_spec: The [configuration](https://www.tensorflow.org/api_docs/python/tf/estimator/EvalSpec) to be used by the "eval" part of the TensorFlow `train_and_evaluate()` call. - eval_input_receiver_fn: The [configuration](https://www.tensorflow.org/tfx/model_analysis/get_started#modify_an_existing_model) to be used by the [ModelValidator](https://www.tensorflow.org/tfx/guide/modelval) component when validating the model. An example of `trainer_fn()` can be found in the [user-supplied code]((https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_utils.py)) of the TFX Chicago Taxi pipeline example. Args: examples: A Channel of 'Examples' type, serving as the source of examples that are used in training (required). May be raw or transformed. transform_graph: An optional Channel of 'TransformGraph' type, serving as the input transform graph if present. schema: A Channel of 'SchemaPath' type, serving as the schema of training and eval data. module_file: A path to python module file containing UDF model definition. The module_file must implement a function named `trainer_fn` at its top level. The function must have the following signature. def trainer_fn(tf.contrib.training.HParams, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma.export.EvalInputReceiver Exactly one of 'module_file' or 'trainer_fn' must be supplied. trainer_fn: A python path to UDF model definition function. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied. train_args: A trainer_pb2.TrainArgs instance, containing args used for training. Current only num_steps is available. eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval. Current only num_steps is available. custom_config: A dict which contains the training job parameters to be passed to Google Cloud ML Engine. For the full set of parameters supported by Google Cloud ML Engine, refer to https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job Returns: model: Optional 'Model' channel for result of exported models. Raises: ValueError: - When both or neither of 'module_file' and 'trainer_fn' is supplied. """ from tfx.components.trainer.component import Trainer as component_class #Generated code import json import os import tensorflow from google.protobuf import json_format, message from tfx.types import Artifact, channel_utils, artifact_utils arguments = locals().copy() component_class_args = {} for name, execution_parameter in component_class.SPEC_CLASS.PARAMETERS.items(): argument_value_obj = argument_value = arguments.get(name, None) if argument_value is None: continue parameter_type = execution_parameter.type if isinstance(parameter_type, type) and issubclass(parameter_type, message.Message): # Maybe FIX: execution_parameter.type can also be a tuple argument_value_obj = parameter_type() json_format.Parse(argument_value, argument_value_obj) component_class_args[name] = argument_value_obj for name, channel_parameter in component_class.SPEC_CLASS.INPUTS.items(): artifact_path = arguments[name + '_path'] if artifact_path: artifact = channel_parameter.type() artifact.uri = artifact_path + '/' # ? if channel_parameter.type.PROPERTIES and 'split_names' in channel_parameter.type.PROPERTIES: # Recovering splits subdirs = tensorflow.io.gfile.listdir(artifact_path) artifact.split_names = artifact_utils.encode_split_names(sorted(subdirs)) component_class_args[name] = channel_utils.as_channel([artifact]) component_class_instance = component_class(**component_class_args) input_dict = {name: channel.get() for name, channel in component_class_instance.inputs.get_all().items()} output_dict = {name: channel.get() for name, channel in component_class_instance.outputs.get_all().items()} exec_properties = component_class_instance.exec_properties # Generating paths for output artifacts for name, artifacts in output_dict.items(): base_artifact_path = arguments[name + '_path'] # Are there still cases where output channel has multiple artifacts? for idx, artifact in enumerate(artifacts): subdir = str(idx + 1) if idx > 0 else '' artifact.uri = os.path.join(base_artifact_path, subdir) # Ends with '/' print('component instance: ' + str(component_class_instance)) #executor = component_class.EXECUTOR_SPEC.executor_class() # Same executor = component_class_instance.executor_spec.executor_class() executor.Do( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, ) import json import argparse _parser = argparse.ArgumentParser(prog='Trainer', description='A TFX component to train a TensorFlow model.\n\n The Trainer component is used to train and eval a model using given inputs and\n a user-supplied estimator. This component includes a custom driver to\n optionally grab previous model to warm start from.\n\n ## Providing an estimator\n The TFX executor will use the estimator provided in the `module_file` file\n to train the model. The Trainer executor will look specifically for the\n `trainer_fn()` function within that file. Before training, the executor will\n call that function expecting the following returned as a dictionary:\n\n - estimator: The\n [estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator)\n to be used by TensorFlow to train the model.\n - train_spec: The\n [configuration](https://www.tensorflow.org/api_docs/python/tf/estimator/TrainSpec)\n to be used by the "train" part of the TensorFlow `train_and_evaluate()`\n call.\n - eval_spec: The\n [configuration](https://www.tensorflow.org/api_docs/python/tf/estimator/EvalSpec)\n to be used by the "eval" part of the TensorFlow `train_and_evaluate()` call.\n - eval_input_receiver_fn: The\n [configuration](https://www.tensorflow.org/tfx/model_analysis/get_started#modify_an_existing_model)\n to be used\n by the [ModelValidator](https://www.tensorflow.org/tfx/guide/modelval)\n component when validating the model.\n\n An example of `trainer_fn()` can be found in the [user-supplied\n code]((https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_utils.py))\n of the TFX Chicago Taxi pipeline example.\n\n\n Args:\n examples: A Channel of \'Examples\' type, serving as the source of\n examples that are used in training (required). May be raw or\n transformed.\n transform_graph: An optional Channel of \'TransformGraph\' type, serving as\n the input transform graph if present.\n schema: A Channel of \'SchemaPath\' type, serving as the schema of training\n and eval data.\n module_file: A path to python module file containing UDF model definition.\n The module_file must implement a function named `trainer_fn` at its\n top level. The function must have the following signature.\n\n def trainer_fn(tf.contrib.training.HParams,\n tensorflow_metadata.proto.v0.schema_pb2) -> Dict:\n ...\n\n where the returned Dict has the following key-values.\n \'estimator\': an instance of tf.estimator.Estimator\n \'train_spec\': an instance of tf.estimator.TrainSpec\n \'eval_spec\': an instance of tf.estimator.EvalSpec\n \'eval_input_receiver_fn\': an instance of tfma.export.EvalInputReceiver\n\n Exactly one of \'module_file\' or \'trainer_fn\' must be supplied.\n trainer_fn: A python path to UDF model definition function. See\n \'module_file\' for the required signature of the UDF.\n Exactly one of \'module_file\' or \'trainer_fn\' must be supplied.\n train_args: A trainer_pb2.TrainArgs instance, containing args used for\n training. Current only num_steps is available.\n eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval.\n Current only num_steps is available.\n custom_config: A dict which contains the training job parameters to be\n passed to Google Cloud ML Engine. For the full set of parameters\n supported by Google Cloud ML Engine, refer to\n https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job\n Returns:\n model: Optional \'Model\' channel for result of exported models.\n Raises:\n ValueError:\n - When both or neither of \'module_file\' and \'trainer_fn\' is supplied.') _parser.add_argument("--examples", dest="examples_path", type=str, required=True, default=argparse.SUPPRESS) _parser.add_argument("--schema", dest="schema_path", type=str, required=True, default=argparse.SUPPRESS) _parser.add_argument("--train-args", dest="train_args", type=str, required=True, default=argparse.SUPPRESS) _parser.add_argument("--eval-args", dest="eval_args", type=str, required=True, default=argparse.SUPPRESS) _parser.add_argument("--module-file", dest="module_file", type=str, required=False, default=argparse.SUPPRESS) _parser.add_argument("--trainer-fn", dest="trainer_fn", type=str, required=False, default=argparse.SUPPRESS) _parser.add_argument("--custom-config", dest="custom_config", type=json.loads, required=False, default=argparse.SUPPRESS) _parser.add_argument("--transform-graph", dest="transform_graph_path", type=str, required=False, default=argparse.SUPPRESS) _parser.add_argument("--base-model", dest="base_model_path", type=str, required=False, default=argparse.SUPPRESS) _parser.add_argument("--hyperparameters", dest="hyperparameters_path", type=str, required=False, default=argparse.SUPPRESS) _parser.add_argument("--model", dest="model_path", type=_make_parent_dirs_and_return_path, required=True, default=argparse.SUPPRESS) _parsed_args = vars(_parser.parse_args()) _output_files = _parsed_args.pop("_output_paths", []) _outputs = Trainer(**_parsed_args) _output_serializers = [ ] import os for idx, output_file in enumerate(_output_files): try: os.makedirs(os.path.dirname(output_file)) except OSError: pass with open(output_file, 'w') as f: f.write(_output_serializers[idx](_outputs[idx])) args: - --examples - {inputPath: examples} - --schema - {inputPath: schema} - --train-args - {inputValue: train_args} - --eval-args - {inputValue: eval_args} - if: cond: {isPresent: module_file} then: - --module-file - {inputValue: module_file} - if: cond: {isPresent: trainer_fn} then: - --trainer-fn - {inputValue: trainer_fn} - if: cond: {isPresent: custom_config} then: - --custom-config - {inputValue: custom_config} - if: cond: {isPresent: transform_graph} then: - --transform-graph - {inputPath: transform_graph} - if: cond: {isPresent: base_model} then: - --base-model - {inputPath: base_model} - if: cond: {isPresent: hyperparameters} then: - --hyperparameters - {inputPath: hyperparameters} - --model - {outputPath: model}