{ "cells": [ { "cell_type": "code", "execution_count": 316, "id": "df2d6b18-bd95-4c9e-a789-e8b79717b970", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib widget\n", "import seaborn as sns\n", "sns.set(style=\"whitegrid\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "3459b50f-5e94-4d74-b35c-ce4f77028ea3", "metadata": {}, "outputs": [], "source": [ "original_df=pd.read_excel('附件1/附件1.xlsx',sheet_name='Data')" ] }, { "cell_type": "code", "execution_count": 143, "id": "4476b11e-37f8-4796-ad32-9b72a6a5df75", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
eventidiyearimonthidayapproxdateextendedresolutioncountrycountry_txtregion...addnotesscite1scite2scite3dbsourceINT_LOGINT_IDEOINT_MISCINT_ANYrelated
0199801010001199811NaN0NaT34Burundi11...NaN“Burundi Rebels, Ex-Rwandan Army Soldiers Blam...“Burundi--Attack Reported on Bujumbura Airport...NaNCETIS0101NaN
1199801010002199811NaN0NaT167Russia9...NaN“Bomb injures 3 in Moscow subway system,” The ...“Bomb injures 3 in Moscow subway,” Charleston ...“Bomb Injures 3 Workers in Moscow Metro,” Los ...CETIS-9-90-9NaN
2199801010003199811NaN0NaT603United Kingdom8...NaN“Protestant gunmen kill Catholic in New Year's...“Ulster Peace Shattered by Shooting: Catholic ...NaNCETIS0011NaN
3199801020001199812NaN0NaT95Iraq10...NaN“Iraq Condemns Attack on UNSCOM Baghdad Office...Farouk Choukri , “Iraq, UN Officials Continue ...“Iraqi Interior Minister on UNSCOM Attack, Kuw...CETIS-9-911NaN
4199801020002199812NaN0NaT155West Bank and Gaza Strip10...NaN“Woman Shot,” The Philadelphia Inquirer, Janua...“Israeli Woman Critically Hurt by Gunfire in W...NaNCETIS-9-90-9NaN
\n", "

5 rows × 135 columns

\n", "
" ], "text/plain": [ " eventid iyear imonth iday approxdate extended resolution country \\\n", "0 199801010001 1998 1 1 NaN 0 NaT 34 \n", "1 199801010002 1998 1 1 NaN 0 NaT 167 \n", "2 199801010003 1998 1 1 NaN 0 NaT 603 \n", "3 199801020001 1998 1 2 NaN 0 NaT 95 \n", "4 199801020002 1998 1 2 NaN 0 NaT 155 \n", "\n", " country_txt region ... addnotes \\\n", "0 Burundi 11 ... NaN \n", "1 Russia 9 ... NaN \n", "2 United Kingdom 8 ... NaN \n", "3 Iraq 10 ... NaN \n", "4 West Bank and Gaza Strip 10 ... NaN \n", "\n", " scite1 \\\n", "0 “Burundi Rebels, Ex-Rwandan Army Soldiers Blam... \n", "1 “Bomb injures 3 in Moscow subway system,” The ... \n", "2 “Protestant gunmen kill Catholic in New Year's... \n", "3 “Iraq Condemns Attack on UNSCOM Baghdad Office... \n", "4 “Woman Shot,” The Philadelphia Inquirer, Janua... \n", "\n", " scite2 \\\n", "0 “Burundi--Attack Reported on Bujumbura Airport... \n", "1 “Bomb injures 3 in Moscow subway,” Charleston ... \n", "2 “Ulster Peace Shattered by Shooting: Catholic ... \n", "3 Farouk Choukri , “Iraq, UN Officials Continue ... \n", "4 “Israeli Woman Critically Hurt by Gunfire in W... \n", "\n", " scite3 dbsource INT_LOG \\\n", "0 NaN CETIS 0 \n", "1 “Bomb Injures 3 Workers in Moscow Metro,” Los ... CETIS -9 \n", "2 NaN CETIS 0 \n", "3 “Iraqi Interior Minister on UNSCOM Attack, Kuw... CETIS -9 \n", "4 NaN CETIS -9 \n", "\n", " INT_IDEO INT_MISC INT_ANY related \n", "0 1 0 1 NaN \n", "1 -9 0 -9 NaN \n", "2 0 1 1 NaN \n", "3 -9 1 1 NaN \n", "4 -9 0 -9 NaN \n", "\n", "[5 rows x 135 columns]" ] }, "execution_count": 143, "metadata": {}, "output_type": "execute_result" } ], "source": [ "original_df.head()" ] }, { "cell_type": "code", "execution_count": 144, "id": "8466c51b-024c-45e2-85b5-3be29aa52f82", "metadata": {}, "outputs": [], "source": [ "df=original_df.copy()" ] }, { "cell_type": "code", "execution_count": 157, "id": "d0fb1f2b-74f5-4196-938e-e3f906e4984a", "metadata": {}, "outputs": [], "source": [ "df[\"y-m\"]=df[\"iyear\"].map(str)+'-'+((df[\"imonth\"]-1)//3+1).map(str)" ] }, { "cell_type": "code", "execution_count": 158, "id": "bc14a61f-6523-4099-98e6-f080889cd449", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 2015-1\n", "1 2015-1\n", "2 2015-1\n", "3 2015-1\n", "4 2015-1\n", " ... \n", "39447 2017-4\n", "39448 2017-4\n", "39449 2017-4\n", "39450 2017-4\n", "39451 2017-4\n", "Name: y-m, Length: 39452, dtype: object" ] }, "execution_count": 158, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df[\"y-m\"]" ] }, { "cell_type": "code", "execution_count": 186, "id": "246a2a4f-d100-4ec3-be38-50cb77190bcd", "metadata": {}, "outputs": [], "source": [ "df_ym_num=df.groupby(\"y-m\").size()" ] }, { "cell_type": "code", "execution_count": 378, "id": "c40da46b-9bd6-490b-8137-909f42ec947d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "y-m\n", "2015-1 4012\n", "2015-2 3761\n", "2015-3 3660\n", "2015-4 3532\n", "2016-1 3460\n", "2016-2 3629\n", "2016-3 3321\n", "2016-4 3177\n", "2017-1 2719\n", "2017-2 3023\n", "2017-3 2800\n", "2017-4 2358\n", "dtype: int64" ] }, "execution_count": 378, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_ym_num" ] }, { "cell_type": "code", "execution_count": 385, "id": "4ef2b298-abad-4492-935f-4961f7a4091d", "metadata": {}, "outputs": [], "source": [ "df_ym_num.to_csv('ym_num.csv')" ] }, { "cell_type": "code", "execution_count": 386, "id": "432312ce-79a1-4064-9b76-5ccf3a065a88", "metadata": {}, "outputs": [], "source": [ "class GrayForecast():\n", " # 初始化\n", " def __init__(self, data, n):\n", " \"\"\"\n", " :param data: Series/np/list\n", " :param n: 预测数量\n", " \"\"\"\n", " plt.rcParams['font.sans-serif'] = ['SimHei']\n", " plt.rcParams['axes.unicode_minus'] = False\n", " if isinstance(data, pd.Series):\n", " self.data = data.values\n", " elif isinstance(data, np.ndarray):\n", " self.data = data\n", " elif isinstance(data, list):\n", " self.data = np.array(data)\n", " self.level_check()\n", " self.GM_11_build_model(n)\n", " print(\"返回值为dataframe,可通过.res_df拿到, 可通过.plot_res画预测图\\n\", self.res_df)\n", "\n", " def level_check(self):\n", " # 数据级比校验\n", " b = self.data[0]\n", " n = len(self.data)\n", " lambda_k = np.zeros(n - 1)\n", " while (True):\n", " for i in range(n - 1):\n", " lambda_k[i] = self.data[i] / self.data[i + 1]\n", " if max(lambda_k) < np.exp(2 / (n + 2)) and min(lambda_k) > np.exp(-2 / (n + 1)):\n", " self.c = self.data[0] - b\n", " print(f\"完成数据 级比校验, 平移变换c={self.c}\")\n", " break\n", " else:\n", " self.data = self.data + 0.1\n", "\n", " # GM(1,1)建模\n", " def GM_11_build_model(self, n):\n", " '''\n", " 灰色预测\n", " x:序列,numpy对象\n", " n:需要往后预测的个数\n", " '''\n", " x = self.data\n", " # 累加生成(1-AGO)序列\n", " x1 = x.cumsum()\n", " # 紧邻均值生成序列\n", " z1 = (x1[:len(x1) - 1] + x1[1:]) / 2.0\n", " z1 = z1.reshape((len(z1), 1))\n", " B = np.append(-z1, np.ones_like(z1), axis=1)\n", " Y = x[1:].reshape((len(x) - 1, 1))\n", " # a为发展系数 b为灰色作用量\n", " [[a], [b]] = np.dot(np.dot(np.linalg.inv(np.dot(B.T, B)), B.T), Y) # 计算参数\n", " # 预测数据\n", " fit_res = [x[0]]\n", " for index in range(1, len(x) + n):\n", " fit_res.append((x[0] - b / a) * (1 - np.exp(a)) * np.exp(-a * (index)))\n", " # 数据还原\n", " self.data -= self.c\n", " fit_res -= self.c\n", " self.res_df = pd.concat([pd.DataFrame({'原始值': self.data}), pd.DataFrame({'预测值': fit_res})], axis=1)\n", " print(f\"发展系数a={a}, 灰色作用量b={b}\\n\")\n", " self.verfify(self.data, fit_res, a)\n", " return self.res_df\n", "\n", " def verfify(self, x, predict, a):\n", " S1_2 = x.var() # 原序列方差\n", " e = list() # 残差序列\n", " for index in range(x.shape[0]):\n", " e.append(x[index] - predict[index])\n", " S2_2 = np.array(e).var() # 残差方差\n", " C = S2_2 / S1_2 # 后验差比\n", " if C <= 0.35:\n", " assess = '后验差比<=0.35,模型精度等级为好'\n", " elif C <= 0.5:\n", " assess = '后验差比<=0.5,模型精度等级为合格'\n", " elif C <= 0.65:\n", " assess = '后验差比<=0.65,模型精度等级为勉强'\n", " else:\n", " assess = '后验差比>0.65,模型精度等级为不合格'\n", " print(f\"后验差比={C}, {assess} \\n\")\n", "\n", " # 级比偏差\n", " a_ = (1 - 0.5 * a) / (1 + 0.5 * a)\n", " delta = [np.nan]\n", " for i in range(x.shape[0] - 1):\n", " delta.append(1 - a_ * (x[i] / x[i + 1]))\n", "\n", " self.res_df = pd.concat([self.res_df, pd.DataFrame({'残差': e}),\n", " pd.DataFrame({'相对误差': list(map(lambda x: '{:.2%}'.format(x), np.abs(e / x)))}),\n", " pd.DataFrame({'级比偏差': delta})\n", " ],\n", " axis=1)\n", "\n", " def plot_res(self, xlabel='', ylabel=''):\n", " res_df = self.res_df\n", " f, ax = plt.subplots(figsize=(8, 5))\n", " sns.lineplot(x=res_df.index.tolist(), y=res_df['预测值'], linewidth=2, ax=ax)\n", " sns.scatterplot(x=res_df.index.tolist(), y=res_df['原始值'], s=60, color='r', marker='v', ax=ax)\n", " plt.fill_between(np.where(np.isnan(res_df[\"原始值\"]))[0], y1=min(plt.yticks()[0]), y2=max(plt.yticks()[0]),\n", " color='orange', alpha=0.2)\n", " ax.set_xlabel(xlabel, fontsize=15)\n", " ax.set_ylabel(ylabel, fontsize=15)\n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": 387, "id": "698d2c86-a81e-45fc-ab9c-8bb05b70da42", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "完成数据 级比校验, 平移变换c=0.5999999999999979\n", "发展系数a=0.03244780698095551, 灰色作用量b=4.6945697959855766\n", "\n", "后验差比=0.11672033342303606, 后验差比<=0.35,模型精度等级为好 \n", "\n", "返回值为dataframe,可通过.res_df拿到, 可通过.plot_res画预测图\n", " 原始值 预测值 残差 相对误差 级比偏差\n", "0 4.012 4.012000 0.000000 0.00% NaN\n", "1 3.761 3.871975 -0.110975 2.95% -0.032677\n", "2 3.660 3.729198 -0.069198 1.89% 0.005215\n", "3 3.532 3.590980 -0.058980 1.67% -0.003153\n", "4 3.460 3.457174 0.002826 0.08% 0.011785\n", "5 3.629 3.327641 0.301359 8.30% 0.077012\n", "6 3.321 3.202243 0.118757 3.58% -0.057852\n", "7 3.177 3.080849 0.096151 3.03% -0.011949\n", "8 2.719 2.963330 -0.244330 8.99% -0.131136\n", "9 3.023 2.849564 0.173436 5.74% 0.129281\n", "10 2.800 2.739429 0.060571 2.16% -0.045170\n", "11 2.358 2.632811 -0.274811 11.65% -0.149532\n", "12 NaN 2.529597 NaN NaN NaN\n", "13 NaN 2.429679 NaN NaN NaN\n", "14 NaN 2.332950 NaN NaN NaN\n", "15 NaN 2.239310 NaN NaN NaN\n", "16 NaN 2.148659 NaN NaN NaN\n", "17 NaN 2.060902 NaN NaN NaN\n", "18 NaN 1.975948 NaN NaN NaN\n", "19 NaN 1.893705 NaN NaN NaN\n", "20 NaN 1.814089 NaN NaN NaN\n", "21 NaN 1.737014 NaN NaN NaN\n" ] } ], "source": [ "gm_model=GrayForecast(df_ym_num/1000,10)" ] }, { "cell_type": "code", "execution_count": 388, "id": "12874e49-b45d-4d7f-ad06-a6dbd6e2809d", "metadata": {}, "outputs": [], "source": [ "res_df=gm_model.res_df" ] }, { "cell_type": "code", "execution_count": 389, "id": "14877d21-3ff3-47b1-823e-f054e0c71bc1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
原始值预测值残差相对误差级比偏差
04.0124.0120000.0000000.00%NaN
13.7613.871975-0.1109752.95%-0.032677
23.6603.729198-0.0691981.89%0.005215
33.5323.590980-0.0589801.67%-0.003153
43.4603.4571740.0028260.08%0.011785
53.6293.3276410.3013598.30%0.077012
63.3213.2022430.1187573.58%-0.057852
73.1773.0808490.0961513.03%-0.011949
82.7192.963330-0.2443308.99%-0.131136
93.0232.8495640.1734365.74%0.129281
102.8002.7394290.0605712.16%-0.045170
112.3582.632811-0.27481111.65%-0.149532
12NaN2.529597NaNNaNNaN
13NaN2.429679NaNNaNNaN
14NaN2.332950NaNNaNNaN
15NaN2.239310NaNNaNNaN
16NaN2.148659NaNNaNNaN
17NaN2.060902NaNNaNNaN
18NaN1.975948NaNNaNNaN
19NaN1.893705NaNNaNNaN
20NaN1.814089NaNNaNNaN
21NaN1.737014NaNNaNNaN
\n", "
" ], "text/plain": [ " 原始值 预测值 残差 相对误差 级比偏差\n", "0 4.012 4.012000 0.000000 0.00% NaN\n", "1 3.761 3.871975 -0.110975 2.95% -0.032677\n", "2 3.660 3.729198 -0.069198 1.89% 0.005215\n", "3 3.532 3.590980 -0.058980 1.67% -0.003153\n", "4 3.460 3.457174 0.002826 0.08% 0.011785\n", "5 3.629 3.327641 0.301359 8.30% 0.077012\n", "6 3.321 3.202243 0.118757 3.58% -0.057852\n", "7 3.177 3.080849 0.096151 3.03% -0.011949\n", "8 2.719 2.963330 -0.244330 8.99% -0.131136\n", "9 3.023 2.849564 0.173436 5.74% 0.129281\n", "10 2.800 2.739429 0.060571 2.16% -0.045170\n", "11 2.358 2.632811 -0.274811 11.65% -0.149532\n", "12 NaN 2.529597 NaN NaN NaN\n", "13 NaN 2.429679 NaN NaN NaN\n", "14 NaN 2.332950 NaN NaN NaN\n", "15 NaN 2.239310 NaN NaN NaN\n", "16 NaN 2.148659 NaN NaN NaN\n", "17 NaN 2.060902 NaN NaN NaN\n", "18 NaN 1.975948 NaN NaN NaN\n", "19 NaN 1.893705 NaN NaN NaN\n", "20 NaN 1.814089 NaN NaN NaN\n", "21 NaN 1.737014 NaN NaN NaN" ] }, "execution_count": 389, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res_df" ] }, { "cell_type": "code", "execution_count": 384, "id": "5c198b7c-6239-4601-ae5c-c306316d3c59", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6b81feb8c3f94da784bc6e6d199b2703", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "gm_model.plot_res(xlabel=\"2015-2017季度\",ylabel=\"案件数量\")" ] }, { "cell_type": "markdown", "id": "746566e1-419b-4ea7-9dc7-400780a4464e", "metadata": {}, "source": [ "## 线性拟合\n" ] }, { "cell_type": "code", "execution_count": 285, "id": "ebb303ae-25ae-4c46-b959-0964db15ca1b", "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LinearRegression\n", "from sklearn.preprocessing import PolynomialFeatures\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.metrics import mean_squared_error" ] }, { "cell_type": "code", "execution_count": 286, "id": "ef866073-3198-4a44-b511-6ae210f7dba2", "metadata": {}, "outputs": [], "source": [ "def polynomial_model(degree=1):\n", " polynomial_features = PolynomialFeatures(degree=degree, include_bias=False)\n", " linear_regression = LinearRegression(normalize=True)\n", " pipeline = Pipeline([(\"polynomial_features\", polynomial_features),\n", " (\"linear_regression\", linear_regression)])\n", " return pipeline" ] }, { "cell_type": "code", "execution_count": 287, "id": "80615ee1-ae0c-4859-933a-9961cdd2c499", "metadata": {}, "outputs": [], "source": [ "X=np.arange(12).reshape(-1,1)\n", "Y=np.array(df_ym_num.values).reshape(-1,1)/1000" ] }, { "cell_type": "code", "execution_count": 288, "id": "6e87522e-d634-47c2-802c-d5afe2b9a2e4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[4.012],\n", " [3.761],\n", " [3.66 ],\n", " [3.532],\n", " [3.46 ],\n", " [3.629],\n", " [3.321],\n", " [3.177],\n", " [2.719],\n", " [3.023],\n", " [2.8 ],\n", " [2.358]])" ] }, "execution_count": 288, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Y" ] }, { "cell_type": "code", "execution_count": 289, "id": "9d829377-8145-4156-94c2-e65d1451cf93", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'model': Pipeline(steps=[('polynomial_features',\n", " PolynomialFeatures(degree=1, include_bias=False)),\n", " ('linear_regression', LinearRegression(normalize=True))]),\n", " 'degree': 1,\n", " 'score': 0.8962555793414528,\n", " 'mse': 0.022496360916860938,\n", " 'RMSE': 0.14998786923235138},\n", " {'model': Pipeline(steps=[('polynomial_features',\n", " PolynomialFeatures(degree=3, include_bias=False)),\n", " ('linear_regression', LinearRegression(normalize=True))]),\n", " 'degree': 3,\n", " 'score': 0.9184471768135963,\n", " 'mse': 0.017684244921744915,\n", " 'RMSE': 0.13298212256444442},\n", " {'model': Pipeline(steps=[('polynomial_features',\n", " PolynomialFeatures(degree=5, include_bias=False)),\n", " ('linear_regression', LinearRegression(normalize=True))]),\n", " 'degree': 5,\n", " 'score': 0.9396800629966126,\n", " 'mse': 0.013080019770670167,\n", " 'RMSE': 0.11436791407851316},\n", " {'model': Pipeline(steps=[('polynomial_features',\n", " PolynomialFeatures(degree=10, include_bias=False)),\n", " ('linear_regression', LinearRegression(normalize=True))]),\n", " 'degree': 10,\n", " 'score': 0.9865833333202765,\n", " 'mse': 0.0029093244148683333,\n", " 'RMSE': 0.05393815361011474}]" ] }, "execution_count": 289, "metadata": {}, "output_type": "execute_result" } ], "source": [ "degrees = [1, 3, 5, 10]\n", "results = []\n", "for d in degrees:\n", " model = polynomial_model(degree=d)\n", " model.fit(X, Y)\n", " train_score = model.score(X, Y)\n", " mse = mean_squared_error(Y, model.predict(X))\n", " results.append({\"model\": model, \"degree\": d, \"score\": train_score, \"mse\": mse,'RMSE':np.sqrt(mse)})\n", "results" ] }, { "cell_type": "code", "execution_count": 310, "id": "5e06cc3d-c9ed-4ffe-8692-bad427984be9", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f8b3bdc31ffa43d8888c301c26dcb0e1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from matplotlib.figure import SubplotParams\n", "plt.figure(figsize=(9, 6), dpi=200, subplotpars=SubplotParams(hspace=0.3))\n", "all_x=np.arange(22).reshape(-1,1)\n", "for i, r in enumerate(results):\n", " fig = plt.subplot(2, 2, i + 1)\n", " plt.xlim(0, 22)\n", " plt.title(\"LinearRegression degree={}\".format(r[\"degree\"]))\n", " plt.scatter(X, Y, s=5, c='b', alpha=0.5)\n", " plt.plot(all_x, r[\"model\"].predict(all_x), 'r-')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 311, "id": "c1e4891a-909a-4b6b-b10b-edfb4cd6b615", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "灰色预测 {'mse': 0.025310297904259395, 'RMSE': 0.15909210509720273}\n" ] } ], "source": [ "mse = mean_squared_error(res_df['原始值'][:12], res_df['预测值'][:12])\n", "print(\"灰色预测\",{\"mse\": mse,'RMSE':np.sqrt(mse)})" ] }, { "cell_type": "markdown", "id": "c4f60b7d-d0ac-4006-b915-23d7f4257c39", "metadata": {}, "source": [ "## 幂函数拟合" ] }, { "cell_type": "code", "execution_count": 312, "id": "1dd975f7-2e5e-4489-9636-0d73a997cfb2", "metadata": {}, "outputs": [], "source": [ "from scipy.optimize import curve_fit" ] }, { "cell_type": "code", "execution_count": 333, "id": "ff4488e7-2ad7-4d04-9900-59482f44920c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[4.700944628373652, 3.7009446283736516, 3.5019459261666057, 3.367654374534397, 3.263346740479471, 3.176851924048842, 3.1023313443550826, 3.036489460403104, 2.977266626492938, 2.923281727390908, 2.873559453844978, 2.8273831520939607]\n", "[4.012 3.761 3.66 3.532 3.46 3.629 3.321 3.177 2.719 3.023 2.8 2.358]\n", "12\n" ] } ], "source": [ "def fund(x, a, b):\n", " return -x**a + b\n", "xdata=np.arange(12)\n", "ydata=np.array(df_ym_num.values)/1000\n", "popt, pcov = curve_fit(fund, xdata, ydata)\n", "#popt数组中,三个值分别是待求参数a,b,c\n", "y2 = [fund(i, popt[0],popt[1]) for i in xdata]\n", "print(y2)\n", "print(ydata)\n", "print(len(y2))" ] }, { "cell_type": "code", "execution_count": 334, "id": "5a5ea1c1-3896-434f-91a2-eaa2cf64ea85", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "806a671697fe4232b76bb35408d122bc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib widget\n", "plt.scatter(xdata, ydata, s=5, c='b', alpha=0.5)\n", "plt.plot(xdata,y2,'r-')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 337, "id": "284e99f8-353a-4f15-838d-a5bd584fe171", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "幂函数预测 {'mse': 0.09527403535238894, 'RMSE': 0.308664924072025}\n" ] } ], "source": [ "mse = mean_squared_error(ydata, y2)\n", "print(\"幂函数预测\",{\"mse\": mse,'RMSE':np.sqrt(mse)})" ] }, { "cell_type": "code", "execution_count": null, "id": "65e96ce8-bd66-400b-845c-3a8ea7a37908", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 5 }