{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Decision Tree를 활용한 Iris 데이터 분류" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1) scikit 활용한 Iris(붓꽃) Data Set 로드" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 요구되는 패키지\n", " - pydot2\n", "- 요구되는 프로그램\n", " - GraphViz: http://www.graphviz.org/\n", " - MAC에서 설치하는 방법\n", " - brew install graphviz" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "from sklearn.datasets import load_iris\n", "from sklearn import tree\n", "from sklearn.externals.six import StringIO\n", "\n", "iris = load_iris()\n", "print type(iris)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 5.1, 3.5, 1.4, 0.2],\n", " [ 4.9, 3. , 1.4, 0.2],\n", " [ 4.7, 3.2, 1.3, 0.2],\n", " [ 4.6, 3.1, 1.5, 0.2],\n", " [ 5. , 3.6, 1.4, 0.2],\n", " [ 5.4, 3.9, 1.7, 0.4],\n", " [ 4.6, 3.4, 1.4, 0.3],\n", " [ 5. , 3.4, 1.5, 0.2],\n", " [ 4.4, 2.9, 1.4, 0.2],\n", " [ 4.9, 3.1, 1.5, 0.1],\n", " [ 5.4, 3.7, 1.5, 0.2],\n", " [ 4.8, 3.4, 1.6, 0.2],\n", " [ 4.8, 3. , 1.4, 0.1],\n", " [ 4.3, 3. , 1.1, 0.1],\n", " [ 5.8, 4. , 1.2, 0.2],\n", " [ 5.7, 4.4, 1.5, 0.4],\n", " [ 5.4, 3.9, 1.3, 0.4],\n", " [ 5.1, 3.5, 1.4, 0.3],\n", " [ 5.7, 3.8, 1.7, 0.3],\n", " [ 5.1, 3.8, 1.5, 0.3],\n", " [ 5.4, 3.4, 1.7, 0.2],\n", " [ 5.1, 3.7, 1.5, 0.4],\n", " [ 4.6, 3.6, 1. , 0.2],\n", " [ 5.1, 3.3, 1.7, 0.5],\n", " [ 4.8, 3.4, 1.9, 0.2],\n", " [ 5. , 3. , 1.6, 0.2],\n", " [ 5. , 3.4, 1.6, 0.4],\n", " [ 5.2, 3.5, 1.5, 0.2],\n", " [ 5.2, 3.4, 1.4, 0.2],\n", " [ 4.7, 3.2, 1.6, 0.2],\n", " [ 4.8, 3.1, 1.6, 0.2],\n", " [ 5.4, 3.4, 1.5, 0.4],\n", " [ 5.2, 4.1, 1.5, 0.1],\n", " [ 5.5, 4.2, 1.4, 0.2],\n", " [ 4.9, 3.1, 1.5, 0.1],\n", " [ 5. , 3.2, 1.2, 0.2],\n", " [ 5.5, 3.5, 1.3, 0.2],\n", " [ 4.9, 3.1, 1.5, 0.1],\n", " [ 4.4, 3. , 1.3, 0.2],\n", " [ 5.1, 3.4, 1.5, 0.2],\n", " [ 5. , 3.5, 1.3, 0.3],\n", " [ 4.5, 2.3, 1.3, 0.3],\n", " [ 4.4, 3.2, 1.3, 0.2],\n", " [ 5. , 3.5, 1.6, 0.6],\n", " [ 5.1, 3.8, 1.9, 0.4],\n", " [ 4.8, 3. , 1.4, 0.3],\n", " [ 5.1, 3.8, 1.6, 0.2],\n", " [ 4.6, 3.2, 1.4, 0.2],\n", " [ 5.3, 3.7, 1.5, 0.2],\n", " [ 5. , 3.3, 1.4, 0.2],\n", " [ 7. , 3.2, 4.7, 1.4],\n", " [ 6.4, 3.2, 4.5, 1.5],\n", " [ 6.9, 3.1, 4.9, 1.5],\n", " [ 5.5, 2.3, 4. , 1.3],\n", " [ 6.5, 2.8, 4.6, 1.5],\n", " [ 5.7, 2.8, 4.5, 1.3],\n", " [ 6.3, 3.3, 4.7, 1.6],\n", " [ 4.9, 2.4, 3.3, 1. ],\n", " [ 6.6, 2.9, 4.6, 1.3],\n", " [ 5.2, 2.7, 3.9, 1.4],\n", " [ 5. , 2. , 3.5, 1. ],\n", " [ 5.9, 3. , 4.2, 1.5],\n", " [ 6. , 2.2, 4. , 1. ],\n", " [ 6.1, 2.9, 4.7, 1.4],\n", " [ 5.6, 2.9, 3.6, 1.3],\n", " [ 6.7, 3.1, 4.4, 1.4],\n", " [ 5.6, 3. , 4.5, 1.5],\n", " [ 5.8, 2.7, 4.1, 1. ],\n", " [ 6.2, 2.2, 4.5, 1.5],\n", " [ 5.6, 2.5, 3.9, 1.1],\n", " [ 5.9, 3.2, 4.8, 1.8],\n", " [ 6.1, 2.8, 4. , 1.3],\n", " [ 6.3, 2.5, 4.9, 1.5],\n", " [ 6.1, 2.8, 4.7, 1.2],\n", " [ 6.4, 2.9, 4.3, 1.3],\n", " [ 6.6, 3. , 4.4, 1.4],\n", " [ 6.8, 2.8, 4.8, 1.4],\n", " [ 6.7, 3. , 5. , 1.7],\n", " [ 6. , 2.9, 4.5, 1.5],\n", " [ 5.7, 2.6, 3.5, 1. ],\n", " [ 5.5, 2.4, 3.8, 1.1],\n", " [ 5.5, 2.4, 3.7, 1. ],\n", " [ 5.8, 2.7, 3.9, 1.2],\n", " [ 6. , 2.7, 5.1, 1.6],\n", " [ 5.4, 3. , 4.5, 1.5],\n", " [ 6. , 3.4, 4.5, 1.6],\n", " [ 6.7, 3.1, 4.7, 1.5],\n", " [ 6.3, 2.3, 4.4, 1.3],\n", " [ 5.6, 3. , 4.1, 1.3],\n", " [ 5.5, 2.5, 4. , 1.3],\n", " [ 5.5, 2.6, 4.4, 1.2],\n", " [ 6.1, 3. , 4.6, 1.4],\n", " [ 5.8, 2.6, 4. , 1.2],\n", " [ 5. , 2.3, 3.3, 1. ],\n", " [ 5.6, 2.7, 4.2, 1.3],\n", " [ 5.7, 3. , 4.2, 1.2],\n", " [ 5.7, 2.9, 4.2, 1.3],\n", " [ 6.2, 2.9, 4.3, 1.3],\n", " [ 5.1, 2.5, 3. , 1.1],\n", " [ 5.7, 2.8, 4.1, 1.3],\n", " [ 6.3, 3.3, 6. , 2.5],\n", " [ 5.8, 2.7, 5.1, 1.9],\n", " [ 7.1, 3. , 5.9, 2.1],\n", " [ 6.3, 2.9, 5.6, 1.8],\n", " [ 6.5, 3. , 5.8, 2.2],\n", " [ 7.6, 3. , 6.6, 2.1],\n", " [ 4.9, 2.5, 4.5, 1.7],\n", " [ 7.3, 2.9, 6.3, 1.8],\n", " [ 6.7, 2.5, 5.8, 1.8],\n", " [ 7.2, 3.6, 6.1, 2.5],\n", " [ 6.5, 3.2, 5.1, 2. ],\n", " [ 6.4, 2.7, 5.3, 1.9],\n", " [ 6.8, 3. , 5.5, 2.1],\n", " [ 5.7, 2.5, 5. , 2. ],\n", " [ 5.8, 2.8, 5.1, 2.4],\n", " [ 6.4, 3.2, 5.3, 2.3],\n", " [ 6.5, 3. , 5.5, 1.8],\n", " [ 7.7, 3.8, 6.7, 2.2],\n", " [ 7.7, 2.6, 6.9, 2.3],\n", " [ 6. , 2.2, 5. , 1.5],\n", " [ 6.9, 3.2, 5.7, 2.3],\n", " [ 5.6, 2.8, 4.9, 2. ],\n", " [ 7.7, 2.8, 6.7, 2. ],\n", " [ 6.3, 2.7, 4.9, 1.8],\n", " [ 6.7, 3.3, 5.7, 2.1],\n", " [ 7.2, 3.2, 6. , 1.8],\n", " [ 6.2, 2.8, 4.8, 1.8],\n", " [ 6.1, 3. , 4.9, 1.8],\n", " [ 6.4, 2.8, 5.6, 2.1],\n", " [ 7.2, 3. , 5.8, 1.6],\n", " [ 7.4, 2.8, 6.1, 1.9],\n", " [ 7.9, 3.8, 6.4, 2. ],\n", " [ 6.4, 2.8, 5.6, 2.2],\n", " [ 6.3, 2.8, 5.1, 1.5],\n", " [ 6.1, 2.6, 5.6, 1.4],\n", " [ 7.7, 3. , 6.1, 2.3],\n", " [ 6.3, 3.4, 5.6, 2.4],\n", " [ 6.4, 3.1, 5.5, 1.8],\n", " [ 6. , 3. , 4.8, 1.8],\n", " [ 6.9, 3.1, 5.4, 2.1],\n", " [ 6.7, 3.1, 5.6, 2.4],\n", " [ 6.9, 3.1, 5.1, 2.3],\n", " [ 5.8, 2.7, 5.1, 1.9],\n", " [ 6.8, 3.2, 5.9, 2.3],\n", " [ 6.7, 3.3, 5.7, 2.5],\n", " [ 6.7, 3. , 5.2, 2.3],\n", " [ 6.3, 2.5, 5. , 1.9],\n", " [ 6.5, 3. , 5.2, 2. ],\n", " [ 6.2, 3.4, 5.4, 2.3],\n", " [ 5.9, 3. , 5.1, 1.8]])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "iris.data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "iris.target" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2) DecisionTreeClassifier를 활용한 결정 트리 분류" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- classifier (clf) 객체 생성 및 학습\n", " - DecisionTreeClassifier 객체 활용\n", " - criterion\n", " - 'gini': Gini impurity\n", " - 'entropy': information gain" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [], "source": [ "clf = tree.DecisionTreeClassifier(criterion='entropy')\n", "clf = clf.fit(iris.data, iris.target)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with open(\"iris.dot\", 'w') as f:\n", " tree.export_graphviz(clf, out_file=f)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "- 위 코드까지 수행하면 동일 디렉토리에 iris.dot 파일이 생성됨\n", "- 콘솔에서 다음 명령어로 각종 이미지 파일을 만들 수 있음\n", " - dot -Tpdf iris.dot > iris.pdf\n", " - dot -Tpng iris.dot > iris.png" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- classifier (clf) 객체를 활용한 새로운 데이터에 대한 분류 추론\n", " - predict 함수 활용" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 5.1, 3.5, 1.4, 0.2]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "iris.data[:1]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([0])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.predict(iris.data[:1])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array(['setosa'], \n", " dtype='|S10')" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "iris.target_names[clf.predict(iris.data[:1])]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 1., 0., 0.]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.predict_proba(iris.data[:1])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 1., 0., 0.],\n", " [ 1., 0., 0.],\n", " [ 1., 0., 0.],\n", " [ 1., 0., 0.],\n", " [ 1., 0., 0.],\n", " [ 1., 0., 0.],\n", " [ 1., 0., 0.],\n", " [ 1., 0., 0.],\n", " [ 1., 0., 0.],\n", " [ 1., 0., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 1., 0.],\n", " [ 0., 0., 1.],\n", " [ 0., 0., 1.],\n", " [ 0., 0., 1.],\n", " [ 0., 0., 1.],\n", " [ 0., 0., 1.],\n", " [ 0., 0., 1.],\n", " [ 0., 0., 1.],\n", " [ 0., 0., 1.],\n", " [ 0., 0., 1.],\n", " [ 0., 0., 1.]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.predict_proba(iris.data[40:110])" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python [Root]", "language": "python", "name": "Python [Root]" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.12" } }, "nbformat": 4, "nbformat_minor": 0 }