name: SchemaGen inputs: - {name: statistics_uri, type: ExampleStatisticsUri} - {name: output_schema_uri, type: SchemaUri} - {name: infer_feature_shape, type: Boolean, optional: true} - {name: beam_pipeline_args, type: JsonArray, optional: true} outputs: - {name: schema_uri, type: SchemaUri} implementation: container: image: tensorflow/tfx:0.21.4 command: - python3 - -u - -c - | def SchemaGen( statistics_uri, output_schema_uri, infer_feature_shape = None, beam_pipeline_args = None, ): from tfx.components import SchemaGen as component_class #Generated code import json import os import tempfile import tensorflow from google.protobuf import json_format, message from tfx.types import channel_utils, artifact_utils from tfx.components.base import base_executor arguments = locals().copy() component_class_args = {} for name, execution_parameter in component_class.SPEC_CLASS.PARAMETERS.items(): argument_value = arguments.get(name, None) if argument_value is None: continue parameter_type = execution_parameter.type if isinstance(parameter_type, type) and issubclass(parameter_type, message.Message): argument_value_obj = parameter_type() json_format.Parse(argument_value, argument_value_obj) else: argument_value_obj = argument_value component_class_args[name] = argument_value_obj for name, channel_parameter in component_class.SPEC_CLASS.INPUTS.items(): artifact_path = arguments.get(name + '_uri') or arguments.get(name + '_path') if artifact_path: artifact = channel_parameter.type() artifact.uri = artifact_path.rstrip('/') + '/' # Some TFX components require that the artifact URIs end with a slash if channel_parameter.type.PROPERTIES and 'split_names' in channel_parameter.type.PROPERTIES: # Recovering splits subdirs = tensorflow.io.gfile.listdir(artifact_path) # Workaround for https://github.com/tensorflow/tensorflow/issues/39167 subdirs = [subdir.rstrip('/') for subdir in subdirs] artifact.split_names = artifact_utils.encode_split_names(sorted(subdirs)) component_class_args[name] = channel_utils.as_channel([artifact]) component_class_instance = component_class(**component_class_args) input_dict = channel_utils.unwrap_channel_dict(component_class_instance.inputs.get_all()) output_dict = channel_utils.unwrap_channel_dict(component_class_instance.outputs.get_all()) exec_properties = component_class_instance.exec_properties # Generating paths for output artifacts for name, artifacts in output_dict.items(): base_artifact_path = arguments.get('output_' + name + '_uri') or arguments.get(name + '_path') if base_artifact_path: # Are there still cases where output channel has multiple artifacts? for idx, artifact in enumerate(artifacts): subdir = str(idx + 1) if idx > 0 else '' artifact.uri = os.path.join(base_artifact_path, subdir) # Ends with '/' print('component instance: ' + str(component_class_instance)) # Workaround for a TFX+Beam bug to make DataflowRunner work. # Remove after the next release that has https://github.com/tensorflow/tfx/commit/ddb01c02426d59e8bd541e3fd3cbaaf68779b2df import tfx tfx.version.__version__ += 'dev' executor_context = base_executor.BaseExecutor.Context( beam_pipeline_args=beam_pipeline_args, tmp_dir=tempfile.gettempdir(), unique_id='tfx_component', ) executor = component_class_instance.executor_spec.executor_class(executor_context) executor.Do( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, ) return (output_schema_uri, ) def _deserialize_bool(s) -> bool: from distutils.util import strtobool return strtobool(s) == 1 import json import argparse _parser = argparse.ArgumentParser(prog='SchemaGen', description='') _parser.add_argument("--statistics-uri", dest="statistics_uri", type=str, required=True, default=argparse.SUPPRESS) _parser.add_argument("--output-schema-uri", dest="output_schema_uri", type=str, required=True, default=argparse.SUPPRESS) _parser.add_argument("--infer-feature-shape", dest="infer_feature_shape", type=_deserialize_bool, required=False, default=argparse.SUPPRESS) _parser.add_argument("--beam-pipeline-args", dest="beam_pipeline_args", type=json.loads, required=False, default=argparse.SUPPRESS) _parser.add_argument("----output-paths", dest="_output_paths", type=str, nargs=1) _parsed_args = vars(_parser.parse_args()) _output_files = _parsed_args.pop("_output_paths", []) _outputs = SchemaGen(**_parsed_args) _output_serializers = [ str, ] import os for idx, output_file in enumerate(_output_files): try: os.makedirs(os.path.dirname(output_file)) except OSError: pass with open(output_file, 'w') as f: f.write(_output_serializers[idx](_outputs[idx])) args: - --statistics-uri - {inputValue: statistics_uri} - --output-schema-uri - {inputValue: output_schema_uri} - if: cond: {isPresent: infer_feature_shape} then: - --infer-feature-shape - {inputValue: infer_feature_shape} - if: cond: {isPresent: beam_pipeline_args} then: - --beam-pipeline-args - {inputValue: beam_pipeline_args} - '----output-paths' - {outputPath: schema_uri}