{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "bOChJSNXtC9g" }, "source": [ "# 逻辑回归" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "OLIxEDq6VhvZ" }, "source": [ "\n", "\n", "在上一节中,我们看到线性回归可以很好的拟合出一条线后者一个超平面来做出对连续变量的预测。但是在分类问题中我们希望的输出是类别的概率,线性回归就不能做的很好了。\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "VoMq0eFRvugb" }, "source": [ "# 概述" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qWro5T5qTJJL" }, "source": [ "\n", "\n", "$ \\hat{y} = \\frac{1}{1 + e^{-XW}} $ \n", "\n", "*where*:\n", "* $\\hat{y}$ = 预测值 | $\\in \\mathbb{R}^{NX1}$ ($N$ 是样本的个数)\n", "* $X$ = 输入 | $\\in \\mathbb{R}^{NXD}$ ($D$ 是特征的个数)\n", "* $W$ = 权重 | $\\in \\mathbb{R}^{DX1}$ \n", "\n", "这个是二项式逻辑回归。主要思想是用线性回归的输出值($z=XW$)经过一个sigmoid函数($\\frac{1}{1+e^{-z}}$)来映射到(0, 1)之间。" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "YcFvkklZSZr9" }, "source": [ "当我们有多于两个分类类别,我们就需要使用多项式逻辑回归(softmax分类器)。softmax分类器将会用线性方程($z=XW$)并且归一化它,来产生对应的类别y的概率。\n", "\n", "$ \\hat{y} = \\frac{e^{XW_y}}{\\sum e^{XW}} $ \n", "\n", "*where*:\n", "* $\\hat{y}$ = 预测值 | $\\in \\mathbb{R}^{NX1}$ ($N$ 是样本的个数)\n", "* $X$ = 输入 | $\\in \\mathbb{R}^{NXD}$ ($D$ 是特征的个数)\n", "* $W$ = 权重 | $\\in \\mathbb{R}^{DXC}$ ($C$ 是类别的个数)\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "T4Y55tpzIjOa" }, "source": [ "* **目标:** 通过输入值$X$来预测$y$的类别概率。softmax分类器将根据归一化线性输出来计算类别概率。 \n", "* **优点:**\n", " * 可以预测与输入对应的类别概率。\n", "* **缺点:**\n", " * 因为使用的损失函数是要最小化交叉熵损失,所以对离群点很敏感。(支持向量机([SVMs](https://towardsdatascience.com/support-vector-machine-vs-logistic-regression-94cc2975433f)) 是对处理离群点一个很好的选择).\n", "* **其他:** Softmax分类器在神经网络结构中广泛用于最后一层,因为它会计算出类别的概率。" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Jq65LZJbSpzd" }, "source": [ "# 训练" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "-HBPn8zPTQfZ" }, "source": [ "*步骤*:\n", "\n", "1. 随机初始化模型权重$W$.\n", "2. 将输入值 $X$ 传入模型并且得到logits ($z=XW$). 在logits上使用softmax操作得到独热编码后的类别概率$\\hat{y}$。 比如, 如果有三个类别, 预测出的类别概率可能为[0.3, 0.3, 0.4]. \n", "3. 使用损失函数将预测值$\\hat{y}$ (例如[0.3, 0.3, 0.4]])和真实值$y$(例如属于第二个类别应该写作[0, 0, 1])做对比,并且计算出损失值$J$。一个很常用的逻辑回归损失函数是交叉熵函数。 \n", " * $J(\\theta) = - \\sum_i y_i ln (\\hat{y_i}) = - \\sum_i y_i ln (\\frac{e^{X_iW_y}}{\\sum e^{X_iW}}) $\n", " * $y$ = [0, 0, 1]\n", " * $\\hat{y}$ = [0.3, 0.3, 0.4]]\n", " * $J(\\theta) = - \\sum_i y_i ln (\\hat{y_i}) = - \\sum_i y_i ln (\\frac{e^{X_iW_y}}{\\sum e^{X_iW}}) = - \\sum_i [0 * ln(0.3) + 0 * ln(0.3) + 1 * ln(0.4)] = -ln(0.4) $\n", " * 简化我们的交叉熵函数: $J(\\theta) = - ln(\\hat{y_i})$ (负的最大似然).\n", " * $J(\\theta) = - ln(\\hat{y_i}) = - ln (\\frac{e^{X_iW_y}}{\\sum_i e^{X_iW}}) $\n", "4. 根据模型权重计算损失梯度$J(\\theta)$。让我们假设类别的分类是互斥的(一种输入仅仅对应一个输出类别).\n", " * $\\frac{\\partial{J}}{\\partial{W_j}} = \\frac{\\partial{J}}{\\partial{y}}\\frac{\\partial{y}}{\\partial{W_j}} = - \\frac{1}{y}\\frac{\\partial{y}}{\\partial{W_j}} = - \\frac{1}{\\frac{e^{W_yX}}{\\sum e^{XW}}}\\frac{\\sum e^{XW}e^{W_yX}0 - e^{W_yX}e^{W_jX}X}{(\\sum e^{XW})^2} = \\frac{Xe^{W_j}X}{\\sum e^{XW}} = XP$\n", " * $\\frac{\\partial{J}}{\\partial{W_y}} = \\frac{\\partial{J}}{\\partial{y}}\\frac{\\partial{y}}{\\partial{W_y}} = - \\frac{1}{y}\\frac{\\partial{y}}{\\partial{W_y}} = - \\frac{1}{\\frac{e^{W_yX}}{\\sum e^{XW}}}\\frac{\\sum e^{XW}e^{W_yX}X - e^{W_yX}e^{W_yX}X}{(\\sum e^{XW})^2} = \\frac{1}{P}(XP - XP^2) = X(P-1)$\n", "5. 使用梯度下降法来对权重做反向传播以更新模型权重。更新后的权重将会使不正确的类别(j)概率大大降低,从而升高正确的类别(y)概率。\n", " * $W_i = W_i - \\alpha\\frac{\\partial{J}}{\\partial{W_i}}$\n", "6. 重复2 - 4步直到模型表现最好(也可以说直到损失收敛)。" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "r_hKrjzdtTgM" }, "source": [ "# 数据" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "PyccHrQztVEu" }, "source": [ "我们来加载在第三节课中用到的titanic数据集" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "H385V4VUtWOv" }, "outputs": [], "source": [ "from argparse import Namespace\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import urllib" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "pL67TlZO6Zg4" }, "outputs": [], "source": [ "# 参数\n", "args = Namespace(\n", " seed=1234,\n", " data_file=\"titanic.csv\",\n", " train_size=0.75,\n", " test_size=0.25,\n", " num_epochs=100,\n", ")\n", "\n", "# 设置随即种子来保证实验结果的可重复性。\n", "np.random.seed(args.seed)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "7sp_tSyItf1_" }, "outputs": [], "source": [ "# 从GitHub上加载数据到notebook本地驱动\n", "url = \"https://raw.githubusercontent.com/LisonEvf/practicalAI-cn/master/data/titanic.csv\"\n", "response = urllib.request.urlopen(url)\n", "html = response.read()\n", "with open(args.data_file, 'wb') as f:\n", " f.write(html)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 272 }, "colab_type": "code", "id": "7alqmyzXtgE8", "outputId": "353702e3-76f7-479d-df7a-5effcc8a7461" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pclassnamesexagesibspparchticketfarecabinembarkedsurvived
01Allen, Miss. Elisabeth Waltonfemale29.00000024160211.3375B5S1
11Allison, Master. Hudson Trevormale0.916712113781151.5500C22 C26S1
21Allison, Miss. Helen Lorainefemale2.000012113781151.5500C22 C26S0
31Allison, Mr. Hudson Joshua Creightonmale30.000012113781151.5500C22 C26S0
41Allison, Mrs. Hudson J C (Bessie Waldo Daniels)female25.000012113781151.5500C22 C26S0
\n", "
" ], "text/plain": [ " pclass name sex age \\\n", "0 1 Allen, Miss. Elisabeth Walton female 29.0000 \n", "1 1 Allison, Master. Hudson Trevor male 0.9167 \n", "2 1 Allison, Miss. Helen Loraine female 2.0000 \n", "3 1 Allison, Mr. Hudson Joshua Creighton male 30.0000 \n", "4 1 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female 25.0000 \n", "\n", " sibsp parch ticket fare cabin embarked survived \n", "0 0 0 24160 211.3375 B5 S 1 \n", "1 1 2 113781 151.5500 C22 C26 S 1 \n", "2 1 2 113781 151.5500 C22 C26 S 0 \n", "3 1 2 113781 151.5500 C22 C26 S 0 \n", "4 1 2 113781 151.5500 C22 C26 S 0 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 把CSV文件内容读到DataFrame中\n", "df = pd.read_csv(args.data_file, header=0)\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "k-5Y4zLIoE6s" }, "source": [ "# Scikit-learn实现" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ILkbyBHQoIwE" }, "source": [ "**注意**: Scikit-learn中`LogisticRegression`类使用的是坐标下降法(coordinate descent)来做的拟合。然而,我们会使用Scikit-learn中的`SGDClassifier`类来做随机梯度下降。我们使用这个优化方法是因为在未来的几节课程中我们也会使用到它。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": {}, "colab_type": "code", "id": "W1MJODStIu8V" }, "outputs": [], "source": [ "# 调包\n", "from sklearn.linear_model import SGDClassifier\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "kItBIOOCTi6p" }, "outputs": [], "source": [ "# 预处理\n", "def preprocess(df):\n", " \n", " # 删除掉含有空值的行\n", " df = df.dropna()\n", "\n", " # 删除基于文本的特征 (我们以后的课程将会学习怎么使用它们)\n", " features_to_drop = [\"name\", \"cabin\", \"ticket\"]\n", " df = df.drop(features_to_drop, axis=1)\n", "\n", " # pclass, sex, 和 embarked 是类别变量\n", " categorical_features = [\"pclass\",\"embarked\",\"sex\"]\n", " df = pd.get_dummies(df, columns=categorical_features)\n", "\n", " return df" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 224 }, "colab_type": "code", "id": "QwQHDh4xuYTB", "outputId": "153ea757-b817-406d-dbde-d1fba88f194b" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesibspparchfaresurvivedpclass_1pclass_2pclass_3embarked_Cembarked_Qembarked_Ssex_femalesex_male
029.000000211.3375110000110
10.916712151.5500110000101
22.000012151.5500010000110
330.000012151.5500010000101
425.000012151.5500010000110
\n", "
" ], "text/plain": [ " age sibsp parch fare survived pclass_1 pclass_2 pclass_3 \\\n", "0 29.0000 0 0 211.3375 1 1 0 0 \n", "1 0.9167 1 2 151.5500 1 1 0 0 \n", "2 2.0000 1 2 151.5500 0 1 0 0 \n", "3 30.0000 1 2 151.5500 0 1 0 0 \n", "4 25.0000 1 2 151.5500 0 1 0 0 \n", "\n", " embarked_C embarked_Q embarked_S sex_female sex_male \n", "0 0 0 1 1 0 \n", "1 0 0 1 0 1 \n", "2 0 0 1 1 0 \n", "3 0 0 1 0 1 \n", "4 0 0 1 1 0 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 数据预处理\n", "df = preprocess(df)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "wsGRZNNiUTqj", "outputId": "c9364be7-3cae-487f-9d96-3210b3129199" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train size: 199, test size: 71\n" ] } ], "source": [ "# 划分数据到训练集和测试集\n", "mask = np.random.rand(len(df)) < args.train_size\n", "train_df = df[mask]\n", "test_df = df[~mask]\n", "print (\"Train size: {0}, test size: {1}\".format(len(train_df), len(test_df)))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "oZKxFmATU95M" }, "source": [ "**注意**: 如果你有类似标准化的预处理步骤,你需要在划分完训练集和测试集之后再使用它们。这是因为我们不可能从测试集中学到任何有用的信息。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": {}, "colab_type": "code", "id": "cLzL_LJd4vQ-" }, "outputs": [], "source": [ "# 分离 X 和 y\n", "X_train = train_df.drop([\"survived\"], axis=1)\n", "y_train = train_df[\"survived\"]\n", "X_test = test_df.drop([\"survived\"], axis=1)\n", "y_test = test_df[\"survived\"]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "id": "AdTYbV472UNJ", "outputId": "214a8114-3fd3-407f-cd6e-5f5d07294f50" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mean: [-1.78528326e-17 7.14113302e-17 -5.80217058e-17 -5.35584977e-17\n", " 3.57056651e-17 -8.92641628e-17 3.57056651e-17 -3.79372692e-17\n", " 0.00000000e+00 3.79372692e-17 1.04885391e-16 -6.69481221e-17]\n", "std: [1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.]\n" ] } ], "source": [ "# 标准化训练数据 (mean=0, std=1)\n", "X_scaler = StandardScaler().fit(X_train)\n", "\n", "# 标准化训练和测试数据 (不要标准化标签分类y)\n", "standardized_X_train = X_scaler.transform(X_train)\n", "standardized_X_test = X_scaler.transform(X_test)\n", "\n", "# 检查\n", "print (\"mean:\", np.mean(standardized_X_train, axis=0)) # mean 应该为 ~0\n", "print (\"std:\", np.std(standardized_X_train, axis=0)) # std 应该为 1" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": {}, "colab_type": "code", "id": "7-vm9AZm1_f9" }, "outputs": [], "source": [ "# 初始化模型\n", "log_reg = SGDClassifier(loss=\"log\", penalty=\"none\", max_iter=args.num_epochs, \n", " random_state=args.seed)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 102 }, "colab_type": "code", "id": "0e8U9NNluYVp", "outputId": "c5f22ade-bb8c-479b-d300-98758a82d396" }, "outputs": [ { "data": { "text/plain": [ "SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,\n", " eta0=0.0, fit_intercept=True, l1_ratio=0.15,\n", " learning_rate='optimal', loss='log', max_iter=100, n_iter=None,\n", " n_jobs=1, penalty='none', power_t=0.5, random_state=1234,\n", " shuffle=True, tol=None, verbose=0, warm_start=False)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 训练\n", "log_reg.fit(X=standardized_X_train, y=y_train)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 102 }, "colab_type": "code", "id": "hA7Oz97NAe8A", "outputId": "ab8a878a-6012-4727-8cd1-40bc5c69245b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.60319594 0.39680406]\n", " [0.00374908 0.99625092]\n", " [0.81886302 0.18113698]\n", " [0.01082253 0.98917747]\n", " [0.93508814 0.06491186]]\n" ] } ], "source": [ "# 概率\n", "pred_test = log_reg.predict_proba(standardized_X_test)\n", "print (pred_test[:5])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "-jZtTd7F6_ps", "outputId": "d2306e4c-88a4-4ac4-9ad5-879fa461617f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0 1 0 1 0 1 0 0 1 1 0 0 0 0 1 0 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 1 0 0 1 0\n", " 1 0 0 1 1 1 0 1 1 0 0 0 0 1 0 0 1 0 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1]\n" ] } ], "source": [ "# 预测 (未标准化)\n", "pred_train = log_reg.predict(standardized_X_train) \n", "pred_test = log_reg.predict(standardized_X_test)\n", "print (pred_test)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "dM7iYW8ANYjy" }, "source": [ "# 评估指标" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": {}, "colab_type": "code", "id": "uFXbczqu8Rno" }, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "sEjansj78Rqe", "outputId": "f5bfbe87-12c9-4aa5-fc61-e615ad4e63d4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train acc: 0.77, test acc: 0.82\n" ] } ], "source": [ "# 正确率\n", "train_acc = accuracy_score(y_train, pred_train)\n", "test_acc = accuracy_score(y_test, pred_test)\n", "print (\"train acc: {0:.2f}, test acc: {1:.2f}\".format(train_acc, test_acc))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "WijzY-vDNbE9" }, "source": [ "到目前为止我们用的是正确率作为我们的评价指标来评定模型的好坏程度。但是我们还有很多的评价指标来对模型进行评价。\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "80MwyE0yOr-k" }, "source": [ "评价指标的选择真的要看应用的情景。\n", "positive - true, 1, tumor, issue, 等等, negative - false, 0, not tumor, not issue, 等等。\n", "\n", "$\\text{accuracy}(正确率) = \\frac{TP+TN}{TP+TN+FP+FN}$ \n", "\n", "$\\text{recall}(召回率)= \\frac{TP}{TP+FN}$ → (有多个正例被我分为正例)\n", "\n", "$\\text{precision} (精确率)= \\frac{TP}{TP+FP}$ → (在所有我预测为正例的样本下,有多少是对的)\n", "\n", "$F_1 = 2 * \\frac{\\text{precision } * \\text{ recall}}{\\text{precision } + \\text{ recall}}$\n", "\n", "where: \n", "* TP: 将正类预测为正类数\n", "* TN: 将负类预测为负类数\n", "* FP: 将负类预测为正类数\n", "* FN: 将正类预测为负类数" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": {}, "colab_type": "code", "id": "opmu3hJm9LXA" }, "outputs": [], "source": [ "import itertools\n", "from sklearn.metrics import classification_report, confusion_matrix" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": {}, "colab_type": "code", "id": "wAzOL8h29m82" }, "outputs": [], "source": [ "# 绘制混淆矩阵\n", "def plot_confusion_matrix(cm, classes):\n", " cmap=plt.cm.Blues\n", " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n", " plt.title(\"Confusion Matrix\")\n", " plt.colorbar()\n", " tick_marks = np.arange(len(classes))\n", " plt.xticks(tick_marks, classes, rotation=45)\n", " plt.yticks(tick_marks, classes)\n", " plt.grid(False)\n", "\n", " fmt = 'd'\n", " thresh = cm.max() / 2.\n", " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", " plt.text(j, i, format(cm[i, j], 'd'),\n", " horizontalalignment=\"center\",\n", " color=\"white\" if cm[i, j] > thresh else \"black\")\n", "\n", " plt.ylabel('True label')\n", " plt.xlabel('Predicted label')\n", " plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 520 }, "colab_type": "code", "id": "KqUVzahQ-5ic", "outputId": "bff8819e-3d5b-45b9-c221-179c873140b1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.74 0.91 0.82 32\n", " 1 0.91 0.74 0.82 39\n", "\n", "avg / total 0.83 0.82 0.82 71\n", "\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAT8AAAEYCAYAAAAqD/ElAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3XmcXvP5//HXexIiJBISiZ3WErWGWGOL2kKpfY2KpZHavlpUU1qipa1WUUv5Sak9dpXaIqVBECSaIBVRS2pJZFFiCbJcvz/OGW5jZu57Zu65t/N+9nEec9/nnPtzrntirn628zmKCMzMsqau3AGYmZWDk5+ZZZKTn5llkpOfmWWSk5+ZZZKTn5llkpOftYikzpL+LulDSXe0oZxBkh4uZmzlIOlBSYPLHYe1nJNfjZJ0uKQJkj6WNCP9I92uCEUfCPQGekTEQa0tJCJujojdihDP10gaICkk3d1g/ybp/rEFljNc0k35zouIPSLi+laGa2Xk5FeDJJ0KXAL8hiRRrQ78GdinCMWvAUyLiIVFKKu9zAb6S+qRs28wMK1YF1DCfz/VLCK81dAGdAM+Bg5q5pxOJMnx3XS7BOiUHhsAvA2cBswCZgBHp8fOBb4AFqTXOBYYDtyUU/aaQAAd0/dHAa8DHwFvAINy9o/L+Vx/4Dngw/Rn/5xjY4FfA0+m5TwM9Gziu9XHfxVwYrqvQ7rvbGBszrl/At4C5gETge3T/QMbfM/JOXGcn8YxH1g73ffD9PiVwJ055V8APAKo3P9dePvm5v/nqj3bAEsB9zRzzlnA1kBfYBNgS+AXOcdXJEmiq5AkuCskLRcR55DUJm+LiC4RcU1zgUhaBrgU2CMiupIkuEmNnLc8cH96bg/gIuD+BjW3w4GjgV7AksDpzV0buAE4Mn29OzCFJNHneo7kd7A8cAtwh6SlIuKhBt9zk5zP/AA4DugKTG9Q3mnAxpKOkrQ9ye9ucKSZ0CqLk1/t6QHMieabpYOAX0XErIiYTVKj+0HO8QXp8QUR8QBJ7adPK+NZDGwoqXNEzIiIKY2c8z3g1Yi4MSIWRsRIYCqwd845f42IaRExH7idJGk1KSKeApaX1IckCd7QyDk3RcTc9Jp/JKkR5/ue10XElPQzCxqU9ylwBEnyvgk4OSLezlOelYmTX+2ZC/SU1LGZc1bm67WW6em+L8tokDw/Bbq0NJCI+AQ4BPgRMEPS/ZLWKyCe+phWyXk/sxXx3AicBOxEIzVhSadJejkduf6ApLbbM0+ZbzV3MCKeJWnmiyRJW4Vy8qs9TwOfAfs2c867JAMX9Vbnm03CQn0CLJ3zfsXcgxExOiJ2BVYiqc2NKCCe+pjeaWVM9W4ETgAeSGtlX0qbpT8DDgaWi4juJP2Nqg+9iTKbbcJKOpGkBvkucEbrQ7f25uRXYyLiQ5KO/Ssk7StpaUlLSNpD0u/T00YCv5C0gqSe6fl5p3U0YRKwg6TVJXUDfl5/QFJvSd9P+/4+J2k+L2qkjAeAddPpOR0lHQKsD9zXypgAiIg3gB1J+jgb6gosJBkZ7ijpbGDZnOPvAWu2ZERX0rrAeSRN3x8AZ0hqtnlu5ePkV4Mi4iLgVJJBjNkkTbWTgL+lp5wHTABeAF4Enk/3teZaY4Db0rIm8vWEVUcyCPAu8D5JIjqhkTLmAnul584lqTHtFRFzWhNTg7LHRURjtdrRwIMk01+mk9SWc5u09RO450p6Pt910m6Gm4ALImJyRLwKnAncKKlTW76DtQ95IMrMssg1PzPLJCc/M8skJz8zyyQnPzPLpOYmwmaeOnYOLdm13GFk1qbfWb3cIWTa9OlvMmfOHOU/szAdll0jYuH8vOfF/NmjI2Jgsa7bFCe/ZmjJrnTqc3C5w8isJ5+5vNwhZNq2W21e1PJi4fyC/p4+m3RFvrtsisLJz8xKQ4K6DuWO4ktOfmZWOhW0BKKTn5mVjorWhdhmTn5mViJu9ppZFgk3e80si+Rmr5lllJu9ZpY9crPXzDJIuNlrZlkkqKuclFM5kZhZ7atzzc/MssZTXcwsmzzJ2cyyygMeZpZJbvaaWeZ4SSszyyw3e80se3yHh5llkXCz18yyyDU/M8sq9/mZWSa52WtmmSM3e80sq9zsNbOsEVBXVzk1v8qJxMxqmwrcmitCWk3SPyW9LGmKpFPS/cMlvSNpUrrtmS8c1/zMrESE2t7sXQicFhHPS+oKTJQ0Jj12cURcWGhBTn5mVjJtbfZGxAxgRvr6I0kvA6u0KpY2RWJm1gKS8m5AT0kTcrbjmihrTWBT4Jl010mSXpB0raTl8sXi5GdmpVF4n9+ciNg8Z7v6G0VJXYC7gB9HxDzgSmAtoC9JzfCP+cJxs9fMSkKoKKO9kpYgSXw3R8TdABHxXs7xEcB9+cpxzc/MSqbAZm9znxdwDfByRFyUs3+lnNP2A17KF4trfmZWMkUY7d0W+AHwoqRJ6b4zgcMk9QUCeBMYmq8gJz8zKw2B2vjoyogYR+OzAR9oaVlOfmZWEirOPL+icfIzs5Jx8jOz7ClCs7eYnPzMrGRc8zOzTHLys6JatXd3/vLrI+ndY1kWR3DtXU9yxcixbLTuKlx21qEs07kT09+dy9FnXc9Hn3xW7nBr3meffcYuO+3AF59/zsJFC9lv/wP55TnnljusshNys9eKa+GixQy76G4mTX2bLkt34qlbfsYjz0zlyrMPZ9jF9zBu4n84cp+t+cngnfnVn+8vd7g1r1OnTjw05lG6dOnCggUL+O6O27Hb7nuw1dZblzu08lJl1fx8h0cNmDlnHpOmvg3Ax59+ztQ3ZrLyCt1ZZ41ejJv4HwAeHT+VfXfuW84wM0MSXbp0AWDBggUsXLCgov7oy6mtd3gUk5NfjVl9peXp22dVnnvpTf792gz2GrARAPvvuhmr9s670IUVyaJFi9iqX19WX7kX391lV7bcaqtyh1QRVKe8W6lUdfJLV289XdKvJO3Sgs+tKSnvvX/VZpnOSzLywh/y0wvv4qNPPmPo8JsZevAOPHnzGXRZuhNfLFhU7hAzo0OHDjwzcRL/efNtJjz3LFNeqrn/3Fqlkmp+NdHnFxFnlzuGcuvYsY6RFw7htgcncO+jkwGY9uZ77H3CFQCsvXov9th+g3KGmEndu3dnhx0H8PDDD7HBhhuWO5yyKnVyy6fqan6SzpL0iqR/AH3SfddJOjB93U/SY5ImShpdv9pDun+ypKeBE8v3DdrHVecM4pU3ZnLpTY9+uW+F5ZJ+J0kMG7I7I+4cV67wMmX27Nl88MEHAMyfP59HH/kHffqsV+aoKkNdXV3erVSqquYnqR9wKMnqrR2B54GJOceXAC4D9omI2ZIOAc4HjgH+CpwcEY9J+kMz1zgOSFaOXaJLO32T4urf99sM2msrXpz2DuNvHQbAOZePYu3VejH0kB0AuPfRSdxw7/hyhpkZM2fMYMgxg1m0aBGLYzEHHHgwe35vr3KHVRkqp+JXXckP2B64JyI+BZA0qsHxPsCGwJi0et0BmCGpG9A9Ih5Lz7sR2KOxC6Srxl4NULd0ryj6N2gHT016nc6bnvSN/aP5N1eMHFv6gDJuo403ZvyEf5U7jIpUSc3eakt+kKzX1RQBUyJim6/tlLrn+ZyZtTMJ6ipoknO19fk9DuwnqXP62Lq9Gxx/BVhB0jaQNIMlbRARHwAfStouPW9Q6UI2s0T+kV6P9jYhfVbnbcAkYDrwRIPjX6QDH5emTd2OwCXAFOBo4FpJnwKjSxu5mUFS+6sUVZX8ACLifJJBjKaOTwJ2aGT/RGCTnF3Dix6cmTWtwpq9VZf8zKw6CSc/M8soN3vNLHvc7DWzLBKe52dmmVRZ9/Y6+ZlZybjZa2bZIw94mFkGuc/PzDLLzV4zy6QKqvhV3cIGZlat1PZl7CWtJumfkl6WNEXSKen+5SWNkfRq+jPvA2uc/MysJISoq8u/5bEQOC0ivgNsDZwoaX1gGPBIRKwDPJK+b5aTn5mVjJR/a05EzIiI59PXHwEvA6sA+wDXp6ddD+ybLxb3+ZlZyRQ42ttT0oSc91enK6w3LGtNkkdaPAP0jogZkCRISb3yXcTJz8xKogUrOc+JiM2bL0tdgLuAH0fEvNZMoXGz18xKphgrOacPKrsLuDki7k53v5fzpMaVgFn5ynHyM7OSaWufn5LseA3wckRclHNoFDA4fT0YuDdfLG72mllpFGdJq22BHwAvSpqU7jsT+B1wu6Rjgf8CB+UryMnPzEpCRVjVJSLG0fTTf3duSVlNJj9Jy+YJYl5LLmRmVkl3eDRX85tC8qzb3HDr3wewejvGZWY1qEM13NsbEauVMhAzq21SZa3qUtBor6RDJZ2Zvl5VUr/2DcvMalGd8m8liyXfCZIuB3YiGWEB+BS4qj2DMrPaVIR7e4umkNHe/hGxmaR/AUTE+5KWbOe4zKzGiGTEt1IUkvwWSKojGeRAUg9gcbtGZWY1qYLGOwpKfleQ3EqygqRzgYOBc9s1KjOrPSptszafvMkvIm6QNBHYJd11UES81L5hmVmtEVBXQaO9hd7h0QFYQNL09f3AZtYqFZT7ChrtPQsYCawMrArcIunn7R2YmdWW+iWtqmm09wigX0R8CiDpfGAi8Nv2DMzMak+1NXunNzivI/B6+4RjZrWsclJf8wsbXEzSx/cpMEXS6PT9bsC40oRnZrVCVMm9vUD9iO4U4P6c/ePbLxwzq1kFrtRcKs0tbHBNKQMxs9pXQbkvf5+fpLWA84H1gaXq90fEuu0Yl5nVmEpr9hYyZ+864K8kse8B3A7c2o4xmVmNKsYDjIqlkOS3dESMBoiI1yLiFySrvJiZtYgK2EqlkKkun6dPTHpN0o+Ad4C8DwQ2M8slVVazt5Dk9xOgC/B/JH1/3YBj2jMoM6tNVTHaWy8inklffsRXC5qambVYBeW+Zic530O6hl9jImL/domogqzz7ZW5+tZflTuMzFpui5PKHUKmff7Kf4tanqSqafZeXrIozCwTqqLZGxGPlDIQM6t9lbQeXqHr+ZmZtUmlTXJ28jOzkqmg3Fd48pPUKSI+b89gzKx2Vd1DyyVtKelF4NX0/SaSLmv3yMys5nSoy7/lI+laSbMkvZSzb7ikdyRNSrc985VTSP/jpcBewFyAiJiMb28zsxaqf4BRvq0A1wEDG9l/cUT0TbcH8hVSSPKri4jpDfYtKuBzZmZfU1fAlk9EPA68X4xY8nlL0pZASOog6cfAtLZe2MyypX6Sc74N6ClpQs52XIGXOEnSC2mzeLl8JxeS/I4HTgVWB94Dtk73mZm1SDLo0fwGzImIzXO2qwso+kpgLaAvMAP4Y74PFHJv7yzg0AIubmbWrPaa6hIR79W/ljQCuC/fZwpZyXkEjdzjGxGFVkXNzNp1krOklSJiRvp2P756BlGTCpnn94+c10ulBb/V8vDMLNNUnJqfpJHAAJK+wbeBc4ABkvqSVNTeBIbmK6eQZu9tDS58IzCm5SGbWdapCGs1R8Rhjexu8QPXWnN727eANVrxOTPLMAEdK2hlg0L6/P7HV31+dSTza4a1Z1BmVpsq6fa2ZpNf+uyOTUie2wGwOCKaXODUzKwpyR0e5Y7iK80mv4gISfdERL9SBWRmNarCHmBUSAv8WUmbtXskZlbT6mt++bZSae4ZHh0jYiGwHTBE0mvAJyTfISLCCdHMWqSCuvyabfY+C2wG7FuiWMyshgnRoYKyX3PJTwAR8VqJYjGzWlbiZm0+zSW/FSSd2tTBiLioHeIxsxpW4Hp9JdFc8usAdIEiTMk2s8yrpgcYzYgIP7HbzIqmgip++fv8zMyKQVTPc3t3LlkUZlb7VCV9fhHR5jXyzczq1T/AqFL4oeVmVjKVk/qc/MysZERdlYz2mpkVTTUNeJiZFVXVrOdnZlY01TLaa2ZWTG72mllmudlrZplUQYO9Tn5mVhpJs7dysp+Tn5mVTAW1ep38zKxU5NFeM8seN3vNLJvkZq+1gwvOPJmnxz5M9x49ue7vTwIw74P/ce6pxzLznbdYcZXVGH7xtXTt1r3MkdaeVXt35y+/PpLePZZlcQTX3vUkV4wcy0brrsJlZx3KMp07Mf3duRx91vV89Mln5Q63rCqp2VtJcw6tDQbudxi/H3H71/bdMuJPbLb1Dtw8+jk223oHbhlxSZmiq20LFy1m2EV3s+kB57HjkRcy9JAdWO/bK3Ll2Yfzi0vvZYuDf8Oof07mJ4OzvURmsZ7bK+laSbMkvZSzb3lJYyS9mv5cLl85Tn41YpMt+tO129f/vZ985AEG7nsoAAP3PZRx/3igHKHVvJlz5jFp6tsAfPzp50x9YyYrr9CdddboxbiJ/wHg0fFT2XfnvuUMsyKogP8V4DpgYIN9w4BHImId4JH0fbOc/GrY+3Nn06PXigD06LUi/3t/Tpkjqn2rr7Q8ffusynMvvcm/X5vBXgM2AmD/XTdj1d55KyM1r07Ku+UTEY8DDRdb3ge4Pn19PQU8b7yqkp+k70vKm9ELLOvjYpRjVm+Zzksy8sIf8tML7+KjTz5j6PCbGXrwDjx58xl0WboTXyxYVO4Qy6oFzd6ekibkbMcVUHzviJgBkP7sle8DFTfgIaljRCxs7FhEjAJGlTikqrV8jxWYO2smPXqtyNxZM1lu+Z7lDqlmdexYx8gLh3DbgxO499HJAEx78z32PuEKANZevRd7bL9BOUOsAAU3a+dExObtHU271fwkLSPpfkmTJb0k6RBJb0rqmR7fXNLY9PVwSVdLehi4QdIzkjbIKWuspH6SjpJ0uaRuaVl16fGlJb0laQlJa0l6SNJESU9IWi8951uSnpb0nKRft9f3riT9v7sHD/3tVgAe+tutbLvznmWOqHZddc4gXnljJpfe9OiX+1ZYrguQ3Mw/bMjujLhzXLnCqwwF1PracO/ve5JWAkh/zsr3gfZs9g4E3o2ITSJiQ+ChPOf3A/aJiMOBW4GD4csvsnJETKw/MSI+BCYDO6a79gZGR8QC4Grg5IjoB5wO/Dk950/AlRGxBTCzqSAkHVdf3f7wf3Nb9o3L6FenDuHEwwby1hv/4cAdN+T+O2/i8CGnMPGpsQzafQsmPjWWw4ecUu4wa1L/vt9m0F5bseMW6zL+1mGMv3UYu2+3PgcP3JwX/nY2k+/5JTNmf8gN944vd6hlVf8Ao7b2+TVhFDA4fT0YuDffB9qz2fsicKGkC4D7IuKJPMvZjIqI+enr24ExwDkkSfCORs6/DTgE+CdwKPBnSV2A/sAdOdfqlP7cFjggfX0jcEFjQUTE1SQJlD4b9o0837FinH3RiEb3X3Td30ocSfY8Nel1Om960jf2j+bfXDFybOkDqmDFmOUnaSQwgKRv8G2SPPE74HZJxwL/BQ7KV067Jb+ImCapH7An8Nu0SbuQr2qbSzX4yCc5n31H0lxJG5MkuKGNXGJUWu7yJLXGR4FlgA8ioqk5BVWTzMxqUTHW84uIw5o41KKJlO3Z57cy8GlE3ARcCGwGvEmSqOCrWlhTbgXOALpFxIsND0bEx8CzJM3Z+yJiUUTMA96QdFAagyRtkn7kSZIaIsCgVn8xM2s1Kf9WKu3Z57cR8KykScBZwHnAucCfJD0B5Bv3v5MkWd3ezDm3AUekP+sNAo6VNBmYQjL/B+AU4ERJzwHdWvhdzKwIVMBWKu3Z7B0NjG7k0LqNnDu8kX3v0SC+iLiOZHZ3/fs7afD7iog3+Obs7/r92+Ts+l0z4ZtZkQkvY29mWeRVXcwsqyoo9zn5mVmpyM1eM8umCsp9Tn5mVhqlHs3Nx8nPzErGzV4zy6QKyn1OfmZWOhWU+5z8zKxE5GavmWVQcodHuaP4ipOfmZVMBeU+Jz8zKx03e80skyoo9zn5mVnpVFDuc/Izs9LwklZmlk1e0srMsqqCcp+Tn5mVipe0MrOMqqDc5+RnZqXhJa3MLLPc7DWzTKqg3OfkZ2alU0G5z8nPzErES1qZWRYVa0krSW8CHwGLgIURsXlrynHyM7OSKWK9b6eImNOWApz8zKxk6iqo2VtX7gDMLENUwJZfAA9LmijpuNaG4pqfmZVMgfW+npIm5Ly/OiKuznm/bUS8K6kXMEbS1Ih4vKWxOPmZWUlIBTd75zQ3iBER76Y/Z0m6B9gSaHHyc7PXzEqnjc1eSctI6lr/GtgNeKk1objmZ2YlU4Thjt7APel8wY7ALRHxUGsKcvIzsxJRm0d7I+J1YJNiROPkZ2YlUWnP7XWfn5llkmt+ZlYylTTJ2cnPzErDDzAysyzySs5mllle0srMMqmCcp+Tn5mVTgXlPic/MyudSmr2KiLKHUPFkjQbmF7uONqgJ9CmBR+tTar9979GRKxQrMIkPUTyO8lnTkQMLNZ1m4zHya92SZrQ2iW+re38+69svsPDzDLJyc/MMsnJr7Zdnf8Ua0f+/Vcw9/mZWSa55mdmmeTkZ2aZ5OSXIZL8712BlM78VSXNAM4A/zFkgKRNACJisRNgRVoHICLCCbB0/IdQ4yQtAfxc0v3gBFhJlOgEPCLpCnACLCX/EdQwSXURsQA4Avhc0vXgBFhB6iLic2A9YG9J54ATYKn4D6CGRcTi9OVBwH+B/pJuqj/mBFheEbEofbkFMIqkhn5+eswJsJ35P/4aJ2lfYDhwJTA02aXbwAmwEkg6DLgCuBzYHxgk6ffgBNjevKRVjZGk+PrM9cXArRHxiqTXgWnA3ZLujIgDc2qHVh51wI0RMRWYKmkn4DlJRMQZ4bsQ2o3/X7+G5Ca+nBrDO8BxkjaLiAUR8TbwGLCUpFXKFWsWNVGL+wA4uP5NRLwB3AIcLGkF1/zaj2t+NSQn8R0PbCrpY+A64CfAHZJOB3oBawODI2JuuWLNopx/nxOBVYGuwBnAeEnPASeS9P91ADaPiGpeC7Di+d7eGiPpRyQDHKcCFwEvR8RJkg4HdgWWAc6LiBfKGGZmSTqBpG/veODvwF0RcVY60LEs0Ac4LSJeLGOYmeDkV+UkrQMsGRFT0vc/I1lNZBCwF7A3SffG4ohYkE5/cT9fidR3ReT8/C1wAXAssBNwIPA5fDnA0Smd/mLtzH1+VUzSt0jm8L0mafl0dw9gIrBzRAxM5/kdBRwrqaMTX2nlDFjskP7sBdwP9AP2j4jPgJOAY9L+vS9KH2U2OflVKUmrAkOA94BNgTMlbUDS1J0GvJaedwxwCvDPiFhYpnAzLb3LZoSk7wO/J+lzHRsRX0gaTNIEfixS5Yw1S9zsrVJpLWEwsC7wIcldAm8CDwKfAJcA80hqGkMj4t/lidQAJB0IbBAR50oaAFwFjCe5r/e4+m4LKx2P9lahnP6jxUBfktHB0cAGwO7AnRGxq6QlgU4R8VEZw80USX2AmRHxoaQDgDERMQ+YDJwqaXREjJW0PUlf3xIedS8PN3urUJr4BgEnA8NImrndgFeAlYEhkjaNiC+c+Eon7XfdA+ggqSNJP99dkk4FFpE0eX8mqXtEzI6IeU585ePkV736ALenU1ZOA/4HbAu8DSwBvFXG2DIpIt4nuY1wJeC3wOnAmcAC4BFgK2D19LiVmZNf9Xoe2FbSBmkN7xKS/r0FwHBPkC2b5UiSWxeSxDctIi4DDgAWAkvhEd2K4AGPKiWpO/BTIIBHgc7Aj0g6z2eUM7asktQPOA/4PrA+yYDU58DFETErXbtP6fQWKzMnvyomaWWSuwX2J6lV+M6AEmpkEQkk3UFyV83Z6aDG94ClgXPdv1dZnPxqgKRlSP4tPy53LFmUzrmcFxHz0kcGHAMMi4j5knYDtgMui4jZZQ3UvsbJz6yFGqyecwTwf8AYYDpwDfAAyXSjEek5nSNifrnitcY5+Zm1UjrdaCfgBpLBw9+RJL7OJFNe9omI6eWL0Jrj0V6zVpDUn2QE95KIeDwixpIkwo9IVmdZk+ROG6tQTn5mrbMRsAZwUHonDRExPyIujogTgT6eblTZfHubWQtI2hPYMiKGp4vFbgPsL+mOiFgkqUP6YKJZ5Y3U8nHyM2tGI9NZZgFbSzojIn6fzt3bhuSxADfWP5HNq7NUPjd7zZqRM6pbv17iv4Cfk9xdc0ZEXEtyb/X6JKtkW5XwaK9ZIxpMZ9kJ+Cuwb0RMShct6EeybNjdEfEHSd0i4sMyhmwt5JqfWQMNEt8JwCrArcCNkjaOiIUR8QwwFRggaXknvurjPj+zBnIS31CSuzX2iYibJM0DrkmXqNoAWBI4Kl3NxaqMk59ZIyTVT1Q+C1iQJsIlSe7TPZDk0QEn+pa16uU+P7MmSDqOZKWct0gGNaaTPG/3N8AC37JW3VzzM2vaDSSju69FxPvp7WwHAAud+Kqfa35meUiqA44GfgwcFhEvlTkkKwLX/MzyWwpYDBwcES+XOxgrDtf8zArQ2MKlVt2c/MwskzzJ2cwyycnPzDLJyc/MMsnJz8wyycnPviRpkaRJkl6SdIekpdtQ1gBJ96Wvvy9pWDPndk8XEGjpNYZLOr3Q/Q3OuU7SgS241pqSPL+vhjj5Wa75EdE3IjYEviC5tetLSrT4v5mIGBURv2vmlO5Ai5OfWVs4+VlTngDWTms8L0v6M/A8sJqk3SQ9Len5tIbYBUDSQElTJY0jeZA66f6jJF2evu4t6R5Jk9OtP8lTz9ZKa51/SM/7qaTnJL0g6dycss6S9IqkfwB98n0JSUPSciZLuqtBbXYXSU9ImiZpr/T8DpL+kHPtoW39RVplcvKzb0gX69wDeDHd1Qe4ISI2JXki2S+AXSJiM2ACcKqkpYARwN7A9sCKTRR/KfBYRGwCbAZMAYaR3D/bNyJ+mj7oex1gS6Av0E/SDpL6AYeSrKiyP7BFAV/n7ojYIr3ey8CxOcfWBHYEvgdclX6HY4EPI2KLtPwhkr5VwHWsyvj2NsvVWdKk9PUTJA/gXhmYHhHj0/1bkyzZ/qQkSJZ5ehpYD3gjIl4FkHQTcFwj1/gucCRA+ryLDyUt1+Cc3dLtX+n7LiTJsCtwT0R8ml5jVAHfaUNJ55E0rbsAo3OO3R4Ri4FXJb2efofdgI1z+gO7pdeeVsC1rIo4+Vmu+RHRN3dHmuBynz8rYExEHNbgvL5AsW4XEvDbiPh/Da7x41Zc4zqS5ecnSzoKGJBzrGGJ1si5AAABFElEQVRZkV775IjITZJIWrOF17UK52avtdR4kof3rA0gaWlJ65Is6f4tSWul5x3WxOcfAY5PP9tB0rIkD/rumnPOaOCYnL7EVST1Ah4H9pPUWVJXkiZ2Pl2BGZKWAAY1OHaQpLo05m8Dr6TXPj49H0nrSvKDiWqQa37WIhExO61BjUwf2wjwi4iYli7+eb+kOcA4YMNGijgFuFrSscAi4PiIeFrSk+lUkgfTfr/vAE+nNc+PgSMi4nlJtwGTSBYWfaKAkH8JPJOe/yJfT7KvAI8BvYEfRcRnkv5C0hf4vJKLzwb2Ley3Y9XECxuYWSa52WtmmeTkZ2aZ5ORnZpnk5GdmmeTkZ2aZ5ORnZpnk5GdmmfT/ASD5obfWzV1hAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 混淆矩阵\n", "cm = confusion_matrix(y_test, pred_test)\n", "plot_confusion_matrix(cm=cm, classes=[\"died\", \"survived\"])\n", "print (classification_report(y_test, pred_test))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "iMk7tN1h98x9" }, "source": [ "当我们有大于两个标签(二分类)的时候,我们可以选择在微观/宏观层面计算评估指标(每个clas标签)、权重等。 更详细内容可以参考[offical docs](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "9v6zc1_1PWnz" }, "source": [ "# 推论" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Zl9euDuMPYTN" }, "source": [ "现在我们来看看你是否会在Titanic中存活下来" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 80 }, "colab_type": "code", "id": "kX9428-EPUzx", "outputId": "ef100af7-9861-4900-e9c7-ed6d93c69069" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agecabinembarkedfarenameparchpclasssexsibspticket
024EC100Goku Mohandas21male1E44
\n", "
" ], "text/plain": [ " age cabin embarked fare name parch pclass sex sibsp ticket\n", "0 24 E C 100 Goku Mohandas 2 1 male 1 E44" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 输入你自己的信息\n", "X_infer = pd.DataFrame([{\"name\": \"Goku Mohandas\", \"cabin\": \"E\", \"ticket\": \"E44\", \n", " \"pclass\": 1, \"age\": 24, \"sibsp\": 1, \"parch\": 2, \n", " \"fare\": 100, \"embarked\": \"C\", \"sex\": \"male\"}])\n", "X_infer.head()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 80 }, "colab_type": "code", "id": "c6OAAQoaWxAb", "outputId": "85eb1c6d-6f53-4bd4-bcc3-90d9ebca74c8" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agefareparchsibsppclass_1embarked_Csex_male
02410021111
\n", "
" ], "text/plain": [ " age fare parch sibsp pclass_1 embarked_C sex_male\n", "0 24 100 2 1 1 1 1" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 进行预处理\n", "X_infer = preprocess(X_infer)\n", "X_infer.head()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 80 }, "colab_type": "code", "id": "48sj5A0mX5Yw", "outputId": "d9571238-70ab-427d-f80c-7b13b00efc95" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesibspparchfarepclass_1pclass_2pclass_3embarked_Cembarked_Qembarked_Ssex_femalesex_male
0241210010010001
\n", "
" ], "text/plain": [ " age sibsp parch fare pclass_1 pclass_2 pclass_3 embarked_C \\\n", "0 24 1 2 100 1 0 0 1 \n", "\n", " embarked_Q embarked_S sex_female sex_male \n", "0 0 0 0 1 " ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 添加缺失列向量\n", "missing_features = set(X_test.columns) - set(X_infer.columns)\n", "for feature in missing_features:\n", " X_infer[feature] = 0\n", "\n", "# 重整title\n", "X_infer = X_infer[X_train.columns]\n", "X_infer.head()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": {}, "colab_type": "code", "id": "rP_i8w9IXFiM" }, "outputs": [], "source": [ "# 标准化\n", "standardized_X_infer = X_scaler.transform(X_infer)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "7O5PbOAvXTzF", "outputId": "f1c3597e-1676-476f-e970-168e5c3fca6c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looks like I would've survived with about 57% probability on the Titanic expedition!\n" ] } ], "source": [ "# 预测\n", "y_infer = log_reg.predict_proba(standardized_X_infer)\n", "classes = {0: \"died\", 1: \"survived\"}\n", "_class = np.argmax(y_infer)\n", "print (\"Looks like I would've {0} with about {1:.0f}% probability on the Titanic expedition!\".format(\n", " classes[_class], y_infer[0][_class]*100.0))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8PLPFFP67tvL" }, "source": [ "# 可解释性" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jv6LKNXO7uch" }, "source": [ "哪些特征是最有影响力的?" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, "colab_type": "code", "id": "KTSpxbwy7ugl", "outputId": "b37bf39c-f35d-4793-a479-6e61179fc5e5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-0.02155712 0.39758992 0.78341184 -0.0070509 -2.71953415 2.01530102\n", " 3.50708962 0.11008796 0. -0.11008796 2.94675085 -2.94675085]]\n", "[5.10843738]\n" ] } ], "source": [ "# 未标准化系数\n", "coef = log_reg.coef_ / X_scaler.scale_\n", "intercept = log_reg.intercept_ - np.sum((coef * X_scaler.mean_))\n", "print (coef)\n", "print (intercept)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "xJgiIupyE0Hd" }, "source": [ "正系数表示与阳性类的相关性(1 = 存活),负系数表示与阴性类的相关性(0 = 死亡)。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "RKRB0er2C5l-", "outputId": "39ad0cf3-13b1-4aa8-9a6b-4456b8975a39" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Features correlated with death: ['sex_male', 'pclass_1', 'embarked_S']\n", "Features correlated with survival: ['pclass_2', 'sex_female', 'pclass_3']\n" ] } ], "source": [ "indices = np.argsort(coef)\n", "features = list(X_train.columns)\n", "print (\"Features correlated with death:\", [features[i] for i in indices[0][:3]])\n", "print (\"Features correlated with survival:\", [features[i] for i in indices[0][-3:]])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RhhFw3Kg-4aL" }, "source": [ "### 非标准化系数的证明:\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ER0HFHXj-4h8" }, "source": [ "注意我们的X和y都已经标准化了。\n", "\n", "$\\mathbb{E}[y] = W_0 + \\sum_{j=1}^{k}W_jz_j$\n", "\n", "$z_j = \\frac{x_j - \\bar{x}_j}{\\sigma_j}$\n", "\n", "$ \\hat{y} = \\hat{W_0} + \\sum_{j=1}^{k}\\hat{W_j}z_j $\n", "\n", "$\\hat{y} = \\hat{W_0} + \\sum_{j=1}^{k} \\hat{W}_j (\\frac{x_j - \\bar{x}_j}{\\sigma_j}) $\n", "\n", "$\\hat{y} = (\\hat{W_0} - \\sum_{j=1}^{k} \\hat{W}_j\\frac{\\bar{x}_j}{\\sigma_j}) + \\sum_{j=1}^{k} (\\frac{\\hat{w}_j}{\\sigma_j})x_j$" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "5yBZLVHwGKSj" }, "source": [ "# K折交叉验证" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "fHyLTMAAGJ_x" }, "source": [ "交叉验证是一个重采样的模型评估方法。与其我们在一开始就仅仅划分一次训练集和验证集,我们用交叉验证来划分k(通常 k=5 或者 10)次不同的训练集和验证集。\n", "\n", "步骤:\n", "1. 随机打乱训练数据集*train*。\n", "2. 将数据集分割成不同的k个片段。\n", "3. 在k次的每次循环中选择一个片段来当作验证集,其余的所有片段当成训练集。\n", "4. 重复这个过程使每个片段都有可能成为训练集或者测试集的一部分。\n", "5. 随机初始化权重来训练模型。\n", "6. 在k个循环中每次都要重新初始化模型,但是权重要保持相同的随机初始化,然后再在验证集中进行验证。\n", "\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "colab": {}, "colab_type": "code", "id": "6XB6X1b0KcvJ" }, "outputs": [], "source": [ "from sklearn.model_selection import cross_val_score" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "colab": {}, "colab_type": "code", "id": "UIqKmAEtVWMg" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Scores: [0.66666667 0.7 0.7 0.4 0.7 0.7\n", " 0.85 0.7 0.68421053 0.78947368]\n", "Mean: 0.6890350877192982\n", "Standard Deviation: 0.10984701440790533\n" ] } ], "source": [ "# K折交叉验证\n", "log_reg = SGDClassifier(loss=\"log\", penalty=\"none\", max_iter=args.num_epochs)\n", "scores = cross_val_score(log_reg, standardized_X_train, y_train, cv=10, scoring=\"accuracy\")\n", "print(\"Scores:\", scores)\n", "print(\"Mean:\", scores.mean())\n", "print(\"Standard Deviation:\", scores.std())" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "L0aQUomQoni1" }, "source": [ "# TODO" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jCpKSu53EA9-" }, "source": [ "- interaction terms\n", "- interpreting odds ratio\n", "- simple example with coordinate descent method (sklearn.linear_model.LogisticRegression)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "05_Logistic_Regression", "provenance": [], "toc_visible": true, "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 1 }