{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multivariate Linear Regression Demo\n", "\n", "_Source: 🤖[Homemade Machine Learning](https://github.com/trekhleb/homemade-machine-learning) repository_\n", "\n", "> ☝Before moving on with this demo you might want to take a look at:\n", "> - 📗[Math behind the Linear Regression](https://github.com/trekhleb/homemade-machine-learning/tree/master/homemade/linear_regression)\n", "> - ⚙️[Linear Regression Source Code](https://github.com/trekhleb/homemade-machine-learning/blob/master/homemade/linear_regression/linear_regression.py)\n", "\n", "**Linear regression** is a linear model, e.g. a model that assumes a linear relationship between the input variables `(x)` and the single output variable `(y)`. More specifically, that output variable `(y)` can be calculated from a linear combination of the input variables `(x)`.\n", "\n", "**Multivariate Linear Regression** is a linear regression that has _more than one_ input parameter and one output label.\n", "\n", "> **Demo Project:** In this demo we will build a model that will predict `Happiness.Score` for the countries based on `Economy.GDP.per.Capita` and `Freedom` parameters." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# To make debugging of linear_regression module easier we enable imported modules autoreloading feature.\n", "# By doing this you may change the code of linear_regression library and all these changes will be available here.\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "# Add project root folder to module loading paths.\n", "import sys\n", "sys.path.append('../..')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import Dependencies\n", "\n", "- [pandas](https://pandas.pydata.org/) - library that we will use for loading and displaying the data in a table\n", "- [numpy](http://www.numpy.org/) - library that we will use for linear algebra operations\n", "- [matplotlib](https://matplotlib.org/) - library that we will use for plotting the data\n", "- [plotly](https://plot.ly/python/) - library that we will use for plotting interactive 3D scatters\n", "- [linear_regression](https://github.com/trekhleb/homemade-machine-learning/blob/master/src/linear_regression/linear_regression.py) - custom implementation of linear regression" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/vnd.plotly.v1+html": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Import 3rd party dependencies.\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import plotly\n", "import plotly.graph_objs as go\n", "\n", "# Configure Plotly to be rendered inline in the notebook.\n", "plotly.offline.init_notebook_mode()\n", "\n", "# Import custom linear regression implementation.\n", "from homemade.linear_regression import LinearRegression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the Data\n", "\n", "In this demo we will use [World Happindes Dataset](https://www.kaggle.com/unsdsn/world-happiness#2017.csv) for 2017." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CountryHappiness.RankHappiness.ScoreWhisker.highWhisker.lowEconomy..GDP.per.Capita.FamilyHealth..Life.Expectancy.FreedomGenerosityTrust..Government.Corruption.Dystopia.Residual
0Norway17.5377.5944457.4795561.6164631.5335240.7966670.6354230.3620120.3159642.277027
1Denmark27.5227.5817287.4622721.4823831.5511220.7925660.6260070.3552800.4007702.313707
2Iceland37.5047.6220307.3859701.4806331.6105740.8335520.6271630.4755400.1535272.322715
3Switzerland47.4947.5617727.4262271.5649801.5169120.8581310.6200710.2905490.3670072.276716
4Finland57.4697.5275427.4104581.4435721.5402470.8091580.6179510.2454830.3826122.430182
5Netherlands67.3777.4274267.3265741.5039451.4289390.8106960.5853840.4704900.2826622.294804
6Canada77.3167.3844037.2475971.4792041.4813490.8345580.6111010.4355400.2873722.187264
7New Zealand87.3147.3795107.2484901.4057061.5481950.8167600.6140620.5000050.3828172.046456
8Sweden97.2847.3440957.2239051.4943871.4781620.8308750.6129240.3853990.3843992.097538
9Australia107.2847.3566517.2113491.4844151.5100420.8438870.6016070.4776990.3011842.065211
\n", "
" ], "text/plain": [ " Country Happiness.Rank Happiness.Score Whisker.high Whisker.low \\\n", "0 Norway 1 7.537 7.594445 7.479556 \n", "1 Denmark 2 7.522 7.581728 7.462272 \n", "2 Iceland 3 7.504 7.622030 7.385970 \n", "3 Switzerland 4 7.494 7.561772 7.426227 \n", "4 Finland 5 7.469 7.527542 7.410458 \n", "5 Netherlands 6 7.377 7.427426 7.326574 \n", "6 Canada 7 7.316 7.384403 7.247597 \n", "7 New Zealand 8 7.314 7.379510 7.248490 \n", "8 Sweden 9 7.284 7.344095 7.223905 \n", "9 Australia 10 7.284 7.356651 7.211349 \n", "\n", " Economy..GDP.per.Capita. Family Health..Life.Expectancy. Freedom \\\n", "0 1.616463 1.533524 0.796667 0.635423 \n", "1 1.482383 1.551122 0.792566 0.626007 \n", "2 1.480633 1.610574 0.833552 0.627163 \n", "3 1.564980 1.516912 0.858131 0.620071 \n", "4 1.443572 1.540247 0.809158 0.617951 \n", "5 1.503945 1.428939 0.810696 0.585384 \n", "6 1.479204 1.481349 0.834558 0.611101 \n", "7 1.405706 1.548195 0.816760 0.614062 \n", "8 1.494387 1.478162 0.830875 0.612924 \n", "9 1.484415 1.510042 0.843887 0.601607 \n", "\n", " Generosity Trust..Government.Corruption. Dystopia.Residual \n", "0 0.362012 0.315964 2.277027 \n", "1 0.355280 0.400770 2.313707 \n", "2 0.475540 0.153527 2.322715 \n", "3 0.290549 0.367007 2.276716 \n", "4 0.245483 0.382612 2.430182 \n", "5 0.470490 0.282662 2.294804 \n", "6 0.435540 0.287372 2.187264 \n", "7 0.500005 0.382817 2.046456 \n", "8 0.385399 0.384399 2.097538 \n", "9 0.477699 0.301184 2.065211 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the data.\n", "data = pd.read_csv('../../data/world-happiness-report-2017.csv')\n", "\n", "# Print the data table.\n", "data.head(10)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Print histograms for each feature to see how they vary.\n", "histohrams = data.hist(grid=False, figsize=(10, 10))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Split the Data Into Training and Test Sets\n", "\n", "In this step we will split our dataset into _training_ and _testing_ subsets (in proportion 80/20%).\n", "\n", "Training data set will be used for training of our linear model. Testing dataset will be used for validating of the model. All data from testing dataset will be new to model and we may check how accurate are model predictions." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Split data set on training and test sets with proportions 80/20.\n", "# Function sample() returns a random sample of items.\n", "train_data = data.sample(frac=0.8)\n", "test_data = data.drop(train_data.index)\n", "\n", "# Decide what fields we want to process.\n", "input_param_name_1 = 'Economy..GDP.per.Capita.'\n", "input_param_name_2 = 'Freedom'\n", "output_param_name = 'Happiness.Score'\n", "\n", "# Split training set input and output.\n", "x_train = train_data[[input_param_name_1, input_param_name_2]].values\n", "y_train = train_data[[output_param_name]].values\n", "\n", "# Split test set input and output.\n", "x_test = test_data[[input_param_name_1, input_param_name_2]].values\n", "y_test = test_data[[output_param_name]].values" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's visualize the training and test datasets to see the shape of the data." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "data": [ { "marker": { "line": { "color": "rgb(255, 255, 255)", "width": 1 }, "opacity": 1, "size": 10 }, "mode": "markers", "name": "Training Set", "type": "scatter3d", "uid": "b2ccc8e3-f5b5-4efb-af5c-e982a479b71b", "x": [ 1.1027104854583702, 1.3208793401718102, 0.7922212481498722, 0.39724862575531, 0.9097844958305359, 1.1893955469131499, 0.401477217674255, 1.06457793712616, 0.479820191860199, 0.368745893239975, 1.00985014438629, 0.381430715322495, 1.07062232494354, 1.12786877155304, 0.6364067792892459, 0.564305365085602, 0.932537317276001, 1.28948748111725, 1.29178786277771, 1.00726580619812, 0.8584281802177429, 0.30544471740722695, 1.1018030643463101, 1.63295245170593, 0.7288706302642819, 1.49438726902008, 1.8707656860351598, 1.3469113111496, 0, 1.05469870567322, 0.951484382152557, 0.5916834473609921, 0.982409417629242, 1.15687310695648, 1.4637807607650801, 1.48238301277161, 1.69227766990662, 0.43108540773391707, 1.3753824234008798, 0.0921023488044739, 1.12112903594971, 0.786441087722778, 1.00082039833069, 0.24454993009567302, 0.36711055040359497, 0.511135876178741, 1.41691517829895, 1.4870972633361799, 1.1531838178634601, 0.7268835306167599, 1.1073532104492199, 0.11904179304838199, 1.2860119342803997, 1.15655755996704, 0.716249227523804, 1.22255623340607, 0.667224824428558, 0.368610262870789, 1.0352252721786501, 0.995538592338562, 1.74194359779358, 1.6263433694839498, 1.12209415435791, 0.6595166921615601, 1.5649795532226598, 1.4844149351120002, 0.925579309463501, 1.2175596952438401, 1.36135590076447, 1.23374843597412, 1.12843120098114, 1.28177809715271, 1.29121541976929, 1.07498753070831, 0.85769921541214, 1.480633020401, 1.1536017656326298, 0.6484572887420649, 0.89465194940567, 0.872001945972443, 0.950612664222717, 1.43092346191406, 1.1614590883255, 1.10970628261566, 1.21768391132355, 0.8089642524719242, 0.9910123944282528, 0.305808693170547, 1.3412059545516999, 0.37584653496742204, 0.339233845472336, 1.40570604801178, 1.0272358655929599, 0.900596737861633, 1.1982743740081798, 1.40167844295502, 1.3559380769729599, 0.560479462146759, 1.07937383651733, 1.48792338371277, 1.3145823478698702, 0.996192753314972, 1.3950666189193701, 0.964434325695038, 1.25278460979462, 1.55167484283447, 1.50394463539124, 1.44163393974304, 1.18529546260834, 1.54625928401947, 1.5357066392898602, 1.1307767629623402, 0.6017650961875921, 0.524713635444641, 0.7885475754737851, 0.0226431842893362, 1.44357192516327, 1.53062355518341, 0.730573117733002, 1.1982102394104, 0.737299203872681, 1.08116579055786, 1.2845562696456898, 0.233442038297653 ], "y": [ 0.288555532693863, 0.479131430387497, 0.469987004995346, 0.147062435746193, 0.432452529668808, 0.491247326135635, 0.106179520487785, 0.325905978679657, 0.44030594825744607, 0.5818438529968261, 0.561213254928589, 0.443185955286026, 0.47748741507530196, 0.580200731754303, 0.461603492498398, 0.430388748645782, 0.473507791757584, 0.0957312509417534, 0.520342111587524, 0.289680689573288, 0, 0.38042613863945, 0.465733230113983, 0.49633759260177596, 0.24072904884815197, 0.612924098968506, 0.604130983352661, 0.47120362520217896, 0.270842045545578, 0.479246735572815, 0.260287940502167, 0.24946372210979503, 0.204403176903725, 0.24932260811328896, 0.5397707223892211, 0.626006722450256, 0.549840569496155, 0.42596277594566295, 0.40598860383033797, 0.235961347818375, 0.194989055395126, 0.6582486629486078, 0.4551981985569, 0.348587512969971, 0.514492034912109, 0.390017777681351, 0.505625545978546, 0.567766189575195, 0.412730008363724, 0.23521526157856, 0.437453746795654, 0.33288118243217496, 0.17586351931095098, 0.295400261878967, 0.25471106171608, 0.255772292613983, 0.423026293516159, 0.0303698573261499, 0.45000287890434293, 0.443323463201523, 0.59662789106369, 0.60834527015686, 0.505196332931519, 0.0149958552792668, 0.620070576667786, 0.601607382297516, 0.474307239055634, 0.5793922543525699, 0.518630743026733, 0.550026834011078, 0.15399712324142498, 0.373783111572266, 0.40226498246192893, 0.28851598501205394, 0.585214674472809, 0.6271626353263849, 0.39815583825111406, 0.0960980430245399, 0.12297477573156401, 0.5313106179237371, 0.309410035610199, 0.470222115516663, 0.28923171758651695, 0.580131649971008, 0.457003742456436, 0.4350258708000179, 0.418421149253845, 0.18919676542282102, 0.572575807571411, 0.336384207010269, 0.408842742443085, 0.6140621304512021, 0.39414396882057207, 0.198303267359734, 0.300740599632263, 0.257921665906906, 0.35511153936386103, 0.45276376605033897, 0.55258983373642, 0.562511384487152, 0.234231784939766, 0.381498634815216, 0.256450712680817, 0.520303547382355, 0.376895278692245, 0.490968644618988, 0.5853844881057739, 0.508190035820007, 0.49451920390129106, 0.505740523338318, 0.5731103420257571, 0.41827192902565, 0.633375823497772, 0.47156670689582797, 0.571055591106415, 0.602126955986023, 0.6179508566856379, 0.449750572443008, 0.348079860210419, 0.31232857704162603, 0.447551846504211, 0.47278770804405196, 0.43745428323745705, 0.466914653778076 ], "z": [ 4.49700021743774, 5.61100006103516, 4.31500005722046, 3.5910000801086404, 6.002999782562259, 5.62900018692017, 3.7939999103546103, 5.175000190734861, 4.961999893188481, 3.47099995613098, 4.44000005722046, 4.08099985122681, 6.35699987411499, 6.4239997863769505, 4.513999938964839, 4.69500017166138, 5.493000030517581, 5.2270002365112305, 5.97300004959106, 4.80499982833862, 3.79500007629395, 3.4949998855590803, 5.525000095367429, 6.10500001907349, 5.837999820709231, 7.28399991989136, 6.375, 5.80999994277954, 2.69300007820129, 4.8289999961853, 5.27899980545044, 3.59299993515015, 5.18200016021729, 4.69199991226196, 6.89099979400635, 7.52199983596802, 6.57200002670288, 3.65700006484985, 7.212999820709231, 4.2800002098083505, 5.23699998855591, 5.97100019454956, 6.00799989700317, 3.50699996948242, 4.54500007629395, 3.34899997711182, 5.92000007629395, 7.00600004196167, 6.57800006866455, 5.26900005340576, 6.6350002288818395, 3.5329999923706095, 5.32399988174438, 5.5689997673034695, 4.77500009536743, 5.2930002212524405, 4.11999988555908, 3.6029999256133998, 5.71500015258789, 5.26200008392334, 6.86299991607666, 6.6479997634887695, 3.7660000324249303, 4.138999938964839, 7.49399995803833, 7.28399991989136, 5.31099987030029, 6.4539999961853, 6.1680002212524405, 6.4520001411438, 5.25, 5.962999820709231, 6.08400011062622, 5.22499990463257, 5.42999982833862, 7.50400018692017, 5.234000205993651, 4.29199981689453, 4.09600019454956, 6.4539999961853, 4.28599977493286, 6.44199991226196, 4.71400022506714, 7.0789999961853, 5.824999809265139, 4.29099988937378, 5.33599996566772, 3.64400005340576, 5.75799989700317, 3.875, 4.46000003814697, 7.31400012969971, 4.95499992370605, 4.37599992752075, 5.5, 5.837999820709231, 5.62099981307983, 4.55299997329712, 5.230000019073491, 6.9510002136230495, 5.90199995040894, 4.64400005340576, 5.96400022506714, 4.57399988174438, 6.65199995040894, 5.47200012207031, 7.3769998550415, 6.71400022506714, 6.59899997711182, 6.993000030517581, 6.9770002365112305, 5.82200002670288, 4.1680002212524405, 5.04099988937378, 5.07399988174438, 5.151000022888179, 7.468999862670901, 6.343999862670901, 5.1810002326965305, 4.46500015258789, 6.07100009918213, 5.2729997634887695, 5.8189997673034695, 3.97000002861023 ] }, { "marker": { "line": { "color": "rgb(255, 255, 255)", "width": 1 }, "opacity": 1, "size": 10 }, "mode": "markers", "name": "Test Set", "type": "scatter3d", "uid": "297d492d-6698-40a5-843b-cf6b403a4442", "x": [ 1.6164631843566901, 1.47920441627502, 1.35268235206604, 1.34327983856201, 1.4336265325546298, 1.38439786434174, 1.32539355754852, 1.4884122610092199, 0.907975316047668, 1.09186446666718, 1.26074862480164, 0.833756566047668, 1.06931757926941, 0.8781145811080929, 1.3151752948761002, 0.783756256103516, 0.885416388511658, 0.5962200760841371, 0.989701807498932, 0.36842092871666, 0.586682975292206, 0.23430564999580397, 0.479309022426605, 0.476180493831635, 0.6030489206314089, 0.35022771358490007, 0.16192533075809498, 0.43801298737525896, 0.521021246910095, 0.777153134346008, 0.09162256866693501 ], "y": [ 0.635422587394714, 0.611100912094116, 0.49094617366790794, 0.588767051696777, 0.361466586589813, 0.408781230449677, 0.295817464590073, 0.536746919155121, 0.5475093722343439, 0.233335807919502, 0.32570791244506797, 0.5587329268455511, 0.20871552824974102, 0.408158332109451, 0.498465299606323, 0.394952565431595, 0.5015376806259161, 0.454943388700485, 0.28211015462875394, 0.318697690963745, 0.478356659412384, 0.48079109191894503, 0.37792226672172496, 0.306613743305206, 0.4477061927318571, 0.32436785101890603, 0.36365869641304, 0.16234202682972002, 0.390661299228668, 0.0815394446253777, 0.0599007532000542 ], "z": [ 7.537000179290769, 7.31599998474121, 6.60900020599365, 6.52699995040894, 6.42199993133545, 6.40299987792969, 6.09800004959106, 6.08699989318848, 5.955999851226809, 5.872000217437741, 5.849999904632571, 5.82299995422363, 5.39499998092651, 5.2350001335144, 5.19500017166138, 5.07399988174438, 5.01100015640259, 5.0040001869201705, 4.7350001335144, 4.70900011062622, 4.60799980163574, 4.550000190734861, 4.53499984741211, 4.19000005722046, 4.17999982833862, 4.03200006484985, 4.02799987792969, 3.9360001087188703, 3.80800008773804, 3.46199989318848, 2.90499997138977 ] } ], "layout": { "margin": { "b": 0, "l": 0, "r": 0, "t": 0 }, "scene": { "xaxis": { "title": "Economy..GDP.per.Capita." }, "yaxis": { "title": "Freedom" }, "zaxis": { "title": "Happiness.Score" } }, "title": "Date Sets" } }, "text/html": [ "
" ], "text/vnd.plotly.v1+html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Configure the plot with training dataset.\n", "plot_training_trace = go.Scatter3d(\n", " x=x_train[:, 0].flatten(),\n", " y=x_train[:, 1].flatten(),\n", " z=y_train.flatten(),\n", " name='Training Set',\n", " mode='markers',\n", " marker={\n", " 'size': 10,\n", " 'opacity': 1,\n", " 'line': {\n", " 'color': 'rgb(255, 255, 255)',\n", " 'width': 1\n", " },\n", " }\n", ")\n", "\n", "# Configure the plot with test dataset.\n", "plot_test_trace = go.Scatter3d(\n", " x=x_test[:, 0].flatten(),\n", " y=x_test[:, 1].flatten(),\n", " z=y_test.flatten(),\n", " name='Test Set',\n", " mode='markers',\n", " marker={\n", " 'size': 10,\n", " 'opacity': 1,\n", " 'line': {\n", " 'color': 'rgb(255, 255, 255)',\n", " 'width': 1\n", " },\n", " }\n", ")\n", "\n", "# Configure the layout.\n", "plot_layout = go.Layout(\n", " title='Date Sets',\n", " scene={\n", " 'xaxis': {'title': input_param_name_1},\n", " 'yaxis': {'title': input_param_name_2},\n", " 'zaxis': {'title': output_param_name} \n", " },\n", " margin={'l': 0, 'r': 0, 'b': 0, 't': 0}\n", ")\n", "\n", "plot_data = [plot_training_trace, plot_test_trace]\n", "\n", "plot_figure = go.Figure(data=plot_data, layout=plot_layout)\n", "\n", "# Render 3D scatter plot.\n", "plotly.offline.iplot(plot_figure)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Init and Train Linear Regression Model\n", "\n", "> ☝🏻This is the place where you might want to play with model configuration.\n", "\n", "- `polynomial_degree` - this parameter will allow you to add additional polynomial features of certain degree. More features - more curved the line will be.\n", "- `num_iterations` - this is the number of iterations that gradient descent algorithm will use to find the minimum of a cost function. Low numbers may prevent gradient descent from reaching the minimum. High numbers will make the algorithm work longer without improving its accuracy.\n", "- `learning_rate` - this is the size of the gradient descent step. Small learning step will make algorithm work longer and will probably require more iterations to reach the minimum of the cost function. Big learning steps may couse missing the minimum and growth of the cost function value with new iterations.\n", "- `regularization_param` - parameter that will fight overfitting. The higher the parameter, the simplier is the model will be.\n", "- `polynomial_degree` - the degree of additional polynomial features (`x1^2 * x2, x1^2 * x2^2, ...`). This will allow you to curve the predictions.\n", "- `sinusoid_degree` - the degree of sinusoid parameter multipliers of additional features (`sin(x), sin(2*x), ...`). This will allow you to curve the predictions by adding sinusoidal component to the prediction curve." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial cost: 224369.59\n", "Optimized cost: 2761.58\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Model Parameters
05.393174
10.765539
20.357765
\n", "
" ], "text/plain": [ " Model Parameters\n", "0 5.393174\n", "1 0.765539\n", "2 0.357765" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Set up linear regression parameters.\n", "num_iterations = 500 # Number of gradient descent iterations.\n", "regularization_param = 0 # Helps to fight model overfitting.\n", "learning_rate = 0.01 # The size of the gradient descent step.\n", "polynomial_degree = 0 # The degree of additional polynomial features.\n", "sinusoid_degree = 0 # The degree of sinusoid parameter multipliers of additional features.\n", "\n", "# Init linear regression instance.\n", "linear_regression = LinearRegression(x_train, y_train, polynomial_degree, sinusoid_degree)\n", "\n", "# Train linear regression.\n", "(theta, cost_history) = linear_regression.train(\n", " learning_rate,\n", " regularization_param,\n", " num_iterations\n", ")\n", "\n", "# Print training results.\n", "print('Initial cost: {:.2f}'.format(cost_history[0]))\n", "print('Optimized cost: {:.2f}'.format(cost_history[-1]))\n", "\n", "# Print model parameters\n", "theta_table = pd.DataFrame({'Model Parameters': theta.flatten()})\n", "theta_table.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Analyze Gradient Descent Progress\n", "\n", "The plot below illustrates how the cost function value changes over each iteration. You should see it decreasing. \n", "\n", "In case if cost function value increases it may mean that gradient descent missed the cost function minimum and with each step it goes further away from it. In this case you might want to reduce the learning rate parameter (the size of the gradient step).\n", "\n", "From this plot you may also get an understanding of how many iterations you need to get an optimal value of the cost function. In current example you may see that there is no much sense to increase the number of gradient descent iterations over 500 since it will not reduce cost function significantly. " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot gradient descent progress.\n", "plt.plot(range(num_iterations), cost_history)\n", "plt.xlabel('Iterations')\n", "plt.ylabel('Cost')\n", "plt.title('Gradient Descent Progress')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot the Model Predictions\n", "\n", "Since our model is trained now we may plot its predictions over the training and test datasets to see how well it fits the data." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "data": [ { "marker": { "line": { "color": "rgb(255, 255, 255)", "width": 1 }, "opacity": 1, "size": 10 }, "mode": "markers", "name": "Training Set", "type": "scatter3d", "uid": "5121ac59-e6e4-4c30-8d0c-c3482182891a", "x": [ 1.1027104854583702, 1.3208793401718102, 0.7922212481498722, 0.39724862575531, 0.9097844958305359, 1.1893955469131499, 0.401477217674255, 1.06457793712616, 0.479820191860199, 0.368745893239975, 1.00985014438629, 0.381430715322495, 1.07062232494354, 1.12786877155304, 0.6364067792892459, 0.564305365085602, 0.932537317276001, 1.28948748111725, 1.29178786277771, 1.00726580619812, 0.8584281802177429, 0.30544471740722695, 1.1018030643463101, 1.63295245170593, 0.7288706302642819, 1.49438726902008, 1.8707656860351598, 1.3469113111496, 0, 1.05469870567322, 0.951484382152557, 0.5916834473609921, 0.982409417629242, 1.15687310695648, 1.4637807607650801, 1.48238301277161, 1.69227766990662, 0.43108540773391707, 1.3753824234008798, 0.0921023488044739, 1.12112903594971, 0.786441087722778, 1.00082039833069, 0.24454993009567302, 0.36711055040359497, 0.511135876178741, 1.41691517829895, 1.4870972633361799, 1.1531838178634601, 0.7268835306167599, 1.1073532104492199, 0.11904179304838199, 1.2860119342803997, 1.15655755996704, 0.716249227523804, 1.22255623340607, 0.667224824428558, 0.368610262870789, 1.0352252721786501, 0.995538592338562, 1.74194359779358, 1.6263433694839498, 1.12209415435791, 0.6595166921615601, 1.5649795532226598, 1.4844149351120002, 0.925579309463501, 1.2175596952438401, 1.36135590076447, 1.23374843597412, 1.12843120098114, 1.28177809715271, 1.29121541976929, 1.07498753070831, 0.85769921541214, 1.480633020401, 1.1536017656326298, 0.6484572887420649, 0.89465194940567, 0.872001945972443, 0.950612664222717, 1.43092346191406, 1.1614590883255, 1.10970628261566, 1.21768391132355, 0.8089642524719242, 0.9910123944282528, 0.305808693170547, 1.3412059545516999, 0.37584653496742204, 0.339233845472336, 1.40570604801178, 1.0272358655929599, 0.900596737861633, 1.1982743740081798, 1.40167844295502, 1.3559380769729599, 0.560479462146759, 1.07937383651733, 1.48792338371277, 1.3145823478698702, 0.996192753314972, 1.3950666189193701, 0.964434325695038, 1.25278460979462, 1.55167484283447, 1.50394463539124, 1.44163393974304, 1.18529546260834, 1.54625928401947, 1.5357066392898602, 1.1307767629623402, 0.6017650961875921, 0.524713635444641, 0.7885475754737851, 0.0226431842893362, 1.44357192516327, 1.53062355518341, 0.730573117733002, 1.1982102394104, 0.737299203872681, 1.08116579055786, 1.2845562696456898, 0.233442038297653 ], "y": [ 0.288555532693863, 0.479131430387497, 0.469987004995346, 0.147062435746193, 0.432452529668808, 0.491247326135635, 0.106179520487785, 0.325905978679657, 0.44030594825744607, 0.5818438529968261, 0.561213254928589, 0.443185955286026, 0.47748741507530196, 0.580200731754303, 0.461603492498398, 0.430388748645782, 0.473507791757584, 0.0957312509417534, 0.520342111587524, 0.289680689573288, 0, 0.38042613863945, 0.465733230113983, 0.49633759260177596, 0.24072904884815197, 0.612924098968506, 0.604130983352661, 0.47120362520217896, 0.270842045545578, 0.479246735572815, 0.260287940502167, 0.24946372210979503, 0.204403176903725, 0.24932260811328896, 0.5397707223892211, 0.626006722450256, 0.549840569496155, 0.42596277594566295, 0.40598860383033797, 0.235961347818375, 0.194989055395126, 0.6582486629486078, 0.4551981985569, 0.348587512969971, 0.514492034912109, 0.390017777681351, 0.505625545978546, 0.567766189575195, 0.412730008363724, 0.23521526157856, 0.437453746795654, 0.33288118243217496, 0.17586351931095098, 0.295400261878967, 0.25471106171608, 0.255772292613983, 0.423026293516159, 0.0303698573261499, 0.45000287890434293, 0.443323463201523, 0.59662789106369, 0.60834527015686, 0.505196332931519, 0.0149958552792668, 0.620070576667786, 0.601607382297516, 0.474307239055634, 0.5793922543525699, 0.518630743026733, 0.550026834011078, 0.15399712324142498, 0.373783111572266, 0.40226498246192893, 0.28851598501205394, 0.585214674472809, 0.6271626353263849, 0.39815583825111406, 0.0960980430245399, 0.12297477573156401, 0.5313106179237371, 0.309410035610199, 0.470222115516663, 0.28923171758651695, 0.580131649971008, 0.457003742456436, 0.4350258708000179, 0.418421149253845, 0.18919676542282102, 0.572575807571411, 0.336384207010269, 0.408842742443085, 0.6140621304512021, 0.39414396882057207, 0.198303267359734, 0.300740599632263, 0.257921665906906, 0.35511153936386103, 0.45276376605033897, 0.55258983373642, 0.562511384487152, 0.234231784939766, 0.381498634815216, 0.256450712680817, 0.520303547382355, 0.376895278692245, 0.490968644618988, 0.5853844881057739, 0.508190035820007, 0.49451920390129106, 0.505740523338318, 0.5731103420257571, 0.41827192902565, 0.633375823497772, 0.47156670689582797, 0.571055591106415, 0.602126955986023, 0.6179508566856379, 0.449750572443008, 0.348079860210419, 0.31232857704162603, 0.447551846504211, 0.47278770804405196, 0.43745428323745705, 0.466914653778076 ], "z": [ 4.49700021743774, 5.61100006103516, 4.31500005722046, 3.5910000801086404, 6.002999782562259, 5.62900018692017, 3.7939999103546103, 5.175000190734861, 4.961999893188481, 3.47099995613098, 4.44000005722046, 4.08099985122681, 6.35699987411499, 6.4239997863769505, 4.513999938964839, 4.69500017166138, 5.493000030517581, 5.2270002365112305, 5.97300004959106, 4.80499982833862, 3.79500007629395, 3.4949998855590803, 5.525000095367429, 6.10500001907349, 5.837999820709231, 7.28399991989136, 6.375, 5.80999994277954, 2.69300007820129, 4.8289999961853, 5.27899980545044, 3.59299993515015, 5.18200016021729, 4.69199991226196, 6.89099979400635, 7.52199983596802, 6.57200002670288, 3.65700006484985, 7.212999820709231, 4.2800002098083505, 5.23699998855591, 5.97100019454956, 6.00799989700317, 3.50699996948242, 4.54500007629395, 3.34899997711182, 5.92000007629395, 7.00600004196167, 6.57800006866455, 5.26900005340576, 6.6350002288818395, 3.5329999923706095, 5.32399988174438, 5.5689997673034695, 4.77500009536743, 5.2930002212524405, 4.11999988555908, 3.6029999256133998, 5.71500015258789, 5.26200008392334, 6.86299991607666, 6.6479997634887695, 3.7660000324249303, 4.138999938964839, 7.49399995803833, 7.28399991989136, 5.31099987030029, 6.4539999961853, 6.1680002212524405, 6.4520001411438, 5.25, 5.962999820709231, 6.08400011062622, 5.22499990463257, 5.42999982833862, 7.50400018692017, 5.234000205993651, 4.29199981689453, 4.09600019454956, 6.4539999961853, 4.28599977493286, 6.44199991226196, 4.71400022506714, 7.0789999961853, 5.824999809265139, 4.29099988937378, 5.33599996566772, 3.64400005340576, 5.75799989700317, 3.875, 4.46000003814697, 7.31400012969971, 4.95499992370605, 4.37599992752075, 5.5, 5.837999820709231, 5.62099981307983, 4.55299997329712, 5.230000019073491, 6.9510002136230495, 5.90199995040894, 4.64400005340576, 5.96400022506714, 4.57399988174438, 6.65199995040894, 5.47200012207031, 7.3769998550415, 6.71400022506714, 6.59899997711182, 6.993000030517581, 6.9770002365112305, 5.82200002670288, 4.1680002212524405, 5.04099988937378, 5.07399988174438, 5.151000022888179, 7.468999862670901, 6.343999862670901, 5.1810002326965305, 4.46500015258789, 6.07100009918213, 5.2729997634887695, 5.8189997673034695, 3.97000002861023 ] }, { "marker": { "line": { "color": "rgb(255, 255, 255)", "width": 1 }, "opacity": 1, "size": 10 }, "mode": "markers", "name": "Test Set", "type": "scatter3d", "uid": "973829dd-d770-41f6-b456-046962af6151", "x": [ 1.6164631843566901, 1.47920441627502, 1.35268235206604, 1.34327983856201, 1.4336265325546298, 1.38439786434174, 1.32539355754852, 1.4884122610092199, 0.907975316047668, 1.09186446666718, 1.26074862480164, 0.833756566047668, 1.06931757926941, 0.8781145811080929, 1.3151752948761002, 0.783756256103516, 0.885416388511658, 0.5962200760841371, 0.989701807498932, 0.36842092871666, 0.586682975292206, 0.23430564999580397, 0.479309022426605, 0.476180493831635, 0.6030489206314089, 0.35022771358490007, 0.16192533075809498, 0.43801298737525896, 0.521021246910095, 0.777153134346008, 0.09162256866693501 ], "y": [ 0.635422587394714, 0.611100912094116, 0.49094617366790794, 0.588767051696777, 0.361466586589813, 0.408781230449677, 0.295817464590073, 0.536746919155121, 0.5475093722343439, 0.233335807919502, 0.32570791244506797, 0.5587329268455511, 0.20871552824974102, 0.408158332109451, 0.498465299606323, 0.394952565431595, 0.5015376806259161, 0.454943388700485, 0.28211015462875394, 0.318697690963745, 0.478356659412384, 0.48079109191894503, 0.37792226672172496, 0.306613743305206, 0.4477061927318571, 0.32436785101890603, 0.36365869641304, 0.16234202682972002, 0.390661299228668, 0.0815394446253777, 0.0599007532000542 ], "z": [ 7.537000179290769, 7.31599998474121, 6.60900020599365, 6.52699995040894, 6.42199993133545, 6.40299987792969, 6.09800004959106, 6.08699989318848, 5.955999851226809, 5.872000217437741, 5.849999904632571, 5.82299995422363, 5.39499998092651, 5.2350001335144, 5.19500017166138, 5.07399988174438, 5.01100015640259, 5.0040001869201705, 4.7350001335144, 4.70900011062622, 4.60799980163574, 4.550000190734861, 4.53499984741211, 4.19000005722046, 4.17999982833862, 4.03200006484985, 4.02799987792969, 3.9360001087188703, 3.80800008773804, 3.46199989318848, 2.90499997138977 ] }, { "marker": { "size": 1 }, "mode": "markers", "name": "Prediction Plane", "opacity": 0.8, "surfaceaxis": 2, "type": "scatter3d", "uid": "b9286cc7-12bf-4571-b516-f8bda6c3c119", "x": [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.20786285400390664, 0.20786285400390664, 0.20786285400390664, 0.20786285400390664, 0.20786285400390664, 0.20786285400390664, 0.20786285400390664, 0.20786285400390664, 0.20786285400390664, 0.20786285400390664, 0.4157257080078133, 0.4157257080078133, 0.4157257080078133, 0.4157257080078133, 0.4157257080078133, 0.4157257080078133, 0.4157257080078133, 0.4157257080078133, 0.4157257080078133, 0.4157257080078133, 0.6235885620117199, 0.6235885620117199, 0.6235885620117199, 0.6235885620117199, 0.6235885620117199, 0.6235885620117199, 0.6235885620117199, 0.6235885620117199, 0.6235885620117199, 0.6235885620117199, 0.8314514160156266, 0.8314514160156266, 0.8314514160156266, 0.8314514160156266, 0.8314514160156266, 0.8314514160156266, 0.8314514160156266, 0.8314514160156266, 0.8314514160156266, 0.8314514160156266, 1.0393142700195332, 1.0393142700195332, 1.0393142700195332, 1.0393142700195332, 1.0393142700195332, 1.0393142700195332, 1.0393142700195332, 1.0393142700195332, 1.0393142700195332, 1.0393142700195332, 1.2471771240234397, 1.2471771240234397, 1.2471771240234397, 1.2471771240234397, 1.2471771240234397, 1.2471771240234397, 1.2471771240234397, 1.2471771240234397, 1.2471771240234397, 1.2471771240234397, 1.4550399780273464, 1.4550399780273464, 1.4550399780273464, 1.4550399780273464, 1.4550399780273464, 1.4550399780273464, 1.4550399780273464, 1.4550399780273464, 1.4550399780273464, 1.4550399780273464, 1.662902832031253, 1.662902832031253, 1.662902832031253, 1.662902832031253, 1.662902832031253, 1.662902832031253, 1.662902832031253, 1.662902832031253, 1.662902832031253, 1.662902832031253, 1.8707656860351598, 1.8707656860351598, 1.8707656860351598, 1.8707656860351598, 1.8707656860351598, 1.8707656860351598, 1.8707656860351598, 1.8707656860351598, 1.8707656860351598, 1.8707656860351598 ], "y": [ 0, 0.07313874032762309, 0.14627748065524618, 0.21941622098286928, 0.29255496131049236, 0.36569370163811543, 0.43883244196573856, 0.5119711822933616, 0.5851099226209847, 0.6582486629486078, 0, 0.07313874032762309, 0.14627748065524618, 0.21941622098286928, 0.29255496131049236, 0.36569370163811543, 0.43883244196573856, 0.5119711822933616, 0.5851099226209847, 0.6582486629486078, 0, 0.07313874032762309, 0.14627748065524618, 0.21941622098286928, 0.29255496131049236, 0.36569370163811543, 0.43883244196573856, 0.5119711822933616, 0.5851099226209847, 0.6582486629486078, 0, 0.07313874032762309, 0.14627748065524618, 0.21941622098286928, 0.29255496131049236, 0.36569370163811543, 0.43883244196573856, 0.5119711822933616, 0.5851099226209847, 0.6582486629486078, 0, 0.07313874032762309, 0.14627748065524618, 0.21941622098286928, 0.29255496131049236, 0.36569370163811543, 0.43883244196573856, 0.5119711822933616, 0.5851099226209847, 0.6582486629486078, 0, 0.07313874032762309, 0.14627748065524618, 0.21941622098286928, 0.29255496131049236, 0.36569370163811543, 0.43883244196573856, 0.5119711822933616, 0.5851099226209847, 0.6582486629486078, 0, 0.07313874032762309, 0.14627748065524618, 0.21941622098286928, 0.29255496131049236, 0.36569370163811543, 0.43883244196573856, 0.5119711822933616, 0.5851099226209847, 0.6582486629486078, 0, 0.07313874032762309, 0.14627748065524618, 0.21941622098286928, 0.29255496131049236, 0.36569370163811543, 0.43883244196573856, 0.5119711822933616, 0.5851099226209847, 0.6582486629486078, 0, 0.07313874032762309, 0.14627748065524618, 0.21941622098286928, 0.29255496131049236, 0.36569370163811543, 0.43883244196573856, 0.5119711822933616, 0.5851099226209847, 0.6582486629486078, 0, 0.07313874032762309, 0.14627748065524618, 0.21941622098286928, 0.29255496131049236, 0.36569370163811543, 0.43883244196573856, 0.5119711822933616, 0.5851099226209847, 0.6582486629486078 ], "z": [ 3.6332939796459245, 3.757851868566506, 3.8824097574870877, 4.00696764640767, 4.131525535328251, 4.256083424248833, 4.380641313169415, 4.5051992020899965, 4.629757091010578, 4.75431497993116, 3.8998205174683562, 4.024378406388938, 4.14893629530952, 4.273494184230102, 4.398052073150683, 4.5226099620712645, 4.647167850991846, 4.771725739912428, 4.89628362883301, 5.020841517753592, 4.166347055290788, 4.29090494421137, 4.415462833131952, 4.5400207220525335, 4.664578610973114, 4.789136499893696, 4.913694388814278, 5.03825227773486, 5.162810166655442, 5.287368055576024, 4.432873593113219, 4.557431482033801, 4.6819893709543825, 4.806547259874964, 4.931105148795545, 5.055663037716127, 5.180220926636709, 5.304778815557291, 5.429336704477873, 5.5538945933984545, 4.699400130935651, 4.823958019856232, 4.948515908776814, 5.073073797697396, 5.197631686617977, 5.322189575538559, 5.446747464459141, 5.5713053533797225, 5.695863242300304, 5.820421131220886, 4.965926668758082, 5.090484557678664, 5.215042446599246, 5.339600335519828, 5.464158224440409, 5.588716113360991, 5.7132740022815725, 5.837831891202154, 5.962389780122736, 6.086947669043318, 5.232453206580514, 5.357011095501096, 5.481568984421678, 5.60612687334226, 5.7306847622628405, 5.855242651183422, 5.979800540104004, 6.104358429024586, 6.228916317945168, 6.35347420686575, 5.498979744402945, 5.623537633323527, 5.748095522244109, 5.87265341116469, 5.997211300085271, 6.121769189005853, 6.246327077926435, 6.370884966847017, 6.495442855767599, 6.620000744688181, 5.765506282225377, 5.8900641711459585, 6.01462206006654, 6.139179948987122, 6.263737837907703, 6.388295726828285, 6.512853615748867, 6.637411504669449, 6.7619693935900305, 6.886527282510612, 6.032032820047808, 6.15659070896839, 6.281148597888972, 6.405706486809554, 6.530264375730135, 6.654822264650717, 6.7793801535712985, 6.90393804249188, 7.028495931412462, 7.153053820333044 ] } ], "layout": { "margin": { "b": 0, "l": 0, "r": 0, "t": 0 }, "scene": { "xaxis": { "title": "Economy..GDP.per.Capita." }, "yaxis": { "title": "Freedom" }, "zaxis": { "title": "Happiness.Score" } }, "title": "Date Sets" } }, "text/html": [ "
" ], "text/vnd.plotly.v1+html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Generate different combinations of X and Y sets to build a predictions plane.\n", "predictions_num = 10\n", "\n", "# Find min and max values along X and Y axes.\n", "x_min = x_train[:, 0].min();\n", "x_max = x_train[:, 0].max();\n", "\n", "y_min = x_train[:, 1].min();\n", "y_max = x_train[:, 1].max();\n", "\n", "# Generate predefined numbe of values for eaxh axis betwing correspondent min and max values.\n", "x_axis = np.linspace(x_min, x_max, predictions_num)\n", "y_axis = np.linspace(y_min, y_max, predictions_num)\n", "\n", "# Create empty vectors for X and Y axes predictions\n", "# We're going to find cartesian product of all possible X and Y values.\n", "x_predictions = np.zeros((predictions_num * predictions_num, 1))\n", "y_predictions = np.zeros((predictions_num * predictions_num, 1))\n", "\n", "# Find cartesian product of all X and Y values.\n", "x_y_index = 0\n", "for x_index, x_value in enumerate(x_axis):\n", " for y_index, y_value in enumerate(y_axis):\n", " x_predictions[x_y_index] = x_value\n", " y_predictions[x_y_index] = y_value\n", " x_y_index += 1\n", "\n", "# Predict Z value for all X and Y pairs. \n", "z_predictions = linear_regression.predict(np.hstack((x_predictions, y_predictions)))\n", "\n", "# Plot training data with predictions.\n", "\n", "# Configure the plot with test dataset.\n", "plot_predictions_trace = go.Scatter3d(\n", " x=x_predictions.flatten(),\n", " y=y_predictions.flatten(),\n", " z=z_predictions.flatten(),\n", " name='Prediction Plane',\n", " mode='markers',\n", " marker={\n", " 'size': 1,\n", " },\n", " opacity=0.8,\n", " surfaceaxis=2, \n", ")\n", "\n", "plot_data = [plot_training_trace, plot_test_trace, plot_predictions_trace]\n", "plot_figure = go.Figure(data=plot_data, layout=plot_layout)\n", "plotly.offline.iplot(plot_figure)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Calculate the value of cost function for the training and test data set. The less this value is, the better." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train cost: 2761.58\n", "Test cost: 105.85\n" ] } ], "source": [ "train_cost = linear_regression.get_cost(x_train, y_train, regularization_param)\n", "test_cost = linear_regression.get_cost(x_test, y_test, regularization_param)\n", "\n", "print('Train cost: {:.2f}'.format(train_cost))\n", "print('Test cost: {:.2f}'.format(test_cost))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's now render the table of prediction values that our trained model does for unknown data (for test dataset). You should see that predicted happiness score should be quite similar to the known happiness score fron the test dataset." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Economy GDP per CapitaFreedomTest Happiness ScorePredicted Happiness ScorePrediction Diff
01.6164630.6354237.5377.3020780.234922
11.4792040.6111017.3167.0006830.315317
21.3526820.4909466.6096.4791390.129861
31.3432800.5887676.5276.706585-0.179585
41.4336270.3614676.4226.2982260.123774
51.3843980.4087816.4036.3298710.073129
61.3253940.2958176.0985.9446810.153319
71.4884120.5367476.0876.831415-0.744415
80.9079750.5475095.9565.8402010.115799
91.0918640.2333365.8725.3792730.492727
\n", "
" ], "text/plain": [ " Economy GDP per Capita Freedom Test Happiness Score \\\n", "0 1.616463 0.635423 7.537 \n", "1 1.479204 0.611101 7.316 \n", "2 1.352682 0.490946 6.609 \n", "3 1.343280 0.588767 6.527 \n", "4 1.433627 0.361467 6.422 \n", "5 1.384398 0.408781 6.403 \n", "6 1.325394 0.295817 6.098 \n", "7 1.488412 0.536747 6.087 \n", "8 0.907975 0.547509 5.956 \n", "9 1.091864 0.233336 5.872 \n", "\n", " Predicted Happiness Score Prediction Diff \n", "0 7.302078 0.234922 \n", "1 7.000683 0.315317 \n", "2 6.479139 0.129861 \n", "3 6.706585 -0.179585 \n", "4 6.298226 0.123774 \n", "5 6.329871 0.073129 \n", "6 5.944681 0.153319 \n", "7 6.831415 -0.744415 \n", "8 5.840201 0.115799 \n", "9 5.379273 0.492727 " ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_predictions = linear_regression.predict(x_test)\n", "\n", "test_predictions_table = pd.DataFrame({\n", " 'Economy GDP per Capita': x_test[:, 0].flatten(),\n", " 'Freedom': x_test[:, 1].flatten(),\n", " 'Test Happiness Score': y_test.flatten(),\n", " 'Predicted Happiness Score': test_predictions.flatten(),\n", " 'Prediction Diff': (y_test - test_predictions).flatten()\n", "})\n", "\n", "test_predictions_table.head(10)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 2 }