{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Diabetes prediction using synthesized health records\n", "\n", "This notebook explores how to train a machine learning model to predict type 2 diabetes using synthesized patient health records. The use of synthesized data allows us to learn about building a model without any concern about the privacy issues surrounding the use of real patient health records.\n", "\n", "## Prerequisites\n", "\n", "This project is part of a series of code patterns pertaining to a fictional health care company called Example Health. This company stores electronic health records in a database on a z/OS server. Before running the notebook, the synthesized health records must be created and loaded into this database. Another project, https://github.com/IBM/example-health-synthea, provides the steps for doing this. The records are created using a tool called Synthea (https://github.com/synthetichealth/synthea), transformed and loaded into the database." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load and prepare the data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Set up the information needed for a JDBC connection to your database below\n", "The database must be set up by following the instructions in https://github.com/IBM/example-health-synthea." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "credentials_1 = {\n", " 'host':'xxx.yyy.com',\n", " 'port':'nnnn',\n", " 'username':'user',\n", " 'password':'password',\n", " 'database':'location',\n", " 'schema':'SMHEALTH'\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define a function to load data from a database table into a Spark dataframe\n", "\n", "The partitionColumn, lowerBound, upperBound, and numPartitions options are used to load the data more quickly\n", "using multiple JDBC connections. The data is partitioned by patient id. It is assumed that there are approximately\n", "5000 patients in the database. If there are more or less patients, adjust the upperBound value appropriately." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def load_data_from_database(table_name):\n", " return (\n", " spark.read.format(\"jdbc\").options(\n", " driver = \"com.ibm.db2.jcc.DB2Driver\",\n", " url = \"jdbc:db2://\" + credentials_1[\"host\"] + \":\" + credentials_1[\"port\"] + \"/\" + credentials_1[\"database\"],\n", " user = credentials_1[\"username\"], \n", " password = credentials_1[\"password\"], \n", " dbtable = credentials_1[\"schema\"] + \".\" + table_name,\n", " partitionColumn = \"patientid\",\n", " lowerBound = 1,\n", " upperBound = 5000,\n", " numPartitions = 10\n", " ).load()\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read patient observations from the database\n", "\n", "The observations include things like blood pressure and cholesterol readings which are potential features for our model." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------+-----------------+--------+--------------------+------------+--------------+-------+\n", "|PATIENTID|DATEOFOBSERVATION| CODE| DESCRIPTION|NUMERICVALUE|CHARACTERVALUE| UNITS|\n", "+---------+-----------------+--------+--------------------+------------+--------------+-------+\n", "| 222| 2019-01-26|8302-2 | Body Height| 49.00| | cm|\n", "| 222| 2019-01-26|72514-3 |Pain severity - 0...| 1.70| |{score}|\n", "| 222| 2019-01-26|29463-7 | Body Weight| 4.50| | kg|\n", "| 222| 2019-01-26|6690-2 |Leukocytes [#/vol...| 5.10| |10*3/uL|\n", "| 222| 2019-01-26|789-8 |Erythrocytes [#/v...| 5.10| |10*6/uL|\n", "+---------+-----------------+--------+--------------------+------------+--------------+-------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "observations_df = load_data_from_database(\"OBSERVATIONS\")\n", "\n", "observations_df.show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The observations table has a generalized format with a separate row per observation\n", "\n", "Let's collect the observations that may be of interest in making a diabetes prediction.\n", "First, select systolic blood pressure readings from the observations. These have code 8480-6." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------+-----------------+--------+\n", "|patientid|dateofobservation|systolic|\n", "+---------+-----------------+--------+\n", "| 222| 2019-03-02| 101.30|\n", "| 72| 2009-05-16| 122.70|\n", "| 72| 2010-05-22| 129.10|\n", "| 72| 2011-05-28| 109.00|\n", "| 72| 2012-06-02| 135.40|\n", "+---------+-----------------+--------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "from pyspark.sql.functions import col\n", "\n", "systolic_observations_df = (\n", " observations_df.select(\"patientid\", \"dateofobservation\", \"numericvalue\")\n", " .withColumnRenamed(\"numericvalue\", \"systolic\")\n", " .filter((col(\"code\") == \"8480-6\"))\n", " )\n", "\n", "\n", "systolic_observations_df.show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Select other observations of potential interest\n", "\n", "* Select diastolic blood pressure readings (code 8462-4).\n", "* Select HDL cholesterol readings (code 2085-9).\n", "* Select LDL cholesterol readings (code 18262-6).\n", "* Select BMI (body mass index) readings (code 39156-5)." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "diastolic_observations_df = (\n", " observations_df.select(\"patientid\", \"dateofobservation\", \"numericvalue\")\n", " .withColumnRenamed('numericvalue', 'diastolic')\n", " .filter((col(\"code\") == \"8462-4\"))\n", " )\n", "\n", "hdl_observations_df = (\n", " observations_df.select(\"patientid\", \"dateofobservation\", \"numericvalue\")\n", " .withColumnRenamed('numericvalue', 'hdl')\n", " .filter((col(\"code\") == \"2085-9\"))\n", " )\n", "\n", "ldl_observations_df = (\n", " observations_df.select(\"patientid\", \"dateofobservation\", \"numericvalue\")\n", " .withColumnRenamed('numericvalue', 'ldl')\n", " .filter((col(\"code\") == \"18262-6\"))\n", " )\n", "\n", "bmi_observations_df = (\n", " observations_df.select(\"patientid\", \"dateofobservation\", \"numericvalue\")\n", " .withColumnRenamed('numericvalue', 'bmi')\n", " .filter((col(\"code\") == \"39156-5\"))\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Join the observations for each patient by date into one dataframe" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------+-----------------+--------+---------+-----+------+-----+\n", "|patientid|dateofobservation|systolic|diastolic| hdl| ldl| bmi|\n", "+---------+-----------------+--------+---------+-----+------+-----+\n", "| 4| 2011-12-17| 105.10| 77.10|71.00| 86.50|57.70|\n", "| 157| 2014-07-16| 138.00| 83.70|21.10|181.40|37.90|\n", "| 230| 2010-04-23| 164.70| 117.90|26.20|147.90|35.20|\n", "| 244| 2015-04-01| 119.00| 84.30|77.60| 96.20|25.50|\n", "| 290| 2018-08-21| 130.60| 70.90|73.90| 77.80|47.10|\n", "+---------+-----------------+--------+---------+-----+------+-----+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "merged_observations_df = (\n", " systolic_observations_df.join(diastolic_observations_df, [\"patientid\", \"dateofobservation\"])\n", " .join(hdl_observations_df, [\"patientid\", \"dateofobservation\"])\n", " .join(ldl_observations_df, [\"patientid\", \"dateofobservation\"])\n", " .join(bmi_observations_df, [\"patientid\", \"dateofobservation\"])\n", ")\n", "\n", "merged_observations_df.show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Another possible feature is the patient's age at the time of observation\n", "\n", "Load the patients' birth dates from the database into a dataframe." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------+-----------+\n", "|patientid|dateofbirth|\n", "+---------+-----------+\n", "| 1| 2017-07-04|\n", "| 2| 1965-04-14|\n", "| 3| 1996-09-14|\n", "| 4| 1958-11-29|\n", "| 5| 1979-01-28|\n", "+---------+-----------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "patients_df = load_data_from_database(\"PATIENT\").select(\"patientid\", \"dateofbirth\")\n", "\n", "patients_df.show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add a column containing the patient's age to the merged observations." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------+-----------------+--------+---------+-----+-----+-----+-----------------+\n", "|patientid|dateofobservation|systolic|diastolic| hdl| ldl| bmi| age|\n", "+---------+-----------------+--------+---------+-----+-----+-----+-----------------+\n", "| 463| 2016-02-13| 136.90| 81.10|66.60|76.20|35.80|55.57808219178082|\n", "| 463| 2013-01-26| 113.40| 77.50|77.30|91.40|35.80|52.52876712328767|\n", "| 463| 2019-03-02| 123.60| 71.60|73.80|95.50|35.80|58.62739726027397|\n", "| 463| 2010-01-09| 113.50| 70.60|71.20|76.00|35.80|49.47945205479452|\n", "| 471| 2017-07-12| 155.60| 99.00|59.00|83.70|38.30|35.19178082191781|\n", "+---------+-----------------+--------+---------+-----+-----+-----+-----------------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "from pyspark.sql.functions import datediff\n", "\n", "merged_observations_with_age_df = (\n", " merged_observations_df.join(patients_df, \"patientid\")\n", " .withColumn(\"age\", datediff(col(\"dateofobservation\"), col(\"dateofbirth\"))/365)\n", " .drop(\"dateofbirth\")\n", " )\n", "\n", "merged_observations_with_age_df.show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Find the patients that have been diagnosed with type 2 diabetes\n", "\n", "The conditions table contains the conditions that patients have and the date they were diagnosed.\n", "Load the patient conditions table and select the patients that have been diagnosed with type 2 diabetes.\n", "Keep the date they were diagnosed (\"start\" column)." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------+----------+\n", "|patientid| start|\n", "+---------+----------+\n", "| 66|2003-06-28|\n", "| 281|2012-07-20|\n", "| 230|2008-04-18|\n", "| 157|1994-12-28|\n", "| 251|2011-02-11|\n", "+---------+----------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "diabetics_df = (\n", " load_data_from_database(\"CONDITIONS\")\n", " .select(\"patientid\", \"start\")\n", " .filter(col(\"description\") == \"Diabetes\")\n", ")\n", "\n", "diabetics_df.show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a \"diabetic\" column which is the \"label\" for the model to predict\n", "\n", "Join the merged observations with the diabetic patients.\n", "This is a left join so that we keep all observations for both diabetic and non-diabetic patients.\n", "Create a new column with a binary value, 1=diabetic, 0=non-diabetic.\n", "This will be the label for the model (the value it is trying to predict)." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------+-----------------+--------+---------+-----+-----+-----+-----------------+-----+--------+\n", "|patientid|dateofobservation|systolic|diastolic| hdl| ldl| bmi| age|start|diabetic|\n", "+---------+-----------------+--------+---------+-----+-----+-----+-----------------+-----+--------+\n", "| 463| 2013-01-26| 113.40| 77.50|77.30|91.40|35.80|52.52876712328767| null| 0|\n", "| 463| 2010-01-09| 113.50| 70.60|71.20|76.00|35.80|49.47945205479452| null| 0|\n", "| 463| 2016-02-13| 136.90| 81.10|66.60|76.20|35.80|55.57808219178082| null| 0|\n", "| 463| 2019-03-02| 123.60| 71.60|73.80|95.50|35.80|58.62739726027397| null| 0|\n", "| 471| 2017-07-12| 155.60| 99.00|59.00|83.70|38.30|35.19178082191781| null| 0|\n", "+---------+-----------------+--------+---------+-----+-----+-----+-----------------+-----+--------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "from pyspark.sql.functions import when\n", "\n", "observations_and_condition_df = (\n", " merged_observations_with_age_df.join(diabetics_df, \"patientid\", \"left_outer\")\n", " .withColumn(\"diabetic\", when(col(\"start\").isNotNull(), 1).otherwise(0))\n", ")\n", "\n", "observations_and_condition_df.show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Filter the observations for diabetics to remove those taken before diagnosis\n", "\n", "This is driven by the way that the diabetes simulation works in Synthea. The impact of the condition (diabetes) is not reflected in the observations until the patient is diagnosed with the condition in a wellness visit. Prior to that the patient's observations won't be any different from a non-diabetic patient. Therefore we want only the observations at the time the patients were diabetic." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "observations_and_condition_df = (\n", " observations_and_condition_df.filter((col(\"diabetic\") == 0) | ((col(\"dateofobservation\") >= col(\"start\"))))\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reduce the observations to a single observation per patient (the earliest available observation)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.window import Window\n", "from pyspark.sql.functions import row_number\n", "\n", "w = Window.partitionBy(observations_and_condition_df[\"patientid\"]).orderBy(merged_observations_df[\"dateofobservation\"].asc())\n", "\n", "first_observation_df = observations_and_condition_df.withColumn(\"rn\", row_number().over(w)).where(col(\"rn\") == 1).drop(\"rn\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize data\n", "\n", "At this point we have collected some observations which might be relevant to making a diabetes prediction. The next step is to look for relationships between those observations and having diabetes. There are many tools that help visualize data to look for relationships. One of the easiest ones to use is called Pixiedust (https://github.com/pixiedust/pixiedust).\n", "\n", "Install the pixiedust visualization tool." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# !pip install --upgrade pixiedust" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Use Pixiedust to visualize whether observations correlate with diabetes\n", "\n", "The PixieDust interactive widget appears when you run this cell.\n", "* Click the chart button and choose Scatter Plot.\n", "* Click the chart options button. Drag \"ldl\" into the Keys box and drag \"hdl\" into the Values box.\n", "Set the # of Rows to Display to 5000. Click OK to close the chart options.\n", "* Select bokeh from the Renderer dropdown menu.\n", "* Select diabetic from the Color dropdown menu.\n", "\n", "The scatter plot chart appears.\n", "\n", "Click Options and try replacing \"ldl\" and \"hdl\" with other attributes." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "pixiedust": { "displayParams": { "chartsize": "100", "color": "diabetic", "handlerId": "scatterPlot", "keyFields": "ldl", "rendererId": "bokeh", "rowCount": "1000", "valueFields": "hdl" } } }, "outputs": [ { "data": { "text/html": [ "