name: Trainer inputs: - {name: examples_uri, type: ExamplesUri} - {name: schema_uri, type: SchemaUri} - {name: output_model_uri, type: ModelUri} - name: train_args type: JsonObject: {data_type: 'proto:tfx.components.trainer.TrainArgs'} - name: eval_args type: JsonObject: {data_type: 'proto:tfx.components.trainer.EvalArgs'} - {name: transform_graph_uri, type: TransformGraphUri, optional: true} - {name: base_model_uri, type: ModelUri, optional: true} - {name: hyperparameters_uri, type: HyperParametersUri, optional: true} - {name: module_file, type: String, optional: true} - {name: run_fn, type: String, optional: true} - {name: trainer_fn, type: String, optional: true} - {name: custom_config, type: JsonObject, optional: true} - {name: beam_pipeline_args, type: JsonArray, optional: true} outputs: - {name: model_uri, type: ModelUri} implementation: container: image: tensorflow/tfx:0.21.4 command: - python3 - -u - -c - | def Trainer( examples_uri, schema_uri, output_model_uri, train_args, eval_args, transform_graph_uri = None, base_model_uri = None, hyperparameters_uri = None, module_file = None, run_fn = None, trainer_fn = None, custom_config = None, beam_pipeline_args = None, ): from tfx.components import Trainer as component_class #Generated code import json import os import tempfile import tensorflow from google.protobuf import json_format, message from tfx.types import channel_utils, artifact_utils from tfx.components.base import base_executor arguments = locals().copy() component_class_args = {} for name, execution_parameter in component_class.SPEC_CLASS.PARAMETERS.items(): argument_value = arguments.get(name, None) if argument_value is None: continue parameter_type = execution_parameter.type if isinstance(parameter_type, type) and issubclass(parameter_type, message.Message): argument_value_obj = parameter_type() json_format.Parse(argument_value, argument_value_obj) else: argument_value_obj = argument_value component_class_args[name] = argument_value_obj for name, channel_parameter in component_class.SPEC_CLASS.INPUTS.items(): artifact_path = arguments.get(name + '_uri') or arguments.get(name + '_path') if artifact_path: artifact = channel_parameter.type() artifact.uri = artifact_path.rstrip('/') + '/' # Some TFX components require that the artifact URIs end with a slash if channel_parameter.type.PROPERTIES and 'split_names' in channel_parameter.type.PROPERTIES: # Recovering splits subdirs = tensorflow.io.gfile.listdir(artifact_path) # Workaround for https://github.com/tensorflow/tensorflow/issues/39167 subdirs = [subdir.rstrip('/') for subdir in subdirs] artifact.split_names = artifact_utils.encode_split_names(sorted(subdirs)) component_class_args[name] = channel_utils.as_channel([artifact]) component_class_instance = component_class(**component_class_args) input_dict = channel_utils.unwrap_channel_dict(component_class_instance.inputs.get_all()) output_dict = channel_utils.unwrap_channel_dict(component_class_instance.outputs.get_all()) exec_properties = component_class_instance.exec_properties # Generating paths for output artifacts for name, artifacts in output_dict.items(): base_artifact_path = arguments.get('output_' + name + '_uri') or arguments.get(name + '_path') if base_artifact_path: # Are there still cases where output channel has multiple artifacts? for idx, artifact in enumerate(artifacts): subdir = str(idx + 1) if idx > 0 else '' artifact.uri = os.path.join(base_artifact_path, subdir) # Ends with '/' print('component instance: ' + str(component_class_instance)) # Workaround for a TFX+Beam bug to make DataflowRunner work. # Remove after the next release that has https://github.com/tensorflow/tfx/commit/ddb01c02426d59e8bd541e3fd3cbaaf68779b2df import tfx tfx.version.__version__ += 'dev' executor_context = base_executor.BaseExecutor.Context( beam_pipeline_args=beam_pipeline_args, tmp_dir=tempfile.gettempdir(), unique_id='tfx_component', ) executor = component_class_instance.executor_spec.executor_class(executor_context) executor.Do( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, ) return (output_model_uri, ) import json import argparse _parser = argparse.ArgumentParser(prog='Trainer', description='') _parser.add_argument("--examples-uri", dest="examples_uri", type=str, required=True, default=argparse.SUPPRESS) _parser.add_argument("--schema-uri", dest="schema_uri", type=str, required=True, default=argparse.SUPPRESS) _parser.add_argument("--output-model-uri", dest="output_model_uri", type=str, required=True, default=argparse.SUPPRESS) _parser.add_argument("--train-args", dest="train_args", type=str, required=True, default=argparse.SUPPRESS) _parser.add_argument("--eval-args", dest="eval_args", type=str, required=True, default=argparse.SUPPRESS) _parser.add_argument("--transform-graph-uri", dest="transform_graph_uri", type=str, required=False, default=argparse.SUPPRESS) _parser.add_argument("--base-model-uri", dest="base_model_uri", type=str, required=False, default=argparse.SUPPRESS) _parser.add_argument("--hyperparameters-uri", dest="hyperparameters_uri", type=str, required=False, default=argparse.SUPPRESS) _parser.add_argument("--module-file", dest="module_file", type=str, required=False, default=argparse.SUPPRESS) _parser.add_argument("--run-fn", dest="run_fn", type=str, required=False, default=argparse.SUPPRESS) _parser.add_argument("--trainer-fn", dest="trainer_fn", type=str, required=False, default=argparse.SUPPRESS) _parser.add_argument("--custom-config", dest="custom_config", type=json.loads, required=False, default=argparse.SUPPRESS) _parser.add_argument("--beam-pipeline-args", dest="beam_pipeline_args", type=json.loads, required=False, default=argparse.SUPPRESS) _parser.add_argument("----output-paths", dest="_output_paths", type=str, nargs=1) _parsed_args = vars(_parser.parse_args()) _output_files = _parsed_args.pop("_output_paths", []) _outputs = Trainer(**_parsed_args) _output_serializers = [ str, ] import os for idx, output_file in enumerate(_output_files): try: os.makedirs(os.path.dirname(output_file)) except OSError: pass with open(output_file, 'w') as f: f.write(_output_serializers[idx](_outputs[idx])) args: - --examples-uri - {inputValue: examples_uri} - --schema-uri - {inputValue: schema_uri} - --output-model-uri - {inputValue: output_model_uri} - --train-args - {inputValue: train_args} - --eval-args - {inputValue: eval_args} - if: cond: {isPresent: transform_graph_uri} then: - --transform-graph-uri - {inputValue: transform_graph_uri} - if: cond: {isPresent: base_model_uri} then: - --base-model-uri - {inputValue: base_model_uri} - if: cond: {isPresent: hyperparameters_uri} then: - --hyperparameters-uri - {inputValue: hyperparameters_uri} - if: cond: {isPresent: module_file} then: - --module-file - {inputValue: module_file} - if: cond: {isPresent: run_fn} then: - --run-fn - {inputValue: run_fn} - if: cond: {isPresent: trainer_fn} then: - --trainer-fn - {inputValue: trainer_fn} - if: cond: {isPresent: custom_config} then: - --custom-config - {inputValue: custom_config} - if: cond: {isPresent: beam_pipeline_args} then: - --beam-pipeline-args - {inputValue: beam_pipeline_args} - '----output-paths' - {outputPath: model_uri}