#!/usr/bin/env python3 import argparse import boto3 from botocore.exceptions import ClientError import csv import gzip import io import numpy as np import pickle import sys import urllib.request DATA_URL = 'http://deeplearning.net/data/mnist/mnist.pkl.gz' def main(): s3_bucket, s3_prefix = parse_args() for data_partition_name, csv_data in convert_data_to_csv(get_data_partitions(DATA_URL)): key = f'{s3_prefix}/{data_partition_name}/examples' url = f's3://{s3_bucket}/{key}' print(f'Uploading {csv_data.getbuffer().nbytes} bytes to {url}') try: boto3.Session().resource('s3').Bucket(s3_bucket).Object(key).upload_fileobj(csv_data) except ClientError as e: print(f'Unable to upload {url}: {e}') return 1 def parse_args(): parser = argparse.ArgumentParser(description='Helper script that splits the MNIST dataset into train, validation, and test sets, then uploads them in CSV format to the specified s3 bucket') parser.add_argument('--s3-bucket', type=str, required=True) parser.add_argument('--s3-prefix', type=str, required=True) args = parser.parse_args() return args.s3_bucket, args.s3_prefix def get_data_partitions(url): ''' Download and pickle dataset. ''' print(f'Downloading dataset from {url}') with gzip.GzipFile(fileobj=urllib.request.urlopen(url), mode='r') as gzip_file: train_set, valid_set, test_set = pickle.load(gzip_file, encoding='latin1') return [('train', train_set), ('validation', valid_set), ('test', test_set)] def convert_data_to_csv(data_partitions): ''' Convert np data partitions to csv format. ''' for data_partition_name, data_partition in data_partitions: print(f'{data_partition_name}: {data_partition[0].shape} {data_partition[1].shape}') labels = [t.tolist() for t in data_partition[1]] features = [t.tolist() for t in data_partition[0]] if data_partition_name != 'test': examples = np.insert(features, 0, labels, axis=1) else: examples = features with io.BytesIO() as f: np.savetxt(f, examples, delimiter=',') f.seek(0) yield data_partition_name, f if __name__ == '__main__': sys.exit(main())