{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic Regression with daru and statsample-glm\n", "\n", "In this notebook we'll see with some examples how the probability of a given outcome can be predicted with logistic regression using daru and statsample-glm." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "data": { "application/javascript": [ "if(window['d3'] === undefined ||\n", " window['Nyaplot'] === undefined){\n", " var path = {\"d3\":\"https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.5/d3.min\",\"downloadable\":\"https://cdn.rawgit.com/domitry/d3-downloadable/master/d3-downloadable\"};\n", "\n", "\n", "\n", " var shim = {\"d3\":{\"exports\":\"d3\"},\"downloadable\":{\"exports\":\"downloadable\"}};\n", "\n", " require.config({paths: path, shim:shim});\n", "\n", "\n", "require(['d3'], function(d3){window['d3']=d3;console.log('finished loading d3');require(['downloadable'], function(downloadable){window['downloadable']=downloadable;console.log('finished loading downloadable');\n", "\n", "\tvar script = d3.select(\"head\")\n", "\t .append(\"script\")\n", "\t .attr(\"src\", \"https://cdn.rawgit.com/domitry/Nyaplotjs/master/release/nyaplot.js\")\n", "\t .attr(\"async\", true);\n", "\n", "\tscript[0][0].onload = script[0][0].onreadystatechange = function(){\n", "\n", "\n", "\t var event = document.createEvent(\"HTMLEvents\");\n", "\t event.initEvent(\"load_nyaplot\",false,false);\n", "\t window.dispatchEvent(event);\n", "\t console.log('Finished loading Nyaplotjs');\n", "\n", "\t};\n", "\n", "\n", "});});\n", "}\n" ], "text/plain": [ "\"if(window['d3'] === undefined ||\\n window['Nyaplot'] === undefined){\\n var path = {\\\"d3\\\":\\\"https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.5/d3.min\\\",\\\"downloadable\\\":\\\"https://cdn.rawgit.com/domitry/d3-downloadable/master/d3-downloadable\\\"};\\n\\n\\n\\n var shim = {\\\"d3\\\":{\\\"exports\\\":\\\"d3\\\"},\\\"downloadable\\\":{\\\"exports\\\":\\\"downloadable\\\"}};\\n\\n require.config({paths: path, shim:shim});\\n\\n\\nrequire(['d3'], function(d3){window['d3']=d3;console.log('finished loading d3');require(['downloadable'], function(downloadable){window['downloadable']=downloadable;console.log('finished loading downloadable');\\n\\n\\tvar script = d3.select(\\\"head\\\")\\n\\t .append(\\\"script\\\")\\n\\t .attr(\\\"src\\\", \\\"https://cdn.rawgit.com/domitry/Nyaplotjs/master/release/nyaplot.js\\\")\\n\\t .attr(\\\"async\\\", true);\\n\\n\\tscript[0][0].onload = script[0][0].onreadystatechange = function(){\\n\\n\\n\\t var event = document.createEvent(\\\"HTMLEvents\\\");\\n\\t event.initEvent(\\\"load_nyaplot\\\",false,false);\\n\\t window.dispatchEvent(event);\\n\\t console.log('Finished loading Nyaplotjs');\\n\\n\\t};\\n\\n\\n});});\\n}\\n\"" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "true" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "require 'daru'\n", "require 'statsample-glm'\n", "require 'open-uri'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this notebook, we will utilize [this dataset](http://www.ats.ucla.edu/stat/data/binary.csv) denoting whether students got admission for a graduate degree program depending on their GRE scores, GPA and rank of the institute they did an undergraduate degree in (ranked from 1 to 4).\n", "\n", "It should be noted that statsample-glm does not yet support categorical data so the ranks will be treated as continuos." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
Daru::DataFrame:27633020 rows: 400 cols: 4
admitgpagrerank
003.613803
113.676603
2148001
313.196404
402.935204
5137602
612.985601
703.084002
813.395403
903.927002
10048004
1103.224401
12147601
1303.087002
14147001
1503.444803
1603.877804
1702.563603
1803.758002
1913.815401
2003.175003
2113.636602
2202.826004
2303.196804
2413.357602
2513.668001
2613.616201
2713.745204
2813.227802
2903.295201
3003.785404
3103.357603
...............
39903.896003
" ], "text/plain": [ "\n", "#\n", " admit gpa gre rank \n", " 0 0 3.61 380 3 \n", " 1 1 3.67 660 3 \n", " 2 1 4 800 1 \n", " 3 1 3.19 640 4 \n", " 4 0 2.93 520 4 \n", " 5 1 3 760 2 \n", " 6 1 2.98 560 1 \n", " 7 0 3.08 400 2 \n", " 8 1 3.39 540 3 \n", " 9 0 3.92 700 2 \n", " 10 0 4 800 4 \n", " 11 0 3.22 440 1 \n", " 12 1 4 760 1 \n", " 13 0 3.08 700 2 \n", " 14 1 4 700 1 \n", " ... ... ... ... ... \n" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "content = open('http://www.ats.ucla.edu/stat/data/binary.csv')\n", "File.write('binary.csv', content.read)\n", "\n", "df = Daru::DataFrame.from_csv \"binary.csv\"\n", "df.vectors = Daru::Index.new([:admit, :gpa, :gre, :rank])\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use the `Statsampel::GLM.compute` method for logisitic regression analysis.\n", "\n", "The first method in the `compute` function is the DataFrame object, followed by the Vector that is to be the dependent variable, and then the method to be used for the link function. Can be :logit, :probit, :poisson or :normal.\n", "\n", "The `coefficients` method calculates the coefficients of the GLM and returns them as a Hash." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "{:gpa=>0.777013573719857, :gre=>0.0022939595044433273, :rank=>-0.5600313868499897, :constant=>-3.4495483976684773}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "glm = Statsample::GLM::compute df, :admit, :logistic, constant: 1\n", "c = glm.coefficients :hash" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The logistic regression coefficients give the change in the log odds of the outcome for a one unit increase in the predictor variable.\n", "\n", "Therefore, to interpret each of the above co-efficients:\n", "* For every one unit change in gre, the log odds of admission (versus non-admission) increases by **0.002**.\n", "* For a one unit increase in gpa, the log odds of being admitted to graduate school increases by **0.777**.\n", "* For every increase in the rank number of the institute (aka decrease in quality of the institute), the log odds of being admitted to graduate school increase by **-0.56**.\n", "\n", "Log odds become a little difficult to interpret, so we'll exponentiate each of the co-efficients so that each co-efficient can be interpreted as an odds-ratio." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
Daru::Vector:17552980 size: 4
nil
gpa2.174967177712439
gre1.0022965926425997
rank0.571191135676971
constant0.03175997601913591
" ], "text/plain": [ "\n", "#\n", " nil\n", " gpa 2.174967177712439\n", " gre 1.0022965926425997\n", " rank 0.571191135676971\n", " constant 0.03175997601913591\n" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Daru::Vector.new(c).exp # Calling `#exp` on Daru::Vector exponentiates each element of the Vector." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now compute the probability of gaining admission into a graduate college based on the rank of the undergraduate college, by keeping the GRE score and GPA constant.\n", "\n", "As you can see in the result below, the `rankp` Vector shows the probability of admission based on the rank. The person from the most highly rated undergrad school (rank 1) has a probability of **0.49** of getting admitted into graduate school." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
Daru::DataFrame:16947240 rows: 4 cols: 4
gpagrerankrankp
13.3899000000000017587.710.4931450619837156
33.3899000000000017587.720.357219500353945
03.3899000000000017587.730.240948896129993
23.3899000000000017587.740.1534862275970381
" ], "text/plain": [ "\n", "#\n", " gpa gre rank rankp \n", " 1 3.38990000 587.7 1 0.49314506 \n", " 3 3.38990000 587.7 2 0.35721950 \n", " 0 3.38990000 587.7 3 0.24094889 \n", " 2 3.38990000 587.7 4 0.15348622 \n" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "e = Math::E\n", "new_data = Daru::DataFrame.new({\n", " gre: [df[:gre].mean]*4,\n", " gpa: [df[:gpa].mean]*4,\n", " rank: df[:rank].factors\n", " })\n", "\n", "new_data[:rankp] = new_data.collect(:row) do |x|\n", " 1 / (1 + e ** -(c[:constant] + x[:gre] * c[:gre] + x[:gpa] * c[:gpa] + x[:rank] * c[:rank]))\n", "end\n", "\n", "new_data.sort! [:rank]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To demonstrate with another example, lets create a hypothetical dataset consisting of the body weight of 20 people and whether they survived or not.\n", "\n", "For this example we will just assume that people with less body weight have lesser chances of survival." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
Daru::DataFrame:27686700 rows: 20 cols: 2
body_weightsurvive
027.0443646507124170
127.471593129275520
227.8981665920078260
328.1242901822027030
428.1305597367509750
528.758652626328711
628.8938911709321040
729.3795371734881421
829.3874557466142650
929.730116546724030
1029.7328655902817541
1129.8046006403850861
1230.8542863969080760
1331.106705413449171
1431.4668026033057481
1531.5206414254100441
1631.9331975672145240
1732.113979627912811
1832.7606066497197761
1934.337393851086471
" ], "text/plain": [ "\n", "#\n", " body_weigh survive \n", " 0 27.0443646 0 \n", " 1 27.4715931 0 \n", " 2 27.8981665 0 \n", " 3 28.1242901 0 \n", " 4 28.1305597 0 \n", " 5 28.7586526 1 \n", " 6 28.8938911 0 \n", " 7 29.3795371 1 \n", " 8 29.3874557 0 \n", " 9 29.7301165 0 \n", " 10 29.7328655 1 \n", " 11 29.8046006 1 \n", " 12 30.8542863 0 \n", " 13 31.1067054 1 \n", " 14 31.4668026 1 \n", " ... ... ... \n" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "require 'distribution'\n", "\n", "# Create a normally distributed Vector with mean 30 and standard deviation 2\n", "rng = Distribution::Normal.rng(30,2)\n", "body_weight = Daru::Vector.new(20.times.map { rng.call }.sort)\n", "\n", "# Populate chances of survival, assume that people with less body weight on average\n", "# are less likely to survive.\n", "survive = Daru::Vector.new [0,0,0,0,0,1,0,1,0,0,1,1,0,1,1,1,0,1,1,1]\n", "\n", "df = Daru::DataFrame.new({\n", " body_weight: body_weight,\n", " survive: survive\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compute the logistic regression co-efficients." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "{:body_weight=>0.8433486251123171, :constant=>-25.24920458377614}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "glm = Statsample::GLM.compute df, :survive, :logistic, constant: 1\n", "coeffs = glm.coefficients :hash" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Based on the coefficients, we compute the predicted probabilities for each number in the Vector :body_weight and store them in another Vector called `:survive_pred`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
Daru::DataFrame:27686700 rows: 20 cols: 3
body_weightsurvivesurvive_pred
027.04436465071241700.08007143558819431
127.4715931292755200.11094995452363857
227.89816659200782600.15170068399992506
328.12429018220270300.17790253325703076
428.13055973675097500.1786771529208482
528.7586526263287110.26980060957631496
628.89389117093210400.2928502245475736
729.37953717348814210.38414006941637974
829.38745574661426500.3857211724501716
929.7301165467240300.456025989208083
1029.73286559028175410.4566011649897577
1129.80460064038508610.4716465476624143
1230.85428639690807600.6838918583579029
1331.1067054134491710.7280185490554567
1431.46680260330574810.7838559408058121
1531.52064142541004410.7914495278564925
1631.93319756721452400.843118090723654
1732.1139796279128110.8622465766953867
1832.76060664971977610.9152435218371247
1934.3373938510864710.9760883965278441
" ], "text/plain": [ "\n", "#\n", " body_weigh survive survive_pr \n", " 0 27.0443646 0 0.08007143 \n", " 1 27.4715931 0 0.11094995 \n", " 2 27.8981665 0 0.15170068 \n", " 3 28.1242901 0 0.17790253 \n", " 4 28.1305597 0 0.17867715 \n", " 5 28.7586526 1 0.26980060 \n", " 6 28.8938911 0 0.29285022 \n", " 7 29.3795371 1 0.38414006 \n", " 8 29.3874557 0 0.38572117 \n", " 9 29.7301165 0 0.45602598 \n", " 10 29.7328655 1 0.45660116 \n", " 11 29.8046006 1 0.47164654 \n", " 12 30.8542863 0 0.68389185 \n", " 13 31.1067054 1 0.72801854 \n", " 14 31.4668026 1 0.78385594 \n", " ... ... ... ... \n" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "e = Math::E\n", "df[:survive_pred] = df[:body_weight].map { |x| 1 / (1 + e ** -(coeffs[:constant] + x*coeffs[:body_weight])) }\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above results can then be plotted using the `plot` function.\n", "\n", "The curve looks is an ideal logit regression curve." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "#[#[#:scatter, :options=>{:x=>:body_weight, :y=>:survive_pred}, :data=>\"8f06fc3d-ddc8-48d7-aeee-e15d5d31480e\"}, @xrange=[27.044364650712417, 34.33739385108647], @yrange=[0.08007143558819431, 0.9760883965278441]>, #:line, :options=>{:x=>:body_weight, :y=>:survive_pred}, :data=>\"8f06fc3d-ddc8-48d7-aeee-e15d5d31480e\"}, @xrange=[27.044364650712417, 34.33739385108647], @yrange=[0.08007143558819431, 0.9760883965278441]>], :options=>{:x_label=>\"Body Weight\", :y_label=>\"Probability of Survival\", :zoom=>true, :width=>700, :xrange=>[27.044364650712417, 34.33739385108647], :yrange=>[0.08007143558819431, 0.9760883965278441]}}>], :data=>{\"8f06fc3d-ddc8-48d7-aeee-e15d5d31480e\"=>#27.044364650712417, :survive=>0, :survive_pred=>0.08007143558819431}, {:body_weight=>27.47159312927552, :survive=>0, :survive_pred=>0.11094995452363857}, {:body_weight=>27.898166592007826, :survive=>0, :survive_pred=>0.15170068399992506}, {:body_weight=>28.124290182202703, :survive=>0, :survive_pred=>0.17790253325703076}, {:body_weight=>28.130559736750975, :survive=>0, :survive_pred=>0.1786771529208482}, {:body_weight=>28.75865262632871, :survive=>1, :survive_pred=>0.26980060957631496}, {:body_weight=>28.893891170932104, :survive=>0, :survive_pred=>0.2928502245475736}, {:body_weight=>29.379537173488142, :survive=>1, :survive_pred=>0.38414006941637974}, {:body_weight=>29.387455746614265, :survive=>0, :survive_pred=>0.3857211724501716}, {:body_weight=>29.73011654672403, :survive=>0, :survive_pred=>0.456025989208083}, {:body_weight=>29.732865590281754, :survive=>1, :survive_pred=>0.4566011649897577}, {:body_weight=>29.804600640385086, :survive=>1, :survive_pred=>0.4716465476624143}, {:body_weight=>30.854286396908076, :survive=>0, :survive_pred=>0.6838918583579029}, {:body_weight=>31.10670541344917, :survive=>1, :survive_pred=>0.7280185490554567}, {:body_weight=>31.466802603305748, :survive=>1, :survive_pred=>0.7838559408058121}, {:body_weight=>31.520641425410044, :survive=>1, :survive_pred=>0.7914495278564925}, {:body_weight=>31.933197567214524, :survive=>0, :survive_pred=>0.843118090723654}, {:body_weight=>32.11397962791281, :survive=>1, :survive_pred=>0.8622465766953867}, {:body_weight=>32.760606649719776, :survive=>1, :survive_pred=>0.9152435218371247}, {:body_weight=>34.33739385108647, :survive=>1, :survive_pred=>0.9760883965278441}]>}, :extension=>[]}>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "df.plot type: [:scatter,:line], x: [:body_weight]*2, y: [:survive_pred]*2 do |plot, diagram|\n", " plot.x_label \"Body Weight\"\n", " plot.y_label \"Probability of Survival\"\n", "end" ] } ], "metadata": { "kernelspec": { "display_name": "Ruby 2.2.1", "language": "ruby", "name": "ruby" }, "language_info": { "file_extension": ".rb", "mimetype": "application/x-ruby", "name": "ruby", "version": "2.2.1" } }, "nbformat": 4, "nbformat_minor": 0 }