{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 导入第三方包"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2021-03-15T00:52:16.629173Z",
"start_time": "2021-03-15T00:52:16.621194Z"
}
},
"outputs": [],
"source": [
"import os\n",
"import gc\n",
"import math\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"import lightgbm as lgb\n",
"import xgboost as xgb\n",
"from catboost import CatBoostRegressor\n",
"from sklearn.linear_model import SGDRegressor, LinearRegression, Ridge\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"\n",
"\n",
"from sklearn.model_selection import StratifiedKFold, KFold\n",
"from sklearn.metrics import log_loss\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"from tqdm import tqdm\n",
"import matplotlib.pyplot as plt\n",
"import time\n",
"import warnings\n",
"warnings.filterwarnings('ignore')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 读取数据"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2021-03-15T00:52:22.085956Z",
"start_time": "2021-03-15T00:52:19.571864Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" id | \n",
" heartbeat_signals | \n",
" label | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0 | \n",
" 0.9912297987616655,0.9435330436439665,0.764677... | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1 | \n",
" 0.9714822034884503,0.9289687459588268,0.572932... | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 2 | \n",
" 1.0,0.9591487564065292,0.7013782792997189,0.23... | \n",
" 2.0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 3 | \n",
" 0.9757952826275774,0.9340884687738161,0.659636... | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 4 | \n",
" 0.0,0.055816398940721094,0.26129357194994196,0... | \n",
" 2.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" id heartbeat_signals label\n",
"0 0 0.9912297987616655,0.9435330436439665,0.764677... 0.0\n",
"1 1 0.9714822034884503,0.9289687459588268,0.572932... 0.0\n",
"2 2 1.0,0.9591487564065292,0.7013782792997189,0.23... 2.0\n",
"3 3 0.9757952826275774,0.9340884687738161,0.659636... 0.0\n",
"4 4 0.0,0.055816398940721094,0.26129357194994196,0... 2.0"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train = pd.read_csv('train.csv')\n",
"test=pd.read_csv('testA.csv')\n",
"train.head()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2021-03-15T00:52:41.773931Z",
"start_time": "2021-03-15T00:52:41.760966Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" id | \n",
" heartbeat_signals | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 100000 | \n",
" 0.9915713654170097,1.0,0.6318163407681274,0.13... | \n",
"
\n",
" \n",
" | 1 | \n",
" 100001 | \n",
" 0.6075533139615096,0.5417083883163654,0.340694... | \n",
"
\n",
" \n",
" | 2 | \n",
" 100002 | \n",
" 0.9752726292239277,0.6710965234906665,0.686758... | \n",
"
\n",
" \n",
" | 3 | \n",
" 100003 | \n",
" 0.9956348033996116,0.9170249621481004,0.521096... | \n",
"
\n",
" \n",
" | 4 | \n",
" 100004 | \n",
" 1.0,0.8879490481178918,0.745564725322326,0.531... | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" id heartbeat_signals\n",
"0 100000 0.9915713654170097,1.0,0.6318163407681274,0.13...\n",
"1 100001 0.6075533139615096,0.5417083883163654,0.340694...\n",
"2 100002 0.9752726292239277,0.6710965234906665,0.686758...\n",
"3 100003 0.9956348033996116,0.9170249621481004,0.521096...\n",
"4 100004 1.0,0.8879490481178918,0.745564725322326,0.531..."
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 数据预处理"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2021-03-15T00:53:20.837171Z",
"start_time": "2021-03-15T00:53:20.824203Z"
}
},
"outputs": [],
"source": [
"def reduce_mem_usage(df):\n",
" start_mem = df.memory_usage().sum() / 1024**2 \n",
" print('Memory usage of dataframe is {:.2f} MB'.format(start_mem))\n",
" \n",
" for col in df.columns:\n",
" col_type = df[col].dtype\n",
" \n",
" if col_type != object:\n",
" c_min = df[col].min()\n",
" c_max = df[col].max()\n",
" if str(col_type)[:3] == 'int':\n",
" if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:\n",
" df[col] = df[col].astype(np.int8)\n",
" elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:\n",
" df[col] = df[col].astype(np.int16)\n",
" elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:\n",
" df[col] = df[col].astype(np.int32)\n",
" elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:\n",
" df[col] = df[col].astype(np.int64) \n",
" else:\n",
" if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:\n",
" df[col] = df[col].astype(np.float16)\n",
" elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:\n",
" df[col] = df[col].astype(np.float32)\n",
" else:\n",
" df[col] = df[col].astype(np.float64)\n",
" else:\n",
" df[col] = df[col].astype('category')\n",
"\n",
" end_mem = df.memory_usage().sum() / 1024**2 \n",
" print('Memory usage after optimization is: {:.2f} MB'.format(end_mem))\n",
" print('Decreased by {:.1f}%'.format(100 * (start_mem - end_mem) / start_mem))\n",
" \n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2021-03-15T00:54:18.244721Z",
"start_time": "2021-03-15T00:53:59.807775Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Memory usage of dataframe is 157.93 MB\n",
"Memory usage after optimization is: 39.67 MB\n",
"Decreased by 74.9%\n",
"Memory usage of dataframe is 31.43 MB\n",
"Memory usage after optimization is: 7.90 MB\n",
"Decreased by 74.9%\n"
]
}
],
"source": [
"# 简单预处理\n",
"train_list = []\n",
"\n",
"for items in train.values:\n",
" train_list.append([items[0]] + [float(i) for i in items[1].split(',')] + [items[2]])\n",
"\n",
"train = pd.DataFrame(np.array(train_list))\n",
"train.columns = ['id'] + ['s_'+str(i) for i in range(len(train_list[0])-2)] + ['label']\n",
"train = reduce_mem_usage(train)\n",
"\n",
"test_list=[]\n",
"for items in test.values:\n",
" test_list.append([items[0]] + [float(i) for i in items[1].split(',')])\n",
"\n",
"test = pd.DataFrame(np.array(test_list))\n",
"test.columns = ['id'] + ['s_'+str(i) for i in range(len(test_list[0])-1)]\n",
"test = reduce_mem_usage(test)\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"ExecuteTime": {
"end_time": "2021-03-15T00:54:57.351196Z",
"start_time": "2021-03-15T00:54:57.321310Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" id | \n",
" s_0 | \n",
" s_1 | \n",
" s_2 | \n",
" s_3 | \n",
" s_4 | \n",
" s_5 | \n",
" s_6 | \n",
" s_7 | \n",
" s_8 | \n",
" ... | \n",
" s_196 | \n",
" s_197 | \n",
" s_198 | \n",
" s_199 | \n",
" s_200 | \n",
" s_201 | \n",
" s_202 | \n",
" s_203 | \n",
" s_204 | \n",
" label | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.0 | \n",
" 0.991211 | \n",
" 0.943359 | \n",
" 0.764648 | \n",
" 0.618652 | \n",
" 0.379639 | \n",
" 0.190796 | \n",
" 0.040222 | \n",
" 0.026001 | \n",
" 0.031708 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1.0 | \n",
" 0.971680 | \n",
" 0.929199 | \n",
" 0.572754 | \n",
" 0.178467 | \n",
" 0.122986 | \n",
" 0.132324 | \n",
" 0.094421 | \n",
" 0.089600 | \n",
" 0.030487 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 2.0 | \n",
" 1.000000 | \n",
" 0.958984 | \n",
" 0.701172 | \n",
" 0.231812 | \n",
" 0.000000 | \n",
" 0.080688 | \n",
" 0.128418 | \n",
" 0.187500 | \n",
" 0.280762 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 2.0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 3.0 | \n",
" 0.975586 | \n",
" 0.934082 | \n",
" 0.659668 | \n",
" 0.249878 | \n",
" 0.237061 | \n",
" 0.281494 | \n",
" 0.249878 | \n",
" 0.249878 | \n",
" 0.241455 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 4.0 | \n",
" 0.000000 | \n",
" 0.055817 | \n",
" 0.261230 | \n",
" 0.359863 | \n",
" 0.433105 | \n",
" 0.453613 | \n",
" 0.499023 | \n",
" 0.542969 | \n",
" 0.616699 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 2.0 | \n",
"
\n",
" \n",
"
\n",
"
5 rows × 207 columns
\n",
"
"
],
"text/plain": [
" id s_0 s_1 s_2 s_3 s_4 s_5 s_6 \\\n",
"0 0.0 0.991211 0.943359 0.764648 0.618652 0.379639 0.190796 0.040222 \n",
"1 1.0 0.971680 0.929199 0.572754 0.178467 0.122986 0.132324 0.094421 \n",
"2 2.0 1.000000 0.958984 0.701172 0.231812 0.000000 0.080688 0.128418 \n",
"3 3.0 0.975586 0.934082 0.659668 0.249878 0.237061 0.281494 0.249878 \n",
"4 4.0 0.000000 0.055817 0.261230 0.359863 0.433105 0.453613 0.499023 \n",
"\n",
" s_7 s_8 ... s_196 s_197 s_198 s_199 s_200 s_201 s_202 \\\n",
"0 0.026001 0.031708 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"1 0.089600 0.030487 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"2 0.187500 0.280762 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"3 0.249878 0.241455 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"4 0.542969 0.616699 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" s_203 s_204 label \n",
"0 0.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 \n",
"2 0.0 0.0 2.0 \n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 2.0 \n",
"\n",
"[5 rows x 207 columns]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train.head()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"ExecuteTime": {
"end_time": "2021-03-15T00:55:36.644040Z",
"start_time": "2021-03-15T00:55:36.619678Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" id | \n",
" s_0 | \n",
" s_1 | \n",
" s_2 | \n",
" s_3 | \n",
" s_4 | \n",
" s_5 | \n",
" s_6 | \n",
" s_7 | \n",
" s_8 | \n",
" ... | \n",
" s_195 | \n",
" s_196 | \n",
" s_197 | \n",
" s_198 | \n",
" s_199 | \n",
" s_200 | \n",
" s_201 | \n",
" s_202 | \n",
" s_203 | \n",
" s_204 | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 100000.0 | \n",
" 0.991699 | \n",
" 1.000000 | \n",
" 0.631836 | \n",
" 0.136230 | \n",
" 0.041412 | \n",
" 0.102722 | \n",
" 0.120850 | \n",
" 0.123413 | \n",
" 0.107910 | \n",
" ... | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.00000 | \n",
"
\n",
" \n",
" | 1 | \n",
" 100001.0 | \n",
" 0.607422 | \n",
" 0.541504 | \n",
" 0.340576 | \n",
" 0.000000 | \n",
" 0.090698 | \n",
" 0.164917 | \n",
" 0.195068 | \n",
" 0.168823 | \n",
" 0.198853 | \n",
" ... | \n",
" 0.389893 | \n",
" 0.386963 | \n",
" 0.367188 | \n",
" 0.364014 | \n",
" 0.360596 | \n",
" 0.357178 | \n",
" 0.350586 | \n",
" 0.350586 | \n",
" 0.350586 | \n",
" 0.36377 | \n",
"
\n",
" \n",
" | 2 | \n",
" 100002.0 | \n",
" 0.975098 | \n",
" 0.670898 | \n",
" 0.686523 | \n",
" 0.708496 | \n",
" 0.718750 | \n",
" 0.716797 | \n",
" 0.720703 | \n",
" 0.701660 | \n",
" 0.596680 | \n",
" ... | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.00000 | \n",
"
\n",
" \n",
" | 3 | \n",
" 100003.0 | \n",
" 0.995605 | \n",
" 0.916992 | \n",
" 0.520996 | \n",
" 0.000000 | \n",
" 0.221802 | \n",
" 0.404053 | \n",
" 0.490479 | \n",
" 0.527344 | \n",
" 0.518066 | \n",
" ... | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.00000 | \n",
"
\n",
" \n",
" | 4 | \n",
" 100004.0 | \n",
" 1.000000 | \n",
" 0.888184 | \n",
" 0.745605 | \n",
" 0.531738 | \n",
" 0.380371 | \n",
" 0.224609 | \n",
" 0.091125 | \n",
" 0.057648 | \n",
" 0.003914 | \n",
" ... | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.00000 | \n",
"
\n",
" \n",
"
\n",
"
5 rows × 206 columns
\n",
"
"
],
"text/plain": [
" id s_0 s_1 s_2 s_3 s_4 s_5 \\\n",
"0 100000.0 0.991699 1.000000 0.631836 0.136230 0.041412 0.102722 \n",
"1 100001.0 0.607422 0.541504 0.340576 0.000000 0.090698 0.164917 \n",
"2 100002.0 0.975098 0.670898 0.686523 0.708496 0.718750 0.716797 \n",
"3 100003.0 0.995605 0.916992 0.520996 0.000000 0.221802 0.404053 \n",
"4 100004.0 1.000000 0.888184 0.745605 0.531738 0.380371 0.224609 \n",
"\n",
" s_6 s_7 s_8 ... s_195 s_196 s_197 s_198 \\\n",
"0 0.120850 0.123413 0.107910 ... 0.000000 0.000000 0.000000 0.000000 \n",
"1 0.195068 0.168823 0.198853 ... 0.389893 0.386963 0.367188 0.364014 \n",
"2 0.720703 0.701660 0.596680 ... 0.000000 0.000000 0.000000 0.000000 \n",
"3 0.490479 0.527344 0.518066 ... 0.000000 0.000000 0.000000 0.000000 \n",
"4 0.091125 0.057648 0.003914 ... 0.000000 0.000000 0.000000 0.000000 \n",
"\n",
" s_199 s_200 s_201 s_202 s_203 s_204 \n",
"0 0.000000 0.000000 0.000000 0.000000 0.000000 0.00000 \n",
"1 0.360596 0.357178 0.350586 0.350586 0.350586 0.36377 \n",
"2 0.000000 0.000000 0.000000 0.000000 0.000000 0.00000 \n",
"3 0.000000 0.000000 0.000000 0.000000 0.000000 0.00000 \n",
"4 0.000000 0.000000 0.000000 0.000000 0.000000 0.00000 \n",
"\n",
"[5 rows x 206 columns]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 训练数据/测试数据准备"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"ExecuteTime": {
"end_time": "2021-03-15T00:56:15.971953Z",
"start_time": "2021-03-15T00:56:15.876344Z"
}
},
"outputs": [],
"source": [
"x_train = train.drop(['id','label'], axis=1)\n",
"y_train = train['label']\n",
"x_test=test.drop(['id'], axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 模型训练"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"ExecuteTime": {
"end_time": "2021-03-15T00:57:01.757175Z",
"start_time": "2021-03-15T00:57:01.750341Z"
}
},
"outputs": [],
"source": [
"def abs_sum(y_pre,y_tru):\n",
" y_pre=np.array(y_pre)\n",
" y_tru=np.array(y_tru)\n",
" loss=sum(sum(abs(y_pre-y_tru)))\n",
" return loss\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"ExecuteTime": {
"end_time": "2021-03-15T00:57:42.940805Z",
"start_time": "2021-03-15T00:57:42.928082Z"
}
},
"outputs": [],
"source": [
"def cv_model(clf, train_x, train_y, test_x, clf_name):\n",
" folds = 5\n",
" seed = 2021\n",
" kf = KFold(n_splits=folds, shuffle=True, random_state=seed)\n",
" test = np.zeros((test_x.shape[0],4))\n",
"\n",
" cv_scores = []\n",
" onehot_encoder = OneHotEncoder(sparse=False)\n",
" for i, (train_index, valid_index) in enumerate(kf.split(train_x, train_y)):\n",
" print('************************************ {} ************************************'.format(str(i+1)))\n",
" trn_x, trn_y, val_x, val_y = train_x.iloc[train_index], train_y[train_index], train_x.iloc[valid_index], train_y[valid_index]\n",
" \n",
" if clf_name == \"lgb\":\n",
" train_matrix = clf.Dataset(trn_x, label=trn_y)\n",
" valid_matrix = clf.Dataset(val_x, label=val_y)\n",
"\n",
" params = {\n",
" 'boosting_type': 'gbdt',\n",
" 'objective': 'multiclass',\n",
" 'num_class': 4,\n",
" 'num_leaves': 2 ** 5,\n",
" 'feature_fraction': 0.8,\n",
" 'bagging_fraction': 0.8,\n",
" 'bagging_freq': 4,\n",
" 'learning_rate': 0.1,\n",
" 'seed': seed,\n",
" 'nthread': 28,\n",
" 'n_jobs':24,\n",
" 'verbose': -1,\n",
" }\n",
"\n",
" model = clf.train(params, \n",
" train_set=train_matrix, \n",
" valid_sets=valid_matrix, \n",
" num_boost_round=2000, \n",
" verbose_eval=100, \n",
" early_stopping_rounds=200)\n",
" val_pred = model.predict(val_x, num_iteration=model.best_iteration)\n",
" test_pred = model.predict(test_x, num_iteration=model.best_iteration) \n",
" \n",
" val_y=np.array(val_y).reshape(-1, 1)\n",
" val_y = onehot_encoder.fit_transform(val_y)\n",
" print('预测的概率矩阵为:')\n",
" print(test_pred)\n",
" test += test_pred\n",
" score=abs_sum(val_y, val_pred)\n",
" cv_scores.append(score)\n",
" print(cv_scores)\n",
" print(\"%s_scotrainre_list:\" % clf_name, cv_scores)\n",
" print(\"%s_score_mean:\" % clf_name, np.mean(cv_scores))\n",
" print(\"%s_score_std:\" % clf_name, np.std(cv_scores))\n",
" test=test/kf.n_splits\n",
"\n",
" return test"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"ExecuteTime": {
"end_time": "2021-03-15T00:58:22.378103Z",
"start_time": "2021-03-15T00:58:22.373222Z"
}
},
"outputs": [],
"source": [
"def lgb_model(x_train, y_train, x_test):\n",
" lgb_test = cv_model(lgb, x_train, y_train, x_test, \"lgb\")\n",
" return lgb_test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2021-03-15T00:53:32.384Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"************************************ 1 ************************************\n",
"Training until validation scores don't improve for 200 rounds\n",
"[100]\tvalid_0's multi_logloss: 0.0525735\n",
"[200]\tvalid_0's multi_logloss: 0.0422444\n",
"[300]\tvalid_0's multi_logloss: 0.0407076\n",
"[400]\tvalid_0's multi_logloss: 0.0420398\n",
"Early stopping, best iteration is:\n",
"[289]\tvalid_0's multi_logloss: 0.0405457\n",
"预测的概率矩阵为:\n",
"[[9.99969791e-01 2.85197261e-05 1.00341946e-06 6.85357631e-07]\n",
" [7.93287264e-05 7.69060914e-04 9.99151590e-01 2.00810971e-08]\n",
" [5.75356884e-07 5.04051497e-08 3.15322414e-07 9.99999059e-01]\n",
" ...\n",
" [6.79267940e-02 4.30206297e-04 9.31640185e-01 2.81516302e-06]\n",
" [9.99960477e-01 3.94098074e-05 8.34030725e-08 2.94638661e-08]\n",
" [9.88705846e-01 2.14081630e-03 6.67418381e-03 2.47915423e-03]]\n",
"[607.0736049372186]\n",
"************************************ 2 ************************************\n",
"[LightGBM] [Warning] num_threads is set with nthread=28, will be overridden by n_jobs=24. Current value: num_threads=24\n",
"Training until validation scores don't improve for 200 rounds\n",
"[100]\tvalid_0's multi_logloss: 0.0566626\n",
"[200]\tvalid_0's multi_logloss: 0.0450852\n"
]
}
],
"source": [
"lgb_test = lgb_model(x_train, y_train, x_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2021-03-15T00:53:33.065Z"
}
},
"outputs": [],
"source": [
"lgb_test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2021-03-15T00:53:33.810Z"
}
},
"outputs": [],
"source": [
"temp=pd.DataFrame(lgb_test)\n",
"temp"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2021-03-15T00:53:34.680Z"
}
},
"outputs": [],
"source": [
"result=pd.read_csv('sample_submit.csv')\n",
"result['label_0']=temp[0]\n",
"result['label_1']=temp[1]\n",
"result['label_2']=temp[2]\n",
"result['label_3']=temp[3]\n",
"result.to_csv('submit.csv',index=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": true
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}