name: Keras - Train classifier description: Trains classifier using Keras sequential model inputs: - {name: training_set_features_path, type: {GcsPath: {data_type: TSV}}, description: 'Local or GCS path to the training set features table.'} - {name: training_set_labels_path, type: {GcsPath: {data_type: TSV}}, description: 'Local or GCS path to the training set labels (each label is a class index from 0 to num-classes - 1).'} - {name: output_model_uri, type: {GcsPath: {data_type: Keras model}}, description: 'Local or GCS path specifying where to save the trained model. The model (topology + weights + optimizer state) is saved in HDF5 format and can be loaded back by calling keras.models.load_model'} #Remove GcsUri and move to outputs once artifact passing support is checked in. - {name: model_config, type: {GcsPath: {data_type: Keras model config json}}, description: 'JSON string containing the serialized model structure. Can be obtained by calling model.to_json() on a Keras model.'} - {name: number_of_classes, type: Integer, description: 'Number of classifier classes.'} - {name: number_of_epochs, type: Integer, default: '100', description: 'Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided.'} - {name: batch_size, type: Integer, default: '32', description: 'Number of samples per gradient update.'} outputs: - {name: output_model_uri, type: {GcsPath: {data_type: Keras model}}, description: 'GCS path where the trained model has been saved. The model (topology + weights + optimizer state) is saved in HDF5 format and can be loaded back by calling keras.models.load_model'} #Remove GcsUri and make it a proper output once artifact passing support is checked in. implementation: container: image: gcr.io/ml-pipeline/sample/keras/train_classifier command: [python3, /pipelines/component/src/train.py] args: [ --training-set-features-path, {inputValue: training_set_features_path}, --training-set-labels-path, {inputValue: training_set_labels_path}, --output-model-path, {inputValue: output_model_uri}, --model-config-json, {inputValue: model_config}, --num-classes, {inputValue: number_of_classes}, --num-epochs, {inputValue: number_of_epochs}, --batch-size, {inputValue: batch_size}, --output-model-path-file, {outputPath: output_model_uri}, ]