{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split, RandomizedSearchCV, cross_val_score\n", "from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB\n", "from sklearn.svm import SVC\n", "from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier, AdaBoostClassifier, GradientBoostingClassifier\n", "from sklearn.metrics import classification_report, accuracy_score, confusion_matrix\n", "import numpy as np\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Naive bayes" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "BernoulliNB()" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.datasets import load_iris\n", "iris = load_iris()\n", "X_train, X_test, y_train, y_test = train_test_split(iris.data,\n", " iris.target,\n", " test_size=0.2,\n", " random_state=0)\n", "\n", "gauss_clf = GaussianNB()\n", "multi_clf = MultinomialNB()\n", "bernl_clf = BernoulliNB()\n", "\n", "gauss_clf.fit(X_train, y_train)\n", "multi_clf.fit(X_train, y_train)\n", "bernl_clf.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "y_pred_gauss = gauss_clf.predict(X_test)\n", "y_pred_multi = multi_clf.predict(X_test)\n", "y_pred_bernl = bernl_clf.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 1.00 1.00 1.00 11\n", " 1 0.93 1.00 0.96 13\n", " 2 1.00 0.83 0.91 6\n", "\n", " accuracy 0.97 30\n", " macro avg 0.98 0.94 0.96 30\n", "weighted avg 0.97 0.97 0.97 30\n", "\n", " precision recall f1-score support\n", "\n", " 0 1.00 1.00 1.00 11\n", " 1 0.00 0.00 0.00 13\n", " 2 0.32 1.00 0.48 6\n", "\n", " accuracy 0.57 30\n", " macro avg 0.44 0.67 0.49 30\n", "weighted avg 0.43 0.57 0.46 30\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.00 0.00 0.00 11\n", " 1 0.00 0.00 0.00 13\n", " 2 0.20 1.00 0.33 6\n", "\n", " accuracy 0.20 30\n", " macro avg 0.07 0.33 0.11 30\n", "weighted avg 0.04 0.20 0.07 30\n", "\n" ] } ], "source": [ "print(classification_report(y_test, y_pred_gauss))\n", "print(classification_report(y_test, y_pred_multi))\n", "print(classification_report(y_test, y_pred_bernl))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## SVM" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_svmlight_file\n", "svc = SVC(kernel='rbf', random_state=101)\n", "X_train, y_train = load_svmlight_file('data_set/ijcnn1.bz2')" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SVC with rbf kernel -> cross validation accuracy: mean = 0.9625, std = 0.0185\n", "Wall time: 31.3 s\n" ] } ], "source": [ "%%time\n", "scores = cross_val_score(svc,\n", " X_train,\n", " y_train,\n", " cv=5,\n", " scoring='accuracy',\n", " n_jobs=-1)\n", "print(\n", " 'SVC with rbf kernel -> cross validation accuracy: mean = {:.4f}, std = {:.4f}'\n", " .format(np.mean(scores), np.std(scores)))" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best parameters {'gamma': 0.1, 'C': 100}\n", "Cross validation accuracy: mean= 0.9625\n", "Wall time: 9min 7s\n" ] } ], "source": [ "%%time\n", "svc_new = SVC(kernel='rbf', random_state=101)\n", "search_dict = {\n", " 'C': [0.01, 0.1, 1, 10, 100],\n", " 'gamma': [0.1, 0.01, 0.001, 0.0001]\n", "}\n", "search_func = RandomizedSearchCV(estimator=svc_new,\n", " param_distributions=search_dict,\n", " n_iter=10,\n", " scoring='accuracy',\n", " n_jobs=-1,\n", " iid=True,\n", " refit=True,\n", " cv=5,\n", " random_state=101)\n", "search_func.fit(X_train, y_train)\n", "print('Best parameters {}'.format(search_func.best_params_))\n", "print('Cross validation accuracy: mean= {:.4f}'.format(search_func.best_score_))" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SVC with rbf kernel -> cross validation accuracy: mean = 0.9625, std = 0.0185\n", "Wall time: 41 s\n" ] } ], "source": [ "%%time\n", "svc_best = SVC(C=100, gamma=0.1, kernel='rbf', random_state=101)\n", "svc_best.fit(X_train, y_train)\n", "print(\n", " 'SVC with rbf kernel -> cross validation accuracy: mean = {:.4f}, std = {:.4f}'\n", " .format(np.mean(scores), np.std(scores)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## RandomForest ExtraTrees" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(581012, 54)" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.datasets import fetch_covtype\n", "covertype = fetch_covtype()\n", "covertype.data.shape" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [], "source": [ "covertype_x = covertype.data\n", "covertype_y = covertype.target\n", "covertype_x_train, covertype_x_test_val, covertype_y_train, covertype_y_test_val = train_test_split(\n", " covertype_x, covertype_y, test_size=0.4, random_state=42)\n", "covertype_x_test, covertype_x_val, covertype_y_test, covertype_y_val = train_test_split(\n", " covertype_x_test_val, covertype_y_test_val, test_size=0.5, random_state=42)\n", "covertypes = [\n", " 'Spruce/Fir', 'Lodgepole Pine', 'Ponderosa Pine', 'Cottonwood/Willow',\n", " 'Aspen', 'Douglas-fir', 'Krummholz'\n", "]" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(348607, 54)\n", "(116203, 54)\n", "(116202, 54)\n" ] } ], "source": [ "print(covertype_x_train.shape)\n", "print(covertype_x_val.shape)\n", "print(covertype_x_test.shape)" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RandomForestClassifier -> cross validation accurary: mean = 0.9431, std = 0.0007\n", "Wall time: 2min 41s\n" ] } ], "source": [ "%%time\n", "rfc = RandomForestClassifier(n_estimators=100, random_state=101)\n", "scores = cross_val_score(rfc,\n", " covertype_x_train,\n", " covertype_y_train,\n", " cv=5,\n", " scoring='accuracy',\n", " n_jobs=-1)\n", "print(\n", " 'RandomForestClassifier -> cross validation accurary: mean = {:.4f}, std = {:.4f}'\n", " .format(np.mean(scores), np.std(scores)))" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.94251456, 0.94254324, 0.94264282, 0.94429225, 0.94356076])" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scores" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ExtraTreesClassifier -> cross validation accurary: mean = 0.9426, std = 0.0008\n", "Wall time: 2min 49s\n" ] } ], "source": [ "%%time\n", "etc = ExtraTreesClassifier(n_estimators=100, random_state=101)\n", "scores = cross_val_score(etc,\n", " covertype_x_train,\n", " covertype_y_train,\n", " cv=5,\n", " scoring='accuracy',\n", " n_jobs=-1)\n", "print(\n", " 'ExtraTreesClassifier -> cross validation accurary: mean = {:.4f}, std = {:.4f}'\n", " .format(np.mean(scores), np.std(scores)))" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.94211296, 0.94245719, 0.94202608, 0.94426356, 0.94219819])" ] }, "execution_count": 83, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scores" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## CalibrationClassifierCV" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from sklearn.calibration import CalibratedClassifierCV\n", "from sklearn.calibration import calibration_curve" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "rfc = RandomForestClassifier(n_estimators=100, random_state=101)\n", "calibration = CalibratedClassifierCV(rfc, method='sigmoid', cv=5)" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [], "source": [ "rfc.fit(covertype_x_train, covertype_y_train)\n", "calibration.fit(covertype_x_train, covertype_y_train)\n", "prob_raw = rfc.predict_proba(covertype_x_test)\n", "prob_cal = calibration.predict_proba(covertype_x_test)" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
rawcalibrated
00.000.001132
10.420.356083
20.000.001131
30.000.001130
40.000.001130
.........
1161970.000.001131
1161980.000.001128
1161990.320.079379
1162000.000.001134
1162010.000.001130
\n", "

116202 rows × 2 columns

\n", "
" ], "text/plain": [ " raw calibrated\n", "0 0.00 0.001132\n", "1 0.42 0.356083\n", "2 0.00 0.001131\n", "3 0.00 0.001130\n", "4 0.00 0.001130\n", "... ... ...\n", "116197 0.00 0.001131\n", "116198 0.00 0.001128\n", "116199 0.32 0.079379\n", "116200 0.00 0.001134\n", "116201 0.00 0.001130\n", "\n", "[116202 rows x 2 columns]" ] }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%matplotlib inline\n", "tree_kind = covertypes.index('Ponderosa Pine')\n", "probs = pd.DataFrame(list(zip(prob_raw[:, tree_kind], prob_cal[:, tree_kind])),\n", " columns=['raw', 'calibrated'])\n", "probs" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot = probs.plot(kind='scatter', x=0, y=1, s=64, c='blue', edgecolors='white')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## AdaBoost" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Adaboost -> cross validation accurary: mean = 0.4493, std = 0.0587\n", "Wall time: 6min 1s\n" ] } ], "source": [ "%%time\n", "adbc = AdaBoostClassifier(n_estimators=300, random_state=101)\n", "scores = cross_val_score(adbc,\n", " covertype_x_train,\n", " covertype_y_train,\n", " cv=5,\n", " scoring='accuracy',\n", " n_jobs=-1)\n", "print(\n", " 'Adaboost -> cross validation accurary: mean = {:.4f}, std = {:.4f}'.format(\n", " np.mean(scores), np.std(scores)))" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.38596139, 0.40148016, 0.42959797, 0.48293914, 0.54637771])" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scores" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GradientBoost" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Wall time: 19min 18s\n" ] }, { "data": { "text/plain": [ "GradientBoostingClassifier(max_depth=5, n_estimators=50, random_state=101)" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "gbc = GradientBoostingClassifier(max_depth=5, n_estimators=50, random_state=101)\n", "gbc.fit(covertype_x_train, covertype_y_train)" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.948254347994458" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_score(covertype_y_val, rfc.predict(covertype_x_val))" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9477547718627907" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_score(covertype_y_test, rfc.predict(covertype_x_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## XGBoost" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [], "source": [ "import xgboost as xgb\n", "xgb_model = xgb.XGBClassifier(object='multi:softprob',\n", " max_depth=24,\n", " gamma=0.1,\n", " subsample=0.9,\n", " learning_rate=0.01,\n", " n_estimators=500,\n", " nthread=-1)" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[20:01:03] WARNING: C:\\Users\\Administrator\\workspace\\xgboost-win64_release_1.1.0\\src\\learner.cc:480: \n", "Parameters: { object } might not be used.\n", "\n", " This may not be accurate due to some parameters are only used in language bindings but\n", " passed down to XGBoost core. Or some parameters are not used but slip through this\n", " verification. Please open an issue if you find above cases.\n", "\n", "\n", "[0]\tvalidation_0-merror:0.07679\n", "Will train until validation_0-merror hasn't improved in 25 rounds.\n", "[1]\tvalidation_0-merror:0.06606\n", "[2]\tvalidation_0-merror:0.06125\n", "[3]\tvalidation_0-merror:0.05930\n", "[4]\tvalidation_0-merror:0.05846\n", "[5]\tvalidation_0-merror:0.05813\n", "[6]\tvalidation_0-merror:0.05737\n", "[7]\tvalidation_0-merror:0.05706\n", "[8]\tvalidation_0-merror:0.05647\n", "[9]\tvalidation_0-merror:0.05622\n", "[10]\tvalidation_0-merror:0.05571\n", "[11]\tvalidation_0-merror:0.05525\n", "[12]\tvalidation_0-merror:0.05520\n", "[13]\tvalidation_0-merror:0.05461\n", "[14]\tvalidation_0-merror:0.05428\n", "[15]\tvalidation_0-merror:0.05423\n", "[16]\tvalidation_0-merror:0.05384\n", "[17]\tvalidation_0-merror:0.05371\n", "[18]\tvalidation_0-merror:0.05373\n", "[19]\tvalidation_0-merror:0.05355\n", "[20]\tvalidation_0-merror:0.05331\n", "[21]\tvalidation_0-merror:0.05326\n", "[22]\tvalidation_0-merror:0.05323\n", "[23]\tvalidation_0-merror:0.05309\n", "[24]\tvalidation_0-merror:0.05298\n", "[25]\tvalidation_0-merror:0.05294\n", "[26]\tvalidation_0-merror:0.05270\n", "[27]\tvalidation_0-merror:0.05259\n", "[28]\tvalidation_0-merror:0.05238\n", "[29]\tvalidation_0-merror:0.05220\n", "[30]\tvalidation_0-merror:0.05217\n", "[31]\tvalidation_0-merror:0.05187\n", "[32]\tvalidation_0-merror:0.05174\n", "[33]\tvalidation_0-merror:0.05150\n", "[34]\tvalidation_0-merror:0.05134\n", "[35]\tvalidation_0-merror:0.05123\n", "[36]\tvalidation_0-merror:0.05101\n", "[37]\tvalidation_0-merror:0.05097\n", "[38]\tvalidation_0-merror:0.05085\n", "[39]\tvalidation_0-merror:0.05071\n", "[40]\tvalidation_0-merror:0.05063\n", "[41]\tvalidation_0-merror:0.05036\n", "[42]\tvalidation_0-merror:0.05034\n", "[43]\tvalidation_0-merror:0.05027\n", "[44]\tvalidation_0-merror:0.05007\n", "[45]\tvalidation_0-merror:0.05005\n", "[46]\tvalidation_0-merror:0.04995\n", "[47]\tvalidation_0-merror:0.04983\n", "[48]\tvalidation_0-merror:0.04978\n", "[49]\tvalidation_0-merror:0.04973\n", "[50]\tvalidation_0-merror:0.04947\n", "[51]\tvalidation_0-merror:0.04936\n", "[52]\tvalidation_0-merror:0.04921\n", "[53]\tvalidation_0-merror:0.04908\n", "[54]\tvalidation_0-merror:0.04894\n", "[55]\tvalidation_0-merror:0.04877\n", "[56]\tvalidation_0-merror:0.04866\n", "[57]\tvalidation_0-merror:0.04861\n", "[58]\tvalidation_0-merror:0.04847\n", "[59]\tvalidation_0-merror:0.04835\n", "[60]\tvalidation_0-merror:0.04824\n", "[61]\tvalidation_0-merror:0.04821\n", "[62]\tvalidation_0-merror:0.04808\n", "[63]\tvalidation_0-merror:0.04806\n", "[64]\tvalidation_0-merror:0.04791\n", "[65]\tvalidation_0-merror:0.04780\n", "[66]\tvalidation_0-merror:0.04784\n", "[67]\tvalidation_0-merror:0.04778\n", "[68]\tvalidation_0-merror:0.04767\n", "[69]\tvalidation_0-merror:0.04750\n", "[70]\tvalidation_0-merror:0.04750\n", "[71]\tvalidation_0-merror:0.04736\n", "[72]\tvalidation_0-merror:0.04727\n", "[73]\tvalidation_0-merror:0.04718\n", "[74]\tvalidation_0-merror:0.04713\n", "[75]\tvalidation_0-merror:0.04707\n", "[76]\tvalidation_0-merror:0.04705\n", "[77]\tvalidation_0-merror:0.04698\n", "[78]\tvalidation_0-merror:0.04694\n", "[79]\tvalidation_0-merror:0.04675\n", "[80]\tvalidation_0-merror:0.04668\n", "[81]\tvalidation_0-merror:0.04660\n", "[82]\tvalidation_0-merror:0.04650\n", "[83]\tvalidation_0-merror:0.04640\n", "[84]\tvalidation_0-merror:0.04630\n", "[85]\tvalidation_0-merror:0.04623\n", "[86]\tvalidation_0-merror:0.04613\n", "[87]\tvalidation_0-merror:0.04598\n", "[88]\tvalidation_0-merror:0.04586\n", "[89]\tvalidation_0-merror:0.04583\n", "[90]\tvalidation_0-merror:0.04583\n", "[91]\tvalidation_0-merror:0.04573\n", "[92]\tvalidation_0-merror:0.04557\n", "[93]\tvalidation_0-merror:0.04546\n", "[94]\tvalidation_0-merror:0.04540\n", "[95]\tvalidation_0-merror:0.04529\n", "[96]\tvalidation_0-merror:0.04527\n", "[97]\tvalidation_0-merror:0.04522\n", "[98]\tvalidation_0-merror:0.04521\n", "[99]\tvalidation_0-merror:0.04514\n", "[100]\tvalidation_0-merror:0.04515\n", "[101]\tvalidation_0-merror:0.04514\n", "[102]\tvalidation_0-merror:0.04503\n", "[103]\tvalidation_0-merror:0.04501\n", "[104]\tvalidation_0-merror:0.04490\n", "[105]\tvalidation_0-merror:0.04480\n", "[106]\tvalidation_0-merror:0.04476\n", "[107]\tvalidation_0-merror:0.04468\n", "[108]\tvalidation_0-merror:0.04463\n", "[109]\tvalidation_0-merror:0.04452\n", "[110]\tvalidation_0-merror:0.04430\n", "[111]\tvalidation_0-merror:0.04428\n", "[112]\tvalidation_0-merror:0.04422\n", "[113]\tvalidation_0-merror:0.04419\n", "[114]\tvalidation_0-merror:0.04416\n", "[115]\tvalidation_0-merror:0.04403\n", "[116]\tvalidation_0-merror:0.04398\n", "[117]\tvalidation_0-merror:0.04397\n", "[118]\tvalidation_0-merror:0.04387\n", "[119]\tvalidation_0-merror:0.04374\n", "[120]\tvalidation_0-merror:0.04375\n", "[121]\tvalidation_0-merror:0.04372\n", "[122]\tvalidation_0-merror:0.04356\n", "[123]\tvalidation_0-merror:0.04346\n", "[124]\tvalidation_0-merror:0.04340\n", "[125]\tvalidation_0-merror:0.04333\n", "[126]\tvalidation_0-merror:0.04335\n", "[127]\tvalidation_0-merror:0.04322\n", "[128]\tvalidation_0-merror:0.04318\n", "[129]\tvalidation_0-merror:0.04313\n", "[130]\tvalidation_0-merror:0.04304\n", "[131]\tvalidation_0-merror:0.04290\n", "[132]\tvalidation_0-merror:0.04284\n", "[133]\tvalidation_0-merror:0.04269\n", "[134]\tvalidation_0-merror:0.04258\n", "[135]\tvalidation_0-merror:0.04260\n", "[136]\tvalidation_0-merror:0.04252\n", "[137]\tvalidation_0-merror:0.04240\n", "[138]\tvalidation_0-merror:0.04240\n", "[139]\tvalidation_0-merror:0.04234\n", "[140]\tvalidation_0-merror:0.04216\n", "[141]\tvalidation_0-merror:0.04208\n", "[142]\tvalidation_0-merror:0.04197\n", "[143]\tvalidation_0-merror:0.04184\n", "[144]\tvalidation_0-merror:0.04183\n", "[145]\tvalidation_0-merror:0.04177\n", "[146]\tvalidation_0-merror:0.04169\n", "[147]\tvalidation_0-merror:0.04163\n", "[148]\tvalidation_0-merror:0.04149\n", "[149]\tvalidation_0-merror:0.04148\n", "[150]\tvalidation_0-merror:0.04144\n", "[151]\tvalidation_0-merror:0.04137\n", "[152]\tvalidation_0-merror:0.04125\n", "[153]\tvalidation_0-merror:0.04120\n", "[154]\tvalidation_0-merror:0.04112\n", "[155]\tvalidation_0-merror:0.04102\n", "[156]\tvalidation_0-merror:0.04100\n", "[157]\tvalidation_0-merror:0.04089\n", "[158]\tvalidation_0-merror:0.04083\n", "[159]\tvalidation_0-merror:0.04078\n", "[160]\tvalidation_0-merror:0.04074\n", "[161]\tvalidation_0-merror:0.04073\n", "[162]\tvalidation_0-merror:0.04069\n", "[163]\tvalidation_0-merror:0.04061\n", "[164]\tvalidation_0-merror:0.04054\n", "[165]\tvalidation_0-merror:0.04049\n", "[166]\tvalidation_0-merror:0.04044\n", "[167]\tvalidation_0-merror:0.04032\n", "[168]\tvalidation_0-merror:0.04027\n", "[169]\tvalidation_0-merror:0.04026\n", "[170]\tvalidation_0-merror:0.04015\n", "[171]\tvalidation_0-merror:0.04009\n", "[172]\tvalidation_0-merror:0.04001\n", "[173]\tvalidation_0-merror:0.03998\n", "[174]\tvalidation_0-merror:0.03983\n", "[175]\tvalidation_0-merror:0.03989\n", "[176]\tvalidation_0-merror:0.03986\n", "[177]\tvalidation_0-merror:0.03977\n", "[178]\tvalidation_0-merror:0.03973\n", "[179]\tvalidation_0-merror:0.03963\n", "[180]\tvalidation_0-merror:0.03957\n", "[181]\tvalidation_0-merror:0.03956\n", "[182]\tvalidation_0-merror:0.03947\n", "[183]\tvalidation_0-merror:0.03944\n", "[184]\tvalidation_0-merror:0.03938\n", "[185]\tvalidation_0-merror:0.03930\n", "[186]\tvalidation_0-merror:0.03932\n", "[187]\tvalidation_0-merror:0.03920\n", "[188]\tvalidation_0-merror:0.03916\n", "[189]\tvalidation_0-merror:0.03904\n", "[190]\tvalidation_0-merror:0.03896\n", "[191]\tvalidation_0-merror:0.03890\n", "[192]\tvalidation_0-merror:0.03888\n", "[193]\tvalidation_0-merror:0.03885\n", "[194]\tvalidation_0-merror:0.03883\n", "[195]\tvalidation_0-merror:0.03882\n", "[196]\tvalidation_0-merror:0.03875\n", "[197]\tvalidation_0-merror:0.03870\n", "[198]\tvalidation_0-merror:0.03864\n", "[199]\tvalidation_0-merror:0.03861\n", "[200]\tvalidation_0-merror:0.03856\n", "[201]\tvalidation_0-merror:0.03855\n", "[202]\tvalidation_0-merror:0.03852\n", "[203]\tvalidation_0-merror:0.03842\n", "[204]\tvalidation_0-merror:0.03839\n", "[205]\tvalidation_0-merror:0.03831\n", "[206]\tvalidation_0-merror:0.03825\n", "[207]\tvalidation_0-merror:0.03816\n", "[208]\tvalidation_0-merror:0.03816\n", "[209]\tvalidation_0-merror:0.03816\n", "[210]\tvalidation_0-merror:0.03816\n", "[211]\tvalidation_0-merror:0.03806\n", "[212]\tvalidation_0-merror:0.03811\n", "[213]\tvalidation_0-merror:0.03808\n", "[214]\tvalidation_0-merror:0.03794\n", "[215]\tvalidation_0-merror:0.03791\n", "[216]\tvalidation_0-merror:0.03783\n", "[217]\tvalidation_0-merror:0.03783\n", "[218]\tvalidation_0-merror:0.03781\n", "[219]\tvalidation_0-merror:0.03776\n", "[220]\tvalidation_0-merror:0.03773\n", "[221]\tvalidation_0-merror:0.03761\n", "[222]\tvalidation_0-merror:0.03759\n", "[223]\tvalidation_0-merror:0.03754\n", "[224]\tvalidation_0-merror:0.03744\n", "[225]\tvalidation_0-merror:0.03743\n", "[226]\tvalidation_0-merror:0.03747\n", "[227]\tvalidation_0-merror:0.03742\n", "[228]\tvalidation_0-merror:0.03737\n", "[229]\tvalidation_0-merror:0.03734\n", "[230]\tvalidation_0-merror:0.03736\n", "[231]\tvalidation_0-merror:0.03735\n", "[232]\tvalidation_0-merror:0.03730\n", "[233]\tvalidation_0-merror:0.03724\n", "[234]\tvalidation_0-merror:0.03719\n", "[235]\tvalidation_0-merror:0.03719\n", "[236]\tvalidation_0-merror:0.03712\n", "[237]\tvalidation_0-merror:0.03712\n", "[238]\tvalidation_0-merror:0.03705\n", "[239]\tvalidation_0-merror:0.03698\n", "[240]\tvalidation_0-merror:0.03694\n", "[241]\tvalidation_0-merror:0.03690\n", "[242]\tvalidation_0-merror:0.03692\n", "[243]\tvalidation_0-merror:0.03694\n", "[244]\tvalidation_0-merror:0.03694\n", "[245]\tvalidation_0-merror:0.03689\n", "[246]\tvalidation_0-merror:0.03687\n", "[247]\tvalidation_0-merror:0.03686\n", "[248]\tvalidation_0-merror:0.03681\n", "[249]\tvalidation_0-merror:0.03677\n", "[250]\tvalidation_0-merror:0.03669\n", "[251]\tvalidation_0-merror:0.03668\n", "[252]\tvalidation_0-merror:0.03668\n", "[253]\tvalidation_0-merror:0.03663\n", "[254]\tvalidation_0-merror:0.03663\n", "[255]\tvalidation_0-merror:0.03661\n", "[256]\tvalidation_0-merror:0.03652\n", "[257]\tvalidation_0-merror:0.03650\n", "[258]\tvalidation_0-merror:0.03644\n", "[259]\tvalidation_0-merror:0.03639\n", "[260]\tvalidation_0-merror:0.03639\n", "[261]\tvalidation_0-merror:0.03635\n", "[262]\tvalidation_0-merror:0.03631\n", "[263]\tvalidation_0-merror:0.03625\n", "[264]\tvalidation_0-merror:0.03619\n", "[265]\tvalidation_0-merror:0.03613\n", "[266]\tvalidation_0-merror:0.03615\n", "[267]\tvalidation_0-merror:0.03608\n", "[268]\tvalidation_0-merror:0.03601\n", "[269]\tvalidation_0-merror:0.03605\n", "[270]\tvalidation_0-merror:0.03601\n", "[271]\tvalidation_0-merror:0.03600\n", "[272]\tvalidation_0-merror:0.03594\n", "[273]\tvalidation_0-merror:0.03590\n", "[274]\tvalidation_0-merror:0.03586\n", "[275]\tvalidation_0-merror:0.03586\n", "[276]\tvalidation_0-merror:0.03584\n", "[277]\tvalidation_0-merror:0.03583\n", "[278]\tvalidation_0-merror:0.03582\n", "[279]\tvalidation_0-merror:0.03579\n", "[280]\tvalidation_0-merror:0.03573\n", "[281]\tvalidation_0-merror:0.03574\n", "[282]\tvalidation_0-merror:0.03569\n", "[283]\tvalidation_0-merror:0.03570\n", "[284]\tvalidation_0-merror:0.03567\n", "[285]\tvalidation_0-merror:0.03569\n", "[286]\tvalidation_0-merror:0.03562\n", "[287]\tvalidation_0-merror:0.03554\n", "[288]\tvalidation_0-merror:0.03552\n", "[289]\tvalidation_0-merror:0.03542\n", "[290]\tvalidation_0-merror:0.03546\n", "[291]\tvalidation_0-merror:0.03541\n", "[292]\tvalidation_0-merror:0.03535\n", "[293]\tvalidation_0-merror:0.03537\n", "[294]\tvalidation_0-merror:0.03535\n", "[295]\tvalidation_0-merror:0.03527\n", "[296]\tvalidation_0-merror:0.03524\n", "[297]\tvalidation_0-merror:0.03527\n", "[298]\tvalidation_0-merror:0.03526\n", "[299]\tvalidation_0-merror:0.03527\n", "[300]\tvalidation_0-merror:0.03525\n", "[301]\tvalidation_0-merror:0.03524\n", "[302]\tvalidation_0-merror:0.03524\n", "[303]\tvalidation_0-merror:0.03522\n", "[304]\tvalidation_0-merror:0.03514\n", "[305]\tvalidation_0-merror:0.03511\n", "[306]\tvalidation_0-merror:0.03508\n", "[307]\tvalidation_0-merror:0.03506\n", "[308]\tvalidation_0-merror:0.03505\n", "[309]\tvalidation_0-merror:0.03506\n", "[310]\tvalidation_0-merror:0.03499\n", "[311]\tvalidation_0-merror:0.03496\n", "[312]\tvalidation_0-merror:0.03492\n", "[313]\tvalidation_0-merror:0.03490\n", "[314]\tvalidation_0-merror:0.03487\n", "[315]\tvalidation_0-merror:0.03487\n", "[316]\tvalidation_0-merror:0.03482\n", "[317]\tvalidation_0-merror:0.03484\n", "[318]\tvalidation_0-merror:0.03482\n", "[319]\tvalidation_0-merror:0.03479\n", "[320]\tvalidation_0-merror:0.03474\n", "[321]\tvalidation_0-merror:0.03471\n", "[322]\tvalidation_0-merror:0.03473\n", "[323]\tvalidation_0-merror:0.03474\n", "[324]\tvalidation_0-merror:0.03473\n", "[325]\tvalidation_0-merror:0.03472\n", "[326]\tvalidation_0-merror:0.03470\n", "[327]\tvalidation_0-merror:0.03466\n", "[328]\tvalidation_0-merror:0.03461\n", "[329]\tvalidation_0-merror:0.03458\n", "[330]\tvalidation_0-merror:0.03458\n", "[331]\tvalidation_0-merror:0.03452\n", "[332]\tvalidation_0-merror:0.03447\n", "[333]\tvalidation_0-merror:0.03439\n", "[334]\tvalidation_0-merror:0.03435\n", "[335]\tvalidation_0-merror:0.03429\n", "[336]\tvalidation_0-merror:0.03432\n", "[337]\tvalidation_0-merror:0.03428\n", "[338]\tvalidation_0-merror:0.03427\n", "[339]\tvalidation_0-merror:0.03429\n", "[340]\tvalidation_0-merror:0.03429\n", "[341]\tvalidation_0-merror:0.03425\n", "[342]\tvalidation_0-merror:0.03427\n", "[343]\tvalidation_0-merror:0.03423\n", "[344]\tvalidation_0-merror:0.03420\n", "[345]\tvalidation_0-merror:0.03418\n", "[346]\tvalidation_0-merror:0.03416\n", "[347]\tvalidation_0-merror:0.03411\n", "[348]\tvalidation_0-merror:0.03412\n", "[349]\tvalidation_0-merror:0.03410\n", "[350]\tvalidation_0-merror:0.03408\n", "[351]\tvalidation_0-merror:0.03408\n", "[352]\tvalidation_0-merror:0.03406\n", "[353]\tvalidation_0-merror:0.03404\n", "[354]\tvalidation_0-merror:0.03401\n", "[355]\tvalidation_0-merror:0.03396\n", "[356]\tvalidation_0-merror:0.03395\n", "[357]\tvalidation_0-merror:0.03393\n", "[358]\tvalidation_0-merror:0.03391\n", "[359]\tvalidation_0-merror:0.03388\n", "[360]\tvalidation_0-merror:0.03389\n", "[361]\tvalidation_0-merror:0.03387\n", "[362]\tvalidation_0-merror:0.03386\n", "[363]\tvalidation_0-merror:0.03382\n", "[364]\tvalidation_0-merror:0.03379\n", "[365]\tvalidation_0-merror:0.03373\n", "[366]\tvalidation_0-merror:0.03371\n", "[367]\tvalidation_0-merror:0.03371\n", "[368]\tvalidation_0-merror:0.03368\n", "[369]\tvalidation_0-merror:0.03364\n", "[370]\tvalidation_0-merror:0.03365\n", "[371]\tvalidation_0-merror:0.03363\n", "[372]\tvalidation_0-merror:0.03367\n", "[373]\tvalidation_0-merror:0.03358\n", "[374]\tvalidation_0-merror:0.03356\n", "[375]\tvalidation_0-merror:0.03357\n", "[376]\tvalidation_0-merror:0.03359\n", "[377]\tvalidation_0-merror:0.03350\n", "[378]\tvalidation_0-merror:0.03350\n", "[379]\tvalidation_0-merror:0.03351\n", "[380]\tvalidation_0-merror:0.03348\n", "[381]\tvalidation_0-merror:0.03345\n", "[382]\tvalidation_0-merror:0.03347\n", "[383]\tvalidation_0-merror:0.03346\n", "[384]\tvalidation_0-merror:0.03343\n", "[385]\tvalidation_0-merror:0.03344\n", "[386]\tvalidation_0-merror:0.03342\n", "[387]\tvalidation_0-merror:0.03337\n", "[388]\tvalidation_0-merror:0.03335\n", "[389]\tvalidation_0-merror:0.03337\n", "[390]\tvalidation_0-merror:0.03338\n", "[391]\tvalidation_0-merror:0.03335\n", "[392]\tvalidation_0-merror:0.03333\n", "[393]\tvalidation_0-merror:0.03330\n", "[394]\tvalidation_0-merror:0.03327\n", "[395]\tvalidation_0-merror:0.03326\n", "[396]\tvalidation_0-merror:0.03326\n", "[397]\tvalidation_0-merror:0.03326\n", "[398]\tvalidation_0-merror:0.03322\n", "[399]\tvalidation_0-merror:0.03327\n", "[400]\tvalidation_0-merror:0.03325\n", "[401]\tvalidation_0-merror:0.03322\n", "[402]\tvalidation_0-merror:0.03320\n", "[403]\tvalidation_0-merror:0.03318\n", "[404]\tvalidation_0-merror:0.03319\n", "[405]\tvalidation_0-merror:0.03313\n", "[406]\tvalidation_0-merror:0.03309\n", "[407]\tvalidation_0-merror:0.03310\n", "[408]\tvalidation_0-merror:0.03308\n", "[409]\tvalidation_0-merror:0.03303\n", "[410]\tvalidation_0-merror:0.03302\n", "[411]\tvalidation_0-merror:0.03304\n", "[412]\tvalidation_0-merror:0.03303\n", "[413]\tvalidation_0-merror:0.03304\n", "[414]\tvalidation_0-merror:0.03299\n", "[415]\tvalidation_0-merror:0.03297\n", "[416]\tvalidation_0-merror:0.03293\n", "[417]\tvalidation_0-merror:0.03293\n", "[418]\tvalidation_0-merror:0.03291\n", "[419]\tvalidation_0-merror:0.03292\n", "[420]\tvalidation_0-merror:0.03289\n", "[421]\tvalidation_0-merror:0.03291\n", "[422]\tvalidation_0-merror:0.03292\n", "[423]\tvalidation_0-merror:0.03293\n", "[424]\tvalidation_0-merror:0.03292\n", "[425]\tvalidation_0-merror:0.03289\n", "[426]\tvalidation_0-merror:0.03288\n", "[427]\tvalidation_0-merror:0.03290\n", "[428]\tvalidation_0-merror:0.03286\n", "[429]\tvalidation_0-merror:0.03283\n", "[430]\tvalidation_0-merror:0.03282\n", "[431]\tvalidation_0-merror:0.03281\n", "[432]\tvalidation_0-merror:0.03278\n", "[433]\tvalidation_0-merror:0.03279\n", "[434]\tvalidation_0-merror:0.03275\n", "[435]\tvalidation_0-merror:0.03276\n", "[436]\tvalidation_0-merror:0.03275\n", "[437]\tvalidation_0-merror:0.03271\n", "[438]\tvalidation_0-merror:0.03270\n", "[439]\tvalidation_0-merror:0.03270\n", "[440]\tvalidation_0-merror:0.03268\n", "[441]\tvalidation_0-merror:0.03267\n", "[442]\tvalidation_0-merror:0.03266\n", "[443]\tvalidation_0-merror:0.03265\n", "[444]\tvalidation_0-merror:0.03263\n", "[445]\tvalidation_0-merror:0.03265\n", "[446]\tvalidation_0-merror:0.03261\n", "[447]\tvalidation_0-merror:0.03261\n", "[448]\tvalidation_0-merror:0.03261\n", "[449]\tvalidation_0-merror:0.03261\n", "[450]\tvalidation_0-merror:0.03258\n", "[451]\tvalidation_0-merror:0.03257\n", "[452]\tvalidation_0-merror:0.03256\n", "[453]\tvalidation_0-merror:0.03254\n", "[454]\tvalidation_0-merror:0.03254\n", "[455]\tvalidation_0-merror:0.03255\n", "[456]\tvalidation_0-merror:0.03251\n", "[457]\tvalidation_0-merror:0.03248\n", "[458]\tvalidation_0-merror:0.03247\n", "[459]\tvalidation_0-merror:0.03247\n", "[460]\tvalidation_0-merror:0.03244\n", "[461]\tvalidation_0-merror:0.03241\n", "[462]\tvalidation_0-merror:0.03237\n", "[463]\tvalidation_0-merror:0.03236\n", "[464]\tvalidation_0-merror:0.03237\n", "[465]\tvalidation_0-merror:0.03237\n", "[466]\tvalidation_0-merror:0.03237\n", "[467]\tvalidation_0-merror:0.03231\n", "[468]\tvalidation_0-merror:0.03234\n", "[469]\tvalidation_0-merror:0.03234\n", "[470]\tvalidation_0-merror:0.03231\n", "[471]\tvalidation_0-merror:0.03230\n", "[472]\tvalidation_0-merror:0.03227\n", "[473]\tvalidation_0-merror:0.03223\n", "[474]\tvalidation_0-merror:0.03223\n", "[475]\tvalidation_0-merror:0.03219\n", "[476]\tvalidation_0-merror:0.03220\n", "[477]\tvalidation_0-merror:0.03220\n", "[478]\tvalidation_0-merror:0.03222\n", "[479]\tvalidation_0-merror:0.03224\n", "[480]\tvalidation_0-merror:0.03224\n", "[481]\tvalidation_0-merror:0.03223\n", "[482]\tvalidation_0-merror:0.03219\n", "[483]\tvalidation_0-merror:0.03218\n", "[484]\tvalidation_0-merror:0.03216\n", "[485]\tvalidation_0-merror:0.03216\n", "[486]\tvalidation_0-merror:0.03219\n", "[487]\tvalidation_0-merror:0.03219\n", "[488]\tvalidation_0-merror:0.03216\n", "[489]\tvalidation_0-merror:0.03216\n", "[490]\tvalidation_0-merror:0.03215\n", "[491]\tvalidation_0-merror:0.03215\n", "[492]\tvalidation_0-merror:0.03219\n", "[493]\tvalidation_0-merror:0.03211\n", "[494]\tvalidation_0-merror:0.03213\n", "[495]\tvalidation_0-merror:0.03213\n", "[496]\tvalidation_0-merror:0.03210\n", "[497]\tvalidation_0-merror:0.03208\n", "[498]\tvalidation_0-merror:0.03207\n", "[499]\tvalidation_0-merror:0.03205\n", "Wall time: 1h 37min 25s\n" ] }, { "data": { "text/plain": [ "XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n", " colsample_bynode=1, colsample_bytree=1, gamma=0.1, gpu_id=-1,\n", " importance_type='gain', interaction_constraints='',\n", " learning_rate=0.01, max_delta_step=0, max_depth=24,\n", " min_child_weight=1, missing=nan, monotone_constraints='()',\n", " n_estimators=500, n_jobs=-1, nthread=-1, num_parallel_tree=1,\n", " object='multi:softprob', objective='multi:softprob',\n", " random_state=0, reg_alpha=0, reg_lambda=1, scale_pos_weight=None,\n", " subsample=0.9, tree_method='exact', validate_parameters=1,\n", " verbosity=None)" ] }, "execution_count": 93, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "xgb_model.fit(covertype_x_train,\n", " covertype_y_train,\n", " eval_set=[(covertype_x_val, covertype_y_val)],\n", " eval_metric='merror',\n", " early_stopping_rounds=25,\n", " verbose=True)" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9671606340682605" ] }, "execution_count": 94, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_score(covertype_y_test, xgb_model.predict(covertype_x_test))" ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[40838, 1494, 0, 0, 17, 2, 64],\n", " [ 1025, 55248, 67, 0, 90, 43, 22],\n", " [ 0, 69, 6980, 35, 6, 104, 0],\n", " [ 0, 0, 54, 463, 0, 12, 0],\n", " [ 21, 276, 31, 0, 1637, 8, 0],\n", " [ 3, 85, 141, 10, 3, 3304, 0],\n", " [ 117, 17, 0, 0, 0, 0, 3916]], dtype=int64)" ] }, "execution_count": 95, "metadata": {}, "output_type": "execute_result" } ], "source": [ "confusion_matrix(covertype_y_test, xgb_model.predict(covertype_x_test))" ] }, { "cell_type": "code", "execution_count": 96, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 1 0.97 0.96 0.97 42415\n", " 2 0.97 0.98 0.97 56495\n", " 3 0.96 0.97 0.96 7194\n", " 4 0.91 0.88 0.89 529\n", " 5 0.93 0.83 0.88 1973\n", " 6 0.95 0.93 0.94 3546\n", " 7 0.98 0.97 0.97 4050\n", "\n", " accuracy 0.97 116202\n", " macro avg 0.95 0.93 0.94 116202\n", "weighted avg 0.97 0.97 0.97 116202\n", "\n" ] } ], "source": [ "print(\n", " classification_report(covertype_y_test,\n", " xgb_model.predict(covertype_x_test)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## XGBoost gpu vs cpu" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [], "source": [ "import xgboost as xgb\n", "\n", "xgb_cpu = xgb.XGBClassifier(object='multi:softprob',\n", " max_depth=8,\n", " gamma=0.1,\n", " subsample=0.9,\n", " learning_rate=0.01,\n", " n_estimators=200,\n", " nthread=-1)" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[09:34:19] WARNING: D:\\Projects\\other_projects\\xgboost\\src\\learner.cc:529: \n", "Parameters: { object } might not be used.\n", "\n", " This may not be accurate due to some parameters are only used in language bindings but\n", " passed down to XGBoost core. Or some parameters are not used but slip through this\n", " verification. Please open an issue if you find above cases.\n", "\n", "\n", "[0]\tvalidation_0-merror:0.21399\n", "[1]\tvalidation_0-merror:0.21641\n", "[2]\tvalidation_0-merror:0.21486\n", "[3]\tvalidation_0-merror:0.21350\n", "[4]\tvalidation_0-merror:0.21427\n", "[5]\tvalidation_0-merror:0.21401\n", "[6]\tvalidation_0-merror:0.21372\n", "[7]\tvalidation_0-merror:0.21296\n", "[8]\tvalidation_0-merror:0.21235\n", "[9]\tvalidation_0-merror:0.21210\n", "[10]\tvalidation_0-merror:0.21214\n", "[11]\tvalidation_0-merror:0.21132\n", "[12]\tvalidation_0-merror:0.21171\n", "[13]\tvalidation_0-merror:0.21129\n", "[14]\tvalidation_0-merror:0.21112\n", "[15]\tvalidation_0-merror:0.21057\n", "[16]\tvalidation_0-merror:0.21017\n", "[17]\tvalidation_0-merror:0.21022\n", "[18]\tvalidation_0-merror:0.20971\n", "[19]\tvalidation_0-merror:0.20983\n", "[20]\tvalidation_0-merror:0.20920\n", "[21]\tvalidation_0-merror:0.20917\n", "[22]\tvalidation_0-merror:0.20898\n", "[23]\tvalidation_0-merror:0.20936\n", "[24]\tvalidation_0-merror:0.20937\n", "[25]\tvalidation_0-merror:0.20906\n", "[26]\tvalidation_0-merror:0.20910\n", "[27]\tvalidation_0-merror:0.20829\n", "[28]\tvalidation_0-merror:0.20810\n", "[29]\tvalidation_0-merror:0.20779\n", "[30]\tvalidation_0-merror:0.20757\n", "[31]\tvalidation_0-merror:0.20763\n", "[32]\tvalidation_0-merror:0.20725\n", "[33]\tvalidation_0-merror:0.20682\n", "[34]\tvalidation_0-merror:0.20703\n", "[35]\tvalidation_0-merror:0.20676\n", "[36]\tvalidation_0-merror:0.20666\n", "[37]\tvalidation_0-merror:0.20687\n", "[38]\tvalidation_0-merror:0.20660\n", "[39]\tvalidation_0-merror:0.20631\n", "[40]\tvalidation_0-merror:0.20617\n", "[41]\tvalidation_0-merror:0.20582\n", "[42]\tvalidation_0-merror:0.20579\n", "[43]\tvalidation_0-merror:0.20574\n", "[44]\tvalidation_0-merror:0.20528\n", "[45]\tvalidation_0-merror:0.20444\n", "[46]\tvalidation_0-merror:0.20448\n", "[47]\tvalidation_0-merror:0.20433\n", "[48]\tvalidation_0-merror:0.20384\n", "[49]\tvalidation_0-merror:0.20346\n", "[50]\tvalidation_0-merror:0.20321\n", "[51]\tvalidation_0-merror:0.20327\n", "[52]\tvalidation_0-merror:0.20314\n", "[53]\tvalidation_0-merror:0.20298\n", "[54]\tvalidation_0-merror:0.20292\n", "[55]\tvalidation_0-merror:0.20249\n", "[56]\tvalidation_0-merror:0.20237\n", "[57]\tvalidation_0-merror:0.20232\n", "[58]\tvalidation_0-merror:0.20205\n", "[59]\tvalidation_0-merror:0.20184\n", "[60]\tvalidation_0-merror:0.20129\n", "[61]\tvalidation_0-merror:0.20129\n", "[62]\tvalidation_0-merror:0.20108\n", "[63]\tvalidation_0-merror:0.20096\n", "[64]\tvalidation_0-merror:0.20069\n", "[65]\tvalidation_0-merror:0.20050\n", "[66]\tvalidation_0-merror:0.20006\n", "[67]\tvalidation_0-merror:0.19992\n", "[68]\tvalidation_0-merror:0.19953\n", "[69]\tvalidation_0-merror:0.19928\n", "[70]\tvalidation_0-merror:0.19912\n", "[71]\tvalidation_0-merror:0.19871\n", "[72]\tvalidation_0-merror:0.19857\n", "[73]\tvalidation_0-merror:0.19834\n", "[74]\tvalidation_0-merror:0.19804\n", "[75]\tvalidation_0-merror:0.19808\n", "[76]\tvalidation_0-merror:0.19775\n", "[77]\tvalidation_0-merror:0.19771\n", "[78]\tvalidation_0-merror:0.19753\n", "[79]\tvalidation_0-merror:0.19729\n", "[80]\tvalidation_0-merror:0.19698\n", "[81]\tvalidation_0-merror:0.19696\n", "[82]\tvalidation_0-merror:0.19702\n", "[83]\tvalidation_0-merror:0.19699\n", "[84]\tvalidation_0-merror:0.19676\n", "[85]\tvalidation_0-merror:0.19669\n", "[86]\tvalidation_0-merror:0.19663\n", "[87]\tvalidation_0-merror:0.19641\n", "[88]\tvalidation_0-merror:0.19632\n", "[89]\tvalidation_0-merror:0.19628\n", "[90]\tvalidation_0-merror:0.19611\n", "[91]\tvalidation_0-merror:0.19601\n", "[92]\tvalidation_0-merror:0.19586\n", "[93]\tvalidation_0-merror:0.19556\n", "[94]\tvalidation_0-merror:0.19542\n", "[95]\tvalidation_0-merror:0.19538\n", "[96]\tvalidation_0-merror:0.19531\n", "[97]\tvalidation_0-merror:0.19514\n", "[98]\tvalidation_0-merror:0.19508\n", "[99]\tvalidation_0-merror:0.19492\n", "[100]\tvalidation_0-merror:0.19481\n", "[101]\tvalidation_0-merror:0.19459\n", "[102]\tvalidation_0-merror:0.19458\n", "[103]\tvalidation_0-merror:0.19434\n", "[104]\tvalidation_0-merror:0.19416\n", "[105]\tvalidation_0-merror:0.19407\n", "[106]\tvalidation_0-merror:0.19396\n", "[107]\tvalidation_0-merror:0.19375\n", "[108]\tvalidation_0-merror:0.19356\n", "[109]\tvalidation_0-merror:0.19359\n", "[110]\tvalidation_0-merror:0.19341\n", "[111]\tvalidation_0-merror:0.19324\n", "[112]\tvalidation_0-merror:0.19301\n", "[113]\tvalidation_0-merror:0.19291\n", "[114]\tvalidation_0-merror:0.19277\n", "[115]\tvalidation_0-merror:0.19271\n", "[116]\tvalidation_0-merror:0.19252\n", "[117]\tvalidation_0-merror:0.19237\n", "[118]\tvalidation_0-merror:0.19230\n", "[119]\tvalidation_0-merror:0.19205\n", "[120]\tvalidation_0-merror:0.19196\n", "[121]\tvalidation_0-merror:0.19179\n", "[122]\tvalidation_0-merror:0.19176\n", "[123]\tvalidation_0-merror:0.19184\n", "[124]\tvalidation_0-merror:0.19169\n", "[125]\tvalidation_0-merror:0.19149\n", "[126]\tvalidation_0-merror:0.19147\n", "[127]\tvalidation_0-merror:0.19129\n", "[128]\tvalidation_0-merror:0.19123\n", "[129]\tvalidation_0-merror:0.19119\n", "[130]\tvalidation_0-merror:0.19105\n", "[131]\tvalidation_0-merror:0.19097\n", "[132]\tvalidation_0-merror:0.19077\n", "[133]\tvalidation_0-merror:0.19065\n", "[134]\tvalidation_0-merror:0.19045\n", "[135]\tvalidation_0-merror:0.19024\n", "[136]\tvalidation_0-merror:0.19006\n", "[137]\tvalidation_0-merror:0.18999\n", "[138]\tvalidation_0-merror:0.18980\n", "[139]\tvalidation_0-merror:0.18968\n", "[140]\tvalidation_0-merror:0.18956\n", "[141]\tvalidation_0-merror:0.18958\n", "[142]\tvalidation_0-merror:0.18947\n", "[143]\tvalidation_0-merror:0.18946\n", "[144]\tvalidation_0-merror:0.18924\n", "[145]\tvalidation_0-merror:0.18918\n", "[146]\tvalidation_0-merror:0.18908\n", "[147]\tvalidation_0-merror:0.18894\n", "[148]\tvalidation_0-merror:0.18882\n", "[149]\tvalidation_0-merror:0.18889\n", "[150]\tvalidation_0-merror:0.18862\n", "[151]\tvalidation_0-merror:0.18839\n", "[152]\tvalidation_0-merror:0.18836\n", "[153]\tvalidation_0-merror:0.18822\n", "[154]\tvalidation_0-merror:0.18802\n", "[155]\tvalidation_0-merror:0.18798\n", "[156]\tvalidation_0-merror:0.18796\n", "[157]\tvalidation_0-merror:0.18782\n", "[158]\tvalidation_0-merror:0.18760\n", "[159]\tvalidation_0-merror:0.18759\n", "[160]\tvalidation_0-merror:0.18752\n", "[161]\tvalidation_0-merror:0.18735\n", "[162]\tvalidation_0-merror:0.18725\n", "[163]\tvalidation_0-merror:0.18716\n", "[164]\tvalidation_0-merror:0.18705\n", "[165]\tvalidation_0-merror:0.18693\n", "[166]\tvalidation_0-merror:0.18681\n", "[167]\tvalidation_0-merror:0.18673\n", "[168]\tvalidation_0-merror:0.18655\n", "[169]\tvalidation_0-merror:0.18649\n", "[170]\tvalidation_0-merror:0.18648\n", "[171]\tvalidation_0-merror:0.18633\n", "[172]\tvalidation_0-merror:0.18628\n", "[173]\tvalidation_0-merror:0.18609\n", "[174]\tvalidation_0-merror:0.18612\n", "[175]\tvalidation_0-merror:0.18604\n", "[176]\tvalidation_0-merror:0.18596\n", "[177]\tvalidation_0-merror:0.18591\n", "[178]\tvalidation_0-merror:0.18593\n", "[179]\tvalidation_0-merror:0.18590\n", "[180]\tvalidation_0-merror:0.18576\n", "[181]\tvalidation_0-merror:0.18574\n", "[182]\tvalidation_0-merror:0.18563\n", "[183]\tvalidation_0-merror:0.18566\n", "[184]\tvalidation_0-merror:0.18562\n", "[185]\tvalidation_0-merror:0.18563\n", "[186]\tvalidation_0-merror:0.18561\n", "[187]\tvalidation_0-merror:0.18560\n", "[188]\tvalidation_0-merror:0.18557\n", "[189]\tvalidation_0-merror:0.18543\n", "[190]\tvalidation_0-merror:0.18529\n", "[191]\tvalidation_0-merror:0.18518\n", "[192]\tvalidation_0-merror:0.18505\n", "[193]\tvalidation_0-merror:0.18487\n", "[194]\tvalidation_0-merror:0.18469\n", "[195]\tvalidation_0-merror:0.18459\n", "[196]\tvalidation_0-merror:0.18444\n", "[197]\tvalidation_0-merror:0.18430\n", "[198]\tvalidation_0-merror:0.18425\n", "[199]\tvalidation_0-merror:0.18417\n", "Wall time: 12min 30s\n" ] }, { "data": { "text/plain": [ "XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n", " colsample_bynode=1, colsample_bytree=1, gamma=0.1, gpu_id=-1,\n", " importance_type='gain', interaction_constraints='',\n", " learning_rate=0.01, max_delta_step=0, max_depth=8,\n", " min_child_weight=1, missing=nan, monotone_constraints='()',\n", " n_estimators=200, n_jobs=-1, nthread=-1, num_parallel_tree=1,\n", " object='multi:softprob', objective='multi:softprob',\n", " random_state=0, reg_alpha=0, reg_lambda=1, scale_pos_weight=None,\n", " subsample=0.9, tree_method='exact', validate_parameters=1,\n", " verbosity=None)" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "xgb_cpu.fit(covertype_x_train,\n", " covertype_y_train,\n", " eval_set=[(covertype_x_val, covertype_y_val)],\n", " eval_metric='merror',\n", " early_stopping_rounds=25,\n", " verbose=True)" ] }, { "cell_type": "code", "execution_count": 102, "metadata": {}, "outputs": [], "source": [ "xgb_gpu = xgb.XGBClassifier(object='multi:softprob',\n", " max_depth=8,\n", " gamma=0.1,\n", " subsample=0.9,\n", " learning_rate=0.01,\n", " n_estimators=200,\n", " nthread=-1,\n", " tree_method='gpu_hist',\n", " gpu_id=0)" ] }, { "cell_type": "code", "execution_count": 103, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[10:01:00] WARNING: D:\\Projects\\other_projects\\xgboost\\src\\learner.cc:529: \n", "Parameters: { object } might not be used.\n", "\n", " This may not be accurate due to some parameters are only used in language bindings but\n", " passed down to XGBoost core. Or some parameters are not used but slip through this\n", " verification. Please open an issue if you find above cases.\n", "\n", "\n", "[0]\tvalidation_0-merror:0.21819\n", "[1]\tvalidation_0-merror:0.21678\n", "[2]\tvalidation_0-merror:0.21659\n", "[3]\tvalidation_0-merror:0.21447\n", "[4]\tvalidation_0-merror:0.21370\n", "[5]\tvalidation_0-merror:0.21469\n", "[6]\tvalidation_0-merror:0.21351\n", "[7]\tvalidation_0-merror:0.21265\n", "[8]\tvalidation_0-merror:0.21252\n", "[9]\tvalidation_0-merror:0.21265\n", "[10]\tvalidation_0-merror:0.21286\n", "[11]\tvalidation_0-merror:0.21154\n", "[12]\tvalidation_0-merror:0.21166\n", "[13]\tvalidation_0-merror:0.21123\n", "[14]\tvalidation_0-merror:0.21070\n", "[15]\tvalidation_0-merror:0.21067\n", "[16]\tvalidation_0-merror:0.21046\n", "[17]\tvalidation_0-merror:0.21040\n", "[18]\tvalidation_0-merror:0.21009\n", "[19]\tvalidation_0-merror:0.20922\n", "[20]\tvalidation_0-merror:0.20916\n", "[21]\tvalidation_0-merror:0.20910\n", "[22]\tvalidation_0-merror:0.20902\n", "[23]\tvalidation_0-merror:0.20827\n", "[24]\tvalidation_0-merror:0.20802\n", "[25]\tvalidation_0-merror:0.20764\n", "[26]\tvalidation_0-merror:0.20784\n", "[27]\tvalidation_0-merror:0.20749\n", "[28]\tvalidation_0-merror:0.20775\n", "[29]\tvalidation_0-merror:0.20767\n", "[30]\tvalidation_0-merror:0.20751\n", "[31]\tvalidation_0-merror:0.20745\n", "[32]\tvalidation_0-merror:0.20728\n", "[33]\tvalidation_0-merror:0.20687\n", "[34]\tvalidation_0-merror:0.20703\n", "[35]\tvalidation_0-merror:0.20666\n", "[36]\tvalidation_0-merror:0.20648\n", "[37]\tvalidation_0-merror:0.20643\n", "[38]\tvalidation_0-merror:0.20605\n", "[39]\tvalidation_0-merror:0.20585\n", "[40]\tvalidation_0-merror:0.20589\n", "[41]\tvalidation_0-merror:0.20584\n", "[42]\tvalidation_0-merror:0.20589\n", "[43]\tvalidation_0-merror:0.20549\n", "[44]\tvalidation_0-merror:0.20519\n", "[45]\tvalidation_0-merror:0.20531\n", "[46]\tvalidation_0-merror:0.20495\n", "[47]\tvalidation_0-merror:0.20471\n", "[48]\tvalidation_0-merror:0.20454\n", "[49]\tvalidation_0-merror:0.20448\n", "[50]\tvalidation_0-merror:0.20423\n", "[51]\tvalidation_0-merror:0.20401\n", "[52]\tvalidation_0-merror:0.20385\n", "[53]\tvalidation_0-merror:0.20358\n", "[54]\tvalidation_0-merror:0.20333\n", "[55]\tvalidation_0-merror:0.20316\n", "[56]\tvalidation_0-merror:0.20291\n", "[57]\tvalidation_0-merror:0.20253\n", "[58]\tvalidation_0-merror:0.20228\n", "[59]\tvalidation_0-merror:0.20187\n", "[60]\tvalidation_0-merror:0.20142\n", "[61]\tvalidation_0-merror:0.20136\n", "[62]\tvalidation_0-merror:0.20111\n", "[63]\tvalidation_0-merror:0.20089\n", "[64]\tvalidation_0-merror:0.20067\n", "[65]\tvalidation_0-merror:0.20047\n", "[66]\tvalidation_0-merror:0.20017\n", "[67]\tvalidation_0-merror:0.19994\n", "[68]\tvalidation_0-merror:0.19975\n", "[69]\tvalidation_0-merror:0.19975\n", "[70]\tvalidation_0-merror:0.19958\n", "[71]\tvalidation_0-merror:0.19926\n", "[72]\tvalidation_0-merror:0.19913\n", "[73]\tvalidation_0-merror:0.19890\n", "[74]\tvalidation_0-merror:0.19879\n", "[75]\tvalidation_0-merror:0.19861\n", "[76]\tvalidation_0-merror:0.19858\n", "[77]\tvalidation_0-merror:0.19837\n", "[78]\tvalidation_0-merror:0.19821\n", "[79]\tvalidation_0-merror:0.19818\n", "[80]\tvalidation_0-merror:0.19795\n", "[81]\tvalidation_0-merror:0.19770\n", "[82]\tvalidation_0-merror:0.19742\n", "[83]\tvalidation_0-merror:0.19751\n", "[84]\tvalidation_0-merror:0.19728\n", "[85]\tvalidation_0-merror:0.19732\n", "[86]\tvalidation_0-merror:0.19709\n", "[87]\tvalidation_0-merror:0.19709\n", "[88]\tvalidation_0-merror:0.19709\n", "[89]\tvalidation_0-merror:0.19672\n", "[90]\tvalidation_0-merror:0.19663\n", "[91]\tvalidation_0-merror:0.19663\n", "[92]\tvalidation_0-merror:0.19652\n", "[93]\tvalidation_0-merror:0.19638\n", "[94]\tvalidation_0-merror:0.19605\n", "[95]\tvalidation_0-merror:0.19586\n", "[96]\tvalidation_0-merror:0.19579\n", "[97]\tvalidation_0-merror:0.19562\n", "[98]\tvalidation_0-merror:0.19545\n", "[99]\tvalidation_0-merror:0.19536\n", "[100]\tvalidation_0-merror:0.19508\n", "[101]\tvalidation_0-merror:0.19495\n", "[102]\tvalidation_0-merror:0.19479\n", "[103]\tvalidation_0-merror:0.19463\n", "[104]\tvalidation_0-merror:0.19450\n", "[105]\tvalidation_0-merror:0.19431\n", "[106]\tvalidation_0-merror:0.19416\n", "[107]\tvalidation_0-merror:0.19395\n", "[108]\tvalidation_0-merror:0.19381\n", "[109]\tvalidation_0-merror:0.19361\n", "[110]\tvalidation_0-merror:0.19346\n", "[111]\tvalidation_0-merror:0.19323\n", "[112]\tvalidation_0-merror:0.19314\n", "[113]\tvalidation_0-merror:0.19298\n", "[114]\tvalidation_0-merror:0.19273\n", "[115]\tvalidation_0-merror:0.19260\n", "[116]\tvalidation_0-merror:0.19255\n", "[117]\tvalidation_0-merror:0.19234\n", "[118]\tvalidation_0-merror:0.19215\n", "[119]\tvalidation_0-merror:0.19194\n", "[120]\tvalidation_0-merror:0.19187\n", "[121]\tvalidation_0-merror:0.19175\n", "[122]\tvalidation_0-merror:0.19167\n", "[123]\tvalidation_0-merror:0.19154\n", "[124]\tvalidation_0-merror:0.19145\n", "[125]\tvalidation_0-merror:0.19131\n", "[126]\tvalidation_0-merror:0.19121\n", "[127]\tvalidation_0-merror:0.19118\n", "[128]\tvalidation_0-merror:0.19099\n", "[129]\tvalidation_0-merror:0.19107\n", "[130]\tvalidation_0-merror:0.19086\n", "[131]\tvalidation_0-merror:0.19081\n", "[132]\tvalidation_0-merror:0.19071\n", "[133]\tvalidation_0-merror:0.19055\n", "[134]\tvalidation_0-merror:0.19039\n", "[135]\tvalidation_0-merror:0.19022\n", "[136]\tvalidation_0-merror:0.18998\n", "[137]\tvalidation_0-merror:0.18983\n", "[138]\tvalidation_0-merror:0.18975\n", "[139]\tvalidation_0-merror:0.18975\n", "[140]\tvalidation_0-merror:0.18959\n", "[141]\tvalidation_0-merror:0.18951\n", "[142]\tvalidation_0-merror:0.18946\n", "[143]\tvalidation_0-merror:0.18926\n", "[144]\tvalidation_0-merror:0.18914\n", "[145]\tvalidation_0-merror:0.18906\n", "[146]\tvalidation_0-merror:0.18895\n", "[147]\tvalidation_0-merror:0.18879\n", "[148]\tvalidation_0-merror:0.18876\n", "[149]\tvalidation_0-merror:0.18872\n", "[150]\tvalidation_0-merror:0.18855\n", "[151]\tvalidation_0-merror:0.18843\n", "[152]\tvalidation_0-merror:0.18826\n", "[153]\tvalidation_0-merror:0.18827\n", "[154]\tvalidation_0-merror:0.18821\n", "[155]\tvalidation_0-merror:0.18802\n", "[156]\tvalidation_0-merror:0.18796\n", "[157]\tvalidation_0-merror:0.18780\n", "[158]\tvalidation_0-merror:0.18777\n", "[159]\tvalidation_0-merror:0.18762\n", "[160]\tvalidation_0-merror:0.18740\n", "[161]\tvalidation_0-merror:0.18740\n", "[162]\tvalidation_0-merror:0.18740\n", "[163]\tvalidation_0-merror:0.18730\n", "[164]\tvalidation_0-merror:0.18717\n", "[165]\tvalidation_0-merror:0.18697\n", "[166]\tvalidation_0-merror:0.18678\n", "[167]\tvalidation_0-merror:0.18668\n", "[168]\tvalidation_0-merror:0.18648\n", "[169]\tvalidation_0-merror:0.18648\n", "[170]\tvalidation_0-merror:0.18658\n", "[171]\tvalidation_0-merror:0.18648\n", "[172]\tvalidation_0-merror:0.18636\n", "[173]\tvalidation_0-merror:0.18634\n", "[174]\tvalidation_0-merror:0.18631\n", "[175]\tvalidation_0-merror:0.18632\n", "[176]\tvalidation_0-merror:0.18632\n", "[177]\tvalidation_0-merror:0.18617\n", "[178]\tvalidation_0-merror:0.18617\n", "[179]\tvalidation_0-merror:0.18617\n", "[180]\tvalidation_0-merror:0.18617\n", "[181]\tvalidation_0-merror:0.18605\n", "[182]\tvalidation_0-merror:0.18593\n", "[183]\tvalidation_0-merror:0.18597\n", "[184]\tvalidation_0-merror:0.18588\n", "[185]\tvalidation_0-merror:0.18580\n", "[186]\tvalidation_0-merror:0.18582\n", "[187]\tvalidation_0-merror:0.18574\n", "[188]\tvalidation_0-merror:0.18555\n", "[189]\tvalidation_0-merror:0.18548\n", "[190]\tvalidation_0-merror:0.18537\n", "[191]\tvalidation_0-merror:0.18524\n", "[192]\tvalidation_0-merror:0.18519\n", "[193]\tvalidation_0-merror:0.18520\n", "[194]\tvalidation_0-merror:0.18508\n", "[195]\tvalidation_0-merror:0.18502\n", "[196]\tvalidation_0-merror:0.18502\n", "[197]\tvalidation_0-merror:0.18483\n", "[198]\tvalidation_0-merror:0.18473\n", "[199]\tvalidation_0-merror:0.18468\n", "Wall time: 1min 54s\n" ] }, { "data": { "text/plain": [ "XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n", " colsample_bynode=1, colsample_bytree=1, gamma=0.1, gpu_id=0,\n", " importance_type='gain', interaction_constraints='',\n", " learning_rate=0.01, max_delta_step=0, max_depth=8,\n", " min_child_weight=1, missing=nan, monotone_constraints='()',\n", " n_estimators=200, n_jobs=-1, nthread=-1, num_parallel_tree=1,\n", " object='multi:softprob', objective='multi:softprob',\n", " random_state=0, reg_alpha=0, reg_lambda=1, scale_pos_weight=None,\n", " subsample=0.9, tree_method='gpu_hist', validate_parameters=1,\n", " verbosity=None)" ] }, "execution_count": 103, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "xgb_gpu.fit(covertype_x_train,\n", " covertype_y_train,\n", " eval_set=[(covertype_x_val, covertype_y_val)],\n", " eval_metric='merror',\n", " early_stopping_rounds=25,\n", " verbose=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LightGBM" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "covertype_y_train_1 = covertype_y_train - 1\n", "covertype_y_val_1 = covertype_y_val - 1\n", "covertype_y_test_1 = covertype_y_test - 1" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import lightgbm as lgb\n", "\n", "params = {\n", " 'task': 'train',\n", " 'boosting_type': 'gbdt',\n", " 'objective': 'multiclass',\n", " 'num_class': len(np.unique(covertype_y)),\n", " 'metric': 'multi_logloss',\n", " 'learning_rate': 0.01,\n", " 'max_depth': 128,\n", " 'num_leaves': 256,\n", " 'feature_fraction': 0.9,\n", " 'bagging_fraction': 0.9,\n", " 'bagging_freq': 10,\n", " 'device': 'cpu'\n", "}" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "train_data = lgb.Dataset(data=covertype_x_train, label=covertype_y_train_1)\n", "val_data = lgb.Dataset(data=covertype_x_val, label=covertype_y_val_1)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training until validation scores don't improve for 25 rounds\n", "[500]\tvalid_0's multi_logloss: 0.258546\n", "[1000]\tvalid_0's multi_logloss: 0.188758\n", "[1500]\tvalid_0's multi_logloss: 0.159213\n", "[2000]\tvalid_0's multi_logloss: 0.140833\n", "[2500]\tvalid_0's multi_logloss: 0.127889\n", "Did not meet early stopping. Best iteration is:\n", "[2500]\tvalid_0's multi_logloss: 0.127889\n" ] } ], "source": [ "bst = lgb.train(params,\n", " train_data,\n", " num_boost_round=2500,\n", " valid_sets=val_data,\n", " verbose_eval=500,\n", " early_stopping_rounds=25)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[250]\tcv_agg's multi_logloss: 0.383821 + 0.000454662\n", "[500]\tcv_agg's multi_logloss: 0.265271 + 0.000599288\n", "[750]\tcv_agg's multi_logloss: 0.220832 + 0.000458877\n", "[1000]\tcv_agg's multi_logloss: 0.196948 + 0.000144783\n", "[1250]\tcv_agg's multi_logloss: 0.180572 + 0.000345486\n", "[1500]\tcv_agg's multi_logloss: 0.167917 + 0.000484253\n", "[1750]\tcv_agg's multi_logloss: 0.157947 + 0.000538372\n", "[2000]\tcv_agg's multi_logloss: 0.150243 + 0.000658097\n", "[2250]\tcv_agg's multi_logloss: 0.143914 + 0.000544808\n", "[2500]\tcv_agg's multi_logloss: 0.139016 + 0.000600344\n" ] } ], "source": [ "lgb_cv = lgb.cv(params,\n", " train_data,\n", " num_boost_round=2500,\n", " nfold=3,\n", " shuffle=True,\n", " verbose_eval=250,\n", " early_stopping_rounds=25)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2499" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nround = lgb_cv['multi_logloss-mean'].index(np.min(\n", " lgb_cv['multi_logloss-mean']))\n", "nround" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9591917522934201\n", "[[40188 2142 0 0 14 0 71]\n", " [ 1442 54835 68 0 84 45 21]\n", " [ 2 61 7011 23 5 92 0]\n", " [ 0 0 58 455 0 16 0]\n", " [ 20 231 27 0 1689 6 0]\n", " [ 2 70 110 10 3 3351 0]\n", " [ 107 12 0 0 0 0 3931]]\n" ] } ], "source": [ "y_probs = bst.predict(covertype_x_test, num_iteration=bst.best_iteration)\n", "y_preds = np.argmax(y_probs, axis=1)\n", "print(accuracy_score(covertype_y_test_1, y_preds))\n", "print(confusion_matrix(covertype_y_test_1, y_preds))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.96 0.95 0.95 42415\n", " 1 0.96 0.97 0.96 56495\n", " 2 0.96 0.97 0.97 7194\n", " 3 0.93 0.86 0.89 529\n", " 4 0.94 0.86 0.90 1973\n", " 5 0.95 0.95 0.95 3546\n", " 6 0.98 0.97 0.97 4050\n", "\n", " accuracy 0.96 116202\n", " macro avg 0.96 0.93 0.94 116202\n", "weighted avg 0.96 0.96 0.96 116202\n", "\n" ] } ], "source": [ "print(classification_report(covertype_y_test_1, y_preds))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LightGBM GPU vs CPU" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "import lightgbm as lgb\n", "\n", "params_cpu = {\n", " 'task': 'train',\n", " 'boosting_type': 'gbdt',\n", " 'objective': 'multiclass',\n", " 'num_class': len(np.unique(covertype_y)),\n", " 'metric': 'multi_logloss',\n", " 'learning_rate': 0.01,\n", " 'max_depth': 128,\n", " 'num_leaves': 256,\n", " 'feature_fraction': 0.9,\n", " 'bagging_fraction': 0.9,\n", " 'bagging_freq': 10,\n", " \"n_jobs\": 8,\n", " 'device': 'cpu',\n", " 'n_jobs': -1,\n", "}" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.030191 seconds.\n", "You can set `force_col_wise=true` to remove the overhead.\n", "[LightGBM] [Info] Total Bins 2266\n", "[LightGBM] [Info] Number of data points in the train set: 348607, number of used features: 53\n", "[LightGBM] [Info] Start training from score -1.008130\n", "[LightGBM] [Info] Start training from score -0.718670\n", "[LightGBM] [Info] Start training from score -2.786590\n", "[LightGBM] [Info] Start training from score -5.333962\n", "[LightGBM] [Info] Start training from score -4.133503\n", "[LightGBM] [Info] Start training from score -3.514835\n", "[LightGBM] [Info] Start training from score -3.339237\n", "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n", "Training until validation scores don't improve for 25 rounds\n", "[100]\tvalid_0's multi_logloss: 0.46432\n", "[200]\tvalid_0's multi_logloss: 0.316533\n", "[300]\tvalid_0's multi_logloss: 0.254985\n", "[400]\tvalid_0's multi_logloss: 0.222611\n", "[500]\tvalid_0's multi_logloss: 0.201845\n", "Did not meet early stopping. Best iteration is:\n", "[500]\tvalid_0's multi_logloss: 0.201845\n", "Wall time: 2min 39s\n" ] } ], "source": [ "%%time\n", "bst_cpu = lgb.train(params_cpu,\n", " train_data,\n", " num_boost_round=500,\n", " valid_sets=val_data,\n", " verbose_eval=100,\n", " early_stopping_rounds=25)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "params_gpu = {\n", " 'task': 'train',\n", " 'boosting_type': 'gbdt',\n", " 'objective': 'multiclass',\n", " 'num_class': len(np.unique(covertype_y)),\n", " 'metric': 'multi_logloss',\n", " 'learning_rate': 0.01,\n", " 'max_depth': 128,\n", " 'num_leaves': 256,\n", " 'feature_fraction': 0.9,\n", " 'bagging_fraction': 0.9,\n", " 'bagging_freq': 10,\n", " \"n_jobs\": 8,\n", " 'device': 'gpu',\n", " 'gpu_platform_id': 2,\n", " 'gpu_device_id': 1,\n", "}" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[LightGBM] [Info] This is the GPU trainer!!\n", "[LightGBM] [Info] Total Bins 2266\n", "[LightGBM] [Info] Number of data points in the train set: 348607, number of used features: 53\n", "[LightGBM] [Info] Using GPU Device: GeForce GTX 1060, Vendor: NVIDIA Corporation\n", "[LightGBM] [Info] Compiling OpenCL Kernel with 256 bins...\n", "[LightGBM] [Info] GPU programs have been built\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.99 MB) transferred to GPU in 0.008094 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Start training from score -1.008130\n", "[LightGBM] [Info] Start training from score -0.718670\n", "[LightGBM] [Info] Start training from score -2.786590\n", "[LightGBM] [Info] Start training from score -5.333962\n", "[LightGBM] [Info] Start training from score -4.133503\n", "[LightGBM] [Info] Start training from score -3.514835\n", "[LightGBM] [Info] Start training from score -3.339237\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008006 secs. 0 sparse feature groups\n", "Training until validation scores don't improve for 25 rounds\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007137 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007045 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007813 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007583 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007343 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007228 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007928 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007948 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008081 secs. 0 sparse feature groups\n", "[100]\tvalid_0's multi_logloss: 0.464445\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007430 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008532 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008299 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008420 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008844 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007474 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.009012 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008419 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007661 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007852 secs. 0 sparse feature groups\n", "[200]\tvalid_0's multi_logloss: 0.316437\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008884 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008355 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007115 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.006948 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007195 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007332 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007107 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007296 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007101 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007279 secs. 0 sparse feature groups\n", "[300]\tvalid_0's multi_logloss: 0.255073\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.006899 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007245 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.006943 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008147 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007125 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008344 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007181 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.009159 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007620 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007047 secs. 0 sparse feature groups\n", "[400]\tvalid_0's multi_logloss: 0.223112\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007003 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007112 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008128 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.58 MB) transferred to GPU in 0.007451 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007739 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007358 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007277 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007078 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.007241 secs. 0 sparse feature groups\n", "[LightGBM] [Info] Size of histogram bin entry: 8\n", "[LightGBM] [Info] 12 dense feature groups (3.59 MB) transferred to GPU in 0.008423 secs. 0 sparse feature groups\n", "[500]\tvalid_0's multi_logloss: 0.202542\n", "Did not meet early stopping. Best iteration is:\n", "[500]\tvalid_0's multi_logloss: 0.202542\n", "Wall time: 5min 7s\n" ] } ], "source": [ "%%time\n", "bst_gpu = lgb.train(params_gpu,\n", " train_data,\n", " num_boost_round=500,\n", " valid_sets=val_data,\n", " verbose_eval=100,\n", " early_stopping_rounds=25)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## CatBoost" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "from catboost import CatBoostClassifier, Pool\n", "\n", "covertype_dataset = fetch_covtype(random_state=101, shuffle=True)\n", "label = covertype_dataset.target.astype(int) - 1\n", "wilderness_area = np.argmax(covertype_dataset.data[:, 10:(10 + 4)], axis=1)\n", "soil_type = np.argmax(covertype_dataset.data[:, (10 + 4):(10 + 4 + 40)], axis=1)\n", "data = (covertype_dataset.data[:, :10], wilderness_area.reshape(-1,\n", " 1).astype(str),\n", " soil_type.reshape(-1, 1).astype(str))\n", "data = np.hstack(data)" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "covertype_train = Pool(data=data[:15000, :],\n", " label=label[:15000],\n", " cat_features=[10, 11])\n", "covertype_val = Pool(data[15000:20000, :], label[15000:20000], [10, 11])\n", "covertype_test_x = Pool(data[20000:25000, :], None, [10, 11])\n", "covertype_test_y = label[20000:25000]" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [], "source": [ "cbc_cpu = CatBoostClassifier(iterations=2500,\n", " learning_rate=0.05,\n", " depth=8,\n", " custom_loss='Accuracy',\n", " eval_metric='Accuracy',\n", " use_best_model=True,\n", " loss_function='MultiClass',\n", " task_type='CPU',\n", " thread_count=-1)" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\tlearn: 0.6639333\ttest: 0.6690000\tbest: 0.6690000 (0)\ttotal: 133ms\tremaining: 5m 32s\n", "500:\tlearn: 0.8766000\ttest: 0.8084000\tbest: 0.8090000 (493)\ttotal: 2m 14s\tremaining: 8m 58s\n", "1000:\tlearn: 0.9358000\ttest: 0.8276000\tbest: 0.8284000 (987)\ttotal: 4m 45s\tremaining: 7m 6s\n", "1500:\tlearn: 0.9673333\ttest: 0.8316000\tbest: 0.8318000 (1477)\ttotal: 7m 17s\tremaining: 4m 50s\n", "2000:\tlearn: 0.9844667\ttest: 0.8344000\tbest: 0.8344000 (1996)\ttotal: 9m 39s\tremaining: 2m 24s\n", "2499:\tlearn: 0.9928667\ttest: 0.8368000\tbest: 0.8368000 (2499)\ttotal: 12m 14s\tremaining: 0us\n", "\n", "bestTest = 0.8368\n", "bestIteration = 2499\n", "\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cbc_cpu.fit(covertype_train, eval_set=covertype_val, verbose=500, plot=False)" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [], "source": [ "preds_class_cpu = cbc_cpu.predict(covertype_test_x)\n", "preds_proba_cpu = cbc_cpu.predict_proba(covertype_test_x)" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.8392\n", "[[1483 321 0 0 0 0 16]\n", " [ 237 2172 14 0 9 14 2]\n", " [ 0 11 269 4 0 17 0]\n", " [ 0 0 4 19 0 4 0]\n", " [ 1 47 7 0 22 0 0]\n", " [ 0 16 42 1 0 85 0]\n", " [ 36 1 0 0 0 0 146]]\n" ] } ], "source": [ "print(accuracy_score(covertype_test_y, preds_class_cpu))\n", "print(confusion_matrix(covertype_test_y, preds_class_cpu))" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.84 0.81 0.83 1820\n", " 1 0.85 0.89 0.87 2448\n", " 2 0.80 0.89 0.84 301\n", " 3 0.79 0.70 0.75 27\n", " 4 0.71 0.29 0.41 77\n", " 5 0.71 0.59 0.64 144\n", " 6 0.89 0.80 0.84 183\n", "\n", " accuracy 0.84 5000\n", " macro avg 0.80 0.71 0.74 5000\n", "weighted avg 0.84 0.84 0.84 5000\n", "\n" ] } ], "source": [ "print(classification_report(covertype_test_y, preds_class_cpu))" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "cbc_gpu = CatBoostClassifier(iterations=2500,\n", " learning_rate=0.05,\n", " depth=8,\n", " custom_loss='Accuracy',\n", " eval_metric='Accuracy',\n", " use_best_model=True,\n", " loss_function='MultiClass',\n", " task_type='GPU',\n", " thread_count=-1)" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\tlearn: 0.6304000\ttest: 0.6248000\tbest: 0.6248000 (0)\ttotal: 36.5ms\tremaining: 1m 31s\n", "500:\tlearn: 0.8871333\ttest: 0.8076000\tbest: 0.8076000 (497)\ttotal: 16.7s\tremaining: 1m 6s\n", "1000:\tlearn: 0.9448667\ttest: 0.8272000\tbest: 0.8282000 (992)\ttotal: 37s\tremaining: 55.5s\n", "1500:\tlearn: 0.9734667\ttest: 0.8336000\tbest: 0.8358000 (1443)\ttotal: 54.9s\tremaining: 36.5s\n", "2000:\tlearn: 0.9886667\ttest: 0.8380000\tbest: 0.8396000 (1950)\ttotal: 1m 14s\tremaining: 18.6s\n", "2499:\tlearn: 0.9950000\ttest: 0.8394000\tbest: 0.8414000 (2348)\ttotal: 1m 35s\tremaining: 0us\n", "bestTest = 0.8414\n", "bestIteration = 2348\n", "Shrink model to first 2349 iterations.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cbc_gpu.fit(covertype_train, eval_set=covertype_val, verbose=500, plot=False)" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [], "source": [ "preds_class_gpu = cbc_gpu.predict(covertype_test_x)\n", "preds_proba_gpu = cbc_gpu.predict_proba(covertype_test_x)" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.84\n", "[[1477 327 0 0 0 0 16]\n", " [ 231 2181 11 0 9 13 3]\n", " [ 0 15 265 4 0 17 0]\n", " [ 0 0 6 18 0 3 0]\n", " [ 1 44 6 0 26 0 0]\n", " [ 0 15 43 1 0 85 0]\n", " [ 34 1 0 0 0 0 148]]\n" ] } ], "source": [ "print(accuracy_score(covertype_test_y, preds_class_gpu))\n", "print(confusion_matrix(covertype_test_y, preds_class_gpu))" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.85 0.81 0.83 1820\n", " 1 0.84 0.89 0.87 2448\n", " 2 0.80 0.88 0.84 301\n", " 3 0.78 0.67 0.72 27\n", " 4 0.74 0.34 0.46 77\n", " 5 0.72 0.59 0.65 144\n", " 6 0.89 0.81 0.85 183\n", "\n", " accuracy 0.84 5000\n", " macro avg 0.80 0.71 0.74 5000\n", "weighted avg 0.84 0.84 0.84 5000\n", "\n" ] } ], "source": [ "print(classification_report(covertype_test_y, preds_class_gpu))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ThunderGBM" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "ename": "OSError", "evalue": "[WinError 126] 找不到指定的模块。", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mOSError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mthundersvm\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;32mD:\\Programing\\Anaconda3\\lib\\site-packages\\thundersvm\\__init__.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 10\u001b[0m \"\"\"\n\u001b[0;32m 11\u001b[0m \u001b[0mname\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m\"thundersvm\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[1;33m.\u001b[0m\u001b[0mthundersvmScikit\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[1;33m*\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;32mD:\\Programing\\Anaconda3\\lib\\site-packages\\thundersvm\\thundersvmScikit.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 39\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 40\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mpath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlib_path\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 41\u001b[1;33m \u001b[0mthundersvm\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mCDLL\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlib_path\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 42\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mFileNotFoundError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Please build the library first!\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mD:\\Programing\\Anaconda3\\lib\\ctypes\\__init__.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, name, mode, handle, use_errno, use_last_error)\u001b[0m\n\u001b[0;32m 362\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 363\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 364\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_handle\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_dlopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 365\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 366\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_handle\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mOSError\u001b[0m: [WinError 126] 找不到指定的模块。" ] } ], "source": [ "import thundersvm" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Object `clf` not found.\n" ] } ], "source": [ "clf?" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Package Version\n", "---------------------------------- -------------------\n", "aiohttp 3.6.3\n", "alabaster 0.7.12\n", "altgraph 0.17\n", "anaconda-client 1.7.2\n", "anaconda-navigator 1.9.12\n", "anaconda-project 0.8.3\n", "argh 0.26.2\n", "argon2-cffi 20.1.0\n", "asn1crypto 1.4.0\n", "astroid 2.4.2\n", "astropy 4.0.1.post1\n", "async-timeout 3.0.1\n", "atomicwrites 1.4.0\n", "attrs 20.1.0\n", "autopep8 1.5.4\n", "Babel 2.8.0\n", "backcall 0.2.0\n", "backports.functools-lru-cache 1.6.1\n", "backports.shutil-get-terminal-size 1.0.0\n", "backports.tempfile 1.0\n", "backports.weakref 1.0.post1\n", "bcrypt 3.2.0\n", "beautifulsoup4 4.9.1\n", "bitarray 1.5.3\n", "bkcharts 0.2\n", "bleach 3.1.5\n", "blis 0.4.1\n", "bokeh 2.2.1\n", "boto 2.49.0\n", "boto3 1.9.66\n", "botocore 1.12.67\n", "Bottleneck 1.3.2\n", "branca 0.4.1\n", "brotlipy 0.7.0\n", "cachetools 4.1.1\n", "catalogue 1.0.0\n", "catboost 0.24\n", "certifi 2020.6.20\n", "cffi 1.14.2\n", "chardet 3.0.4\n", "click 7.1.2\n", "cloudpickle 1.6.0\n", "clyent 1.2.2\n", "colorama 0.4.3\n", "comtypes 1.1.7\n", "conda 4.8.5\n", "conda-build 3.20.2\n", "conda-pack 0.5.0\n", "conda-package-handling 1.7.0\n", "conda-verify 3.4.2\n", "confuse 1.3.0\n", "contextlib2 0.6.0.post1\n", "convertdate 2.2.1\n", "cryptography 3.1\n", "cssselect 1.1.0\n", "cupy 6.0.0\n", "cutecharts 1.2.0\n", "cycler 0.10.0\n", "cymem 2.0.3\n", "Cython 0.29.21\n", "cytoolz 0.10.1\n", "dash 1.16.1\n", "dash-core-components 1.3.1\n", "dash-html-components 1.0.1\n", "dash-renderer 1.1.2\n", "dash-table 4.4.1\n", "dask 2.25.0\n", "decorator 4.4.2\n", "defusedxml 0.6.0\n", "diff-match-patch 20200713\n", "distributed 2.25.0\n", "docutils 0.16\n", "en-core-web-lg 2.3.1\n", "en-core-web-sm 2.3.1\n", "entropy-based-binning 0.0.1\n", "entrypoints 0.3\n", "ephem 3.7.7.0\n", "et-xmlfile 1.0.1\n", "fastcache 1.1.0\n", "fastrlock 0.4\n", "fbprophet 0.6\n", "ffn 0.3.4\n", "filelock 3.0.12\n", "flake8 3.8.3\n", "Flask 1.1.2\n", "Flask-Compress 1.5.0\n", "folium 0.11.0\n", "fsspec 0.8.0\n", "future 0.18.2\n", "fuzzywuzzy 0.17.0\n", "gensim 3.8.0\n", "gevent 20.6.2\n", "glob2 0.7\n", "gmpy2 2.0.8\n", "google-api-core 1.22.2\n", "google-auth 1.22.1\n", "google-cloud-core 1.4.3\n", "google-cloud-storage 1.31.0\n", "google-crc32c 1.0.0\n", "google-resumable-media 1.1.0\n", "googleapis-common-protos 1.52.0\n", "greenlet 0.4.16\n", "h5py 2.10.0\n", "HeapDict 1.0.1\n", "holidays 0.10.3\n", "html5lib 1.1\n", "htmlmin 0.1.12\n", "idna 2.10\n", "ImageHash 4.1.0\n", "imageio 2.9.0\n", "imagesize 1.2.0\n", "imbalanced-learn 0.7.0\n", "importlib-metadata 1.7.0\n", "inflection 0.5.1\n", "iniconfig 0.0.0\n", "intervaltree 3.1.0\n", "ipykernel 5.3.4\n", "ipympl 0.5.7\n", "ipython 7.18.1\n", "ipython-genutils 0.2.0\n", "ipywidgets 7.5.1\n", "isort 5.4.2\n", "itsdangerous 1.1.0\n", "jdcal 1.4.1\n", "jedi 0.14.1\n", "jieba 0.42.1\n", "Jinja2 2.11.2\n", "jmespath 0.10.0\n", "joblib 0.16.0\n", "json5 0.9.5\n", "jsonschema 3.0.2\n", "jupyter 1.0.0\n", "jupyter-bokeh 2.0.3\n", "jupyter-client 6.1.6\n", "jupyter-console 6.2.0\n", "jupyter-contrib-core 0.3.3\n", "jupyter-core 4.6.3\n", "jupyter-kite 1.0.0\n", "jupyter-nbextensions-configurator 0.4.1\n", "jupyterlab 2.2.6\n", "jupyterlab-code-formatter 1.3.6\n", "jupyterlab-server 1.2.0\n", "keyring 21.4.0\n", "kiwisolver 1.2.0\n", "korean-lunar-calendar 0.2.1\n", "lazy-object-proxy 1.4.3\n", "libarchive-c 2.9\n", "lightgbm 3.0.0.99\n", "llvmlite 0.33.0+1.g022ab0f\n", "locket 0.2.0\n", "LunarCalendar 0.0.9\n", "lxml 4.5.2\n", "macholib 1.14\n", "MarkupSafe 1.1.1\n", "matplotlib 3.3.1\n", "mccabe 0.6.1\n", "mdlp-discretization 0.3.3\n", "menuinst 1.4.16\n", "missingno 0.4.2\n", "mistune 0.8.4\n", "mkl-fft 1.1.0\n", "mkl-random 1.1.1\n", "mkl-service 2.3.0\n", "ml-metrics 0.1.4\n", "mlxtend 0.17.3\n", "mock 4.0.2\n", "more-itertools 8.5.0\n", "MouseInfo 0.1.3\n", "mpl-finance 0.10.1\n", "mpmath 1.1.0\n", "msgpack 1.0.0\n", "multidict 4.7.6\n", "multipledispatch 0.6.0\n", "murmurhash 1.0.2\n", "mysql-connector 2.2.9\n", "mysql-connector-python 8.0.18\n", "navigator-updater 0.2.1\n", "nb-conda-kernels 2.2.4\n", "nbconvert 5.6.1\n", "nbformat 5.0.7\n", "nbresuse 0.3.6\n", "networkx 2.5\n", "nltk 3.5\n", "nose 1.3.7\n", "notebook 6.1.1\n", "numba 0.50.1\n", "numexpr 2.7.1\n", "numpy 1.19.1\n", "numpydoc 1.1.0\n", "oauthlib 3.1.0\n", "olefile 0.46\n", "opencv-python 3.4.11.41\n", "openpyxl 3.0.5\n", "packaging 20.4\n", "pandas 1.1.1\n", "pandas-datareader 0.9.0\n", "pandas-profiling 2.9.0\n", "pandocfilters 1.4.2\n", "paramiko 2.7.2\n", "parsel 1.6.0\n", "parso 0.5.2\n", "partd 1.1.0\n", "path 15.0.0\n", "pathlib2 2.3.5\n", "pathtools 0.1.2\n", "patsy 0.5.1\n", "peewee 3.13.3\n", "pefile 2019.4.18\n", "pep8 1.7.1\n", "pexpect 4.8.0\n", "phik 0.10.0\n", "pickleshare 0.7.5\n", "Pillow 7.2.0\n", "pip 20.2.3\n", "pkginfo 1.5.0.1\n", "pkuseg 0.0.25\n", "plac 0.9.6\n", "plotly 4.10.0\n", "pluggy 0.13.1\n", "ply 3.11\n", "preshed 3.0.2\n", "prettytable 0.7.2\n", "prometheus-client 0.8.0\n", "prompt-toolkit 3.0.7\n", "protobuf 3.6.0\n", "psutil 5.7.2\n", "py 1.9.0\n", "pyasn1 0.4.8\n", "pyasn1-modules 0.2.8\n", "PyAutoGUI 0.9.50\n", "pycodestyle 2.6.0\n", "pycosat 0.6.3\n", "pycparser 2.20\n", "pycrypto 2.6.1\n", "pycryptodome 3.9.8\n", "pycurl 7.43.0.5\n", "pydocstyle 5.1.1\n", "pyecharts 1.8.1\n", "pyflakes 2.2.0\n", "pyftpdlib 1.5.6\n", "PyGetWindow 0.0.8\n", "Pygments 2.6.1\n", "PyInstaller 3.6\n", "pylint 2.6.0\n", "PyMeeus 0.3.7\n", "pymemcache 3.3.0\n", "pymongo 3.11.0\n", "PyMsgBox 1.0.8\n", "PyMySQL 0.9.3\n", "PyNaCl 1.4.0\n", "pyodbc 4.0.0-unsupported\n", "pyOpenSSL 19.1.0\n", "pyparsing 2.4.7\n", "pyperclip 1.8.0\n", "pyreadline 2.1\n", "PyRect 0.1.4\n", "pyrsistent 0.16.0\n", "PyScreeze 0.1.26\n", "PySocks 1.7.1\n", "pystan 2.19.0.0\n", "pytest 6.0.1\n", "python-bidi 0.4.2\n", "python-dateutil 2.8.1\n", "python-docx 0.8.10\n", "python-jsonrpc-server 0.3.4\n", "python-language-server 0.31.7\n", "python-Levenshtein 0.12.0\n", "PyTweening 1.0.3\n", "pytz 2020.1\n", "PyWavelets 1.1.1\n", "pywin32 227\n", "pywin32-ctypes 0.2.0\n", "pywinpty 0.5.7\n", "PyYAML 5.3.1\n", "pyzmq 19.0.1\n", "QDarkStyle 2.8.1\n", "QtAwesome 0.7.2\n", "qtconsole 4.7.6\n", "QtPy 1.9.0\n", "Quandl 3.5.2\n", "redis 3.5.3\n", "regex 2020.7.14\n", "requestium 0.1.9\n", "requests 2.24.0\n", "requests-file 1.5.1\n", "retrying 1.3.3\n", "rope 0.17.0\n", "rsa 4.6\n", "Rtree 0.9.4\n", "ruamel-yaml 0.15.87\n", "s3transfer 0.1.13\n", "scikit-image 0.16.2\n", "scikit-learn 0.23.2\n", "scikit-surprise 1.1.1\n", "scipy 1.5.2\n", "seaborn 0.11.0\n", "selenium 3.14.1\n", "Send2Trash 1.5.0\n", "setuptools 49.6.0.post20200814\n", "simplegeneric 0.8.1\n", "simplejson 3.17.2\n", "singledispatch 3.4.0.3\n", "six 1.15.0\n", "smart-open 3.0.0\n", "snowballstemmer 2.0.0\n", "sortedcollections 1.2.1\n", "sortedcontainers 2.2.2\n", "soupsieve 2.0.1\n", "spacy 2.3.1\n", "spacy-lookups-data 0.3.0\n", "Sphinx 3.2.1\n", "sphinxcontrib-applehelp 1.0.2\n", "sphinxcontrib-devhelp 1.0.2\n", "sphinxcontrib-htmlhelp 1.0.3\n", "sphinxcontrib-jsmath 1.0.1\n", "sphinxcontrib-qthelp 1.0.3\n", "sphinxcontrib-serializinghtml 1.1.4\n", "sphinxcontrib-websupport 1.2.4\n", "spyder 4.0.1\n", "spyder-kernels 1.8.1\n", "SQLAlchemy 1.3.19\n", "srsly 1.0.2\n", "statsmodels 0.11.1\n", "sympy 1.6.2\n", "tables 3.6.1\n", "tabulate 0.8.7\n", "tangled-up-in-unicode 0.0.6\n", "tblib 1.7.0\n", "terminado 0.8.3\n", "testpath 0.4.4\n", "thinc 7.4.1\n", "threadpoolctl 2.1.0\n", "thundersvm 0.3.4\n", "tldextract 2.2.3\n", "toml 0.10.1\n", "toolz 0.10.0\n", "tornado 6.0.4\n", "tqdm 4.48.2\n", "traitlets 4.3.3\n", "tsfresh 0.16.0\n", "typed-ast 1.4.1\n", "typing-extensions 3.7.4.3\n", "ujson 1.35\n", "unicodecsv 0.14.1\n", "urllib3 1.24.3\n", "visions 0.5.0\n", "w3lib 1.22.0\n", "wasabi 0.8.0\n", "watchdog 0.10.3\n", "wcwidth 0.2.5\n", "webencodings 0.5.1\n", "Werkzeug 1.0.1\n", "wheel 0.35.1\n", "widgetsnbextension 3.5.1\n", "win-inet-pton 1.1.0\n", "win-unicode-console 0.5\n", "wincertstore 0.2\n", "wolframclient 1.1.4\n", "wordcloud 1.8.0\n", "wrapt 1.11.2\n", "xgboost 1.3.0-SNAPSHOT\n", "xlrd 1.2.0\n", "XlsxWriter 1.3.3\n", "xlwings 0.20.5\n", "xlwt 1.3.0\n", "xmltodict 0.12.0\n", "xx-ent-wiki-sm 2.3.0\n", "yapf 0.30.0\n", "yarl 1.6.2\n", "zh-core-web-lg 2.3.1\n", "zh-core-web-sm 2.3.1\n", "zhon 1.1.5\n", "zict 2.0.0\n", "zipp 3.1.0\n", "zope.event 4.4\n", "zope.interface 5.1.0\n" ] } ], "source": [ "!pip list" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }