{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# Clustering: Picking the 'K' hyperparameter\n", "The unsupervised machine learning technique of clustering data into similar groups can be useful and fairly efficient in most cases. The big trick is often how you pick the number of clusters to make (the K hyperparameter). \n", "The number of clusters may vary dramatically depending on the characteristics of the data, the different types of variables (numeric or categorical), how the data is normalized/encoded and the distance metric used.\n", "\n", "
\n", "\n", "**For this notebook we're going to focus specifically on the following:**\n", "- Optimizing the number of clusters (K hyperparameter) using Silhouette Scoring\n", "- Utilizing an algorithm (DBSCAN) that automatically determines the number of clusters\n", "\n", "\n", "### Software\n", "- Zeek Analysis Tools (ZAT): https://github.com/SuperCowPowers/zat\n", "- Pandas: https://github.com/pandas-dev/pandas\n", "- Scikit-Learn: http://scikit-learn.org/stable/index.html\n", "\n", "
\n", "\n", "### Techniques\n", "- One Hot Encoding: http://pandas.pydata.org/pandas-docs/stable/generated/pandas.get_dummies.html\n", "- t-SNE: https://distill.pub/2016/misread-tsne/\n", "- Kmeans: http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html\n", "- Silhouette Score: https://en.wikipedia.org/wiki/Silhouette_(clustering)\n", "- DBSCAN: https://en.wikipedia.org/wiki/DBSCAN" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ZAT: 0.3.6\n", "Pandas: 0.25.1\n", "Scikit Learn Version: 0.21.2\n" ] } ], "source": [ "# Third Party Imports\n", "import pandas as pd\n", "import numpy as np\n", "import sklearn\n", "from sklearn.manifold import TSNE\n", "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", "from sklearn.cluster import KMeans, DBSCAN\n", "\n", "# Local imports\n", "import zat\n", "from zat.log_to_dataframe import LogToDataFrame\n", "from zat.dataframe_to_matrix import DataFrameToMatrix\n", "\n", "# Good to print out versions of stuff\n", "print('ZAT: {:s}'.format(zat.__version__))\n", "print('Pandas: {:s}'.format(pd.__version__))\n", "print('Scikit Learn Version:', sklearn.__version__)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
uidid.orig_hid.orig_pid.resp_hid.resp_ptrans_depthmethodhosturireferrer...info_msgfilenametagsusernamepasswordproxiedorig_fuidsorig_mime_typesresp_fuidsresp_mime_types
ts
2013-09-15 23:44:27.668081999CyIaMO7IheOh38Zsi192.168.33.10103154.245.228.191801GETguyspy.com/NaN...NaNNaN(empty)NaNNaNNaNNaNNaNFnjq3r4R0VGmHVWiN5text/html
2013-09-15 23:44:27.731701851CoyZrY2g74UvMMgp4a192.168.33.10103254.245.228.191801GETwww.guyspy.com/NaN...NaNNaN(empty)NaNNaNNaNNaNNaNFCQ5aX37YzsjAKpcv8text/html
2013-09-15 23:44:28.092921972CoyZrY2g74UvMMgp4a192.168.33.10103254.245.228.191802GETwww.guyspy.com/wp-content/plugins/slider-pro/css/advanced-sl...http://www.guyspy.com/...NaNNaN(empty)NaNNaNNaNNaNNaNFD9Xu815Hwui3sniSftext/html
2013-09-15 23:44:28.150300980CiCKTz4e0fkYYazBS3192.168.33.10104054.245.228.191801GETwww.guyspy.com/wp-content/plugins/contact-form-7/includes/cs...http://www.guyspy.com/...NaNNaN(empty)NaNNaNNaNNaNNaNFMZXWm1yCdsCAU3K9dtext/plain
2013-09-15 23:44:28.150601864C1YBkC1uuO9bzndRvh192.168.33.10104154.245.228.191801GETwww.guyspy.com/wp-content/plugins/slider-pro/css/slider/adva...http://www.guyspy.com/...NaNNaN(empty)NaNNaNNaNNaNNaNFA4NM039Rf9Y8Sn2Rhtext/plain
\n", "

5 rows × 26 columns

\n", "
" ], "text/plain": [ " uid id.orig_h id.orig_p \\\n", "ts \n", "2013-09-15 23:44:27.668081999 CyIaMO7IheOh38Zsi 192.168.33.10 1031 \n", "2013-09-15 23:44:27.731701851 CoyZrY2g74UvMMgp4a 192.168.33.10 1032 \n", "2013-09-15 23:44:28.092921972 CoyZrY2g74UvMMgp4a 192.168.33.10 1032 \n", "2013-09-15 23:44:28.150300980 CiCKTz4e0fkYYazBS3 192.168.33.10 1040 \n", "2013-09-15 23:44:28.150601864 C1YBkC1uuO9bzndRvh 192.168.33.10 1041 \n", "\n", " id.resp_h id.resp_p trans_depth method \\\n", "ts \n", "2013-09-15 23:44:27.668081999 54.245.228.191 80 1 GET \n", "2013-09-15 23:44:27.731701851 54.245.228.191 80 1 GET \n", "2013-09-15 23:44:28.092921972 54.245.228.191 80 2 GET \n", "2013-09-15 23:44:28.150300980 54.245.228.191 80 1 GET \n", "2013-09-15 23:44:28.150601864 54.245.228.191 80 1 GET \n", "\n", " host \\\n", "ts \n", "2013-09-15 23:44:27.668081999 guyspy.com \n", "2013-09-15 23:44:27.731701851 www.guyspy.com \n", "2013-09-15 23:44:28.092921972 www.guyspy.com \n", "2013-09-15 23:44:28.150300980 www.guyspy.com \n", "2013-09-15 23:44:28.150601864 www.guyspy.com \n", "\n", " uri \\\n", "ts \n", "2013-09-15 23:44:27.668081999 / \n", "2013-09-15 23:44:27.731701851 / \n", "2013-09-15 23:44:28.092921972 /wp-content/plugins/slider-pro/css/advanced-sl... \n", "2013-09-15 23:44:28.150300980 /wp-content/plugins/contact-form-7/includes/cs... \n", "2013-09-15 23:44:28.150601864 /wp-content/plugins/slider-pro/css/slider/adva... \n", "\n", " referrer ... info_msg filename \\\n", "ts ... \n", "2013-09-15 23:44:27.668081999 NaN ... NaN NaN \n", "2013-09-15 23:44:27.731701851 NaN ... NaN NaN \n", "2013-09-15 23:44:28.092921972 http://www.guyspy.com/ ... NaN NaN \n", "2013-09-15 23:44:28.150300980 http://www.guyspy.com/ ... NaN NaN \n", "2013-09-15 23:44:28.150601864 http://www.guyspy.com/ ... NaN NaN \n", "\n", " tags username password proxied orig_fuids \\\n", "ts \n", "2013-09-15 23:44:27.668081999 (empty) NaN NaN NaN NaN \n", "2013-09-15 23:44:27.731701851 (empty) NaN NaN NaN NaN \n", "2013-09-15 23:44:28.092921972 (empty) NaN NaN NaN NaN \n", "2013-09-15 23:44:28.150300980 (empty) NaN NaN NaN NaN \n", "2013-09-15 23:44:28.150601864 (empty) NaN NaN NaN NaN \n", "\n", " orig_mime_types resp_fuids \\\n", "ts \n", "2013-09-15 23:44:27.668081999 NaN Fnjq3r4R0VGmHVWiN5 \n", "2013-09-15 23:44:27.731701851 NaN FCQ5aX37YzsjAKpcv8 \n", "2013-09-15 23:44:28.092921972 NaN FD9Xu815Hwui3sniSf \n", "2013-09-15 23:44:28.150300980 NaN FMZXWm1yCdsCAU3K9d \n", "2013-09-15 23:44:28.150601864 NaN FA4NM039Rf9Y8Sn2Rh \n", "\n", " resp_mime_types \n", "ts \n", "2013-09-15 23:44:27.668081999 text/html \n", "2013-09-15 23:44:27.731701851 text/html \n", "2013-09-15 23:44:28.092921972 text/html \n", "2013-09-15 23:44:28.150300980 text/plain \n", "2013-09-15 23:44:28.150601864 text/plain \n", "\n", "[5 rows x 26 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a Pandas dataframe from the Zeek log\n", "log_to_df = LogToDataFrame()\n", "http_df = log_to_df.create_dataframe('../data/http.log')\n", "\n", "# Print out the head of the dataframe\n", "http_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Our HTTP features are a mix of numeric and categorical data\n", "When we look at the http records some of the data is numerical and some of it is categorical so we'll need a way of handling both data types in a generalized way. We have a DataFrameToMatrix class that handles a lot of the details and mechanics of combining numerical and categorical data, we'll use below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "## Transformers\n", "**We'll now use the Scikit-Learn tranformer class to convert the Pandas DataFrame to a numpy ndarray (matrix). The transformer class takes care of many low-level details**\n", "* Applies 'one-hot' encoding for the Categorical fields\n", "* Normalizes the Numeric fields\n", "* The class can be serialized for use in training and evaluation\n", " * The categorical mappings are saved during training and applied at evaluation\n", " * The normalized field ranges are stored during training and applied at evaluation" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Normalizing column id.resp_p...\n", "Normalizing column request_body_len...\n", "\n", "NOTE: The resulting numpy matrix has 12 dimensions based on one-hot encoding\n", "(150, 12)\n" ] }, { "data": { "text/plain": [ "array([[0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0.]], dtype=float32)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We're going to pick some features that might be interesting\n", "# some of the features are numerical and some are categorical\n", "features = ['id.resp_p', 'method', 'resp_mime_types', 'request_body_len']\n", "\n", "# Use the DataframeToMatrix class (handles categorical data)\n", "# You can see below it uses a heuristic to detect category data. When doing\n", "# this for real we should explicitly convert before sending to the transformer.\n", "to_matrix = DataFrameToMatrix()\n", "http_feature_matrix = to_matrix.fit_transform(http_df[features], normalize=True)\n", "\n", "print('\\nNOTE: The resulting numpy matrix has 12 dimensions based on one-hot encoding')\n", "print(http_feature_matrix.shape)\n", "http_feature_matrix[:1]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Plotting defaults\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "plt.rcParams['font.size'] = 12.0\n", "plt.rcParams['figure.figsize'] = 14.0, 7.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# Silhouette Scoring\n", "\"The silhouette value is a measure of how similar an object is to its own cluster (cohesion) compared to other clusters (separation). The silhouette ranges from -1 to 1, where a high value indicates that the object is well matched to its own cluster and poorly matched to neighboring clusters. If most objects have a high value, then the clustering configuration is appropriate. If many points have a low or negative value, then the clustering configuration may have too many or too few clusters.\"\n", "- https://en.wikipedia.org/wiki/Silhouette_(clustering)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import silhouette_score\n", "\n", "scores = []\n", "clusters = range(2,16)\n", "for K in clusters:\n", " \n", " clusterer = KMeans(n_clusters=K)\n", " cluster_labels = clusterer.fit_predict(http_feature_matrix)\n", " score = silhouette_score(http_feature_matrix, cluster_labels)\n", " scores.append(score)\n", "\n", "# Plot it out\n", "pd.DataFrame({'Num Clusters':clusters, 'score':scores}).plot(x='Num Clusters', y='score')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Silhouette graphs shows that 10 is the 'optimal' number of clusters\n", "- 'Optimal': Human intuition and clustering involves interpretation/pattern finding and is often partially subjective :)\n", "- For large datasets running an exhaustive search can be time consuming\n", "- For large datasets you can often get a large K using max score, so pick the 'knee' of the graph as your K" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# So we know that the highest (closest to 1) silhouette score is at 10 clusters\n", "kmeans = KMeans(n_clusters=10).fit_predict(http_feature_matrix)\n", "\n", "# TSNE is a great projection algorithm. In this case we're going from 12 dimensions to 2\n", "projection = TSNE().fit_transform(http_feature_matrix)\n", "\n", "# Now we can put our ML results back onto our dataframe!\n", "http_df['cluster'] = kmeans\n", "http_df['x'] = projection[:, 0] # Projection X Column\n", "http_df['y'] = projection[:, 1] # Projection Y Column" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Now use dataframe group by cluster\n", "cluster_groups = http_df.groupby('cluster')\n", "\n", "# Plot the Machine Learning results\n", "colors = {-1:'black', 0:'green', 1:'blue', 2:'red', 3:'orange', 4:'purple', 5:'brown', 6:'pink', 7:'lightblue', 8:'grey', 9:'yellow'}\n", "fig, ax = plt.subplots()\n", "for key, group in cluster_groups:\n", " group.plot(ax=ax, kind='scatter', x='x', y='y', alpha=0.5, s=250,\n", " label='Cluster: {:d}'.format(key), color=colors[key])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Cluster 0: 13 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 23:44:40.230550051 80 GET application/pdf 0\n", "2013-09-15 23:44:40.230550051 80 GET application/pdf 0\n", "2013-09-15 23:44:40.230550051 80 GET application/pdf 0\n", "\n", "Cluster 1: 40 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 23:44:28.150300980 80 GET text/plain 0\n", "2013-09-15 23:44:28.150601864 80 GET text/plain 0\n", "2013-09-15 23:44:28.192918062 80 GET text/plain 0\n", "\n", "Cluster 2: 22 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 23:44:30.064238070 80 GET image/jpeg 0\n", "2013-09-15 23:44:30.104156017 80 GET image/jpeg 0\n", "2013-09-15 23:44:30.725122929 80 GET image/jpeg 0\n", "\n", "Cluster 3: 14 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 23:44:31.386100054 80 GET NaN 0\n", "2013-09-15 23:44:31.417192936 80 GET NaN 0\n", "2013-09-15 23:44:31.471001863 80 GET NaN 0\n", "\n", "Cluster 4: 10 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 23:48:10.495719910 80 POST text/plain 69823\n", "2013-09-15 23:48:11.495719910 80 POST text/plain 69993\n", "2013-09-15 23:48:12.495719910 80 POST text/plain 71993\n", "\n", "Cluster 5: 15 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 23:44:30.061532021 80 GET image/png 0\n", "2013-09-15 23:44:30.061532021 80 GET image/png 0\n", "2013-09-15 23:44:30.063459873 80 GET image/png 0\n", "\n", "Cluster 6: 14 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 23:44:27.668081999 80 GET text/html 0\n", "2013-09-15 23:44:27.731701851 80 GET text/html 0\n", "2013-09-15 23:44:28.092921972 80 GET text/html 0\n", "\n", "Cluster 7: 8 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 23:44:47.464160919 80 GET application/x-dosexec 0\n", "2013-09-15 23:44:47.464160919 80 GET application/x-dosexec 0\n", "2013-09-15 23:44:49.221977949 80 GET application/x-dosexec 0\n", "\n", "Cluster 8: 7 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 23:48:06.495719910 80 OPTIONS text/plain 0\n", "2013-09-15 23:48:07.495719910 80 OPTIONS text/plain 0\n", "2013-09-15 23:48:08.495719910 80 OPTIONS text/plain 0\n", "\n", "Cluster 9: 7 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 23:48:03.495719910 8080 GET text/plain 0\n", "2013-09-15 23:48:04.495719910 8080 GET text/plain 0\n", "2013-09-15 23:48:04.495719910 8080 GET text/plain 0\n" ] } ], "source": [ "# Now print out the details for each cluster\n", "pd.set_option('display.width', 1000)\n", "for key, group in cluster_groups:\n", " print('\\nCluster {:d}: {:d} observations'.format(key, len(group)))\n", " print(group[features].head(3))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# Look Ma... No K!\n", "\n", "### DBSCAN\n", "Density-based spatial clustering is a data clustering algorithm that given a set of points in space, groups points that are closely packed together and marking low-density regions as outliers.\n", "\n", "- You don't have to pick K\n", "- There are other hyperparameters (eps and min_samples) but defaults often work well\n", "- https://en.wikipedia.org/wiki/DBSCAN\n", "- Hierarchical version: https://github.com/scikit-learn-contrib/hdbscan" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of Clusters: 10\n" ] } ], "source": [ "# Now try DBScan\n", "http_df['cluster_db'] = DBSCAN().fit_predict(http_feature_matrix)\n", "print('Number of Clusters: {:d}'.format(http_df['cluster_db'].nunique()))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Now use dataframe group by cluster\n", "cluster_groups = http_df.groupby('cluster_db')\n", "\n", "# Plot the Machine Learning results\n", "fig, ax = plt.subplots()\n", "for key, group in cluster_groups:\n", " group.plot(ax=ax, kind='scatter', x='x', y='y', alpha=0.5, s=250,\n", " label='Cluster: {:d}'.format(key), color=colors[key])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# DBSCAN automagically determined 10 clusters!\n", "So obviously we got a bit lucky here and for different datasets with different feature distributions DBSCAN may not give you the optimal number of clusters right off the zat. There are two hyperparameters that can be tweeked but like we said the defaults often work well. See the DBSCAN and Hierarchical DBSCAN links for more information.\n", "\n", "- https://en.wikipedia.org/wiki/DBSCAN\n", "- Hierarchical version: https://github.com/scikit-learn-contrib/hdbscan\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Wrap Up\n", "Well that's it for this notebook, given the usefulness and relatively efficiency of clustering it a good technique to include in your toolset. Understanding the K hyperparameter and how to determine optimal K (or not if you're using DBSCAN) is a good trick to know.\n", "\n", "If you liked this notebook please visit the [zat](https://github.com/SuperCowPowers/zat) project for more notebooks and examples." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 1 }