{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; \n", "\n", "import urllib.request\n", "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import ktrain\n", "from ktrain import tabular" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Income Prediction from Census Dataset\n", "\n", "In this notebook, we will predict which individuals make more than $50K from Census data. This is the same dataset used in the [AutoGluon tabular prediction example](https://autogluon.mxnet.io/tutorials/tabular_prediction/tabular-quickstart.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Collect Training and Test Sets\n", "\n", "The original dataset is available from the [UCI Machine Learning Repository](http://archive.ics.uci.edu/ml/datasets/Adult), but we will download it from the AutoGluon website." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('/tmp/train.csv', )" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# training set\n", "urllib.request.urlretrieve('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv', \n", " '/tmp/train.csv')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Load and Preprocess Data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "processing train: 35179 rows x 15 columns\n", "\n", "The following integer column(s) are being treated as categorical variables:\n", "['education-num']\n", "To treat any of these column(s) as numerical, cast the column to float in DataFrame or CSV\n", " and re-run tabular_from* function.\n", "\n", "processing test: 3894 rows x 15 columns\n" ] } ], "source": [ "trn, val, preproc = tabular.tabular_from_csv('/tmp/train.csv', label_columns='class', random_state=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Create Model and Wrap in `Learner`" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "done.\n" ] } ], "source": [ "model = tabular.tabular_classifier('mlp', trn)\n", "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=128)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Estimate LR\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Train for 274 steps\n", "Epoch 1/1024\n", "274/274 [==============================] - 8s 28ms/step - loss: 0.7151 - accuracy: 0.3431\n", "Epoch 2/1024\n", "274/274 [==============================] - 7s 25ms/step - loss: 0.6359 - accuracy: 0.6889\n", "Epoch 3/1024\n", "274/274 [==============================] - 7s 25ms/step - loss: 0.4145 - accuracy: 0.8113\n", "Epoch 4/1024\n", "274/274 [==============================] - 7s 25ms/step - loss: 0.3268 - accuracy: 0.8486\n", "Epoch 5/1024\n", "274/274 [==============================] - 7s 25ms/step - loss: 0.6269 - accuracy: 0.7968\n", "Epoch 6/1024\n", "274/274 [==============================] - 7s 25ms/step - loss: 0.5543 - accuracy: 0.7589\n", "Epoch 7/1024\n", " 50/274 [====>.........................] - ETA: 7s - loss: 47426.1304 - accuracy: 0.7517\n", "\n", "done.\n", "Visually inspect loss plot and select learning rate associated with falling loss\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxU1f3/8ddnspKVhIR930QEFYgKilu1irYV96WurWtbq+23tZv9Vlvbb239dbNupW7d3G0V99a6oFCWYAVZlU0WEUISEshCtvP7405gCElIIHfmTub9fDzyyMy9d2Y+2eadc86955hzDhERSVyhWBcgIiKxpSAQEUlwCgIRkQSnIBARSXAKAhGRBKcgEBFJcMmxLqCzCgoK3NChQ2NdhohIXFm4cOE251xha/viLgiGDh1KcXFxrMsQEYkrZvZxW/vUNSQikuAUBCIiCU5BICKS4BQEIiIJTkEgIpLgFAQiIgnOtyAws4fNbKuZLWlj/6VmttjMPjCzOWZ2hF+1iEh82VBWTUVNfazLSBh+tggeBaa1s38tcKJzbjxwBzDDx1pEJI4c/8s3+eyv3451GQnDtwvKnHOzzGxoO/vnRNydCwz0qxYRiT9bd+yKdQkJIyhjBFcDr8S6CBGRRBTzKSbM7GS8IJjazjHXAdcBDB48OEqViYgkhpi2CMzscOBBYLpzrrSt45xzM5xzRc65osLCVudMEhGRAxSzIDCzwcDfgcudcx/Gqg4RkUTnW9eQmT0OnAQUmNlG4DYgBcA59wDwI6AXcJ+ZATQ454r8qkdE4k9tfSPpKUmxLqPb8/OsoUv2s/8a4Bq/Xl9E4l9FTb2CIAqCctaQiMg+tlfrorJoUBCISGBtr66LdQkJQUEgIoG1XdNMRIWCQEQCq0JdQ1GhIBCRwEkKGQDba9Q1FA0KAhEJnIxU70yhsiq1CKJBQSAiweO8Txosjg4FgYgETjgHKKtSEESDgkBEAqtcLYKoUBCISOA457UJ1CKIDgWBiASOuoaiS0EgIoHjmgeLa+ppbHLtHywHTUEgIoGUmhzCObSIfRQoCEQkcByO/IxUQN1D0aAgEJHAcQ7yM70g0JlD/lMQiEjgOKBXlloE0aIgEJFA2t0iUBD4TkEgIsHjIK95jEBdQ75TEIhI4DgcPVKT6JGSRNlOBYHfFAQiEjjOgeF1D6lF4D8FgYgEjgPMvCDQGIH/FAQiEkiGkZeZSplWKfOdgkBEAqd50rn8jBS1CKJAQSAigdPcNZSnrqGoUBCISODsHizOSGXHrgbqGppiXVK3piAQkWAyIz9L00xEg4JARAJLE89Fh4JARAKleaDY8MYIQNNM+E1BICKB0rwoTfN1BKBpJvymIBCRQGlej8yw3fMNqUXgLwWBiASSGfTMSAGgrEoXlflJQSAigdI8RgCQkhQiJz1ZZw35TEEgIoGyp2vI0ysrjVJ1DflKQSAigRI5WAyQp2kmfKcgEJFAsnAS5Gem6joCnykIRCRQHG6v+3kZqRoj8JlvQWBmD5vZVjNb0sZ+M7O7zWyVmS02s4l+1SIi8cPtnQO7WwSu5Q7pMn62CB4FprWz/wxgVPjjOuB+H2sRkTize4wgM5VdDU3U1DfGtqBuzLcgcM7NAsraOWQ68GfnmQv0NLN+ftUjIvFh92Axe8YIQPMN+SmWYwQDgA0R9zeGt+3DzK4zs2IzKy4pKYlKcSISW80tAk0857+4GCx2zs1wzhU554oKCwtjXY6I+GifwWK1CHwXyyDYBAyKuD8wvE1EEtieriFPc9eQzhzyTyyDYCZwRfjsoclAhXNucwzrEZEA2H1lcXPXUDgItu1QEPgl2a8nNrPHgZOAAjPbCNwGpAA45x4AXgbOBFYB1cCX/KpFROJP82BxTnoy6Skhtu6ojXFF3ZdvQeCcu2Q/+x3wNb9eX0TiU8vrBcyM3tnpbN2xK0YVdX9xMVgsIomjZdcQQJ+cNLZUqkXgFwWBiARKaxcQq0XgLwWBiASSRTQJeueksbVSQeAXBYGIBEsbLYKduxqo2tUQ/XoSgIJARAKl+YKyiCEC+uSkAah7yCcKAhEJlJYL04DXIgDYqgFjXygIRCRQWi5VCd4YAcAWtQh8oSAQkUCKHCzuE9EiaGrSugRdTUEgIoHS2gI0OT2SSU0O8dC7axn+g5fZUFYdg8q6LwWBiARKaxeUmRl9ctLYXOGNEby3vjz6hXVjCgIRCZSWs482ax4wBti2UxPQdSUFgYgEk+0dBc2nkAKU7tSgcVdSEIhIoLRcmKZZZIugVC2CLqUgEJFgaatrKLJFUKUWQVdSEIhIoLQ2WAwwoGeP3bc1RtC1FAQiEih7Bov3ToKBeXuCQC2CrqUgEJFAatkiGJiXsfu2lq3sWgoCEQmUtgaLC7P2jBHU1DdqJtIupCAQkUBp6zqCUGjvLRvLa6JTUAJQEIhIoLQ1WNzSek0z0WUUBCISSC0Hi1tSEHQdBYGIBEprk841e+5rx/GDM8eQmZqkiee6UHKsCxARieRaW5Ag7MhBPTlyUE/+/t4mBUEXUotARAKpvY6hwfkZ6hrqQgoCEQkka2e0eHB+BhvKq9vtRpKOUxCISKB05L19UH4GtfVNlGgW0i6hIBCRQGm+oGx/XUOAxgm6iIJARAJl9wVl7STBoHAQaJygaygIRCRQOnJBWfMEdOtLdXVxV1AQiEggtXdBWXpKEoXZaWzarhZBV1AQiEigdPRMoAE9e7Bpu1oEXUFBICKB0tG5hgbm9dAYQRdREIhIoHT00oAxfbPZUFZDZW29vwUlAAWBiARSexeUAYztnwPAis07olFOt6YgEJGA6ViT4JC+XhB8uEVBcLB8DQIzm2ZmK81slZl9r5X9g83sTTP7r5ktNrMz/axHRIKvrYVpWuqfm05WWjIfKQgOmm9BYGZJwL3AGcBY4BIzG9visB8CTznnJgAXA/f5VY+IxIeODhabGSN7Z7FSQXDQ/GwRHA2scs6tcc7VAU8A01sc44Cc8O1c4BMf6xGROLCnRbC/NgGM7pPFqq07fa6o+/MzCAYAGyLubwxvi3Q7cJmZbQReBr7e2hOZ2XVmVmxmxSUlJX7UKiIBs78WAcDoPtls21lHWVWd/wV1Y7EeLL4EeNQ5NxA4E/iLme1Tk3NuhnOuyDlXVFhYGPUiRSR6XAcHiwFG9s4C0DjBQepQEJjZzWaWY56HzOw9MzttPw/bBAyKuD8wvC3S1cBTAM65/wDpQEHHSheR7qijg8XgtQgAPlL30EHpaIvgy865SuA0IA+4HLhzP49ZAIwys2Fmloo3GDyzxTHrgVMAzOxQvCBQ349IAuvI7KPN+unMoS7R0SBo/pGcCfzFObeU/QS2c64BuBF4DViOd3bQUjP7iZmdFT7sW8C1ZrYIeBy4ymnJIREBOtImMDNG9M5iVYlaBAejo4vXLzSzfwLDgO+bWTbQtL8HOedexhsEjtz2o4jby4DjOl6uiHR3nRkjABhekMm8NaU+VZMYOtoiuBr4HnCUc64aSAG+5FtVIpKwOtM1BF4QfFJRS3Vdg39FdXMdDYIpwErn3HYzuwzvQrAK/8oSkUTXwRxgeKF35tC6bZqJ9EB1NAjuB6rN7Ai8fv3VwJ99q0pEEtaeFkHHomBYQSYAqzVOcMA6GgQN4UHc6cA9zrl7gWz/yhIR6ZgRvTNJTQ7xwSZ1Uhyojg4W7zCz7+OdNnp8+KKvFP/KEpFE1TxY3NGuobTkJA4fkEvxujL/iurmOtoiuAjYhXc9wad4F4fd5VtVIpKwOjtYDFA0NJ8PNlVQW9/oT1HdXIeCIPzm/zcg18w+D9Q65zRGICJdrqOzj0Y6amge9Y2ORRu2+1JTd9fRKSYuBOYDFwAXAvPM7Hw/CxORxNaR2UebTRqSB0Dxx+V+ldOtdXSM4Fa8awi2AphZIfA68IxfhYlIYjqQyQV6ZqQyqncW89eW8bWTfSiqm+voGEGoOQTCSjvx2MDQ7BUiwfPmyq1sqazdfX/3X2knuoYAJg/vxYJ1ZdQ37nfSA2mhoy2CV83sNbz5gMAbPH65neMDZ8mmCr706AKuP2E42enJJIdClFfXUZidxsC8HuT2SCUvI4WQGZlpyaQmx13OicSlLz2yAIB1d34O6Nzso5GmjOjFX+Z+zAebKpg4OK8LK+z+OhQEzrlbzOw89swLNMM59w//yup6uxoaKdmxi5++tLxDx2emJpGZlkxWejJZaclkpiZ799OSyEhLJi05xICePRicn0FujxRq6hupb3Q0NDaRmZZMTo8UBvTsQWZaEmnJSSSFOvtrLdL9RbbSdzU0kpactPt+Ry8oa3bU0HwAFqwtUxB0UkdbBDjnngWe9bEWX00aks/i209jTUkVWWlJmBlpySHKqup2f+yobaCxyVFatYuauiaqdjWws67B+1zbwMbyaqrqGqje1UhNfSPVdR07VS0lyeiTk05eRio9M1LI6ZFCTnPApHmfC7LSGFaQyZBeXrB09o9AJB41Nu0JgiWbKsODvgfWhVuYncbwgkzmry3j+hNHdFGFiaHdIDCzHbT+UzHAOedyWtkXWDnpKRw5qOde2wbmZRzQcznnqKipZ3VJFdV1DWSmJZMSCpGcZOyobaCypp6N5dXsamiivLqeLZW1lFfXUV5Vx8byGnaGw6WmlfOe8zNTGVmYxYjeWYwozGRk7yxGFGbRNzedlCR1WUn30RARBP/z1Pu8fcvJB9w1BF730PPvf0JDYxPJ+lvpsHaDwDmnaSTaYGb0zEhl0pDUg3qehsYmqnY1UrKzllVbq9hQVs3qkp2sLtnJq0s2U15dv/vY5JAxYXBPJg7O48hBPRk3IJeBeT3UepC4FRkEH5dW09jkDug6gmbHjijgb/PWs1jjBJ3S4a4h8UdyUojcjBC5GSmM7L1v7pbu3MXqkirWlOxk7bYq5qwu5ZHZ66gLnxmRn5nKYf1zdofDsSN77dXPKhJkjY3e237PjBS2V9fzyOy1HD7Qa7V35jqCZlNG9MIMZn+0TUHQCQqCgOuVlUavrDSOHpa/e1t9YxOLN25nxac7WLRhO4s3VnD3Gx/hHBRkpXHKmN4cPSyfqaMK6JOTHsPqRdpX3+T9Q3PF5CHc/cYqfvrScp66fgpwYC2C5n+M3lm1ja+fMqorS+3WFARxKCUpxKQh+Uwaks+lxwwBoGpXA3PXlPLMwo28vGQzTxZvwAzG9c9l+pH9OXvCAAqy0mJcucjeln1SCUDf3B6M7pPFh1t2srmi5qCe87iRBTz87lqqdnljd7J/Gk3pJjLTkjnl0D7cf9kkFv3oNF6+6Xi+9dnRAPz0peVM+fm/+dpj7zFn1TaamnRhnQTDFQ/PB7zxr7vOPwKA4nXeNBEHOvJ1/MhC6hsd89dqNtKOUlx2Q6GQMbZ/DmP753DjZ0bx0ZYdPD5/A8++t5GXFm9mVO8sbjhxBGcd2V9nIUkgVNbWM7a/dxLiX+Z+7G08wCQoGppHanKId1dt4+Qxvbuowu5N7wIJYFSfbH70hbHM+8Ep/OqCIwiZ8a2nF3HSXW/xpznrqOng9RAiftlR20BKUogzxvXdve1ABosB0lOSOHpoPu9+tK2ryuv2FAQJJD0lifMmDeSVm4/noSuL6Jubzm0zlzL1F29w75urqKip3/+TiHShw8KtgOtPHA7Aby46skue97iRBazcsoMNZVrHuCMUBAkoFDJOObQPz9wwhSevm8y4Abnc9dpKjrvzDX7x6gpKd+6KdYmSABoam1j6SSX9c9PJSPV6qdNTkuid7Z3U4A7wCmOAM8b1JWTw0Ltru6TW7k5BkMDMjGOG9+JPXz6aF78+lRNHF/LA26s54ZdvMmPWaho0i6P46JPt3oyjn1TU7rX9wqJBAOyqP/Dfv6EFmZwxvh/Pvb+JXQ3q+twfBYEAMG5ALvdeOpF/ffNEJg/vxf+9vIKz7pnNwo915oX4Y1uV1/K8KPzG3+ybnx3NA5dN5MTRhQf1/BcVDWJ7dT3/XLrloJ4nESgIZC8je2fx4JVF3H/pRMqr67h4xlx+9/pHWgtWutT60mrOvW8OAJdPGbLXvqSQMW1cP0IHOWPvcSMLGNCzB08Vbzio50kECgLZh5lxxvh+vPbNE/jMmN785vUPOfXXb/PGCv1nJV3j7jc+2n27eUygqyWFjPMmDeTdVdvYWK5B4/YoCKRNOekp/OHyIh6/djI9UpL48qPFfO/ZxWyvrot1aRLn+uXumfrEzyveL5g0EIBnF27y7TW6AwWB7NeUEb148aapXHXsUJ4q3sC5981h3baqWJclcawyfKryVccOPeguoPYMys/guBEFPL1wg66ob4eCQDokLTmJ2886jCeum0J5dR3n3Deb2at0wY4cmG076xhRmMntZx3m+2tdeNQgNpbXMGd1qe+vFa8UBNIpRw/L5x9fPY68zFSueHg+d7y4TAPJ0mnNizlFw2lj+5DbI0WDxu1QEEinDS3I5O9fOZYLiwbx0LtruXjGXMqqNG4gHdfQ5EiO0jre6SlJnDtxAC9/sJkPt+yIymvGGwWBHJCeGan8/NzxPHDZJJZvruSLf5zLNl2RLB3U0OhIDkXv7efrnxlFVnoytz2/FOc0VtCSgkAOyrRxfXn4qqNYV1rFJTMUBtIxDU1NJCdFb4nV/MxUvnHKKP6zppR3Nba1DwWBHLTjRhbwyFVHs6G8mssenKduItmv+kYX9cXlLzlmMP1z0/nhc0uoa9D0KZF8/UmY2TQzW2lmq8zse20cc6GZLTOzpWb2mJ/1iH+mjOjFg1ccxdptVVz24DxdayDtamhqitoYQbO05CR+PH0cH5dWM2PW6qi+dtD5FgRmlgTcC5wBjAUuMbOxLY4ZBXwfOM45dxjwDb/qEf9NHVXAjCuKWLV1J5c/NF/TWkubvDGC6AYBwKmH9ubM8X35zesfsfJTDRw387NFcDSwyjm3xjlXBzwBTG9xzLXAvc65cgDn3FYf65EoOHF0IQ9cPpEVn1Zy5cPzqaxVGMi+GppcTFbHMzN+dvZ4kkLG1x57T2NaYX7+JAYAkSfubgxvizQaGG1ms81srplN87EeiZLPjOnDvV+cyJJNFZxz72yNGcg+GhqjO1gcKS8zlXu/OJF126r41T9XxqSGoIn1YHEyMAo4CbgE+KOZ9Wx5kJldZ2bFZlZcUlIS5RLlQJx2WF8euuooNpTV8J1nFumUPdlLQ5MjKQZdQ80+O7YPFxQN4vH5G5h+z7sJ//vpZxBsAiInGh8Y3hZpIzDTOVfvnFsLfIgXDHtxzs1wzhU554oKCw9ujnKJnhNHF/L9M8fw+vKt/Pk/H8e6HAmQhkZHShSvI2jNd04/hP656SzaWMHf5q2nsckl7HxEfv4kFgCjzGyYmaUCFwMzWxzzHF5rADMrwOsqWuNjTRJlVx07lONG9uK2mUt5aoEu8RdPtK8jaE1eZipvf+dkjh3Rix8+t4QRP3iZi2fMTciV+XwLAudcA3Aj8BqwHHjKObfUzH5iZmeFD3sNKDWzZcCbwC3OOc0M1Y2YGQ9fdRSThuTx4xeWsrWydv8Pkm4vmlNMtCclKcT9l01i+pH9AZi/rowrH5nPWyu3suLTSiqqE+NkB4u3vrGioiJXXFwc6zKkk9aU7OT0385iWEEmL3x9KmnJSbEuSWJo3G2vcdFRg/jfz4/d/8FR9ItXV3D/W3tfY/Djsw7jkqMHk5oc6yHVg2NmC51zRa3ti++vTOLG8MIsrjl+OB9u2cnlD86naldDrEuSGKqP4VlD7fnO6Yfww88dyumH9dm97baZSxn9w1dYvHF7DCvzV3TmgRUBbjntEHqkJPGb1z/k9plLueuCI2JdksRIY0C6hloyM645fjjXHD8cgE3bazjuzjcAOOue2Xsd++o3jmdM35wuff26hibuf2s1jU1NXFA0iEH5GQDsqK2nsraB/rnpmHX9901BIFETChk3nTKK+sYmfv/GKiYP78V54aUEJXE458JjBMHvkBjQswfr7vwc/11fzhPzNzB79TY2ltcAMO2373DH9MO4fMrQ3cc759jV0MTqkp2M6Zuz31NkS3bs4t1VJfzo+aUMzs9g6SeVu/fd/cYqLps8mMsmD2Hab98B4Nrjh3Hr57q+O01BIFF30ymjmLO6lG8/s4j8zFROHtM71iVJFDWET9FMCWDXUFsmDM5jwuC83fcXb9zOWffM5n+fX8r9b61mbP9cNpZXs6KVaSvuvmQCXzi8H2+tLGF1yU4+f3h/fv7Kcp5//5O9jmsOgUuPGczJh/Tm+UWf8Ne56/nr3PUAjOmbzYVFg/Z5/q6gwWKJiR219d7ax6VVPHHdFCYNydv/g6RbqK1vZMz/vsp3p43hKyeNiHU5B2zpJxV87u53D+o5Lj5qEBOH5FGYlcZJhxRSWlVHQVYa4LUubn7ifd7fsJ17vjiBwwfuc61tp7Q3WKwWgcREdnoKj107mbPvnc01f1rAs185luGFWbEuS6KgLnyefjy1CFpzWP9c/vXNE8hIS2blp5UkhUL0z01nZO8szAznHPWNjrv//REPz17rtXwd5PRIYfqR/Zk8vNc+z9kcAuCNV9x9yYSofC0KAomZwuw0/nz10Zx73xxueuK/vHDjVF8GwiRYtlZ6E71FvunFq1F9sgFvLKElMyM12fj26Yfw7dMPiXZpnRL80Rrp1kYUZnHrmYeyZFMlry3dEutyJAo+2e4NtvZv5c1TYkNBIDF39oQBDMrvwQ1/Xcg1f1qQ8BOAdXebK7wg6JebHuNKpJmCQGIuNTnEU9dPYWivDF5fvpWfvLgs1iWJj6rrGgHITlfPdFAoCCQQ+uX24JWbTwDgkdnrmLNaC4x3V43h00djOQ217E1BIIHRIzWJhT88lYKsVK5+tJhlERfXSPehIAgeBYEESq+sNP5w+SRq6hu56pH5lGt1s26nMTwGFNIZYoGhIJDAmTQknwcum0R5dR0T7vgXCz8ui3VJ0oUaG70gCOJcQ4lKQSCBNG1cX370hcMAuOSP89RN1I00twjUNRQcCgIJrMsnD+HFr08lNSnEmXe/w39Wa82i7qCxyREydPFggCgIJNDGDcjl8WsnA3D9X4qZv1bdRPGuMcYL18u+FAQSeOMH5vLKzcdTWdvAt59epAvO4pzXIlAQBImCQOLCof1y+P4ZY1hfVs2tzy1RGMSxoC5Kk8gUBBI3rp46jEuPGcxj89Yz7Psv89MXlykQ4lBDkyOkIAgUBYHEjeSkEHdMH8dF4cU5Hnx3LVc8PJ/K2voYVyad0eTUIggaBYHElVDI+MX5h7Pm/87kxpNH8s5H2zjxl29SUaMwiBcNGiwOHAWBxKVQyJvn/SfTD6O8up7LH5pHfXjBEwm2hsamuFivOJHopyFx7YopQ/m/c8azeGMF1/ypmKYmjRkE3Y7aBs08GjAKAol7Fx81iONHFfD2hyX87/NLYl2O7Mf6smoFQcAoCCTuhULGo186GoC/zVvPf9eXx7giaUtjk2PpJ5VkpCoIgkRBIN1CUsh45zsnA3DOfXMYf/trrNq6M8ZVSUvNq5OdemjvGFcikRQE0m0Mys/grvMPB7x+6DN+N4tFG7bHuCqJ9HFpNQCj+2bHuBKJpCCQbuWCokGsu/NzPHPDFPIzU7n2z8VsKKuOdVkSNmPWGgCG9MqMcSUSSUEg3VLR0HzuvngCFTX1XPSH/7B2W1WsSxLg7Q9LAOibo4Xrg0RBIN3WMcN78chVR1FWXcepv36bpxZsoK5B1xrE0pi+2Zx0SKEuKAsYBYF0a8eOLOC3Fx1JZmoS33l2MaN/+Arr1DqImbKqOvpkqzUQNAoC6famjevH/FtP5dJjBgNw12srY1xRYnLOUV5dR35WaqxLkRYUBJIQ0lOS+Nk547npMyN56YPN3PHisliXlHB+8/pH1Dc6+uWqRRA0CgJJKDd+ZhRj++Xw0Ltr+cE/Poh1OQlj8cbt3P3vjwA464j+Ma5GWvI1CMxsmpmtNLNVZva9do47z8ycmRX5WY9IanKIJ6+fzHkTB/LYvPW8umRzrEtKCMs3VwLwwGWT6JmhrqGg8S0IzCwJuBc4AxgLXGJmY1s5Lhu4GZjnVy0ikbLTU7jzvPGMG5DDDX99j7dWbo11Sd1eWZU3TfiJowtjXIm0xs8WwdHAKufcGudcHfAEML2V4+4AfgHU+liLyF5SkkLcea53FfJVjyygZMeuGFfUvZVX15GeEqJHalKsS5FW+BkEA4ANEfc3hrftZmYTgUHOuZd8rEOkVeMG5PLgFV5v5LTfzuJ9TUfhm7KqOvLUJRRYMRssNrMQ8GvgWx049jozKzaz4pKSEv+Lk4Rx6tg+PHndZNJTkvj204t0wZlPyhUEgeZnEGwCBkXcHxje1iwbGAe8ZWbrgMnAzNYGjJ1zM5xzRc65osJC9TFK1zpmeC9+Mv0wVm3dyTn3zaa2vjHWJXU7/16xlZQkXU0cVH4GwQJglJkNM7NU4GJgZvNO51yFc67AOTfUOTcUmAuc5Zwr9rEmkVadcmgfJg/PZ+knldz2/NJYl9OtVNZ6A8WLNlbEuBJpi29B4JxrAG4EXgOWA08555aa2U/M7Cy/XlfkQP316mM46ZBCnizewL+WbYl1Od3G5u3eeSA3nzIqxpVIW3wdI3DOveycG+2cG+Gc+1l424+cczNbOfYktQYklpKTQjxw2ST65abz7acXsSP8n6wcnNtmesuHHjM8P8aVSFt0ZbFIhPSUJO754kQqauq55enFGi84SCs/3cHcNWUATB7WK8bVSFsUBCItTBqSx4VFA3l16adc++dihcEB+rSiltN/OwuAmTceR0hTTweWVpAWacUvzz+CnhmpzJi1hvPun8MfLp/EwLyMWJcVeM455qwu5dIH90wUcO6EARw+sGcMq5L9UYtApA0/OPNQfnfxkaz4dAdTf/Emj89fT0OjrjNoS0V1PVc+smCvELh66jB+fdGRMaxKOsKcc7GuoVOKiopccbHGlCV6XvlgM1/523u77//rmycwqo8WXwdvMrlb//EB5dX1ey0Hesvph3D+pIEUZqWpSyggzGyhc67ViT0VBCIdsPLTHVz20DxKduwiLyOFb512CIPyM5g0JI/05BBJIcMssd7wXnRcjHgAAAwiSURBVF3yKTf8deE+22fdcjKDe6kbLWgUBCJdZOHHZVz64Dxq6/ftInrhxqmMH5gbg6qir3hdGRf+4T80Ofj5ueP57Ng+9OyRQm1DE1lpGnoMIgWBSBdqaGzixcWb+e/6cjaU1/DGij3TWA/K78HLNx1PdnpKDCvsuOa//7ZaM8s3V1JRU8/EwXlsqaxlUH4Gs1dt2z0OMP/WU+itNYjjQntBoOgW6aTkpBBnTxjA2RMG4JyjycG8taV88Y/z2FBWw/jb/8kvzz8c5xznTxpEUkD6yMur6qhrbKJPTjrVdQ38a9kW7nxlBT1Sk/jVBUdw5KCeuwNhe3Udc1aX8tWIsZGWemenKQS6CbUIRLrQd59ZzJPFe2ZfL8xO4+9fOZbkJGNTeQ2ThuR1aCyhqcmxtrSKEYVZe23fXl3Hna+s4N8rtvLdaWM47bA+5KSn8MdZa5i56BM+2FTB1JEFXD11GIf2y+F3//6IXQ2N/P29PfM9jumbzYpPd7T6ugVZqYTM2LZzF03ht4arpw5j8cbtLFhXDsBh/XO4+ZRRHDOsF7kZ8dHyEXUNiURNU5Nj+aeV/HPpFp5YsJ4tlXsveHPEwFxOHtObi48aTN82FnFfU7KT219YxqwPSxjQswdfPXkE26vr+eeyLSxqZc2E4QWZrIk4Y6cteRkplFd702aM7J1FQVYqXziiP+MH5HLWPbP3OrZvTjonjC7gs2P78tmxfTr65UuAKQhEYmTpJxVc9uA8yqvrGZyfwfqy6t37fnr2OOoamvjJi8v4zJjevLe+nO3V+5/f6JbTD2FEYRaPzF7LvLVlu7dfOWUIVxw7lDeWb2Xe2lJG9M6id3Y6q7bu5MTRhUwb15d126p4b30550wYsE/LpLqugeq6RuauKeXzh2uB+e5GQSASQzt3NZBkRo/UJEp37uKBt1fz+vKte513H+ncCQM4f9JAjh1ZwJzV21iyqQLnYEy/nH3W/C2rqiMjNYn6xqa4GaCW2FAQiATMx6VVnHjXW5w2tg9XHTeUvjnpDMjrQVqy1vQVf+isIZGAGdIrk3V3fi7WZYgAmmtIRCThKQhERBKcgkBEJMEpCEREEpyCQEQkwSkIREQSnIJARCTBKQhERBJc3F1ZbGYlwHagIrwpdz+3mz8XANs68VKRz9XR/S237a+2eKmztW2qs+vqTOlkjaozunW2tS/e6hzinCtsZb+3MEW8fQAzOno74nPxgb5GR/e33NZd6mxjm+rsojo7W6PqjG6dbe2L1zpb+4jXrqEXOnE7ctuBvkZH97fc1l3qbGt/Z6hO1dnR/UGrs6198VrnPuKua+hAmVmxa2PCpSBRnV0rHuqMhxpBdXa1INUZry2CAzEj1gV0kOrsWvFQZzzUCKqzqwWmzoRpEYiISOsSqUUgIiKtUBCIiCQ4BYGISIJL+CAws+PN7AEze9DM5sS6nraYWcjMfmZmvzezK2NdT1vM7CQzeyf8PT0p1vW0x8wyzazYzD4f61raYmaHhr+Xz5jZV2JdT1vM7Gwz+6OZPWlmp8W6nraY2XAze8jMnol1LS2Ffx//FP4+XhrN147rIDCzh81sq5ktabF9mpmtNLNVZva99p7DOfeOc+4G4EXgT0GtE5gODATqgY0BrtMBO4H0gNcJ8F3gKT9qDNfTFb+fy8O/nxcCxwW4zuecc9cCNwAXBbjONc65q/2orzWdrPlc4Jnw9/GsaNUIxOeVxRFXzJ0ATASWRGxLAlYDw4FUYBEwFhiP92Yf+dE74nFPAdlBrRP4HnB9+LHPBLjOUPhxfYC/BbjOzwIXA1cBnw9qneHHnAW8AnwxyHWGH/crYGIc1OnL39BB1vx94MjwMY9Fo77mj7hevN45N8vMhrbYfDSwyjm3BsDMngCmO+d+DrTaBWBmg4EK59yOoNZpZhuBuvDdxqDWGaEcSAtqneFuq0y8P8AaM3vZOdcUtDrDzzMTmGlmLwGPdWWNXVWnmRlwJ/CKc+69rq6xq+qMts7UjNeCHgi8T5R7a+I6CNowANgQcX8jcMx+HnM18IhvFbWus3X+Hfi9mR0PzPKzsBY6VaeZnQucDvQE7vG3tL10qk7n3K0AZnYVsK2rQ6Adnf1+noTXZZAGvOxrZXvr7O/n14FTgVwzG+mce8DP4iJ09vvZC/gZMMHMvh8OjGhrq+a7gXvM7HMc+BQUB6Q7BkGnOedui3UN++Ocq8YLrEBzzv0dL7TignPu0VjX0B7n3FvAWzEuY7+cc3fjvZEFmnOuFG8cI3Ccc1XAl2Lx2nE9WNyGTcCgiPsDw9uCRnV2LdXZtVSnfwJXc3cMggXAKDMbZmapeAOCM2NcU2tUZ9dSnV1LdfoneDVHc2TahxH5x4HN7Dml8urw9jOBD/FG5m9VnapTdapO1dz2hyadExFJcN2xa0hERDpBQSAikuAUBCIiCU5BICKS4BQEIiIJTkEgIpLgFATiOzPbGYXXOKuDU0935WueZGbHHsDjJpjZQ+HbV5lZNOdkapOZDW05XXIrxxSa2avRqkmiQ0EgccPMktra55yb6Zy704fXbG8+rpOATgcB8APiYF6e1jjnSoDNZubL2ggSGwoCiSozu8XMFpjZYjP7ccT258xsoZktNbPrIrbvNLNfmdkiYIqZrTOzH5vZe2b2gZmNCR+3+z9rM3vUzO42szlmtsbMzg9vD5nZfWa2wsz+ZWYvN+9rUeNbZvZbMysGbjazL5jZPDP7r5m9bmZ9wlML3wB808zeN2+lu0Izezb89S1o7c3SzLKBw51zi1rZN9TM3gh/b/4dnh4dMxthZnPDX+9PW2thmbe61UtmtsjMlpjZReHtR4W/D4vMbL6ZZYdf553w9/C91lo1ZpZkZndF/Kyuj9j9HBDVFbTEZ7G+tFkf3f8D2Bn+fBowAzC8f0JeBE4I78sPf+4BLAF6he874MKI51oHfD18+6vAg+HbVwH3hG8/Cjwdfo2xeHO/A5yPN5VzCOiLt2bC+a3U+xZwX8T9PNh9Ff41wK/Ct28Hvh1x3GPA1PDtwcDyVp77ZODZiPuRdb8AXBm+/WXgufDtF4FLwrdvaP5+tnje84A/RtzPxVv0ZA1wVHhbDt6MwxlAenjbKKA4fHso4QVUgOuAH4ZvpwHFwLDw/QHAB7H+vdJH131oGmqJptPCH/8N38/CeyOaBdxkZueEtw8Kby/FW4Tn2RbP0zzN9UK8ufpb85zz1hhYZmZ9wtumAk+Ht39qZm+2U+uTEbcHAk+aWT+8N9e1bTzmVGCsmTXfzzGzLOdc5H/w/YCSNh4/JeLr+Qvwy4jtZ4dvPwb8v1Ye+wHwKzP7BfCic+4dMxsPbHbOLQBwzlWC13rAm/f+SLzv7+hWnu804PCIFlMu3s9kLbAV6N/G1yBxSEEg0WTAz51zf9hro7f4yqnAFOdctZm9hbfmMUCtc67limy7wp8baft3eFfEbWvjmPZURdz+PfBr59zMcK23t/GYEDDZOVfbzvPWsOdr6zLOuQ/NbCLeZGY/NbN/A/9o4/BvAluAI/Bqbq1ew2t5vdbKvnS8r0O6CY0RSDS9BnzZzLIAzGyAmfXG+2+zPBwCY4DJPr3+bOC88FhBH7zB3o7IZc988VdGbN8BZEfc/yfeSl0AhP/jbmk5MLKN15mDNyUxeH3w74Rvz8Xr+iFi/17MrD9Q7Zz7K3AX3jq5K4F+ZnZU+Jjs8OB3Ll5LoQm4HG8N3ZZeA75iZinhx44OtyTAa0G0e3aRxBcFgUSNc+6feF0b/zGzD4Bn8N5IXwWSzWw53rq3c30q4Vm8qYCXAX8F3gMqOvC424GnzWwhsC1i+wvAOc2DxcBNQFF4cHUZrayE5ZxbgbecY3bLfXgh8iUzW4z3Bn1zePs3gP8Jbx/ZRs3jgflm9j5wG/BT51wdcBHeEqeLgH/h/Td/H3BleNsY9m79NHsQ7/v0XviU0j+wp/V1MvBSK4+ROKVpqCWhNPfZm7d27XzgOOfcp1Gu4ZvADufcgx08PgOocc45M7sYb+B4uq9Ftl/PLLwF4stjVYN0LY0RSKJ50cx64g363hHtEAi7H7igE8dPwhvcNWA73hlFMWFmhXjjJQqBbkQtAhGRBKcxAhGRBKcgEBFJcAoCEZEEpyAQEUlwCgIRkQSnIBARSXD/H+euULpz5E4yAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find(show_plot=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 4: Train" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "early_stopping automatically enabled at patience=5\n", "reduce_on_plateau automatically enabled at patience=2\n", "\n", "\n", "begin training using triangular learning rate policy with max lr of 0.001...\n", "Train for 275 steps, validate for 122 steps\n", "Epoch 1/1024\n", "275/275 [==============================] - 10s 38ms/step - loss: 0.3674 - accuracy: 0.8285 - val_loss: 0.2957 - val_accuracy: 0.8624\n", "Epoch 2/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.3157 - accuracy: 0.8549 - val_loss: 0.2962 - val_accuracy: 0.8652\n", "Epoch 3/1024\n", "269/275 [============================>.] - ETA: 0s - loss: 0.3128 - accuracy: 0.8558\n", "Epoch 00003: Reducing Max LR on Plateau: new max lr will be 0.0005 (if not early_stopping).\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.3127 - accuracy: 0.8559 - val_loss: 0.2994 - val_accuracy: 0.8621\n", "Epoch 4/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.3086 - accuracy: 0.8574 - val_loss: 0.2951 - val_accuracy: 0.8652\n", "Epoch 5/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.3078 - accuracy: 0.8586 - val_loss: 0.2953 - val_accuracy: 0.8654\n", "Epoch 6/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.3057 - accuracy: 0.8595 - val_loss: 0.2933 - val_accuracy: 0.8659\n", "Epoch 7/1024\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.3045 - accuracy: 0.8595 - val_loss: 0.2928 - val_accuracy: 0.8634\n", "Epoch 8/1024\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.3033 - accuracy: 0.8605 - val_loss: 0.2927 - val_accuracy: 0.8649\n", "Epoch 9/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.3037 - accuracy: 0.8605 - val_loss: 0.2931 - val_accuracy: 0.8624\n", "Epoch 10/1024\n", "269/275 [============================>.] - ETA: 0s - loss: 0.3016 - accuracy: 0.8612\n", "Epoch 00010: Reducing Max LR on Plateau: new max lr will be 0.00025 (if not early_stopping).\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.3015 - accuracy: 0.8611 - val_loss: 0.2931 - val_accuracy: 0.8659\n", "Epoch 11/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2993 - accuracy: 0.8624 - val_loss: 0.2924 - val_accuracy: 0.8641\n", "Epoch 12/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2986 - accuracy: 0.8625 - val_loss: 0.2925 - val_accuracy: 0.8636\n", "Epoch 13/1024\n", "269/275 [============================>.] - ETA: 0s - loss: 0.2982 - accuracy: 0.8636\n", "Epoch 00013: Reducing Max LR on Plateau: new max lr will be 0.000125 (if not early_stopping).\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.2982 - accuracy: 0.8636 - val_loss: 0.2926 - val_accuracy: 0.8634\n", "Epoch 14/1024\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.2958 - accuracy: 0.8641 - val_loss: 0.2923 - val_accuracy: 0.8636\n", "Epoch 15/1024\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.2950 - accuracy: 0.8642 - val_loss: 0.2920 - val_accuracy: 0.8654\n", "Epoch 16/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2944 - accuracy: 0.8645 - val_loss: 0.2938 - val_accuracy: 0.8608\n", "Epoch 17/1024\n", "272/275 [============================>.] - ETA: 0s - loss: 0.2940 - accuracy: 0.8640\n", "Epoch 00017: Reducing Max LR on Plateau: new max lr will be 6.25e-05 (if not early_stopping).\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2943 - accuracy: 0.8638 - val_loss: 0.2924 - val_accuracy: 0.8641\n", "Epoch 18/1024\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.2929 - accuracy: 0.8651 - val_loss: 0.2919 - val_accuracy: 0.8647\n", "Epoch 19/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2929 - accuracy: 0.8644 - val_loss: 0.2926 - val_accuracy: 0.8649\n", "Epoch 20/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2924 - accuracy: 0.8651 - val_loss: 0.2917 - val_accuracy: 0.8647\n", "Epoch 21/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2921 - accuracy: 0.8658 - val_loss: 0.2918 - val_accuracy: 0.8652\n", "Epoch 22/1024\n", "274/275 [============================>.] - ETA: 0s - loss: 0.2908 - accuracy: 0.8667\n", "Epoch 00022: Reducing Max LR on Plateau: new max lr will be 3.125e-05 (if not early_stopping).\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.2912 - accuracy: 0.8665 - val_loss: 0.2919 - val_accuracy: 0.8652\n", "Epoch 23/1024\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.2909 - accuracy: 0.8663 - val_loss: 0.2920 - val_accuracy: 0.8652\n", "Epoch 24/1024\n", "272/275 [============================>.] - ETA: 0s - loss: 0.2905 - accuracy: 0.8667\n", "Epoch 00024: Reducing Max LR on Plateau: new max lr will be 1.5625e-05 (if not early_stopping).\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2904 - accuracy: 0.8670 - val_loss: 0.2921 - val_accuracy: 0.8649\n", "Epoch 25/1024\n", "269/275 [============================>.] - ETA: 0s - loss: 0.2899 - accuracy: 0.8668Restoring model weights from the end of the best epoch.\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2900 - accuracy: 0.8666 - val_loss: 0.2921 - val_accuracy: 0.8649\n", "Epoch 00025: early stopping\n", "Weights from best epoch have been loaded into model.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.autofit(1e-3)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " <=50K 0.89 0.94 0.91 3013\n", " >50K 0.74 0.62 0.67 881\n", "\n", " accuracy 0.86 3894\n", " macro avg 0.82 0.78 0.79 3894\n", "weighted avg 0.86 0.86 0.86 3894\n", "\n" ] }, { "data": { "text/plain": [ "array([[2824, 189],\n", " [ 338, 543]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate(class_names=preproc.get_classes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate Model on Unseen Test Data" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# download test dataset\n", "urllib.request.urlretrieve('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv', \n", " '/tmp/test.csv')\n", "\n", "\n", "test_df = pd.read_csv('/tmp/test.csv')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryclass
031Private16908511th7Married-civ-spouseSalesWifeWhiteFemale0020United-States<=50K
117Self-emp-not-inc22620312th8Never-marriedSalesOwn-childWhiteMale0045United-States<=50K
247Private54260Assoc-voc11Married-civ-spouseExec-managerialHusbandWhiteMale0188760United-States>50K
321Private176262Some-college10Never-marriedExec-managerialOwn-childWhiteFemale0030United-States<=50K
417Private24118512th8Never-marriedProf-specialtyOwn-childWhiteMale0020United-States<=50K
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education-num \\\n", "0 31 Private 169085 11th 7 \n", "1 17 Self-emp-not-inc 226203 12th 8 \n", "2 47 Private 54260 Assoc-voc 11 \n", "3 21 Private 176262 Some-college 10 \n", "4 17 Private 241185 12th 8 \n", "\n", " marital-status occupation relationship race sex \\\n", "0 Married-civ-spouse Sales Wife White Female \n", "1 Never-married Sales Own-child White Male \n", "2 Married-civ-spouse Exec-managerial Husband White Male \n", "3 Never-married Exec-managerial Own-child White Female \n", "4 Never-married Prof-specialty Own-child White Male \n", "\n", " capital-gain capital-loss hours-per-week native-country class \n", "0 0 0 20 United-States <=50K \n", "1 0 0 45 United-States <=50K \n", "2 0 1887 60 United-States >50K \n", "3 0 0 30 United-States <=50K \n", "4 0 0 20 United-States <=50K " ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `learner.evaluate` method is just an alias to `learner.validate`. By default, it was validate/evaluate\n", "`learner.val_data`, but both can accept a test set as an argument in the form of a `TabularDataset`.\n", "\n", "We use `learner.evaluate` here to compute test set metrics." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "processing test: 9769 rows x 15 columns\n", " precision recall f1-score support\n", "\n", " <=50K 0.88 0.94 0.91 7451\n", " >50K 0.76 0.61 0.67 2318\n", "\n", " accuracy 0.86 9769\n", " macro avg 0.82 0.77 0.79 9769\n", "weighted avg 0.85 0.86 0.85 9769\n", "\n" ] }, { "data": { "text/plain": [ "array([[6996, 455],\n", " [ 914, 1404]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.evaluate(preproc.preprocess_test(test_df), class_names=preproc.get_classes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Generating Test Results\n", "\n", "Let's generate a DataFrame showing the test set predictions for each instance:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['<=50K', '>50K']" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preproc.get_classes()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)\n", "preds = predictor.predict(test_df)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "df = test_df.copy()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "df['predicted_class'] = preds" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryclasspredicted_class
031Private16908511th7Married-civ-spouseSalesWifeWhiteFemale0020United-States<=50K<=50K
117Self-emp-not-inc22620312th8Never-marriedSalesOwn-childWhiteMale0045United-States<=50K<=50K
247Private54260Assoc-voc11Married-civ-spouseExec-managerialHusbandWhiteMale0188760United-States>50K>50K
321Private176262Some-college10Never-marriedExec-managerialOwn-childWhiteFemale0030United-States<=50K<=50K
417Private24118512th8Never-marriedProf-specialtyOwn-childWhiteMale0020United-States<=50K<=50K
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education-num \\\n", "0 31 Private 169085 11th 7 \n", "1 17 Self-emp-not-inc 226203 12th 8 \n", "2 47 Private 54260 Assoc-voc 11 \n", "3 21 Private 176262 Some-college 10 \n", "4 17 Private 241185 12th 8 \n", "\n", " marital-status occupation relationship race sex \\\n", "0 Married-civ-spouse Sales Wife White Female \n", "1 Never-married Sales Own-child White Male \n", "2 Married-civ-spouse Exec-managerial Husband White Male \n", "3 Never-married Exec-managerial Own-child White Female \n", "4 Never-married Prof-specialty Own-child White Male \n", "\n", " capital-gain capital-loss hours-per-week native-country class \\\n", "0 0 0 20 United-States <=50K \n", "1 0 0 45 United-States <=50K \n", "2 0 1887 60 United-States >50K \n", "3 0 0 30 United-States <=50K \n", "4 0 0 20 United-States <=50K \n", "\n", " predicted_class \n", "0 <=50K \n", "1 <=50K \n", "2 >50K \n", "3 <=50K \n", "4 <=50K " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }