{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Clustering for dataset exploration\n", "> A Summary of lecture \"Unsupervised Learning with scikit-learn\", via datacamp\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Datacamp, Machine_Learning]\n", "- image: images/kmeans-centroid.png" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Unsupervised Learning\n", "- Unsupervised Learning\n", " - Finds patterns in data (E.g., clustering customers by their purchase)\n", " - Compressing the data using purchase patterns (dimension reduction)\n", "- Supervised vs unsupervised learning\n", " - Supervised learning finds patterns for a prediction task \n", " \n", " e.g., classify tumors as benign or cancerous (labels)\n", " - Unsupervised learning finds patterns in data, but without a specific prediction task in mind\n", "- K-means clustering\n", " - Finds clusters of samples\n", " - Number of clusters must be specified\n", "- Cluster labels for new samples\n", " - New samples can be assigned to existing clusters\n", " - k-means remembers the mean of each cluster (the \"centroids\")\n", " - Finds the nearest centroid to each new sample" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How many clusters?" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "points = np.array([[ 0.06544649, -0.76866376],\n", " [-1.52901547, -0.42953079],\n", " [ 1.70993371, 0.69885253],\n", " [ 1.16779145, 1.01262638],\n", " [-1.80110088, -0.31861296],\n", " [-1.63567888, -0.02859535],\n", " [ 1.21990375, 0.74643463],\n", " [-0.26175155, -0.62492939],\n", " [-1.61925804, -0.47983949],\n", " [-1.84329582, -0.16694431],\n", " [ 1.35999602, 0.94995827],\n", " [ 0.42291856, -0.7349534 ],\n", " [-1.68576139, 0.10686728],\n", " [ 0.90629995, 1.09105162],\n", " [-1.56478322, -0.84675394],\n", " [-0.0257849 , -1.18672539],\n", " [ 0.83027324, 1.14504612],\n", " [ 1.22450432, 1.35066759],\n", " [-0.15394596, -0.71704301],\n", " [ 0.86358809, 1.06824613],\n", " [-1.43386366, -0.2381297 ],\n", " [ 0.03844769, -0.74635022],\n", " [-1.58567922, 0.08499354],\n", " [ 0.6359888 , -0.58477698],\n", " [ 0.24417242, -0.53172465],\n", " [-2.19680359, 0.49473677],\n", " [ 1.0323503 , -0.55688 ],\n", " [-0.28858067, -0.39972528],\n", " [ 0.20597008, -0.80171536],\n", " [-1.2107308 , -0.34924109],\n", " [ 1.33423684, 0.7721489 ],\n", " [ 1.19480152, 1.04788556],\n", " [ 0.9917477 , 0.89202008],\n", " [-1.8356219 , -0.04839732],\n", " [ 0.08415721, -0.71564326],\n", " [-1.48970175, -0.19299604],\n", " [ 0.38782418, -0.82060119],\n", " [-0.01448044, -0.9779841 ],\n", " [-2.0521341 , -0.02129125],\n", " [ 0.10331194, -0.82162781],\n", " [-0.44189315, -0.65710974],\n", " [ 1.10390926, 1.02481182],\n", " [-1.59227759, -0.17374038],\n", " [-1.47344152, -0.02202853],\n", " [-1.35514704, 0.22971067],\n", " [ 0.0412337 , -1.23776622],\n", " [ 0.4761517 , -1.13672124],\n", " [ 1.04335676, 0.82345905],\n", " [-0.07961882, -0.85677394],\n", " [ 0.87065059, 1.08052841],\n", " [ 1.40267313, 1.07525119],\n", " [ 0.80111157, 1.28342825],\n", " [-0.16527516, -1.23583804],\n", " [-0.33779221, -0.59194323],\n", " [ 0.80610749, -0.73752159],\n", " [-1.43590032, -0.56384446],\n", " [ 0.54868895, -0.95143829],\n", " [ 0.46803131, -0.74973907],\n", " [-1.5137129 , -0.83914323],\n", " [ 0.9138436 , 1.51126532],\n", " [-1.97233903, -0.41155375],\n", " [ 0.5213406 , -0.88654894],\n", " [ 0.62759494, -1.18590477],\n", " [ 0.94163014, 1.35399335],\n", " [ 0.56994768, 1.07036606],\n", " [-1.87663382, 0.14745773],\n", " [ 0.90612186, 0.91084011],\n", " [-1.37481454, 0.28428395],\n", " [-1.80564029, -0.96710574],\n", " [ 0.34307757, -0.79999275],\n", " [ 0.70380566, 1.00025804],\n", " [-1.68489862, -0.30564595],\n", " [ 1.31473221, 0.98614978],\n", " [ 0.26151216, -0.26069251],\n", " [ 0.9193121 , 0.82371485],\n", " [-1.21795929, -0.20219674],\n", " [-0.17722723, -1.02665245],\n", " [ 0.64824862, -0.66822881],\n", " [ 0.41206786, -0.28783784],\n", " [ 1.01568202, 1.13481667],\n", " [ 0.67900254, -0.91489502],\n", " [-1.05182747, -0.01062376],\n", " [ 0.61306599, 1.78210384],\n", " [-1.50219748, -0.52308922],\n", " [-1.72717293, -0.46173916],\n", " [-1.60995631, -0.1821007 ],\n", " [-1.09111021, -0.0781398 ],\n", " [-0.01046978, -0.80913034],\n", " [ 0.32782303, -0.80734754],\n", " [ 1.22038503, 1.1959793 ],\n", " [-1.33328681, -0.30001937],\n", " [ 0.87959517, 1.11566491],\n", " [-1.14829098, -0.30400762],\n", " [-0.58019755, -1.19996018],\n", " [-0.01161159, -0.78468854],\n", " [ 0.17359724, -0.63398145],\n", " [ 1.32738556, 0.67759969],\n", " [-1.93467327, 0.30572472],\n", " [-1.57761893, -0.27726365],\n", " [ 0.47639 , 1.21422648],\n", " [-1.65237509, -0.6803981 ],\n", " [-0.12609976, -1.04327457],\n", " [-1.89607082, -0.70085502],\n", " [ 0.57466899, 0.74878369],\n", " [-0.16660312, -0.83110295],\n", " [ 0.8013355 , 1.22244435],\n", " [ 1.18455426, 1.4346467 ],\n", " [ 1.08864428, 0.64667112],\n", " [-1.61158505, 0.22805725],\n", " [-1.57512205, -0.09612576],\n", " [ 0.0721357 , -0.69640328],\n", " [-1.40054298, 0.16390598],\n", " [ 1.09607713, 1.16804691],\n", " [-2.54346204, -0.23089822],\n", " [-1.34544875, 0.25151126],\n", " [-1.35478629, -0.19103317],\n", " [ 0.18368113, -1.15827725],\n", " [-1.31368677, -0.376357 ],\n", " [ 0.09990129, 1.22500491],\n", " [ 1.17225574, 1.30835143],\n", " [ 0.0865397 , -0.79714371],\n", " [-0.21053923, -1.13421511],\n", " [ 0.26496024, -0.94760742],\n", " [-0.2557591 , -1.06266022],\n", " [-0.26039757, -0.74774225],\n", " [-1.91787359, 0.16434571],\n", " [ 0.93021139, 0.49436331],\n", " [ 0.44770467, -0.72877918],\n", " [-1.63802869, -0.58925528],\n", " [-1.95712763, -0.10125137],\n", " [ 0.9270337 , 0.88251423],\n", " [ 1.25660093, 0.60828073],\n", " [-1.72818632, 0.08416887],\n", " [ 0.3499788 , -0.30490298],\n", " [-1.51696082, -0.50913109],\n", " [ 0.18763605, -0.55424924],\n", " [ 0.89609809, 0.83551508],\n", " [-1.54968857, -0.17114782],\n", " [ 1.2157457 , 1.23317728],\n", " [ 0.20307745, -1.03784906],\n", " [ 0.84589086, 1.03615273],\n", " [ 0.53237919, 1.47362884],\n", " [-0.05319044, -1.36150553],\n", " [ 1.38819743, 1.11729915],\n", " [ 1.00696304, 1.0367721 ],\n", " [ 0.56681869, -1.09637176],\n", " [ 0.86888296, 1.05248874],\n", " [-1.16286609, -0.55875245],\n", " [ 0.27717768, -0.83844015],\n", " [ 0.16563267, -0.80306607],\n", " [ 0.38263303, -0.42683241],\n", " [ 1.14519807, 0.89659026],\n", " [ 0.81455857, 0.67533667],\n", " [-1.8603152 , -0.09537561],\n", " [ 0.965641 , 0.90295579],\n", " [-1.49897451, -0.33254044],\n", " [-0.1335489 , -0.80727582],\n", " [ 0.12541527, -1.13354906],\n", " [ 1.06062436, 1.28816358],\n", " [-1.49154578, -0.2024641 ],\n", " [ 1.16189032, 1.28819877],\n", " [ 0.54282033, 0.75203524],\n", " [ 0.89221065, 0.99211624],\n", " [-1.49932011, -0.32430667],\n", " [ 0.3166647 , -1.34482915],\n", " [ 0.13972469, -1.22097448],\n", " [-1.5499724 , -0.10782584],\n", " [ 1.23846858, 1.37668804],\n", " [ 1.25558954, 0.72026098],\n", " [ 0.25558689, -1.28529763],\n", " [ 0.45168933, -0.55952093],\n", " [ 1.06202057, 1.03404604],\n", " [ 0.67451908, -0.54970299],\n", " [ 0.22759676, -1.02729468],\n", " [-1.45835281, -0.04951074],\n", " [ 0.23273501, -0.70849262],\n", " [ 1.59679589, 1.11395076],\n", " [ 0.80476105, 0.544627 ],\n", " [ 1.15492521, 1.04352191],\n", " [ 0.59632776, -1.19142897],\n", " [ 0.02839068, -0.43829366],\n", " [ 1.13451584, 0.5632633 ],\n", " [ 0.21576204, -1.04445753],\n", " [ 1.41048987, 1.02830719],\n", " [ 1.12289302, 0.58029441],\n", " [ 0.25200688, -0.82588436],\n", " [-1.28566081, -0.07390909],\n", " [ 1.52849815, 1.11822469],\n", " [-0.23907858, -0.70541972],\n", " [-0.25792784, -0.81825035],\n", " [ 0.59367818, -0.45239915],\n", " [ 0.07931909, -0.29233213],\n", " [-1.27256815, 0.11630577],\n", " [ 0.66930129, 1.00731481],\n", " [ 0.34791546, -1.20822877],\n", " [-2.11283993, -0.66897935],\n", " [-1.6293824 , -0.32718222],\n", " [-1.53819139, -0.01501972],\n", " [-0.11988545, -0.6036339 ],\n", " [-1.54418956, -0.30389844],\n", " [ 0.30026614, -0.77723173],\n", " [ 0.00935449, -0.53888192],\n", " [-1.33424393, -0.11560431],\n", " [ 0.47504489, 0.78421384],\n", " [ 0.59313264, 1.232239 ],\n", " [ 0.41370369, -1.35205857],\n", " [ 0.55840948, 0.78831053],\n", " [ 0.49855018, -0.789949 ],\n", " [ 0.35675809, -0.81038693],\n", " [-1.86197825, -0.59071305],\n", " [-1.61977671, -0.16076687],\n", " [ 0.80779295, -0.73311294],\n", " [ 1.62745775, 0.62787163],\n", " [-1.56993593, -0.08467567],\n", " [ 1.02558561, 0.89383302],\n", " [ 0.24293461, -0.6088253 ],\n", " [ 1.23130242, 1.00262186],\n", " [-1.9651013 , -0.15886289],\n", " [ 0.42795032, -0.70384432],\n", " [-1.58306818, -0.19431923],\n", " [-1.57195922, 0.01413469],\n", " [-0.98145373, 0.06132285],\n", " [-1.48637844, -0.5746531 ],\n", " [ 0.98745828, 0.69188053],\n", " [ 1.28619721, 1.28128821],\n", " [ 0.85850596, 0.95541481],\n", " [ 0.19028286, -0.82112942],\n", " [ 0.26561046, -0.04255239],\n", " [-1.61897897, 0.00862372],\n", " [ 0.24070183, -0.52664209],\n", " [ 1.15220993, 0.43916694],\n", " [-1.21967812, -0.2580313 ],\n", " [ 0.33412533, -0.86117761],\n", " [ 0.17131003, -0.75638965],\n", " [-1.19828397, -0.73744665],\n", " [-0.12245932, -0.45648879],\n", " [ 1.51200698, 0.88825741],\n", " [ 1.10338866, 0.92347479],\n", " [ 1.30972095, 0.59066989],\n", " [ 0.19964876, 1.14855889],\n", " [ 0.81460515, 0.84538972],\n", " [-1.6422739 , -0.42296206],\n", " [ 0.01224351, -0.21247816],\n", " [ 0.33709102, -0.74618065],\n", " [ 0.47301054, 0.72712075],\n", " [ 0.34706626, 1.23033757],\n", " [-0.00393279, -0.97209694],\n", " [-1.64303119, 0.05276337],\n", " [ 1.44649625, 1.14217033],\n", " [-1.93030087, -0.40026146],\n", " [-2.37296135, -0.72633645],\n", " [ 0.45860122, -1.06048953],\n", " [ 0.4896361 , -1.18928313],\n", " [-1.02335902, -0.17520578],\n", " [-1.32761107, -0.93963549],\n", " [-1.50987909, -0.09473658],\n", " [ 0.02723057, -0.79870549],\n", " [ 1.0169412 , 1.26461701],\n", " [ 0.47733527, -0.9898471 ],\n", " [-1.27784224, -0.547416 ],\n", " [ 0.49898802, -0.6237259 ],\n", " [ 1.06004731, 0.86870008],\n", " [ 1.00207501, 1.38293512],\n", " [ 1.31161394, 0.62833956],\n", " [ 1.13428443, 1.18346542],\n", " [ 1.27671346, 0.96632878],\n", " [-0.63342885, -0.97768251],\n", " [ 0.12698779, -0.93142317],\n", " [-1.34510812, -0.23754226],\n", " [-0.53162278, -1.25153594],\n", " [ 0.21959934, -0.90269938],\n", " [-1.78997479, -0.12115748],\n", " [ 1.23197473, -0.07453764],\n", " [ 1.4163536 , 1.21551752],\n", " [-1.90280976, -0.1638976 ],\n", " [-0.22440081, -0.75454248],\n", " [ 0.59559412, 0.92414553],\n", " [ 1.21930773, 1.08175284],\n", " [-1.99427535, -0.37587799],\n", " [-1.27818474, -0.52454551],\n", " [ 0.62352689, -1.01430108],\n", " [ 0.14024251, -0.428266 ],\n", " [-0.16145713, -1.16359731],\n", " [-1.74795865, -0.06033101],\n", " [-1.16659791, 0.0902393 ],\n", " [ 0.41110408, -0.8084249 ],\n", " [ 1.14757168, 0.77804528],\n", " [-1.65590748, -0.40105446],\n", " [-1.15306865, 0.00858699],\n", " [ 0.60892121, 0.68974833],\n", " [-0.08434138, -0.97615256],\n", " [ 0.19170053, -0.42331438],\n", " [ 0.29663162, -1.13357399],\n", " [-1.36893628, -0.25052124],\n", " [-0.08037807, -0.56784155],\n", " [ 0.35695011, -1.15064408],\n", " [ 0.02482179, -0.63594828],\n", " [-1.49075558, -0.2482507 ],\n", " [-1.408588 , 0.25635431],\n", " [-1.98274626, -0.54584475]])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xs = points[:, 0]\n", "ys = points[:, 1]\n", "\n", "plt.scatter(xs, ys)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Clustering 2D points\n", "From the scatter plot of the previous exercise, you saw that the points seem to separate into 3 clusters. You'll now create a KMeans model to find 3 clusters, and fit it to the data points from the previous exercise. After the model has been fit, you'll obtain the cluster labels for some new points using the ```.predict()``` method." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "new_points = np.array([[ 4.00233332e-01, -1.26544471e+00],\n", " [ 8.03230370e-01, 1.28260167e+00],\n", " [-1.39507552e+00, 5.57292921e-02],\n", " [-3.41192677e-01, -1.07661994e+00],\n", " [ 1.54781747e+00, 1.40250049e+00],\n", " [ 2.45032018e-01, -4.83442328e-01],\n", " [ 1.20706886e+00, 8.88752605e-01],\n", " [ 1.25132628e+00, 1.15555395e+00],\n", " [ 1.81004415e+00, 9.65530731e-01],\n", " [-1.66963401e+00, -3.08103509e-01],\n", " [-7.17482105e-02, -9.37939700e-01],\n", " [ 6.82631927e-01, 1.10258160e+00],\n", " [ 1.09039598e+00, 1.43899529e+00],\n", " [-1.67645414e+00, -5.04557049e-01],\n", " [-1.84447804e+00, 4.52539544e-02],\n", " [ 1.24234851e+00, 1.02088661e+00],\n", " [-1.86147041e+00, 6.38645811e-03],\n", " [-1.46044943e+00, 1.53252383e-01],\n", " [ 4.98981817e-01, 8.98006058e-01],\n", " [ 9.83962244e-01, 1.04369375e+00],\n", " [-1.83136742e+00, -1.63632835e-01],\n", " [ 1.30622617e+00, 1.07658717e+00],\n", " [ 3.53420328e-01, -7.51320218e-01],\n", " [ 1.13957970e+00, 1.54503860e+00],\n", " [ 2.93995694e-01, -1.26135005e+00],\n", " [-1.14558225e+00, -3.78709636e-02],\n", " [ 1.18716105e+00, 6.00240663e-01],\n", " [-2.23211946e+00, 2.30475094e-01],\n", " [-1.28320430e+00, -3.93314568e-01],\n", " [ 4.94296696e-01, -8.83972009e-01],\n", " [ 6.31834930e-02, -9.11952228e-01],\n", " [ 9.35759539e-01, 8.66820685e-01],\n", " [ 1.58014721e+00, 1.03788392e+00],\n", " [ 1.06304960e+00, 1.02706082e+00],\n", " [-1.39732536e+00, -5.05162249e-01],\n", " [-1.09935240e-01, -9.08113619e-01],\n", " [ 1.17346758e+00, 9.47501092e-01],\n", " [ 9.20084511e-01, 1.45767672e+00],\n", " [ 5.82658956e-01, -9.00086832e-01],\n", " [ 9.52772328e-01, 8.99042386e-01],\n", " [-1.37266956e+00, -3.17878215e-02],\n", " [ 2.12706760e-02, -7.07614194e-01],\n", " [ 3.27049052e-01, -5.55998107e-01],\n", " [-1.71590267e+00, 2.15222266e-01],\n", " [ 5.12516209e-01, -7.60128245e-01],\n", " [ 1.13023469e+00, 7.22451122e-01],\n", " [-1.43074310e+00, -3.42787511e-01],\n", " [-1.82724625e+00, 1.17657775e-01],\n", " [ 1.41801350e+00, 1.11455080e+00],\n", " [ 1.26897304e+00, 1.41925971e+00],\n", " [ 8.04076494e-01, 1.63988557e+00],\n", " [ 8.34567752e-01, 1.09956689e+00],\n", " [-1.24714732e+00, -2.23522320e-01],\n", " [-1.29422537e+00, 8.18770024e-02],\n", " [-2.27378316e-01, -4.13331387e-01],\n", " [ 2.18830387e-01, -4.68183120e-01],\n", " [-1.22593414e+00, 2.55599147e-01],\n", " [-1.31294033e+00, -4.28892070e-01],\n", " [-1.33532382e+00, 6.52053776e-01],\n", " [-3.01100233e-01, -1.25156451e+00],\n", " [ 2.02778356e-01, -9.05277445e-01],\n", " [ 1.01357784e+00, 1.12378981e+00],\n", " [ 8.18324394e-01, 8.60841257e-01],\n", " [ 1.26181556e+00, 1.46613744e+00],\n", " [ 4.64867724e-01, -7.97212459e-01],\n", " [ 3.60908898e-01, 8.44106720e-01],\n", " [-2.15098310e+00, -3.69583937e-01],\n", " [ 1.05005281e+00, 8.74181364e-01],\n", " [ 1.06580074e-01, -7.49268153e-01],\n", " [-1.73945723e+00, 2.52183577e-01],\n", " [-1.12017687e-01, -6.52469788e-01],\n", " [ 5.16618951e-01, -6.41267582e-01],\n", " [ 3.26621787e-01, -8.80608015e-01],\n", " [ 1.09017759e+00, 1.10952558e+00],\n", " [ 3.64459576e-01, -6.94215622e-01],\n", " [-1.90779318e+00, 1.87383674e-01],\n", " [-1.95601829e+00, 1.39959126e-01],\n", " [ 3.18541701e-01, -4.05271704e-01],\n", " [ 7.36512699e-01, 1.76416255e+00],\n", " [-1.44175162e+00, -5.72320429e-02],\n", " [ 3.21757168e-01, -5.34283821e-01],\n", " [-1.37317305e+00, 4.64484644e-02],\n", " [ 6.87225910e-02, -1.10522944e+00],\n", " [ 9.59314218e-01, 6.52316210e-01],\n", " [-1.62641919e+00, -5.62423280e-01],\n", " [ 1.06788305e+00, 7.29260482e-01],\n", " [-1.79643547e+00, -9.88307418e-01],\n", " [-9.88628377e-02, -6.81198092e-02],\n", " [-1.05135700e-01, 1.17022143e+00],\n", " [ 8.79964699e-01, 1.25340317e+00],\n", " [ 9.80753407e-01, 1.15486539e+00],\n", " [-8.33224966e-02, -9.24844368e-01],\n", " [ 8.48759673e-01, 1.09397425e+00],\n", " [ 1.32941649e+00, 1.13734563e+00],\n", " [ 3.23788068e-01, -7.49732451e-01],\n", " [-1.52610970e+00, -2.49016929e-01],\n", " [-1.48598116e+00, -2.68828608e-01],\n", " [-1.80479553e+00, 1.87052700e-01],\n", " [-2.01907347e+00, -4.49511651e-01],\n", " [ 2.87202402e-01, -6.55487415e-01],\n", " [ 8.22295102e-01, 1.38443234e+00],\n", " [-3.56997036e-02, -8.01825807e-01],\n", " [-1.66955440e+00, -1.38258505e-01],\n", " [-1.78226821e+00, 2.93353033e-01],\n", " [ 7.25837138e-01, -6.23374024e-01],\n", " [ 3.88432593e-01, -7.61283497e-01],\n", " [ 1.49002783e+00, 7.95678671e-01],\n", " [ 6.55423228e-04, -7.40580702e-01],\n", " [-1.34533116e+00, -4.75629937e-01],\n", " [-8.03845106e-01, -3.09943013e-01],\n", " [-2.49041295e-01, -1.00662418e+00],\n", " [-1.41095118e+00, -7.06744127e-02],\n", " [-1.75119594e+00, -3.00491336e-01],\n", " [-1.27942724e+00, 1.73774600e-01],\n", " [ 3.35028183e-01, 6.24761151e-01],\n", " [ 1.16819649e+00, 1.18902251e+00],\n", " [ 7.15210457e-01, 9.26077419e-01],\n", " [ 1.30057278e+00, 9.16349565e-01],\n", " [-1.21697008e+00, 1.10039477e-01],\n", " [-1.70707935e+00, -5.99659536e-02],\n", " [ 1.20730655e+00, 1.05480463e+00],\n", " [ 1.86896009e-01, -9.58047234e-01],\n", " [ 8.03463471e-01, 3.86133140e-01],\n", " [-1.73486790e+00, -1.49831913e-01],\n", " [ 1.31261499e+00, 1.11802982e+00],\n", " [ 4.04993148e-01, -5.10900347e-01],\n", " [-1.93267968e+00, 2.20764694e-01],\n", " [ 6.56004799e-01, 9.61887161e-01],\n", " [-1.40588215e+00, 1.17134403e-01],\n", " [-1.74306264e+00, -7.47473959e-02],\n", " [ 5.43745412e-01, 1.47209224e+00],\n", " [-1.97331669e+00, -2.27124493e-01],\n", " [ 1.53901171e+00, 1.36049081e+00],\n", " [-1.48323452e+00, -4.90302063e-01],\n", " [ 3.86748484e-01, -1.26173400e+00],\n", " [ 1.17015716e+00, 1.18549415e+00],\n", " [-8.05381721e-02, -3.21923627e-01],\n", " [-6.82273156e-02, -8.52825887e-01],\n", " [ 7.13500028e-01, 1.27868520e+00],\n", " [-1.85014378e+00, -5.03490558e-01],\n", " [ 6.36085266e-02, -1.41257040e+00],\n", " [ 1.52966062e+00, 9.66056572e-01],\n", " [ 1.62165714e-01, -1.37374843e+00],\n", " [-3.23474497e-01, -7.06620269e-01],\n", " [-1.51768993e+00, 1.87658302e-01],\n", " [ 8.88895911e-01, 7.62237161e-01],\n", " [ 4.83164032e-01, 8.81931869e-01],\n", " [-5.52997766e-02, -7.11305016e-01],\n", " [-1.57966441e+00, -6.29220313e-01],\n", " [ 5.51308645e-02, -8.47206763e-01],\n", " [-2.06001582e+00, 5.87697787e-02],\n", " [ 1.11810855e+00, 1.30254175e+00],\n", " [ 4.87016164e-01, -9.90143937e-01],\n", " [-1.65518042e+00, -1.69386383e-01],\n", " [-1.44349738e+00, 1.90299243e-01],\n", " [-1.70074547e-01, -8.26736022e-01],\n", " [-1.82433979e+00, -3.07814626e-01],\n", " [ 1.03093485e+00, 1.26457691e+00],\n", " [ 1.64431169e+00, 1.27773115e+00],\n", " [-1.47617693e+00, 2.60783872e-02],\n", " [ 1.00953067e+00, 1.14270181e+00],\n", " [-1.45285636e+00, -2.55216207e-01],\n", " [-1.74092917e+00, -8.34443177e-02],\n", " [ 1.22038299e+00, 1.28699961e+00],\n", " [ 9.16925397e-01, 7.32070275e-01],\n", " [-1.60754185e-03, -7.26375571e-01],\n", " [ 8.93841238e-01, 8.41146643e-01],\n", " [ 6.33791961e-01, 1.00915134e+00],\n", " [-1.47927075e+00, -6.99781936e-01],\n", " [ 5.44799374e-02, -1.06441970e+00],\n", " [-1.51935568e+00, -4.89276929e-01],\n", " [ 2.89939026e-01, -7.73145523e-01],\n", " [-9.68154061e-03, -1.13302207e+00],\n", " [ 1.13474639e+00, 9.71541744e-01],\n", " [ 5.36421406e-01, -8.47906388e-01],\n", " [ 1.14759864e+00, 6.89915205e-01],\n", " [ 5.73291902e-01, 7.90802710e-01],\n", " [ 2.12377397e-01, -6.07569808e-01],\n", " [ 5.26579548e-01, -8.15930264e-01],\n", " [-2.01831641e+00, 6.78650740e-02],\n", " [-2.35512624e-01, -1.08205132e+00],\n", " [ 1.59274780e-01, -6.00717261e-01],\n", " [ 2.28120356e-01, -1.16003549e+00],\n", " [-1.53658378e+00, 8.40798808e-02],\n", " [ 1.13954609e+00, 6.31782001e-01],\n", " [ 1.01119255e+00, 1.04360805e+00],\n", " [-1.42039867e-01, -4.81230337e-01],\n", " [-2.23120182e+00, 8.49162905e-02],\n", " [ 1.25554811e-01, -1.01794793e+00],\n", " [-1.72493509e+00, -6.94426177e-01],\n", " [-1.60434630e+00, 4.45550868e-01],\n", " [ 7.37153979e-01, 9.26560744e-01],\n", " [ 6.72905271e-01, 1.13366030e+00],\n", " [ 1.20066456e+00, 7.26273093e-01],\n", " [ 7.58747209e-02, -9.83378326e-01],\n", " [ 1.28783262e+00, 1.18088601e+00],\n", " [ 1.06521930e+00, 1.00714746e+00],\n", " [ 1.05871698e+00, 1.12956519e+00],\n", " [-1.12643410e+00, 1.66787744e-01],\n", " [-1.10157218e+00, -3.64137806e-01],\n", " [ 2.35118217e-01, -1.39769949e-01],\n", " [ 1.13853795e+00, 1.01018519e+00],\n", " [ 5.31205654e-01, -8.81990792e-01],\n", " [ 4.33085936e-01, -7.64059042e-01],\n", " [-4.48926156e-03, -1.30548411e+00],\n", " [-1.76348589e+00, -4.97430739e-01],\n", " [ 1.36485681e+00, 5.83404699e-01],\n", " [ 5.66923900e-01, 1.51391963e+00],\n", " [ 1.35736826e+00, 6.70915318e-01],\n", " [ 1.07173397e+00, 6.11990884e-01],\n", " [ 1.00106915e+00, 8.93815326e-01],\n", " [ 1.33091007e+00, 8.79773879e-01],\n", " [-1.79603740e+00, -3.53883973e-02],\n", " [-1.27222979e+00, 4.00156642e-01],\n", " [ 8.47480603e-01, 1.17032364e+00],\n", " [-1.50989129e+00, -7.12318330e-01],\n", " [-1.24953576e+00, -5.57859730e-01],\n", " [-1.27717973e+00, -5.99350550e-01],\n", " [-1.81946743e+00, 7.37057673e-01],\n", " [ 1.19949867e+00, 1.56969386e+00],\n", " [-1.25543847e+00, -2.33892826e-01],\n", " [-1.63052058e+00, 1.61455865e-01],\n", " [ 1.10611305e+00, 7.39698224e-01],\n", " [ 6.70193192e-01, 8.70567001e-01],\n", " [ 3.69670156e-01, -6.94645306e-01],\n", " [-1.26362293e+00, -6.99249285e-01],\n", " [-3.66687507e-01, -1.35310260e+00],\n", " [ 2.44032147e-01, -6.59470793e-01],\n", " [-1.27679142e+00, -4.85453412e-01],\n", " [ 3.77473612e-02, -6.99251605e-01],\n", " [-2.19148539e+00, -4.91199500e-01],\n", " [-2.93277777e-01, -5.89488212e-01],\n", " [-1.65737397e+00, -2.98337786e-01],\n", " [ 7.36638861e-01, 5.78037057e-01],\n", " [ 1.13709081e+00, 1.30119754e+00],\n", " [-1.44146601e+00, 3.13934680e-02],\n", " [ 5.92360708e-01, 1.22545114e+00],\n", " [ 6.51719414e-01, 4.92674894e-01],\n", " [ 5.94559139e-01, 8.25637315e-01],\n", " [-1.87900722e+00, -5.21899626e-01],\n", " [ 2.15225041e-01, -1.28269851e+00],\n", " [ 4.99145965e-01, -6.70268634e-01],\n", " [-1.82954176e+00, -3.39269731e-01],\n", " [ 7.92721403e-01, 1.33785606e+00],\n", " [ 9.54363372e-01, 9.80396626e-01],\n", " [-1.35359846e+00, 1.03976340e-01],\n", " [ 1.05595062e+00, 8.07031927e-01],\n", " [-1.94311010e+00, -1.18976964e-01],\n", " [-1.39604137e+00, -3.10095976e-01],\n", " [ 1.28977624e+00, 1.01753365e+00],\n", " [-1.59503139e+00, -5.40574609e-01],\n", " [-1.41994046e+00, -3.81032569e-01],\n", " [-2.35569801e-02, -1.10133702e+00],\n", " [-1.26038568e+00, -6.93273886e-01],\n", " [ 9.60215981e-01, -8.11553694e-01],\n", " [ 5.51803308e-01, -1.01793176e+00],\n", " [ 3.70185085e-01, -1.06885468e+00],\n", " [ 8.25529207e-01, 8.77007060e-01],\n", " [-1.87032595e+00, 2.87507199e-01],\n", " [-1.56260769e+00, -1.89196712e-01],\n", " [-1.26346548e+00, -7.74725237e-01],\n", " [-6.33800421e-02, -7.59400611e-01],\n", " [ 8.85298280e-01, 8.85620519e-01],\n", " [-1.43324686e-01, -1.16083678e+00],\n", " [-1.83908725e+00, -3.26655515e-01],\n", " [ 2.74709229e-01, -1.04546829e+00],\n", " [-1.45703573e+00, -2.91842036e-01],\n", " [-1.59048842e+00, 1.66063031e-01],\n", " [ 9.25549284e-01, 7.41406406e-01],\n", " [ 1.97245469e-01, -7.80703225e-01],\n", " [ 2.88401697e-01, -8.32425551e-01],\n", " [ 7.24141618e-01, -7.99149200e-01],\n", " [-1.62658639e+00, -1.80005543e-01],\n", " [ 5.84481588e-01, 1.13195640e+00],\n", " [ 1.02146732e+00, 4.59657799e-01],\n", " [ 8.65050554e-01, 9.57714887e-01],\n", " [ 3.98717766e-01, -1.24273147e+00],\n", " [ 8.62234892e-01, 1.10955561e+00],\n", " [-1.35999430e+00, 2.49942654e-02],\n", " [-1.19178505e+00, -3.82946323e-02],\n", " [ 1.29392424e+00, 1.10320509e+00],\n", " [ 1.25679630e+00, -7.79857582e-01],\n", " [ 9.38040302e-02, -5.53247258e-01],\n", " [-1.73512175e+00, -9.76271667e-02],\n", " [ 2.23153587e-01, -9.43474351e-01],\n", " [ 4.01989100e-01, -1.10963051e+00],\n", " [-1.42244158e+00, 1.81914703e-01],\n", " [ 3.92476267e-01, -8.78426277e-01],\n", " [ 1.25181875e+00, 6.93614996e-01],\n", " [ 1.77481317e-02, -7.20304235e-01],\n", " [-1.87752521e+00, -2.63870424e-01],\n", " [-1.58063602e+00, -5.50456344e-01],\n", " [-1.59589493e+00, -1.53932892e-01],\n", " [-1.01829770e+00, 3.88542370e-02],\n", " [ 1.24819659e+00, 6.60041803e-01],\n", " [-1.25551377e+00, -2.96172009e-02],\n", " [-1.41864559e+00, -3.58230179e-01],\n", " [ 5.25758326e-01, 8.70500543e-01],\n", " [ 5.55599988e-01, 1.18765072e+00],\n", " [ 2.81344439e-02, -6.99111314e-01]])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0 1 2 0 1 0 1 1 1 2 0 1 1 2 2 1 2 2 1 1 2 1 0 1 0 2 1 2 2 0 0 1 1 1 2 0 1\n", " 1 0 1 2 0 0 2 0 1 2 2 1 1 1 1 2 2 0 0 2 2 2 0 0 1 1 1 0 1 2 1 0 2 0 0 0 1\n", " 0 2 2 0 1 2 0 2 0 1 2 1 2 0 1 1 1 0 1 1 0 2 2 2 2 0 1 0 2 2 0 0 1 0 2 2 0\n", " 2 2 2 1 1 1 1 2 2 1 0 1 2 1 0 2 1 2 2 1 2 1 2 0 1 0 0 1 2 0 1 0 0 2 1 1 0\n", " 2 0 2 1 0 2 2 0 2 1 1 2 1 2 2 1 1 0 1 1 2 0 2 0 0 1 0 1 1 0 0 2 0 0 0 2 1\n", " 1 0 2 0 2 2 1 1 1 0 1 1 1 2 2 0 1 0 0 0 2 1 1 1 1 1 1 2 2 1 2 2 2 2 1 2 2\n", " 1 1 0 2 0 0 2 0 2 0 2 1 1 2 1 1 1 2 0 0 2 1 1 2 1 2 2 1 2 2 0 2 0 0 0 1 2\n", " 2 2 0 1 0 2 0 2 2 1 0 0 0 2 1 1 1 0 1 2 2 1 0 0 2 0 0 2 0 1 0 2 2 2 2 1 2\n", " 2 1 1 0]\n" ] } ], "source": [ "from sklearn.cluster import KMeans\n", "\n", "# Create a KMeans instance with 3 clusters: model\n", "model = KMeans(n_clusters=3)\n", "\n", "# Fit model to points\n", "model.fit(points)\n", "\n", "# Determine the cluster labels of new_points: labels\n", "labels = model.predict(new_points)\n", "\n", "# Print cluster labels of new_points\n", "print(labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Inspect your clustering\n", "Let's now inspect the clustering you performed in the previous exercise!\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Assign the columns of new_points: xs and ys\n", "xs = new_points[:, 0]\n", "ys = new_points[:, 1]\n", "\n", "# Make a scatter plot of xs and ys, using labels to define the colors\n", "plt.scatter(xs, ys, c=labels, alpha=0.5)\n", "\n", "# Assign the cluster centers: centroids\n", "centroids = model.cluster_centers_\n", "\n", "# Assign the columns of centroids: centroids_x, centroids_y\n", "centroids_x = centroids[:, 0]\n", "centroids_y = centroids[:, 1]\n", "\n", "# Make a scatter plot of centroids_x and centroids_y\n", "plt.scatter(centroids_x, centroids_y, marker='D', s=50, color='red')\n", "plt.savefig('../images/kmeans-centroid.png')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluating a clustering\n", "- Evaluating a clustering\n", " - Can check correspondence with e.g. iris species, but what if there are no species to check against?\n", " - Measure quality of a clustering\n", " - Informs choice of how many clusters to look for\n", "- Cross-tabulation with pandas\n", " - Clusters vs species is a \"cross-tabulation\"\n", "- Measuring clustering quality\n", " - Using only samples and their cluster labels\n", " - A good clustering has tight clusters\n", " - Samples in each cluster bunched together\n", "- Inertia measures clustering quality\n", " - Measures how spread out the clusters are (lower is better)\n", " - Distance from each sample to centroid of its cluster\n", " - k-means attempts to minimize the inertia when choosing clusters\n", "- How many clusters to choose?\n", " - Choose an \"elbow\" in the inertia plot\n", " - Where inertia begins to decrease more slowly" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How many clusters of grain?\n", "In the video, you learned how to choose a good number of clusters for a dataset using the k-means inertia graph. You are given an array ```samples``` containing the measurements (such as area, perimeter, length, and several others) of samples of grain. What's a good number of clusters in this case?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Preprocess" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01234567
015.2614.840.87105.7633.3122.2215.220Kama wheat
114.8814.570.88115.5543.3331.0184.956Kama wheat
214.2914.090.90505.2913.3372.6994.825Kama wheat
313.8413.940.89555.3243.3792.2594.805Kama wheat
416.1414.990.90345.6583.5621.3555.175Kama wheat
\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 7\n", "0 15.26 14.84 0.8710 5.763 3.312 2.221 5.220 Kama wheat\n", "1 14.88 14.57 0.8811 5.554 3.333 1.018 4.956 Kama wheat\n", "2 14.29 14.09 0.9050 5.291 3.337 2.699 4.825 Kama wheat\n", "3 13.84 13.94 0.8955 5.324 3.379 2.259 4.805 Kama wheat\n", "4 16.14 14.99 0.9034 5.658 3.562 1.355 5.175 Kama wheat" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('./dataset/seeds.csv', header=None)\n", "df[7] = df[7].map({1:'Kama wheat', 2:'Rosa wheat', 3:'Canadian wheat'})\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "samples = df.iloc[:, :-1].values\n", "varieties = df.iloc[:, -1].values" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "([,\n", " ,\n", " ,\n", " ,\n", " ],\n", " )" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ks = range(1,6)\n", "inertias = []\n", "\n", "for k in ks:\n", " # Create a KMeans instance with k clusters: model\n", " model = KMeans(n_clusters=k)\n", " \n", " # Fit model to samples\n", " model.fit(samples)\n", " \n", " # Append the inertia to the list of inertias\n", " inertias.append(model.inertia_)\n", " \n", "# Plot ks vs inertias\n", "plt.plot(ks, inertias, '-o')\n", "plt.xlabel('number of clusters, k')\n", "plt.ylabel('inertia')\n", "plt.xticks(ks)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluating the grain clustering\n", "In the previous exercise, you observed from the inertia plot that 3 is a good number of clusters for the grain data. In fact, the grain samples come from a mix of 3 different grain varieties: \"Kama\", \"Rosa\" and \"Canadian\". In this exercise, cluster the grain samples into three clusters, and compare the clusters to the grain varieties using a cross-tabulation.\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "varieties Canadian wheat Kama wheat Rosa wheat\n", "labels \n", "0 2 60 10\n", "1 0 1 60\n", "2 68 9 0\n" ] } ], "source": [ "# Create a KMeans model with 3 clusters: model\n", "model = KMeans(n_clusters=3)\n", "\n", "# Use fit_predict to fit model and obtain cluster labels: labels\n", "labels = model.fit_predict(samples)\n", "\n", "# Create a DataFrame with labels and varieties as columns: df\n", "df = pd.DataFrame({'labels': labels, 'varieties': varieties})\n", "\n", "# Create crosstab: ct\n", "ct = pd.crosstab(df['labels'], df['varieties'])\n", "\n", "# Display ct\n", "print(ct)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transforming features for better clusterings\n", "- StandardScaler\n", " - In kmeans, feature variance = feature influence\n", " - ```StandardScaler``` transforms each feature to have mean 0 and variance 1\n", " - Features are said to be \"standardized\"\n", "- StandardScaler, then KMeans\n", " - Need to perform two steps: ```StandardScaler```, then ```KMeans```\n", " - Use ```sklearn``` pipeline to combine multiple steps\n", " - Data flows from one step into the next" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Scaling fish data for clustering\n", "You are given an array ```samples``` giving measurements of fish. Each row represents an individual fish. The measurements, such as weight in grams, length in centimeters, and the percentage ratio of height to length, have very different scales. In order to cluster this data effectively, you'll need to standardize these features first. In this exercise, you'll build a pipeline to standardize and cluster the data.\n", "\n", "These fish measurement data were sourced from the [Journal of Statistics Education](http://ww2.amstat.org/publications/jse/jse_data_archive.htm)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Preprocess" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456
0Bream242.023.225.430.038.413.4
1Bream290.024.026.331.240.013.8
2Bream340.023.926.531.139.815.1
3Bream363.026.329.033.538.013.3
4Bream430.026.529.034.036.615.1
\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6\n", "0 Bream 242.0 23.2 25.4 30.0 38.4 13.4\n", "1 Bream 290.0 24.0 26.3 31.2 40.0 13.8\n", "2 Bream 340.0 23.9 26.5 31.1 39.8 15.1\n", "3 Bream 363.0 26.3 29.0 33.5 38.0 13.3\n", "4 Bream 430.0 26.5 29.0 34.0 36.6 15.1" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('./dataset/fish.csv', header=None)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "samples = df.iloc[:, 1:].values\n", "species = df.iloc[:, 0].values" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import make_pipeline\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.cluster import KMeans\n", "\n", "# Create scaler: scaler\n", "scaler = StandardScaler()\n", "\n", "# Create KMeans instance: kmeans\n", "kmeans = KMeans(n_clusters=4)\n", "\n", "# Create pipeline: pipeline\n", "pipeline = make_pipeline(scaler, kmeans)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Clustering the fish data\n", "You'll now use your standardization and clustering pipeline from the previous exercise to cluster the fish by their measurements, and then create a cross-tabulation to compare the cluster labels with the fish species." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "species Bream Pike Roach Smelt\n", "labels \n", "0 33 0 1 0\n", "1 0 17 0 0\n", "2 0 0 0 13\n", "3 1 0 19 1\n" ] } ], "source": [ "# Fit the pipeline to samples\n", "pipeline.fit(samples)\n", "\n", "# Calculate the cluster labels: labels\n", "labels = pipeline.predict(samples)\n", "\n", "# Create a DataFrame with labels and species as columns: df\n", "df = pd.DataFrame({'labels': labels, 'species': species})\n", "\n", "# Create crosstab: ct\n", "ct = pd.crosstab(df['labels'], df['species'])\n", "\n", "# Display ct\n", "print(ct)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Clustering stocks using KMeans\n", "In this exercise, you'll cluster companies using their daily stock price movements (i.e. the dollar difference between the closing and opening prices for each trading day). You are given a NumPy array ```movements``` of daily price movements from 2010 to 2015 (obtained from Yahoo! Finance), where each row corresponds to a company, and each column corresponds to a trading day.\n", "\n", "Some stocks are more expensive than others. To account for this, include a ```Normalizer``` at the beginning of your pipeline. The Normalizer will separately transform each company's stock price to a relative scale before the clustering begins.\n", "\n", "Note that ```Normalizer()``` is different to ```StandardScaler()```, which you used in the previous exercise. While ```StandardScaler()``` standardizes features (such as the features of the fish data from the previous exercise) by removing the mean and scaling to unit variance, ```Normalizer()``` rescales each sample - here, each company's stock price - independently of the other." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Preprocess" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
2010-01-042010-01-052010-01-062010-01-072010-01-082010-01-112010-01-122010-01-132010-01-142010-01-15...2013-10-162013-10-172013-10-182013-10-212013-10-222013-10-232013-10-242013-10-252013-10-282013-10-29
Apple0.580000-0.220005-3.409998-1.1700001.680011-2.689994-1.4699942.779997-0.680003-4.999995...0.3200084.5199972.8999879.590019-6.5400165.9599766.910011-5.3599620.840019-19.589981
AIG-0.640002-0.650000-0.210001-0.4200000.710001-0.200001-1.1300010.069999-0.119999-0.500000...0.9199980.7099990.119999-0.4800000.010002-0.279998-0.190003-0.040001-0.4000020.660000
Amazon-2.3500061.260009-2.350006-2.0099952.960006-2.309997-1.6400071.209999-1.790001-2.039994...2.1099853.6999829.570008-3.4500134.820008-4.0799862.5799864.790009-1.7600093.740021
American express0.1099970.0000000.2600020.7200020.190003-0.2700010.7500000.3000040.639999-0.130001...0.6800012.2900010.409996-0.0699990.1000060.0699990.1300051.8499990.0400010.540001
Boeing0.4599991.7700001.5499992.6900030.059997-1.0800020.3600000.5499990.530002-0.709999...1.5599972.4800030.019997-1.2200010.4800033.020004-0.0299991.9400021.1300050.309998
\n", "

5 rows × 963 columns

\n", "
" ], "text/plain": [ " 2010-01-04 2010-01-05 2010-01-06 2010-01-07 2010-01-08 \\\n", "Apple 0.580000 -0.220005 -3.409998 -1.170000 1.680011 \n", "AIG -0.640002 -0.650000 -0.210001 -0.420000 0.710001 \n", "Amazon -2.350006 1.260009 -2.350006 -2.009995 2.960006 \n", "American express 0.109997 0.000000 0.260002 0.720002 0.190003 \n", "Boeing 0.459999 1.770000 1.549999 2.690003 0.059997 \n", "\n", " 2010-01-11 2010-01-12 2010-01-13 2010-01-14 2010-01-15 \\\n", "Apple -2.689994 -1.469994 2.779997 -0.680003 -4.999995 \n", "AIG -0.200001 -1.130001 0.069999 -0.119999 -0.500000 \n", "Amazon -2.309997 -1.640007 1.209999 -1.790001 -2.039994 \n", "American express -0.270001 0.750000 0.300004 0.639999 -0.130001 \n", "Boeing -1.080002 0.360000 0.549999 0.530002 -0.709999 \n", "\n", " ... 2013-10-16 2013-10-17 2013-10-18 2013-10-21 \\\n", "Apple ... 0.320008 4.519997 2.899987 9.590019 \n", "AIG ... 0.919998 0.709999 0.119999 -0.480000 \n", "Amazon ... 2.109985 3.699982 9.570008 -3.450013 \n", "American express ... 0.680001 2.290001 0.409996 -0.069999 \n", "Boeing ... 1.559997 2.480003 0.019997 -1.220001 \n", "\n", " 2013-10-22 2013-10-23 2013-10-24 2013-10-25 2013-10-28 \\\n", "Apple -6.540016 5.959976 6.910011 -5.359962 0.840019 \n", "AIG 0.010002 -0.279998 -0.190003 -0.040001 -0.400002 \n", "Amazon 4.820008 -4.079986 2.579986 4.790009 -1.760009 \n", "American express 0.100006 0.069999 0.130005 1.849999 0.040001 \n", "Boeing 0.480003 3.020004 -0.029999 1.940002 1.130005 \n", "\n", " 2013-10-29 \n", "Apple -19.589981 \n", "AIG 0.660000 \n", "Amazon 3.740021 \n", "American express 0.540001 \n", "Boeing 0.309998 \n", "\n", "[5 rows x 963 columns]" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('./dataset/company-stock-movements-2010-2015-incl.csv', index_col=0)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "movements = df.values\n", "companies = df.index.values" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Pipeline(memory=None,\n", " steps=[('normalizer', Normalizer(copy=True, norm='l2')),\n", " ('kmeans',\n", " KMeans(algorithm='auto', copy_x=True, init='k-means++',\n", " max_iter=300, n_clusters=10, n_init=10, n_jobs=None,\n", " precompute_distances='auto', random_state=None,\n", " tol=0.0001, verbose=0))],\n", " verbose=False)" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.preprocessing import Normalizer\n", "\n", "# Create a normalizer: normalizer\n", "normalizer = Normalizer()\n", "\n", "# Create a KMeans model with 10 clusters: kmeans\n", "kmeans = KMeans(n_clusters=10)\n", "\n", "# Make a pipeline chaining normalizer and kmeans: pipeline\n", "pipeline = make_pipeline(normalizer, kmeans)\n", "\n", "# Fit pipeline to the daily price movements\n", "pipeline.fit(movements)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Which stocks move together?\n", "In the previous exercise, you clustered companies by their daily stock price movements. So which company have stock prices that tend to change in the same way? You'll now inspect the cluster labels from your clustering to find out." ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " labels companies\n", "42 0 Royal Dutch Shell\n", "39 0 Pfizer\n", "52 0 Unilever\n", "37 0 Novartis\n", "6 0 British American Tobacco\n", "46 0 Sanofi-Aventis\n", "49 0 Total\n", "19 0 GlaxoSmithKline\n", "22 1 HP\n", "14 1 Dell\n", "34 2 Mitsubishi\n", "45 2 Sony\n", "48 2 Toyota\n", "15 2 Ford\n", "21 2 Honda\n", "7 2 Canon\n", "55 3 Wells Fargo\n", "26 3 JPMorgan Chase\n", "3 3 American express\n", "5 3 Bank of America\n", "18 3 Goldman Sachs\n", "16 3 General Electrics\n", "1 3 AIG\n", "10 4 ConocoPhillips\n", "35 4 Navistar\n", "8 4 Caterpillar\n", "53 4 Valero Energy\n", "57 4 Exxon\n", "44 4 Schlumberger\n", "12 4 Chevron\n", "38 5 Pepsi\n", "40 5 Procter Gamble\n", "28 5 Coca Cola\n", "9 5 Colgate-Palmolive\n", "41 5 Philip Morris\n", "56 5 Wal-Mart\n", "27 5 Kimberly-Clark\n", "33 6 Microsoft\n", "24 6 Intel\n", "11 6 Cisco\n", "51 7 Texas instruments\n", "50 7 Taiwan Semiconductor Manufacturing\n", "47 7 Symantec\n", "59 7 Yahoo\n", "32 7 3M\n", "31 7 McDonalds\n", "30 7 MasterCard\n", "58 7 Xerox\n", "25 7 Johnson & Johnson\n", "23 7 IBM\n", "20 7 Home Depot\n", "13 7 DuPont de Nemours\n", "2 7 Amazon\n", "43 7 SAP\n", "54 7 Walgreen\n", "0 8 Apple\n", "17 8 Google/Alphabet\n", "4 9 Boeing\n", "36 9 Northrop Grumman\n", "29 9 Lookheed Martin\n" ] } ], "source": [ "# Predict the cluster labels: labels\n", "labels = pipeline.predict(movements)\n", "\n", "# Create a DataFrame aligning labels and companies: df\n", "df = pd.DataFrame({'labels': labels, 'companies': companies})\n", "\n", "# Display df sorted by cluster label\n", "print(df.sort_values('labels'))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }