{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "# Clustering: Picking the 'K' hyperparameter\n", "The unsupervised machine learning technique of clustering data into similar groups can be useful and fairly efficient in most cases. The big trick is often how you pick the number of clusters to make (the K hyperparameter). \n", "The number of clusters may vary dramatically depending on the characteristics of the data, the different types of variables (numeric or categorical), how the data is normalized/encoded and the distance metric used.\n", "\n", "\n", "\n", "**For this notebook we're going to focus specifically on the following:**\n", "- Optimizing the number of clusters (K hyperparameter) using Silhouette Scoring\n", "- Utilizing an algorithm (DBSCAN) that automatically determines the number of clusters\n", "\n", "\n", "### Software\n", "- Bro Analysis Tools (BAT): https://github.com/SuperCowPowers/bat\n", "- Pandas: https://github.com/pandas-dev/pandas\n", "- Scikit-Learn: http://scikit-learn.org/stable/index.html\n", "\n", "\n", "\n", "### Techniques\n", "- [One Hot Encoding](http://pandas.pydata.org/pandas-docs/stable/generated/pandas.get_dummies.html)\n", "- [t-SNE](https://distill.pub/2016/misread-tsne/)\n", "- [Kmeans](http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html)\n", "- [Silhouette Score](https://en.wikipedia.org/wiki/Silhouette_(clustering) )\n", "- [DBSCAN](https://en.wikipedia.org/wiki/DBSCAN)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BAT: 0.3.4\n", "Pandas: 0.23.4\n", "Scikit Learn Version: 0.20.0\n" ] } ], "source": [ "# Third Party Imports\n", "import pandas as pd\n", "import numpy as np\n", "import sklearn\n", "from sklearn.manifold import TSNE\n", "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", "from sklearn.cluster import KMeans, DBSCAN\n", "\n", "# Local imports\n", "import bat\n", "from bat.log_to_dataframe import LogToDataFrame\n", "from bat.dataframe_to_matrix import DataFrameToMatrix\n", "\n", "# Good to print out versions of stuff\n", "print('BAT: {:s}'.format(bat.__version__))\n", "print('Pandas: {:s}'.format(pd.__version__))\n", "print('Scikit Learn Version:', sklearn.__version__)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Successfully monitoring data/http.log...\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
filenamehostid.orig_hid.orig_pid.resp_hid.resp_pinfo_codeinfo_msgmethodorig_fuids...resp_mime_typesresponse_body_lenstatus_codestatus_msgtagstrans_depthuiduriuser_agentusername
ts
2013-09-15 17:44:27.668082-guyspy.com192.168.33.10103154.245.228.191800-GET-...text/html184301Moved Permanently(empty)1CyIaMO7IheOh38Zsi/Mozilla/4.0 (compatible; MSIE 8.0; Windows NT ...-
2013-09-15 17:44:27.731702-www.guyspy.com192.168.33.10103254.245.228.191800-GET-...text/html100631200OK(empty)1CoyZrY2g74UvMMgp4a/Mozilla/4.0 (compatible; MSIE 8.0; Windows NT ...-
2013-09-15 17:44:28.092922-www.guyspy.com192.168.33.10103254.245.228.191800-GET-...text/html55817404Not Found(empty)2CoyZrY2g74UvMMgp4a/wp-content/plugins/slider-pro/css/advanced-sl...Mozilla/4.0 (compatible; MSIE 8.0; Windows NT ...-
2013-09-15 17:44:28.150301-www.guyspy.com192.168.33.10104054.245.228.191800-GET-...text/plain887200OK(empty)1CiCKTz4e0fkYYazBS3/wp-content/plugins/contact-form-7/includes/cs...Mozilla/4.0 (compatible; MSIE 8.0; Windows NT ...-
2013-09-15 17:44:28.150602-www.guyspy.com192.168.33.10104154.245.228.191800-GET-...text/plain10068200OK(empty)1C1YBkC1uuO9bzndRvh/wp-content/plugins/slider-pro/css/slider/adva...Mozilla/4.0 (compatible; MSIE 8.0; Windows NT ...-
\n", "

5 rows × 26 columns

\n", "
" ], "text/plain": [ " filename host id.orig_h id.orig_p \\\n", "ts \n", "2013-09-15 17:44:27.668082 - guyspy.com 192.168.33.10 1031 \n", "2013-09-15 17:44:27.731702 - www.guyspy.com 192.168.33.10 1032 \n", "2013-09-15 17:44:28.092922 - www.guyspy.com 192.168.33.10 1032 \n", "2013-09-15 17:44:28.150301 - www.guyspy.com 192.168.33.10 1040 \n", "2013-09-15 17:44:28.150602 - www.guyspy.com 192.168.33.10 1041 \n", "\n", " id.resp_h id.resp_p info_code info_msg \\\n", "ts \n", "2013-09-15 17:44:27.668082 54.245.228.191 80 0 - \n", "2013-09-15 17:44:27.731702 54.245.228.191 80 0 - \n", "2013-09-15 17:44:28.092922 54.245.228.191 80 0 - \n", "2013-09-15 17:44:28.150301 54.245.228.191 80 0 - \n", "2013-09-15 17:44:28.150602 54.245.228.191 80 0 - \n", "\n", " method orig_fuids ... resp_mime_types \\\n", "ts ... \n", "2013-09-15 17:44:27.668082 GET - ... text/html \n", "2013-09-15 17:44:27.731702 GET - ... text/html \n", "2013-09-15 17:44:28.092922 GET - ... text/html \n", "2013-09-15 17:44:28.150301 GET - ... text/plain \n", "2013-09-15 17:44:28.150602 GET - ... text/plain \n", "\n", " response_body_len status_code status_msg \\\n", "ts \n", "2013-09-15 17:44:27.668082 184 301 Moved Permanently \n", "2013-09-15 17:44:27.731702 100631 200 OK \n", "2013-09-15 17:44:28.092922 55817 404 Not Found \n", "2013-09-15 17:44:28.150301 887 200 OK \n", "2013-09-15 17:44:28.150602 10068 200 OK \n", "\n", " tags trans_depth uid \\\n", "ts \n", "2013-09-15 17:44:27.668082 (empty) 1 CyIaMO7IheOh38Zsi \n", "2013-09-15 17:44:27.731702 (empty) 1 CoyZrY2g74UvMMgp4a \n", "2013-09-15 17:44:28.092922 (empty) 2 CoyZrY2g74UvMMgp4a \n", "2013-09-15 17:44:28.150301 (empty) 1 CiCKTz4e0fkYYazBS3 \n", "2013-09-15 17:44:28.150602 (empty) 1 C1YBkC1uuO9bzndRvh \n", "\n", " uri \\\n", "ts \n", "2013-09-15 17:44:27.668082 / \n", "2013-09-15 17:44:27.731702 / \n", "2013-09-15 17:44:28.092922 /wp-content/plugins/slider-pro/css/advanced-sl... \n", "2013-09-15 17:44:28.150301 /wp-content/plugins/contact-form-7/includes/cs... \n", "2013-09-15 17:44:28.150602 /wp-content/plugins/slider-pro/css/slider/adva... \n", "\n", " user_agent \\\n", "ts \n", "2013-09-15 17:44:27.668082 Mozilla/4.0 (compatible; MSIE 8.0; Windows NT ... \n", "2013-09-15 17:44:27.731702 Mozilla/4.0 (compatible; MSIE 8.0; Windows NT ... \n", "2013-09-15 17:44:28.092922 Mozilla/4.0 (compatible; MSIE 8.0; Windows NT ... \n", "2013-09-15 17:44:28.150301 Mozilla/4.0 (compatible; MSIE 8.0; Windows NT ... \n", "2013-09-15 17:44:28.150602 Mozilla/4.0 (compatible; MSIE 8.0; Windows NT ... \n", "\n", " username \n", "ts \n", "2013-09-15 17:44:27.668082 - \n", "2013-09-15 17:44:27.731702 - \n", "2013-09-15 17:44:28.092922 - \n", "2013-09-15 17:44:28.150301 - \n", "2013-09-15 17:44:28.150602 - \n", "\n", "[5 rows x 26 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a Pandas dataframe from the Bro log\n", "http_df = LogToDataFrame('data/http.log')\n", "\n", "# Print out the head of the dataframe\n", "http_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Our HTTP features are numeric and categorical data\n", "When we look at the http records some of the data is numerical and some of it is categorical so we'll need a way of handling both data types in a generalized way. We have a DataFrameToMatrix class that handles a lot of the details and mechanics of combining numerical and categorical data, we'll use below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "## Transformers\n", "**We'll now use the Scikit-Learn tranformer class to convert the Pandas DataFrame to a numpy ndarray (matrix). The transformer class takes care of many low-level details**\n", "* Applies 'one-hot' encoding for the Categorical fields\n", "* Normalizes the Numeric fields\n", "* The class can be serialized for use in training and evaluation\n", " * The categorical mappings are saved during training and applied at evaluation\n", " * The normalized field ranges are stored during training and applied at evaluation" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Changing column method to category...\n", "Changing column resp_mime_types to category...\n", "Normalizing column id.resp_p...\n", "Normalizing column request_body_len...\n", "\n", "NOTE: The resulting numpy matrix has 12 dimensions based on one-hot encoding\n", "(150, 12)\n" ] }, { "data": { "text/plain": [ "array([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0.]])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We're going to pick some features that might be interesting\n", "# some of the features are numerical and some are categorical\n", "features = ['id.resp_p', 'method', 'resp_mime_types', 'request_body_len']\n", "\n", "# Use the DataframeToMatrix class (handles categorical data)\n", "# You can see below it uses a heuristic to detect category data. When doing\n", "# this for real we should explicitly convert before sending to the transformer.\n", "to_matrix = DataFrameToMatrix()\n", "http_feature_matrix = to_matrix.fit_transform(http_df[features], normalize=True)\n", "\n", "print('\\nNOTE: The resulting numpy matrix has 12 dimensions based on one-hot encoding')\n", "print(http_feature_matrix.shape)\n", "http_feature_matrix[:1]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Plotting defaults\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "plt.rcParams['font.size'] = 12.0\n", "plt.rcParams['figure.figsize'] = 14.0, 7.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "# Silhouette Scoring\n", "\"The silhouette value is a measure of how similar an object is to its own cluster (cohesion) compared to other clusters (separation). The silhouette ranges from -1 to 1, where a high value indicates that the object is well matched to its own cluster and poorly matched to neighboring clusters. If most objects have a high value, then the clustering configuration is appropriate. If many points have a low or negative value, then the clustering configuration may have too many or too few clusters.\"\n", "- https://en.wikipedia.org/wiki/Silhouette_(clustering)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import silhouette_score\n", "\n", "scores = []\n", "clusters = range(2,16)\n", "for K in clusters:\n", " \n", " clusterer = KMeans(n_clusters=K)\n", " cluster_labels = clusterer.fit_predict(http_feature_matrix)\n", " score = silhouette_score(http_feature_matrix, cluster_labels)\n", " scores.append(score)\n", "\n", "# Plot it out\n", "pd.DataFrame({'Num Clusters':clusters, 'score':scores}).plot(x='Num Clusters', y='score')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Silhouette graphs shows that 10 is the 'optimal' number of clusters\n", "- 'Optimal': Human intuition and clustering involves interpretation/pattern finding and is often partially subjective :)\n", "- For large datasets running an exhaustive search can be time consuming\n", "- For large datasets you can often get a large K using max score, so pick the 'knee' of the graph as your K" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# So we know that the highest (closest to 1) silhouette score is at 10 clusters\n", "kmeans = KMeans(n_clusters=10).fit_predict(http_feature_matrix)\n", "\n", "# TSNE is a great projection algorithm. In this case we're going from 12 dimensions to 2\n", "projection = TSNE().fit_transform(http_feature_matrix)\n", "\n", "# Now we can put our ML results back onto our dataframe!\n", "http_df['cluster'] = kmeans\n", "http_df['x'] = projection[:, 0] # Projection X Column\n", "http_df['y'] = projection[:, 1] # Projection Y Column" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Now use dataframe group by cluster\n", "cluster_groups = http_df.groupby('cluster')\n", "\n", "# Plot the Machine Learning results\n", "colors = {-1:'black', 0:'green', 1:'blue', 2:'red', 3:'orange', 4:'purple', 5:'brown', 6:'pink', 7:'lightblue', 8:'grey', 9:'yellow'}\n", "fig, ax = plt.subplots()\n", "for key, group in cluster_groups:\n", " group.plot(ax=ax, kind='scatter', x='x', y='y', alpha=0.5, s=250,\n", " label='Cluster: {:d}'.format(key), color=colors[key])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Cluster 0: 7 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 17:48:03.495720 8080 GET text/plain 0\n", "2013-09-15 17:48:04.495720 8080 GET text/plain 0\n", "2013-09-15 17:48:04.495720 8080 GET text/plain 0\n", "\n", "Cluster 1: 40 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 17:44:28.150301 80 GET text/plain 0\n", "2013-09-15 17:44:28.150602 80 GET text/plain 0\n", "2013-09-15 17:44:28.192918 80 GET text/plain 0\n", "\n", "Cluster 2: 22 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 17:44:30.064238 80 GET image/jpeg 0\n", "2013-09-15 17:44:30.104156 80 GET image/jpeg 0\n", "2013-09-15 17:44:30.725123 80 GET image/jpeg 0\n", "\n", "Cluster 3: 15 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 17:44:30.061532 80 GET image/png 0\n", "2013-09-15 17:44:30.061532 80 GET image/png 0\n", "2013-09-15 17:44:30.063460 80 GET image/png 0\n", "\n", "Cluster 4: 14 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 17:44:31.386100 80 GET - 0\n", "2013-09-15 17:44:31.417193 80 GET - 0\n", "2013-09-15 17:44:31.471002 80 GET - 0\n", "\n", "Cluster 5: 10 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 17:48:10.495720 80 POST text/plain 69823\n", "2013-09-15 17:48:11.495720 80 POST text/plain 69993\n", "2013-09-15 17:48:12.495720 80 POST text/plain 71993\n", "\n", "Cluster 6: 14 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 17:44:27.668082 80 GET text/html 0\n", "2013-09-15 17:44:27.731702 80 GET text/html 0\n", "2013-09-15 17:44:28.092922 80 GET text/html 0\n", "\n", "Cluster 7: 8 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 17:44:47.464161 80 GET application/x-dosexec 0\n", "2013-09-15 17:44:47.464161 80 GET application/x-dosexec 0\n", "2013-09-15 17:44:49.221978 80 GET application/x-dosexec 0\n", "\n", "Cluster 8: 13 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 17:44:40.230550 80 GET application/pdf 0\n", "2013-09-15 17:44:40.230550 80 GET application/pdf 0\n", "2013-09-15 17:44:40.230550 80 GET application/pdf 0\n", "\n", "Cluster 9: 7 observations\n", " id.resp_p method resp_mime_types request_body_len\n", "ts \n", "2013-09-15 17:48:06.495720 80 OPTIONS text/plain 0\n", "2013-09-15 17:48:07.495720 80 OPTIONS text/plain 0\n", "2013-09-15 17:48:08.495720 80 OPTIONS text/plain 0\n" ] } ], "source": [ "# Now print out the details for each cluster\n", "pd.set_option('display.width', 1000)\n", "for key, group in cluster_groups:\n", " print('\\nCluster {:d}: {:d} observations'.format(key, len(group)))\n", " print(group[features].head(3))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "# Look Ma... No K!\n", "### DBSCAN\n", "Density-based spatial clustering is a data clustering algorithm that given a set of points in space, groups points that are closely packed together and marking low-density regions as outliers.\n", "\n", "- You don't have to pick K\n", "- There are other hyperparameters (eps and min_samples) but defaults often work well\n", "- https://en.wikipedia.org/wiki/DBSCAN\n", "- Hierarchical version: https://github.com/scikit-learn-contrib/hdbscan" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of Clusters: 10\n" ] } ], "source": [ "# Now try DBScan\n", "http_df['cluster_db'] = DBSCAN().fit_predict(http_feature_matrix)\n", "print('Number of Clusters: {:d}'.format(http_df['cluster_db'].nunique()))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Now use dataframe group by cluster\n", "cluster_groups = http_df.groupby('cluster_db')\n", "\n", "# Plot the Machine Learning results\n", "fig, ax = plt.subplots()\n", "for key, group in cluster_groups:\n", " group.plot(ax=ax, kind='scatter', x='x', y='y', alpha=0.5, s=250,\n", " label='Cluster: {:d}'.format(key), color=colors[key])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "# DBSCAN automagically determined 10 clusters!\n", "So obviously we got a bit lucky here and for different datasets with different feature distributions DBSCAN may not give you the optimal number of clusters right off the bat. There are two hyperparameters that can be tweeked but like we said the defaults often work well. See the DBSCAN and Hierarchical DBSCAN links for more information.\n", "\n", "- https://en.wikipedia.org/wiki/DBSCAN\n", "- Hierarchical version: https://github.com/scikit-learn-contrib/hdbscan\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "## Wrap Up\n", "Well that's it for this notebook, given the usefulness and relatively efficiency of clustering it a good technique to include in your toolset. Understanding the K hyperparameter and how to determine optimal K (or not if you're using DBSCAN) is a good trick to know.\n", "\n", "If you liked this notebook please visit [SCP Labs](https://github.com/SuperCowPowers/scp-labs) for more notebooks and examples, or visit our company page for consulting and development services [SuperCowPowers](https://www.supercowpowers.com)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n" ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# This cell is simply for adding some CSS (Ignore it :)\n", "from IPython.core.display import HTML\n", "def css_styling():\n", " styles = open(\"styles/custom.css\", \"r\").read()\n", " return HTML(styles)\n", "css_styling()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 1 }