{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "bOChJSNXtC9g"
},
"source": [
"# 线性回归"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "OLIxEDq6VhvZ"
},
"source": [
"
\n",
"\n",
"在这节课上我们将学习线性回归。 我们将先理解它背后的数学基础原理再用python去实现它。 我们还将通过方法去讲解线性模型。\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "VoMq0eFRvugb"
},
"source": [
"# 概述"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "-qHciBsX93ej"
},
"source": [
"
\n",
"\n",
"$\\hat{y} = XW$\n",
"\n",
"*where*:\n",
"* $\\hat{y}$ = 预测值 | $\\in \\mathbb{R}^{NX1}$ ($N$ 是样本的个数)\n",
"* $X$ = 输入 | $\\in \\mathbb{R}^{NXD}$ ($D$ 是特征的个数)\n",
"* $W$ = 权重 | $\\in \\mathbb{R}^{DX1}$ "
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "QAgr7Grv9pb6"
},
"source": [
"* **目标:** 通过线性模型的输入 $X$ 去预测 $\\hat{y}$。模型将会寻找一条最优的线使得我们的预测值和目标值最为接近。训练数据 $(X, y)$ 用来训练这个模型并且通过随机梯度下降(SGD)学习权重 $W$。\n",
"* **优点:**\n",
" * 计算简单。\n",
" * 解释性强。\n",
" * 可用于连续(continuous)和无序的类别(categorical)特征。\n",
"* **缺点:**\n",
" * 线性模型只能用于线性可分的数据(针对于分类任务).\n",
" * 但是通常来讲不会用于分类任务,仅仅用于回归问题。\n",
"* **其他:** 当然你也可以使用线性回归去做二分类任务,如果预测出的连续数值高于一个阈值它就属于一个特定的分类。但是我们在未来的课程中将会介绍可用于做二分类任务更好的模型,所以我们本次课程只会集中在怎么用线性回归去做回归任务。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xP7XD24-09Io"
},
"source": [
"# 训练"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "476yPgTM1BKJ"
},
"source": [
"*步骤*: \n",
"1. 随机初始化模型的权重$W$。\n",
"2. 将输入值 $X$ 传入模型并且得到预测值$\\hat{y}$。\n",
"3. 通过损失函数来计算预测值$\\hat{y}$和真实值$\\hat{y}$之间的差距,从而得到损失值$J$。普遍在线性回归中用到的损失函数是均方误差(MSE)。这个函数计算出预测值和真实值之间的差距的平方($\\frac{1}{2}$ 没有数学意义,只是在求导的时候可以正好和平方抵消,方便计算)。\n",
" * $MSE = J(\\theta) = \\frac{1}{2}\\sum_{i}(\\hat{y}_i - y_i)^2$\n",
"4. 计算出对于模型权重的损失梯度$J(\\theta)$\n",
" * $J(\\theta) = \\frac{1}{2}\\sum_{i}(\\hat{y}_i - y_i)^2 = \\frac{1}{2}\\sum_{i}(X_iW - y_i)^2 $\n",
" * $\\frac{\\partial{J}}{\\partial{W}} = X(\\hat{y} - y)$\n",
"4. 我们使用学习率$\\alpha$和一个优化方法(比如随机梯度下降),通过反向传播来更新权重$W$。 一个简单的比方就是梯度可以告诉你在哪个方向上增加数值,然后通过减法来使得损失值$J(\\theta)$越来越小。\n",
" * $W = W- \\alpha\\frac{\\partial{J}}{\\partial{W}}$\n",
"5. 重复2 - 4步直到模型表现最好(也可以说直到损失收敛)。"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "jvJKjkMeJP4Q"
},
"source": [
"# 数据"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "RuPl9qlSJTIY"
},
"source": [
"我们将自己创建一些假数据应用在线性回归上。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "HRXD7LqVJZ43"
},
"outputs": [],
"source": [
"from argparse import Namespace\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "NFsKg-Z6IWqG"
},
"outputs": [],
"source": [
"# 参数\n",
"args = Namespace(\n",
" seed=1234,\n",
" data_file=\"sample_data.csv\",\n",
" num_samples=100,\n",
" train_size=0.75,\n",
" test_size=0.25,\n",
" num_epochs=100,\n",
")\n",
"\n",
"# 设置随机种子来保证实验结果的可重复性。\n",
"np.random.seed(args.seed)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "NWux2lcoIWss"
},
"outputs": [],
"source": [
"# 生成数据\n",
"def generate_data(num_samples):\n",
" X = np.array(range(num_samples))\n",
" y = 3.65*X + 10\n",
" return X, y"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"colab_type": "code",
"id": "2mb2SjSQIWvF",
"outputId": "3aa66ef6-3c88-40fd-f53a-a77f93c8e052"
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" X | \n",
" y | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.0 | \n",
" 10.00 | \n",
"
\n",
" \n",
" 1 | \n",
" 1.0 | \n",
" 13.65 | \n",
"
\n",
" \n",
" 2 | \n",
" 2.0 | \n",
" 17.30 | \n",
"
\n",
" \n",
" 3 | \n",
" 3.0 | \n",
" 20.95 | \n",
"
\n",
" \n",
" 4 | \n",
" 4.0 | \n",
" 24.60 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" X y\n",
"0 0.0 10.00\n",
"1 1.0 13.65\n",
"2 2.0 17.30\n",
"3 3.0 20.95\n",
"4 4.0 24.60"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 生成随机数据\n",
"X, y = generate_data(args.num_samples)\n",
"data = np.vstack([X, y]).T\n",
"df = pd.DataFrame(data, columns=['X', 'y'])\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 362
},
"colab_type": "code",
"id": "6LwVVOkiLfBN",
"outputId": "ab2ecddb-bbeb-4117-d1cf-29312846da16"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAEICAYAAABRSj9aAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAGY1JREFUeJzt3X2QHPV95/H3RytZFg+xDAisR8s2wjb2lUVuD3DI+Qhw4eFckXNlA04Og0NOSYwqJiEOwq4KxHVO8JmAScEpkY2NcNkWROZBRRFzGMwRcoAtgcKDMRclxrDsIgms5cGWMZK+90f/lrSW2Z2e3emZ6Z7Pq2pqZ3q6Z35Ni89+9jc9M4oIzMysvmZ0ewBmZlYuB72ZWc056M3Mas5Bb2ZWcw56M7Oac9CbmdWcg96sBJLOkXRvC+s/KemkMsdk/ctBbx0j6UxJD0j6qaTt6fonJKnbYxtP0t2Sfrfb42hEUkg6vNvjsOpw0FtHSLoAuBL4AvAW4DDg94HjgDd0eCwzO/l8Zt3moLfSSXoT8FngExGxISJeisxDEfHbEfFKWm+2pMskPSVpm6S/kTQn3Xe8pCFJF6S/BkYkfTz3HEW2vVDSs8BXJb1Z0q2Sdkjama4vSut/DviPwFWSXpZ0VVr+Lkl3SPqJpCcknZ57/oMlbZT0oqTvAe9o8t/kLEk/lvS8pM+Mu+9oSfdJGk37eZWkN6T77kmr/VMa2xmT7YsZOOitM94PzAZuabLe54EjgOXA4cBC4M9y978FeFNafi5wtaQ3t7DtQcBbgZVk//a/mm4vAXYBVwFExGeAfwBWRcQBEbFK0v7AHcA3gEOBjwL/S9J70uNfDfwcmA/8Tro0JOlIYA1wFrAAOBjIB/Me4I+AQ8j+250IfCKN7QNpnfelsV0/2b6YARARvvhS6gX4b8Cz45b9X2CULJQ+AAj4KfCO3DrvB36Urh+f1p2Zu387cGzBbX8BvHGSMS4HduZu3w38bu72GcA/jNvmb4GLgQHgVeBdufv+Arh3guf6M2B97vb+aXwnTbD++cBNudsBHF50X3zxxXOV1gnPA4dImhkRuwEi4lcAJA2RNdJ5wH7A5txrsyIL0dceZ2z75GfAAQW33RERP3/tTmk/4ArgFGDsr4IDJQ1ExJ4G+/BW4BhJo7llM4GvpeefCTydu+/Hjf9TAFmLf23diPippOdzYzsCuBwYTPs1E9g80YNNYV+sz3jqxjrhPuAVYMUk6zxH1tjfExFz0+VNEXFAgccvsu34j2m9AHgncExE/BLZXxWQ/YJotP7TwP/JPf7cyKZO/gDYAewGFufWXzLJeEfy66agPjh3/xrgh8CyNLZP58bVSLN9sT7noLfSRcQo8Odkc9oflnSApBmSlpNNWxARe4EvAVdIOhRA0kJJJxd4/KlseyDZL4dRSQeRTcHkbQPenrt9K3BEehF1Vrr8B0nvTq35RuASSfulOfizJ3nuDcAHJf1qepH1s+z7/+KBwIvAy5LeBfxBk7E12xfrcw5664iI+J/AHwN/Sja3vo1sjvtCsvl60vWtwP2SXgS+Q9ZUi2h12y8Cc8j+Grgf+Pa4+68EPpzOYvnriHgJ+HXgTGAYeJbsBeDZaf1VZNNIzwLXkr042lBEPAacR/bC7giwExjKrfInwG8BL5H9Art+3ENcAqxLZ+WcXmBfrM8pwl88YmZWZ270ZmY156A3M6s5B72ZWc056M3Maq4n3jB1yCGHxNKlS7s9DDOzStm8efNzETGv2Xo9EfRLly5l06ZN3R6GmVmlSJrsHdiv8dSNmVnNOejNzGrOQW9mVnMOejOzmnPQm5nVXE+cdWNm1m9ufugZvnD7EwyP7mLB3Dl86uR38qGjFpbyXA56M7MOu/mhZ7joxkfY9Wr2vTDPjO7iohsfASgl7B30ZmYdMtbinxnd9br7dr26hy/c/oSD3sysqsa3+EaGG/wCaAcHvZlZiSZr8eMtmDunlDE46M3MSlKkxY+ZM2uAT51c9AvVWuOgNzNrs1ZaPMBCn3VjZlYdrbb4v/yv/660gB/joDczm6b8OfEzJPYU+C7uslt8noPezGwaxjf4ZiHfqRaf56A3M5uCVufhobMtPs9Bb2bWolbm4aE7LT6vadBLeiNwDzA7rb8hIi6WdC3wn4AX0qrnRMQWSQKuBE4DfpaWP1jG4M3MOqmVFj8gsTei9M+xKaJIo38FOCEiXpY0C7hX0t+n+z4VERvGrX8qsCxdjgHWpJ9mZpXVi2fTFNU06CMigJfTzVnpMtmrDSuA69J290uaK2l+RIxMe7RmZh3Wa+fET0Whz6OXNCBpC7AduCMiHkh3fU7Sw5KukDQ7LVsIPJ3bfCgtG/+YKyVtkrRpx44d09gFM7NyjLX4IiE/Z9YAXzxjOf+4+oSeCnko+GJsROwBlkuaC9wk6b3ARcCzwBuAtcCFwGcBNXqIBo+5Nm3H4OBg85NOzcw6pA4tPq+ls24iYlTS3cApEXFZWvyKpK8Cf5JuDwGLc5stAoanO1Azs06o8lz8RIqcdTMPeDWF/BzgJODzY/Pu6SybDwGPpk02AqskrSd7EfYFz8+bWa+rW4vPK9Lo5wPrJA2QzenfEBG3Sror/RIQsAX4/bT+bWSnVm4lO73y4+0ftplZ+9SxxecVOevmYeCoBstPmGD9AM6b/tDMzMpV5xaf53fGmllfqnuLz3PQm1lf6ZcWn+egN7O+0U8tPs9Bb2a1148tPs9Bb2a11q8tPs9Bb2a11O8tPs9Bb2a14xa/Lwe9mdWGW3xjDnozqwW3+Ik56M2s0tzim3PQm1llucUX46A3s8pxi2+Ng97MKsUtvnUOejPreWMNfnh0FzMk9kTzL6Xr9xaf56A3s542vsE3C3m3+Ndz0JtZT2p1Hh7c4ifioDezntPKPDy4xTfjoDezntFKix+Q2BvBArf4pop8OfgbgXuA2Wn9DRFxsaS3AeuBg4AHgbMi4heSZgPXAf8eeB44IyKeLGn8ZlYTPpumPEUa/SvACRHxsqRZwL2S/h74Y+CKiFgv6W+Ac4E16efOiDhc0pnA54EzShq/mVWcz4kvX5EvBw/g5XRzVroEcALwW2n5OuASsqBfka4DbACukqT0OGZmr3GL74xCc/SSBoDNwOHA1cC/AKMRsTutMgSM/ddfCDwNEBG7Jb0AHAw8N+4xVwIrAZYsWTK9vTCzSnGL76xCQR8Re4DlkuYCNwHvbrRa+qlJ7ss/5lpgLcDg4KDbvlmfcIvvvJbOuomIUUl3A8cCcyXNTK1+ETCcVhsCFgNDkmYCbwJ+0r4hm1kVucV3T5GzbuYBr6aQnwOcRPYC63eBD5OdeXM2cEvaZGO6fV+6/y7Pz5v1N7f47irS6OcD69I8/Qzghoi4VdIPgPWS/gfwEHBNWv8a4GuStpI1+TNLGLeZVYBbfG8octbNw8BRDZb/K3B0g+U/Bz7SltGZWWW5xfcOvzPWzNrKLb73OOjNrG3c4nuTg97Mps0tvrc56M1sWtzie5+D3symxC2+Ohz0ZtYyt/hqcdCbWWFu8dXkoDezQtziq8tBb2aTcouvPge9mU3ILb4eHPRm9jpu8fXioDezfbjF14+D3swAt/g6c9CbmVt8zTnozfrUWIMfHt3FDIk9Bb4fyC2+mhz0Zn1ofINvFvJu8dXmoDfrI63Ow4NbfB046M36RCvz8OAWXyczmq0gabGk70p6XNJjkj6Zll8i6RlJW9LltNw2F0naKukJSSeXuQNmNrmbH3qG4y69i/Ov39I05AckRNbiHfL1UaTR7wYuiIgHJR0IbJZ0R7rvioi4LL+ypCPJvhD8PcAC4DuSjoiIYjXCzNrGZ9MYFPty8BFgJF1/SdLjwGT/ElYA6yPiFeBHkraSfYn4fW0Yr5kV4HPiLa+lOXpJS4GjgAeA44BVkj4GbCJr/TvJfgncn9tsiMl/MZhZG7nF23hN5+jHSDoA+BZwfkS8CKwB3gEsJ2v8fzW2aoPNX3fulqSVkjZJ2rRjx46WB25m+2plLh48D99PCjV6SbPIQv7rEXEjQERsy93/JeDWdHMIWJzbfBEwPP4xI2ItsBZgcHCw+Ts1zGxCbvE2maZBL0nANcDjEXF5bvn8NH8P8JvAo+n6RuAbki4nezF2GfC9to7azADPxVsxRRr9ccBZwCOStqRlnwY+Kmk52bTMk8DvAUTEY5JuAH5AdsbOeT7jxqz93OKtqCJn3dxL43n32ybZ5nPA56YxLjObgFu8tcrvjDWrELd4mwoHvVkFuMXbdDjozXqcW7xNl4PerEe5xVu7OOjNepBbvLWTg96sh7jFWxkc9GY9wi3eyuKgN+syt3grm4PerIvc4q0THPRmXeAWb53koDfrMLd46zQHvVmHuMVbtzjozTrALd66yUFvViK3eOsFDnqzkrjFW69w0Ju10ViDHx7dxQyJPdH8WzLd4q1sDnqzNhnf4JuFvFu8dYqD3myaWp2HB7d46ywHvdk0tDIPD27x1h1Ng17SYuA64C3AXmBtRFwp6SDgemAp2ZeDnx4ROyUJuBI4DfgZcE5EPFjO8M26o5UWPyCxN4IFbvHWJUUa/W7ggoh4UNKBwGZJdwDnAHdGxKWSVgOrgQuBU4Fl6XIMsCb9NKsFn01jVdM06CNiBBhJ11+S9DiwEFgBHJ9WWwfcTRb0K4DrIiKA+yXNlTQ/PY5ZZfmceKuqluboJS0FjgIeAA4bC++IGJF0aFptIfB0brOhtGyfoJe0ElgJsGTJkikM3axz3OKtygoHvaQDgG8B50fEi9lUfONVGyx73XlmEbEWWAswODjY/GRjsy5wi7c6KBT0kmaRhfzXI+LGtHjb2JSMpPnA9rR8CFic23wRMNyuAZt1ilu81UWRs24EXAM8HhGX5+7aCJwNXJp+3pJbvkrSerIXYV/w/LxViVu81U2RRn8ccBbwiKQtadmnyQL+BknnAk8BH0n33UZ2auVWstMrP97WEZuVyC3e6qjIWTf30njeHeDEBusHcN40x2XWUW7xVmd+Z6z1Pbd4qzsHvfUtt3jrFw5660tu8dZPHPTWV9zirR856K1vuMVbv3LQW+25xVu/c9BbrbnFmznorabc4s3+jYPeasct3mxfDnqrDbd4s8Yc9FYLbvFmE3PQW6W5xZs156C3ynKLNyvGQW+VMtbgh0d3MUNiTzT/cjK3eOt3DnqrjPENvlnIu8WbZRz01vNanYcHt3izPAe99bRW5uHBLd6sEQe99aRWWvyAxN4IFrjFmzVU5MvBvwJ8ENgeEe9Nyy4B/juwI6326Yi4Ld13EXAusAf4w4i4vYRxW435bBqz9irS6K8FrgKuG7f8ioi4LL9A0pHAmcB7gAXAdyQdERHF/u62vuZz4s3KUeTLwe+RtLTg460A1kfEK8CPJG0Fjgbum/IIrS+4xZuVZzpz9KskfQzYBFwQETuBhcD9uXWG0rLXkbQSWAmwZMmSaQzDqswt3qx8M6a43RrgHcByYAT4q7RcDdZteLJzRKyNiMGIGJw3b94Uh2FVNtbii4T8nFkDfPGM5fzj6hMc8mYtmlKjj4htY9clfQm4Nd0cAhbnVl0EDE95dFZLbvFmnTWloJc0PyJG0s3fBB5N1zcC35B0OdmLscuA7017lFYbnos367wip1d+EzgeOETSEHAxcLyk5WTTMk8CvwcQEY9JugH4AbAbOM9n3Bi4xZt1k6LAh0KVbXBwMDZt2tTtYVhJ3OLNyiFpc0QMNlvP74y10rjFm/UGB72Vwi3erHc46K2t3OLNeo+D3trGLd6sNznobdrc4s16m4PepsUt3qz3OehtStzizarDQW8tc4s3qxYHvRXmFm9WTQ56K8Qt3qy6HPQ2Kbd4s+pz0NuE3OLN6sFBb6/jFm9WLw5624dbvFn9OOjttQY/PLqLGRJ7Cnx0tVu8WXU46Pvc+AbfLOTd4s2qx0Hfp1qdhwe3eLOqctD3oVbm4cEt3qzqHPR9pJUWPyCxN4IFbvFmlVfky8G/AnwQ2B4R703LDgKuB5aSfTn46RGxU5KAK4HTgJ8B50TEg+UM3Vrhs2nM+teMAutcC5wybtlq4M6IWAbcmW4DnAosS5eVwJr2DNOm6uaHnuG4S+/i/Ou3FAr5hXPnOOTNaqZpo4+IeyQtHbd4BXB8ur4OuBu4MC2/LiICuF/SXEnzI2KkXQO24tzizQymPkd/2Fh4R8SIpEPT8oXA07n1htKy1wW9pJVkrZ8lS5ZMcRjWiN/ZamZ57X4xVg2WNTwxOyLWAmsBBgcHm79Dxwpxizez8aYa9NvGpmQkzQe2p+VDwOLceouA4ekM0IpxizeziUw16DcCZwOXpp+35JavkrQeOAZ4wfPz5XOLN7PJFDm98ptkL7weImkIuJgs4G+QdC7wFPCRtPptZKdWbiU7vfLjJYzZErd4MyuiyFk3H53grhMbrBvAedMdlDXnFm9mRfmdsRXjFm9mrXLQV4hbvJlNhYO+AtzizWw6HPQ9zi3ezKbLQd+j3OLNrF0c9D3ILd7M2slB30Pc4s2sDA76HuEWb2ZlcdB3mVu8mZXNQd9FbvFm1gkO+i5wizezTnLQd5hbvJl1moO+A8Ya/PDoLmZI7Inm37PiFm9m7eKgL9n4Bt8s5N3izazdHPQlaXUeHtzizawcDvoStDIPD27xZlYuB30btdLiByT2RrDALd7MSuagbxOfTWNmvWpaQS/pSeAlYA+wOyIGJR0EXA8sBZ4ETo+IndMbZu/yOfFm1uva0eh/LSKey91eDdwZEZdKWp1uX9iG5+k5bvFmVgVlTN2sAI5P19cBd1OzoHeLN7MqmW7QB/C/JQXwtxGxFjgsIkYAImJE0qGNNpS0ElgJsGTJkmkOo3Pc4s2saqYb9MdFxHAK8zsk/bDohumXwlqAwcHB5m8V7TK3eDOrqmkFfUQMp5/bJd0EHA1skzQ/tfn5wPY2jLOr3OLNrMqmHPSS9gdmRMRL6fqvA58FNgJnA5emn7e0Y6Dd4BZvZnUwnUZ/GHCTpLHH+UZEfFvS94EbJJ0LPAV8ZPrD7Dy3eDOriykHfUT8K/C+BsufB06czqC6yS3ezOrG74zNcYs3szpy0OMWb2b11vdB7xZvZnXXt0HvFm9m/aIvg94t3sz6SV8FvVu8mfWjvgl6t3gz61e1D3q3eDPrd7UOerd4M7OaBr1bvJnZv6ld0LvFm5ntqzZB7xZvZtZYLYLeLd7MbGKVDfqxBj88uosZEnui+ZdUucWbWT+qZNCPb/DNQt4t3sz6WSWD/gu3P1Fomgbc4s3MKhn0wwVecHWLNzPLzOj2AKZiwdw5DZcPSIisxTvkzcwypTV6SacAVwIDwJcj4tJ2PfanTn7n686ycYM3M2uslKCXNABcDfxnYAj4vqSNEfGDdjz+WJiPnXWzwPPwZmYTKqvRHw1sTV8gjqT1wAqgLUEPWdg72M3Mmitrjn4h8HTu9lBa9hpJKyVtkrRpx44dJQ3DzMzKCno1WLbPye4RsTYiBiNicN68eSUNw8zMygr6IWBx7vYiYLik5zIzs0mUFfTfB5ZJepukNwBnAhtLei4zM5tEKS/GRsRuSauA28lOr/xKRDxWxnOZmdnkFAU+DKz0QUg7gB9PcfNDgOfaOJyq6Mf97sd9hv7c737cZ2h9v98aEU1f5OyJoJ8OSZsiYrDb4+i0ftzvftxn6M/97sd9hvL2u5IfgWBmZsU56M3Maq4OQb+22wPokn7c737cZ+jP/e7HfYaS9rvyc/RmZja5OjR6MzObhIPezKzmKh30kk6R9ISkrZJWd3s8ZZC0WNJ3JT0u6TFJn0zLD5J0h6R/Tj/f3O2xlkHSgKSHJN2abr9N0gNpv69P77yuDUlzJW2Q9MN0zN/fD8da0h+lf9+PSvqmpDfW8VhL+oqk7ZIezS1reHyV+euUbw9L+uWpPm9lgz73mfenAkcCH5V0ZHdHVYrdwAUR8W7gWOC8tJ+rgTsjYhlwZ7pdR58EHs/d/jxwRdrvncC5XRlVea4Evh0R7wLeR7bvtT7WkhYCfwgMRsR7yd5Nfyb1PNbXAqeMWzbR8T0VWJYuK4E1U33SygY9uc+8j4hfAGOfeV8rETESEQ+m6y+R/Y+/kGxf16XV1gEf6s4IyyNpEfBfgC+n2wJOADakVWq135J+CfgAcA1ARPwiIkbpg2NN9nEscyTNBPYDRqjhsY6Ie4CfjFs80fFdAVwXmfuBuZLmT+V5qxz0TT/zvm4kLQWOAh4ADouIEch+GQCHdm9kpfki8KfA3nT7YGA0Inan23U75m8HdgBfTdNVX5a0PzU/1hHxDHAZ8BRZwL8AbKbexzpvouPbtoyrctA3/cz7OpF0APAt4PyIeLHb4ymbpA8C2yNic35xg1XrdMxnAr8MrImIo4CfUrNpmkbSnPQK4G3AAmB/smmL8ep0rIto27/3Kgd933zmvaRZZCH/9Yi4MS3eNvZnXPq5vVvjK8lxwG9IepJsWu4EsoY/N/15D/U75kPAUEQ8kG5vIAv+uh/rk4AfRcSOiHgVuBH4Fep9rPMmOr5ty7gqB31ffOZ9mpe+Bng8Ii7P3bURODtdPxu4pdNjK1NEXBQRiyJiKdmxvSsifhv4LvDhtFqt9jsingWelvTOtOhEsu9ZrvWxJpuyOVbSfunf+9h+1/ZYjzPR8d0IfCydfXMs8MLYFE/LIqKyF+A04P8B/wJ8ptvjKWkff5Xsz7WHgS3pchrZfPWdwD+nnwd1e6wl/jc4Hrg1XX878D1gK/B3wOxuj6/N+7oc2JSO983Am/vhWAN/DvwQeBT4GjC7jsca+CbZ6xCvkjX2cyc6vmRTN1enfHuE7KykKT2vPwLBzKzmqjx1Y2ZmBTjozcxqzkFvZlZzDnozs5pz0JuZ1ZyD3sys5hz0ZmY19/8BQ400FTRzri8AAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# 画散点图\n",
"plt.title(\"Generated data\")\n",
"plt.scatter(x=df[\"X\"], y=df[\"y\"])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Qwn29SjK-XCg"
},
"source": [
"# Scikit-learn 实现方法"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "-kSEp8MY-y9C"
},
"source": [
"**注意**: `LinearRegression`类在Scikit-learn中使用的是正规方程法来做的拟合。然而,我们将会使用Scikit-learn中的随机梯度下降`SGDRegressor`类来拟合数据。我们使用这个优化方法是因为在未来的几节课程中我们也会使用到它。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "uKmBKodpgHEE"
},
"outputs": [],
"source": [
"# 调包\n",
"from sklearn.linear_model.stochastic_gradient import SGDRegressor\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"colab_type": "code",
"id": "WuUQwD72NVAE",
"outputId": "c46cd73c-7ca4-4d57-ee4b-0fc636f5cde5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X_train: (75, 1)\n",
"y_train: (75,)\n",
"X_test: (25, 1)\n",
"y_test: (25,)\n"
]
}
],
"source": [
"# 划分数据到训练集和测试集\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" df[\"X\"].values.reshape(-1, 1), df[\"y\"], test_size=args.test_size, \n",
" random_state=args.seed)\n",
"print (\"X_train:\", X_train.shape)\n",
"print (\"y_train:\", y_train.shape)\n",
"print (\"X_test:\", X_test.shape)\n",
"print (\"y_test:\", y_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MJVs6JF7trja"
},
"source": [
"我们需要标准化我们的数据(零均值和单位方差),以便正确使用SGD并在速度上优化。"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"colab_type": "code",
"id": "VlOYPD5GRjRC",
"outputId": "9d63c0d3-da44-487c-fb8c-c9879f2d97d3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean: [8.22952817e-17] -1.5617137213060536e-16\n",
"std: [1.] 0.9999999999999999\n"
]
}
],
"source": [
"# 标准化训练集数据 (mean=0, std=1) \n",
"X_scaler = StandardScaler().fit(X_train)\n",
"y_scaler = StandardScaler().fit(y_train.values.reshape(-1,1))\n",
"\n",
"# 在训练集和测试集上进行标准化操作\n",
"standardized_X_train = X_scaler.transform(X_train)\n",
"standardized_y_train = y_scaler.transform(y_train.values.reshape(-1,1)).ravel()\n",
"standardized_X_test = X_scaler.transform(X_test)\n",
"standardized_y_test = y_scaler.transform(y_test.values.reshape(-1,1)).ravel()\n",
"\n",
"\n",
"# 检查\n",
"print (\"mean:\", np.mean(standardized_X_train, axis=0), \n",
" np.mean(standardized_y_train, axis=0)) # mean 应该是 ~0\n",
"print (\"std:\", np.std(standardized_X_train, axis=0), \n",
" np.std(standardized_y_train, axis=0)) # std 应该是 1"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "CiE3oLCkOCEa"
},
"outputs": [],
"source": [
"# 初始化模型\n",
"lm = SGDRegressor(loss=\"squared_loss\", penalty=\"none\", max_iter=args.num_epochs)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 102
},
"colab_type": "code",
"id": "sGH_pQaDOb49",
"outputId": "ddfff892-314a-4d19-d8a3-549b5e6c555a"
},
"outputs": [
{
"data": {
"text/plain": [
"SGDRegressor(alpha=0.0001, average=False, epsilon=0.1, eta0=0.01,\n",
" fit_intercept=True, l1_ratio=0.15, learning_rate='invscaling',\n",
" loss='squared_loss', max_iter=100, n_iter=None, penalty='none',\n",
" power_t=0.25, random_state=None, shuffle=True, tol=None, verbose=0,\n",
" warm_start=False)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 训练\n",
"lm.fit(X=standardized_X_train, y=standardized_y_train)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "fA8VkVVGPkTr"
},
"outputs": [],
"source": [
"# 预测 (还未标准化)\n",
"pred_train = (lm.predict(standardized_X_train) * np.sqrt(y_scaler.var_)) + y_scaler.mean_\n",
"pred_test = (lm.predict(standardized_X_test) * np.sqrt(y_scaler.var_)) + y_scaler.mean_"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "T8Ws-khqJuNr"
},
"source": [
"# 评估"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Y2pha3VRWd2D"
},
"source": [
"有很多种方法可以来评估我们模型表现的好坏。"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "abGgfBbLVjJ_"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "RKm8IiP7O66e",
"outputId": "4fd58928-36da-4d96-9222-9658e1a0bd58"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_MSE: 0.00, test_MSE: 0.00\n"
]
}
],
"source": [
"# 训练和测试集上的均方误差 MSE\n",
"train_mse = np.mean((y_train - pred_train) ** 2)\n",
"test_mse = np.mean((y_test - pred_test) ** 2)\n",
"print (\"train_MSE: {0:.2f}, test_MSE: {1:.2f}\".format(train_mse, test_mse))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "TegkJM2-YKEq"
},
"source": [
"除了使用MSE,如果我们只有一个特征向量,我们可以把他们可视化出来直观的评估模型。"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 335
},
"colab_type": "code",
"id": "gH5N-U7YQVgn",
"outputId": "4a1ce429-18c5-45e8-e701-2cc7affcb94b"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# 图例大小\n",
"plt.figure(figsize=(15,5))\n",
"\n",
"# 画出训练数据\n",
"plt.subplot(1, 2, 1)\n",
"plt.title(\"Train\")\n",
"plt.scatter(X_train, y_train, label=\"y_train\")\n",
"plt.plot(X_train, pred_train, color=\"red\", linewidth=1, linestyle=\"-\", label=\"lm\")\n",
"plt.legend(loc='lower right')\n",
"\n",
"# 画出测试数据\n",
"plt.subplot(1, 2, 2)\n",
"plt.title(\"Test\")\n",
"plt.scatter(X_test, y_test, label=\"y_test\")\n",
"plt.plot(X_test, pred_test, color=\"red\", linewidth=1, linestyle=\"-\", label=\"lm\")\n",
"plt.legend(loc='lower right')\n",
"\n",
"# 显示图例\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xAP1EoQi86XB"
},
"source": [
"# 推论"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 159
},
"colab_type": "code",
"id": "K2yfNk3d8-Vj",
"outputId": "aef548b1-d0a1-4481-a9fe-4c975f278a3f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[10.00362276 13.65354177 17.30346078]\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" X | \n",
" y | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.0 | \n",
" 10.00 | \n",
"
\n",
" \n",
" 1 | \n",
" 1.0 | \n",
" 13.65 | \n",
"
\n",
" \n",
" 2 | \n",
" 2.0 | \n",
" 17.30 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" X y\n",
"0 0.0 10.00\n",
"1 1.0 13.65\n",
"2 2.0 17.30"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 传入我们自己的输入值\n",
"X_infer = np.array((0, 1, 2), dtype=np.float32)\n",
"standardized_X_infer = X_scaler.transform(X_infer.reshape(-1, 1))\n",
"pred_infer = (lm.predict(standardized_X_infer) * np.sqrt(y_scaler.var_)) + y_scaler.mean_\n",
"print (pred_infer)\n",
"df.head(3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PHH0fYp_BYC5"
},
"source": [
"# 可解释性"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "OhXo8CbPBZ-G"
},
"source": [
"线性回归有很强的可解释性。每一个特征都有一个系数来控制对输出值y的影响大小。我们可以这样解释这个系数: 如果我们把x增加1, 我们将把y增加 $W$ (~3.65)。\n",
"\n",
"**注意**: 因为我们在梯度下降时标准化了我们的输入和输出,我们需要对我们的系数和截距做一个反标准化。过程可见下方。"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"colab_type": "code",
"id": "JZxnrDuCBbK9",
"outputId": "adbe06d5-449a-4aa0-ebbb-ef5fdd5ef9e9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[3.64992205]\n",
"[10.00362489]\n"
]
}
],
"source": [
"# 未标准化系数\n",
"coef = lm.coef_ * (y_scaler.scale_/X_scaler.scale_)\n",
"intercept = lm.intercept_ * y_scaler.scale_ + y_scaler.mean_ - np.sum(coef*X_scaler.mean_)\n",
"print (coef) # ~3.65\n",
"print (intercept) # ~10"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "yVmIP13u9s33"
},
"source": [
"### 非标准化系数的证明:\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ViDPSLbR9v4B"
},
"source": [
"注意我们的X和y都已经标准化了。\n",
"\n",
"$\\frac{\\mathbb{E}[y] - \\hat{y}}{\\sigma_y} = W_0 + \\sum_{j=1}^{k}W_jz_j$\n",
"\n",
"$z_j = \\frac{x_j - \\bar{x}_j}{\\sigma_j}$\n",
"\n",
"$ \\hat{y}_{scaled} = \\frac{\\hat{y}_{unscaled} - \\bar{y}}{\\sigma_y} = \\hat{W_0} + \\sum_{j=1}^{k} \\hat{W}_j (\\frac{x_j - \\bar{x}_j}{\\sigma_j}) $\n",
"\n",
"$\\hat{y}_{unscaled} = \\hat{W}_0\\sigma_y + \\bar{y} - \\sum_{j=1}^{k} \\hat{W}_j(\\frac{\\sigma_y}{\\sigma_j})\\bar{x}_j + \\sum_{j=1}^{k}(\\frac{\\sigma_y}{\\sigma_j})x_j $\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "rToCXKqeJcvj"
},
"source": [
"# 正则化"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "L4GFv8xRJmOZ"
},
"source": [
"正规化有助于减少过拟合。下方是L2正则化(ridge regression)。有很多正则化的方法他们都可以使我们的模型减少过拟合。对于L2正则化, 我们会减小那些值很大的权重。数值很大的权重将会使模型更加看中它们的特征,但是我们希望的是模型会公平的对待所有的特征而不是仅仅权重很大的几个。 当然还有其他的正则化方法比如L1(lasso regression),它对于我们想创建更加稀疏的数据模型有好处,因为它会使得一些权重变成0,或者我们可以结合L2和L1正则化方法。\n",
"\n",
"**注意**: 正则化不仅仅用于线性回归。它可以用于任何常规模型以及我们以后将会学到的模型。"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "D_OcpRxF-Oj7"
},
"source": [
"* $ J(\\theta) = = \\frac{1}{2}\\sum_{i}(X_iW - y_i)^2 + \\frac{\\lambda}{2}\\sum\\sum W^2$\n",
"* $ \\frac{\\partial{J}}{\\partial{W}} = X (\\hat{y} - y) + \\lambda W $\n",
"* $W = W- \\alpha\\frac{\\partial{J}}{\\partial{W}}$\n",
"where:\n",
" * $\\lambda$ 是正则化系数"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "HHaazL9f8QZX"
},
"outputs": [],
"source": [
"# 初始化带有L2正则化的模型\n",
"lm = SGDRegressor(loss=\"squared_loss\", penalty='l2', alpha=1e-2, \n",
" max_iter=args.num_epochs)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 102
},
"colab_type": "code",
"id": "VTIUZLbGZP4e",
"outputId": "e284d26b-6091-4ce9-b40c-5b9d675f0837"
},
"outputs": [
{
"data": {
"text/plain": [
"SGDRegressor(alpha=0.01, average=False, epsilon=0.1, eta0=0.01,\n",
" fit_intercept=True, l1_ratio=0.15, learning_rate='invscaling',\n",
" loss='squared_loss', max_iter=100, n_iter=None, penalty='l2',\n",
" power_t=0.25, random_state=None, shuffle=True, tol=None, verbose=0,\n",
" warm_start=False)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 训练\n",
"lm.fit(X=standardized_X_train, y=standardized_y_train)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ORwkUqcuZhbX"
},
"outputs": [],
"source": [
"# 预测 (还未标准化)\n",
"pred_train = (lm.predict(standardized_X_train) * np.sqrt(y_scaler.var_)) + y_scaler.mean_\n",
"pred_test = (lm.predict(standardized_X_test) * np.sqrt(y_scaler.var_)) + y_scaler.mean_"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "IWCvYxBxZhd5",
"outputId": "a38bb230-eb97-43c1-d43d-8dead30dedcb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_MSE: 1.09, test_MSE: 1.15\n"
]
}
],
"source": [
"# 训练集和测试集的MSE\n",
"train_mse = np.mean((y_train - pred_train) ** 2)\n",
"test_mse = np.mean((y_test - pred_test) ** 2)\n",
"print (\"train_MSE: {0:.2f}, test_MSE: {1:.2f}\".format(\n",
" train_mse, test_mse))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mdNX2W5eh2ma"
},
"source": [
"正则化对于我们在做的这个数据帮助很少,因为我们在创建数据的时候用的就是一个线性的函数。但是对于现实中的数据,正则化就可以帮助我们构建更好的模型。"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"colab_type": "code",
"id": "C2mrVS4UZp3Q",
"outputId": "5ba6159b-0791-4093-9ed9-8925386edf36"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[3.61384419]\n",
"[11.67785083]\n"
]
}
],
"source": [
"# 未标准化系数\n",
"coef = lm.coef_ * (y_scaler.scale_/X_scaler.scale_)\n",
"intercept = lm.intercept_ * y_scaler.scale_ + y_scaler.mean_ - (coef*X_scaler.mean_)\n",
"print (coef) # ~3.65\n",
"print (intercept) # ~10"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "V74lNFE5v5pQ"
},
"source": [
"# 类别变量"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "2r6Xhyg7v5vX"
},
"source": [
"在我们的例子中,特征用的是连续的数值,那么假设我们要用类别的特征变量呢?一种选择就是使用独热编码来处理类别变量,这种方法用Pandas很容易实现,你可以用和上面相同的步骤来训练你的模型"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 173
},
"colab_type": "code",
"id": "unhcIOfMxQEQ",
"outputId": "ecf36ce0-28af-4381-b61b-2592a3c1b289"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" favorite_letter | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" a | \n",
"
\n",
" \n",
" 1 | \n",
" b | \n",
"
\n",
" \n",
" 2 | \n",
" c | \n",
"
\n",
" \n",
" 3 | \n",
" a | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" favorite_letter\n",
"0 a\n",
"1 b\n",
"2 c\n",
"3 a"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 创建类别特征\n",
"cat_data = pd.DataFrame(['a', 'b', 'c', 'a'], columns=['favorite_letter'])\n",
"cat_data.head()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 173
},
"colab_type": "code",
"id": "m4eQmJdrxQGr",
"outputId": "247aaac2-afcb-4899-e415-91d3fe169fbf"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" favorite_letter_a | \n",
" favorite_letter_b | \n",
" favorite_letter_c | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" 3 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" favorite_letter_a favorite_letter_b favorite_letter_c\n",
"0 1 0 0\n",
"1 0 1 0\n",
"2 0 0 1\n",
"3 1 0 0"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dummy_cat_data = pd.get_dummies(cat_data) #独热编码 one-hot encoding,与dummy变量不同要注意。\n",
"dummy_cat_data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "B5R8x-KyiBWJ"
},
"source": [
"现在你可以拼接上连续特征变量来训练线性模型。"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "eVOXoCRsokzp"
},
"source": [
"# TODO"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "4c7ttuUwfeLA"
},
"source": [
"- 多项式回归\n",
"- 一个简单的用正规方程的例子(sklearn.linear_model.LinearRegression)来分析优点和缺点,并且和随机梯度下降线性回归做对比。"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "04_Linear_Regression",
"provenance": [],
"toc_visible": true,
"version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 1
}