In [None]:
import sys
!{sys.executable} -m pip install PyAthena

In [None]:
from pyathena import connect 
import pandas as pd
import sagemaker
import matplotlib.pyplot as plt

#TODO: Change the bucket to point to an s3 bucket to use for model output and training data
bucket = 'athena-federation-test'
output_location = 's3://' + bucket + '/athena-ml/'
connection = connect(s3_staging_dir=output_location, region_name='us-east-1') 

In [None]:
create_table = \
"""
CREATE EXTERNAL TABLE `taxi_ridership_data`(
 `time` string , 
 `number` int)
ROW FORMAT SERDE 
 'org.apache.hadoop.hive.serde2.OpenCSVSerde' 
WITH SERDEPROPERTIES ( 
 'separatorChar'=',') 
STORED AS INPUTFORMAT 
 'org.apache.hadoop.mapred.TextInputFormat' 
OUTPUTFORMAT 
 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
LOCATION
 's3://athena-examples-us-east-1/workshop-ml/'
"""

In [None]:
## Create a new Athena table holding data we will use to predict anomalies
pd.read_sql(create_table, connection) 

In [None]:
## Select the entire dataset and save it to a variable to be used later to fit the model.
results = pd.read_sql("SELECT * FROM default.taxi_ridership_data", connection) 

In [None]:
## Lets see the data we are working with
results

In [None]:
from sagemaker import RandomCutForest

prefix = 'athena-ml/anomalydetection'
execution_role = sagemaker.get_execution_role()
session = sagemaker.Session()

# specify general training job information
rcf = RandomCutForest(role=execution_role,
 train_instance_count=1,
 train_instance_type='ml.c4.8xlarge',
 data_location='s3://{}/{}/'.format(bucket, prefix),
 output_path='s3://{}/{}/output'.format(bucket, prefix),
 num_samples_per_tree=512,
 num_trees=50)

# Run the training job using the results we got from the Athena query earlier
rcf.fit(rcf.record_set(results.number.values.reshape(-1,1)))

print('Training job name: {}'.format(rcf.latest_training_job.job_name))

rcf_inference = rcf.deploy(
 initial_instance_count=1,
 instance_type='ml.c4.8xlarge',
)

print('\nEndpoint name (used by Athena): {}'.format(rcf_inference.endpoint))