{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" \n", "import ktrain\n", "from ktrain import vision as vis" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "DATADIR = 'data/planet'\n", "ORIGINAL_DATA = DATADIR+'/train_v2.csv'\n", "CONVERTED_DATA = DATADIR+'/train_v2-CONVERTED.csv'\n", "labels = vis.preprocess_csv(ORIGINAL_DATA, \n", " CONVERTED_DATA, \n", " x_col='image_name', y_col='tags', suffix='.jpg')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 40479 images belonging to 1 classes.\n", "Found 36357 validated image filenames.\n", "Found 4122 validated image filenames.\n" ] } ], "source": [ "trn, val, preproc = vis.images_from_csv(\n", " CONVERTED_DATA,\n", " 'image_name',\n", " directory=DATADIR+'/train-jpg',\n", " val_filepath = None,\n", " label_columns = labels,\n", " data_aug=vis.get_data_aug(horizontal_flip=True, vertical_flip=True))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The normalization scheme has been changed for use with a pretrained_resnet50 model. If you decide to use a different model, please reload your dataset with a ktrain.vision.data.images_from* function.\n", "\n", "Is Multi-Label? True\n", "pretrained_resnet50 model created.\n" ] } ], "source": [ "model = vis.image_classifier('pretrained_resnet50', trn, val_data=val)\n", "learner = ktrain.get_learner(model, train_data=trn, val_data=val, \n", " batch_size=64, workers=8, use_multiprocessing=False)\n", "learner.freeze(2)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "568\n", "5\n", "2840\n", "Epoch 1/5\n", "568/568 [==============================] - 213s 375ms/step - loss: 0.6386 - acc: 0.7609\n", "Epoch 2/5\n", "568/568 [==============================] - 201s 353ms/step - loss: 0.2260 - acc: 0.9303\n", "Epoch 3/5\n", "380/568 [===================>..........] - ETA: 1:05 - loss: 0.3707 - acc: 0.9044\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3xV9f3H8dcngzASwgozQNiITAkgoAVnUSm4FUdFqYq1jta22q1Wf9ra2tYtjtq6B4iIIm4BmWHvvcIMCYSQkP39/XGvNGASEsjJyc19Px+P+8g94977ziXcz/2e7znfrznnEBGR8BXhdwAREfGXCoGISJhTIRARCXMqBCIiYU6FQEQkzKkQiIiEuSi/A1RWs2bNXFJSkt8xRERCysKFC/c55xJK2xZyhSApKYmUlBS/Y4iIhBQz21rWNh0aEhEJcyoEIiJhToVARCTMqRCIiIQ5FQIRkTCnQiAiEubCqhBs2JvFwdwCv2OIiNQoYVMIPlq2i3Mfn0Hv+z8ldX+O33FERGqMsCkE3VrGHrl/2bOzKSgq9jGNiEjNETaFoHPzOJ4Y0w+APQfz+Mu0NT4nEhGpGcKmEACM6tOamb8+C4A35m8jv1CtAhGRsCoEAG2b1OfZa08jJ7+In76+iDkb03l02hq2pmf7HU1ExBchN+hcVbigVytG923NB0t28vnqPQB8vHwXX94zjKjIsKuNIhLmPPvUM7OXzWyvma0oY/u1ZrbMzJab2Wwz6+NVltL886q+vDpuIGYwIKkx2zJyuHfi8uqMICJSI3j59fcVYEQ52zcDw5xzvYA/AxM8zPI9ZsaZXRLY/MhFvH3LYE5p1ZCJi1L50ZOzOJCTX51RRER85VkhcM7NADLK2T7bObc/uDgXSPQqy/FERBjv/3QIZ3dvzvIdmdzx5mJy8gv9iiMiUq1qygHxccC0sjaa2S1mlmJmKWlpaZ4EqBsdyctjB3BpvzbMXL+PBz9c5cnriIjUNL53FpvZWQQKwRll7eOcm0Dw0FFycrLzMs/jV/WlWVwME2ZsIq5uFHef25UGMb6/TSIinvG1RWBmvYEXgdHOuXQ/s5T0i/O60jsxnhdmbuZHT83SVcgiUqv5VgjMrB0wCbjeObfOrxylqRsdyQe3D+WPI3uwKS2bu99e4nckERHPeHn66JvAHKCbmaWa2TgzG29m44O7/BFoCjxjZkvMrEbNSG9mjB2SRJtG9fho2S5W7zrodyQREU+Yc54ecq9yycnJLiWl+mrG9owchj32FcUOJt42hP7tG1fba4uIVBUzW+icSy5tW005a6jGatukPv+5aSAQGLV01U61DESkdlEhqIAzuyQcGbn0l+8uJUuT24hILaJCUEGj+rTmr5f3ZtWug/xhcqmjZoiIhCQVgkq4Mrktd5zdmclLdvLVmr1+xxERqRIqBJV0x9ld6NI8lvs/XEleYZHfcURETpoKQSXViYrg9yN7sDU9h4kLd/gdR0TkpKkQnIAfdGlGv3aN+Ofn69QqEJGQp0JwAsyMe87rxt6sPD5YvNPvOCIiJ0WF4AQN7dyUHq0a8szXG8gtUKtAREKXCsEJMjPuvaA7W9JzeHP+Nr/jiIicMBWCkzCsawIDOzThqS83kJmji8xEJDSpEJykP47swf6cfB7/bK3fUURETogKwUnq2SaeK5Pb8ub87ezKPOx3HBGRSlMhqAK3De+Ew/HUlxv8jiIiUmkqBFWgfdMGXJHcljfmb2Plzky/44iIVIoKQRW594fdia0TxdNfqVUgIqFFhaCKxNeP5oYhSUxbsZv1e7L8jiMiUmEqBFXopjM6UC86Uq0CEQkpKgRVqEmDOlx3enumLN3Jln3ZfscREakQFYIq9pMzOxAdGcEzX6tVICKhQYWgijWPq8tVA9oyefFO9mbl+h1HROS4VAg8cOPQDuQXFfPaXI1BJCI1nwqBBzo0a8A53Zvz+tytGplURGo8FQKPjDujA+nZ+UxZovkKRKRmUyHwyOBOTeneMo6Xv92Mc87vOCIiZVIh8IiZcd3p7VmzO4s1u3WBmYjUXCoEHrqgZ0siI4yPlu3yO4qISJlUCDzUNDaGwR2bMnXZTh0eEpEay7NCYGYvm9leM1tRxnYzsyfMbIOZLTOz07zK4qeL+7VhS3oOczam+x1FRKRUXrYIXgFGlLP9AqBL8HYL8KyHWXwzsncrmjSow79nb/E7iohIqTwrBM65GUBGObuMBv7rAuYCjcyslVd5/FI3OpKrB7Tli9V72J6R43ccEZHv8bOPoA2wvcRyanDd95jZLWaWYmYpaWlp1RKuKl0/uD2REcaLMzf5HUVE5HtCorPYOTfBOZfsnEtOSEjwO06ltYqvx4W9WjFp8Q7SsvL8jiMichQ/C8EOoG2J5cTgulrp9rM6k1tQxBNfrPc7iojIUfwsBFOAHwfPHjodyHTO1doT7ru2iOPy/om8nbJdo5KKSI3i5emjbwJzgG5mlmpm48xsvJmND+7yMbAJ2AC8APzUqyw1xbgzOpBfWMwnK3b7HUVE5Igor57YOTfmONsdcLtXr18TdW4eR7cWcfx3zlbGDGxHdGRIdNGISC2nT6Jqdve5Xdiw9xCPTlvjdxQREUCFoNqdfUpzAA07ISI1hgpBNYuJiuSxy3uz52Aei7cf8DuOiIgKgR9G9GxJ3egI3lmw/fg7i4h4TIXAB3F1o7mkXxsmL9lBZk6B33FEJMypEPjkmoHtyS0oZupyTWUpIv5SIfBJzzYN6dI8lokLU/2OIiJhToXAJ2bGZf0TWbTtAJv3ZfsdR0TCmAqBjy7u24YIg/cXqVUgIv5RIfBRy/i6nN6xKVOW6poCEfGPCoHPRvdtzZb0HJbvyPQ7ioiEKRUCn404tRXRkcaUJTp7SET8oULgs/j60Qzr2pypy3ZRXKzDQyJS/VQIaoBRfVuz+2AuC7aUN8WziIg3VAhqgHO6N6dOZASfr97jdxQRCUMqBDVAg5goBnVswpdr9vodRUTCkApBDXF29+ZsTMtmY9ohv6OISJhRIaghLurVKnhx2Q6/o4hImFEhqCGaN6zL0M7N+GDpDl1cJiLVSoWgBrm4bxu2Zxxm4db9fkcRkTCiQlCD/LBnSxrUieRtTVgjItVIhaAGiY2JYnS/Nny4bKcmrBGRaqNCUMNcM7AduQXFTFqsEUlFpHqoENQwPdvE06dtI96Yt02dxiJSLVQIaqCrB7Rl/d5DLNp2wO8oIhIGVAhqoB/1aU1c3ShenbPF7ygiEgZUCGqg2JgoLurVik9W7mZ3Zq7fcUSklvO0EJjZCDNba2YbzOy+Ura3M7OvzGyxmS0zswu9zBNKbh3WibzCYt6Yv83vKCJSy3lWCMwsEngauADoAYwxsx7H7PZ74B3nXD/gauAZr/KEmg7NGjCsawJvL9hGYVGx33FEpBbzskUwENjgnNvknMsH3gJGH7OPAxoG78cDmqarhDED27HnYB6zNuzzO4qI1GJeFoI2QMlLZFOD60q6H7jOzFKBj4E7SnsiM7vFzFLMLCUtLc2LrDXSsK4J1IuO5IvVGp5aRLzjd2fxGOAV51wicCHwqpl9L5NzboJzLtk5l5yQkFDtIf1SNzqSM7o044vVe3RNgYh4xstCsANoW2I5MbiupHHAOwDOuTlAXaCZh5lCzohTW7IzM5f5mzWNpYh4w8tCsADoYmYdzKwOgc7gKcfssw04B8DMTiFQCMLn2E8FXNCrJXExURqITkQ841khcM4VAj8DpgOrCZwdtNLMHjSzUcHd7gFuNrOlwJvAWKdjIEepXyeKkX1aM23FbnLyC/2OIyK1UJSXT+6c+5hAJ3DJdX8scX8VMNTLDLXBxX1b8+b8bXy2ag+j+x7b3y4icnL87iyWChiQ1IRW8XWZskRn14pI1VMhCAEREcaoPq35Zl0a+7Pz/Y4jIrWMCkGIGNW3NYXFjo9X7PI7iojUMioEIaJHq4Z0bh7LBzo8JCJVTIUgRJgZo/u0Zv7mDLZn5PgdR0RqkQoVAjO7y8waWsBLZrbIzM73Opwc7dL+iZjBpEXHXpcnInLiKtoiuMk5dxA4H2gMXA886lkqKVWbRvUY0L4J09RPICJVqKKFwII/LwRedc6tLLFOqtEFvVqyZncWG/Ye8juKiNQSFS0EC83sUwKFYLqZxQEaJN8HF/ZqRYTBB0t0eEhEqkZFC8E44D5ggHMuB4gGbvQslZSpRcO6nNklgYkLUykq1mgcInLyKloIBgNrnXMHzOw6AjOLZXoXS8pzef9Edmbm8q0mrBGRKlDRQvAskGNmfQgMFLcR+K9nqaRc55/agkb1o3k7RSOSisjJq2ghKAyOCjoaeMo59zQQ510sKU9MVCSX9kvk05W7ydCQEyJykipaCLLM7DcEThv9KDiLWLR3seR4rkhOpKDIMUWdxiJykipaCK4C8ghcT7CbwGxjj3mWSo7rlFYN6dGqIZMWqxCIyMmpUCEIfvi/DsSb2Ugg1zmnPgKfXdY/kWWpmazfk+V3FBEJYRUdYuJKYD5wBXAlMM/MLvcymBzf6L6tiYww3luU6ncUEQlhFT009DsC1xDc4Jz7MTAQ+IN3saQimsXGcO4pzXlr/nYO5hb4HUdEQlRFC0GEc25vieX0SjxWPHT7WZ3JPFzAk1+s9zuKiISoin6Yf2Jm081srJmNBT7imLmIxR+9Extxfo8WTFq0g9yCIr/jiEgIqmhn8a+ACUDv4G2Cc+5eL4NJxd0wJIn07Hwm6wwiETkBURXd0Tk3EZjoYRY5QUM6NaVL81jeX7yDqwe28zuOiISYclsEZpZlZgdLuWWZ2cHqCinlMzNG9m7N/C0ZbEzT8NQiUjnlFgLnXJxzrmEptzjnXMPqCinHd82gdtSJjOCFGZv8jiIiIUZn/tQSCXExXJnclomLUtmdmet3HBEJISoEtcjNZ3akqNjx8reb/Y4iIiFEhaAWade0PiN7t+b1uVvJzNEFZiJSMSoEtcz4YZ3Izi/itXlb/Y4iIiHC00JgZiPMbK2ZbTCz+8rY50ozW2VmK83sDS/zhIMerRsyrGsCL8/arAvMRKRCPCsEZhYJPA1cAPQAxphZj2P26QL8BhjqnDsVuNurPOHktuGdSM/O592FGoxORI7PyxbBQGCDc26Tcy4feIvADGcl3Qw87ZzbD3DMeEZyggZ1aEK/do2YMGMjhUXFfscRkRrOy0LQBig5qW5qcF1JXYGuZvatmc01sxGlPZGZ3WJmKWaWkpaW5lHc2sPMGD+sE9szDvPWAs1rLCLl87uzOAroAgwHxgAvmFmjY3dyzk1wziU755ITEhKqOWJoOu+UFvRr14invtxAgVoFIlIOLwvBDqBtieXE4LqSUoEpzrkC59xmYB2BwiAnKSLCuH14Z3YfzGX6yt1+xxGRGszLQrAA6GJmHcysDnA1MOWYfSYTaA1gZs0IHCrSGAlV5OzuzUlsXI/X5upUUhEpm2eFwDlXCPwMmA6sBt5xzq00swfNbFRwt+lAupmtAr4CfuWcS/cqU7iJiDCuO709czdlaF5jESmTOef8zlApycnJLiUlxe8YIWPfoTwGP/IFo/u24W9X9PE7joj4xMwWOueSS9vmd2exeKxZbAzjzujIewtTmbFOZ1yJyPepEISBO8/pTLsm9Xls+lqKi0OrBSgi3lMhCAP160Rx49Aklu/IZOIiXW0sIkdTIQgTY4ck0atNPE9+uUFXG4vIUVQIwoSZccfZndmWkcO0FbquQET+R4UgjJx7SguSmtbnldlb/I4iIjWICkEY+e66goVb9zNnoy7XEJEAFYIwc/XAdsTFRPHwx6sItWtIRMQbKgRhJjYmil+N6MaKHQd5WyOTiggqBGHpukHtGdq5KQ98uIp1GnpCJOypEIShiAjj0Ut7Exlh/PLdpbrITCTMqRCEqbZN6vPg6FNZlprJpMXHjg4uIuFEhSCMXdy3DcntG3P/lJWkH8rzO46I+ESFIIxFRBiPXtaLnPxCnp+haSBEwpUKQZjr3DyOkb1b8+a8bRzOL/I7joj4QIVAuHZQO7LyCnniy/W6tkAkDKkQCAM7NOHK5ESe/Xoj//hsnd9xRKSaRfkdQPxnFjid1Dl44ssN9E5sxLk9WvgdS0SqiVoEAgQ6jh++pBedm8fywNSVHMjJ9zuSiFQTFQI5ok5UBH+5rDe7M3O55x1daCYSLlQI5Cj92zfmdxeewhdr9jJ1+S6/44hINVAhkO/58eAkOjeP5Z+fryMrt8DvOCLiMRUC+Z6ICOPXP+zGprRsbnh5PgWa2lLEd1vTs8nM8eaLmQqBlOr8U1vy0MU9WbTtAK/P3ep3HJGwd94/ZvDMNxs8eW4VAinTtYPaMaRTU/46fS3bM3L8jiMStpxz5BcWUzcq0pPnVyGQMpkZf728NwDXvzSPXZmHfU4kEp7yCgOHZ2OivfnIViGQciU2rs9/bxpI+qF8rn1xHjn5hX5HEgk7RwpBKLYIzGyEma01sw1mdl85+11mZs7Mkr3MIycmOakJz/+4P5v3ZXPPO0s1HpFINcsrDAwIGRMVYi0CM4sEngYuAHoAY8ysRyn7xQF3AfO8yiInb0inZvzy/G5MW7Gb1+dt8zuOSFjJK/iuRRBihQAYCGxwzm1yzuUDbwGjS9nvz8BfgFwPs0gVuG1YJ87o3IzfT17B+FcXkpalyWxEqsOrwTP3YqJD79BQG2B7ieXU4LojzOw0oK1z7iMPc0gViYgwXhqbzG3DO/HJyt2c949v2LIv2+9YIrWWc45DeYV8vHwXEQYDkhp78jq+dRabWQTwOHBPBfa9xcxSzCwlLS3N+3BSppioSO4d0Z13bh3MgZwCrnh+zpHjlyJSdTKy8+n358/o+afppO4/zG8vPIVW8fU8eS0vC8EOoG2J5cTguu/EAT2Br81sC3A6MKW0DmPn3ATnXLJzLjkhIcHDyFJRAzs04dcjupGWlcej09b4HUek1nl02moOlLiSuH97b1oD4O18BAuALmbWgUABuBq45ruNzrlMoNl3y2b2NfBL51yKh5mkCv10eGf2Hszj399uoWfreC7rn+h3JJEaL7+wmDrH6fTNPFzAB0t2ckm/Ntw4NImZ6/fRO7GRZ5k8axE45wqBnwHTgdXAO865lWb2oJmN8up1pXr97qJTGNyxKfe8u5RB//c5//hsnU4vFSnDr99bStffT+OpL9eXu98zX20gr7CY6we3p3diI24/qzOREeZZLk/7CJxzHzvnujrnOjnnHg6u+6Nzbkop+w5XayD0REdG8My1pzFmYFv2HMzjX1+s1+mlIqXYnZnLOympAPzt03XM3rivzH3fXZhKi4Yx9GvrXSugJF1ZLCetcYM6PHJpbzY8fAGntGrI7yev4IrnZvPZqj1+RxOpMR6dthqA138yiA7NGnDzf1KYuf77J7/sz84nIzufn5zRETPvWgElqRBIlYmKjODNmwcxtHNTFmzZz83/TeGRj1frUJGEpeJid9T4XJOX7ARgcMemvH3L6dSrE8n1L83nrflHt6DX7skCoHPz2GrLqsnrpUo1ql+Hl8cOIGXLfibM2MTzMzaRW1DEr0Z0JzZGf24SPn753lImLQqcKHl58ESKC3u1JCLCaN6wLm/fOphRT87ivknLyS0o4seDk/hw2U7uemsJURFGf4+uGSiNhdq3teTkZJeSoq6EUFBYVMxdby3ho+W7aFAnkr9c3puRvVv7HUvkpBUVBy70Wp6aSd92jWhQJ/Kowzg7DxxmyKNffu9xD13ck+tOb39kOa+wiNteW8SXa/Yetd/NZ3bgdxd9b0Sek2JmC51zpY7npq9o4pmoyAieuqYfo1a15p53lvKzNxYTGxPF8G7N/Y4mckL2Z+fz2/eXM23F7qPWt2xYl8v7J7J2TxaxMVG8vzjQEvj6l8NpVD+a6St3M3P9Pkb2bnXU42KiInn++v488vEa3l+cSny9aB65tLdnVxCXRS0CqRZ7DuZy0ROziI2J5LWfDCKxcX2/I4lU2v1TVvLK7C00rBtFTn4RZoHDoaWNu9W9ZRyf3P0DH1KWTi0C8V2LhnV5+JKe/PztJVz1/FzeuHkQ7ZrUr7azIkRO1kuzNvPK7C1c0q8N/7iq71HbDuTks3DrfvYdyuMHXRPYezCPDgkNfEpaeWoRSLVanprJNS/MJSuvkOT2jXn5xgE0rBvtdyyRcm3el81Zf/uayAhj+f3nU79O6H2HLq9FoNNHpVr1SoznwzvO4K5zurBk+wEufvpb1uw+6HcsCSOH8gp5dNoarnx+DvdPWcmkRalk5RaUuX9uQRFj/z0fgBd+3D8ki8Dx1L7fSGq8pGYN+Pl5XenVJp7bXl/ImAlzefra0xjSqdnxHyxSSfsO5fHizM1EGHRrGcdfP1nLjgOB8/vnb844st/I3q04rV1jTu/YlK4tYomKjCC3oIiHP1rN1vQcHr+yD2d3b+HXr+EpHRoSX63fk8Xlz80h83ABHZo14DcXdOe8Hi3UdyBVIq+wiKsnzGXxtgNHrf/DyB6MO6MDxcWOBVsyeHDqKjbvyyYn/+gh1ZvHxbA3K4+LerXiqWv6hfTfZXmHhlQIxHfph/K4660lzNoQGHtlSKemPHPtaTSqX8fnZBLqfjNpOW/O38YTY/qR2LgeqfsP07RBHYZ2/n7rs7ComDmb0tmw9xDfrEvj67VpdGsRx23DO3FxvzalPHtoUSGQkFBQVMzDH63mldlbiDCYfPtQT4feldot83ABfR74lKuS2/KXy3v7Hcd3On1UQkJ0ZAR/+lEPGtWP5p+fr2fUU9/yi/O6MqpPa5Kahc6peKHs5VmbeXDqKgA6NmtAx4RYbh3WkQFJTY7ssyntEC/O2szFfdswsEOTsp7KN8458gqLuezZ2QAM76bJrI5HLQKpkbZn5DDuPwtYt+cQEQYPjj760nypegu2ZHDFc3MAOLV1Q+pGR7Jw637M4Ic9WpJbWERBUTHfbkg/8pjk9o358ZAkhnVJYMGWDB7+eDUdmjVgdN/W9G3biGaxMbw5fxsb07I5pVUcPVo1pF+7xp6Nrf/F6j088OEqtmXkAHDviO6MH1Z9o3jWZDo0JCEpr7CIORvTmTBjE7M3pnN+jxb836W9aBYb43e0Wqeo2HHREzM5lFfI1DvOONI/s+9QHne+uZjZG//34d+1RSy3De/Eqp0HeWHmZgAiDIor+FEyvFsCT47pR1wVXz/y+ryt/O79FQCYBYrZxNuGEBMVWaWvE6pUCCSk5RYU8dKszTzxxXq6t2rIC9f3p3nDun7HqlXueWcpExel8uSYfvyoz/cHBkzdnxP4Vt8y7qj3fuXOTDamZbN+Txb7DuVxz/ndAPh2wz72HMxl36F8GtWP5uYzOzJ12U5+/d4yCoocbRrV46lr+tGvXdWMqXM4v4iB//c57ZvW57Vxg3SiQSlUCKRWmL5yN3e9tZjExvWZOH4I8fV1RfLJ2nHgMH+bvpb3F+/g2kHteOjinp4eRnHO8f7iHTw4dRUHcgq48+zO/CJYPE7UwdwC/jB5BR8s2clbt5zO6R2bVlHa2kVXFkut8MNTW/LvsQPZlp7DiH/NYF1wAg85MbkFRVz1/BzeX7yDyAhj/LBOnh9LNzMuPS2RWfeezYW9WvLElxv489RVJzR50aG8Qh6Ztpqb/5PCB0t2Mu6MDgyqgZ3XoUAtAgk501fuZvxrC4kw486zu3D7WZ2IitR3msooLnbc+tpCPlu1h/t/1INrT29PdDW/h4VFxTzw4SpenbuVtk3q8eilvRmQ1AQzys2SX1jMb99fznsLU4+s+9UPu3H7WZ2rI3bI0qEhqXW2Z+Tw4NRVfLZqD+2b1uee87sxqpRj2/J9uQVFXPCvmWzel82twzrymwtO8S2Lc45nv9nIXz9Ze2RddKTRtkl97j63K8O6JLA3K5f4etE0rBfN/VNW8taC7QA0bVCHMQPbMahjE4Z0aubZmUi1hQqB1ErOOd5J2c7jn61jz8E8kts35rEr+tBB1xyUaen2A1z30jyycgu5dVhH7hvRvUacWnk4v4jX5m5l7qZ0snILWbx9PwVFZX823X1uF+4+t2s1Jgx9KgRSq+07lMdDU1cdmRx8SKem3Di0A8O6JlAnSoeMvrMs9QDXvhgoAn+/og+XBefRrYmKix3vLUrljXnbOK1dYw4XFLE78zCDOjbl5jM76tv/CVAhkLCQsiWDb9al8ca8baRn59O4fjTPX59cI69+rW5zNqYz5oW5xERF8MHPhtK9ZUO/I0k1UyGQsHIwt4BZ6/fxt0/XsvPAYX574SlcmdyWutHheWHR+j1ZjHlhLnWjI3nm2tM0flOYUiGQsJSWlccNL89n1a7AxDfXnd6Om8/sSPum4dWHcMVzs1m/9xDvjR9C5+axfscRn2jQOQlLCXExTL3jDF6atZkPlu7gnQWBY843DEniqgFta+3hkb1Zufx9+jpioiPYnpHDgi37GT+sk4qAlEktAgkbuzIPc+2L89iUlg3AeT1acPOZHWtNH8LW9Gyen7GJ91JSyS8qBgJj7lzStw1/vby3rrUIc761CMxsBPAvIBJ40Tn36DHbfwH8BCgE0oCbnHNbvcwk4atVfD0+vfsH7MnK4+3523j66418sXoPNw3twK9GdAvpwckKi4oZ++8FbN6XTZ/EeO69oDsDkpoQYaYzbOS4PGsRmFkksA44D0gFFgBjnHOrSuxzFjDPOZdjZrcBw51zV5X3vGoRSFXJzivkDx+sYNKiHbRsWJdbh3VkzMB2Idep/NmqPdz838D/iYcu1nDdUjq/WgQDgQ3OuU3BEG8Bo4EjhcA591WJ/ecC13mYR+QoDWKiePzKvpzfowUvzdrMAx+u4pFpaygsKiaubjTdW8bxh5E9OLV1wxpx0VVpsnILuP2NRQDcODSJqwe09TmRhCIvC0EbYHuJ5VRgUDn7jwOmlbbBzG4BbgFo165dVeUTAWBEz1aM6NmKr9bu5dOVu8nOK2JrRg7zNmcw8slZ1IuO5IYhSdw4NIkWNWz460mLdpBfWMzE24bQv33VDOks4adGnDVkZtcBycCw0rY75yYAEyBwaKgao0kYOatbc87q1vzI8q7Mw7ybksrM9Wk8981GnvtmI4M6NOHifmkqdlAAAArqSURBVG0Y3bc19ev4998nr7CIqUt38ZdP1tAnMV5FQE6Kl3/JO4CS7dTE4LqjmNm5wO+AYc65PA/ziFRKq/h63HlOF+48pwuLtu1n4sJUZq7fx28mLedPH6wkMsK4bXgnxg5NomEVz7ZVnu0ZOdz51mIWbzsABKZjFDkZXnYWRxHoLD6HQAFYAFzjnFtZYp9+wHvACOfc+oo8rzqLxU/OORZu3c/L327m2w3pZB4uIK5uFKe1a8yApMbE169Do3rRREcaMVGRHMwtoEOzBrRpVI+mJzHF5rxN6Tz99UZW7sgkPTsfgD+O7MENQ5J0VpBUiC+dxc65QjP7GTCdwOmjLzvnVprZg0CKc24K8BgQC7wb7Izb5pwb5VUmkZNlZiQnNSE5KXDtweJt+3lp1maW78jkm3Vp5T62WWwMXVvEckaXZpzTvQUt4+sSX+/olsTUZTt5c/42kts3ITLCSNm6n23p2WxJD0zG3qdtIy7rn8iVyYl0bh7nzS8pYUcXlIlUAeccuQXF7DiQQ/qhfPIKi8nOK6Rtk/rM35zBuj1Z5BcWs2BrBtszDh95XExUBPXqROIcFDtHVm7hUc/bLDaGDs3q069dY24f3lnTc8oJ0xATIh4zM+rViaRz8zg6Nz96W8828UctL9iSQer+HFbuOMhXa/fSvVVDYqIiyCsoplNCA352dhf25+QTXy865K5pkNCkFoGISBjQ5PUiIlImFQIRkTCnQiAiEuZUCEREwpwKgYhImFMhEBEJcyoEIiJhToVARCTMhdwFZWaWBnw3nWU8kFnO/e9+NgP2neBLlnzeymw/dn15y15kP17u8vZRdmWvzD4VWafslctVkX0qm72Lcy6e0jjnQvYGTCjvfomfKVXxGpXZfuz68pa9yH683Mqu7FWVvSLrlL1mZT/2FuqHhj48zv2S66riNSqz/dj15S17kb0ij1f2799X9srvU5F1yl42P7IfJeQODZ0IM0txZYyxUdMpuz+U3R/K7o9QbxFU1AS/A5wEZfeHsvtD2X0QFi0CEREpW7i0CEREpAwqBCIiYU6FQEQkzIV9ITCzM83sOTN70cxm+52nMswswsweNrMnzewGv/NUhpkNN7OZwfd+uN95KsvMGphZipmN9DtLZZjZKcH3/D0zu83vPJVhZheb2Qtm9raZne93nsows45m9pKZved3ltKEdCEws5fNbK+ZrThm/QgzW2tmG8zsvvKewzk30zk3HpgK/MfLvCVVRXZgNJAIFACpXmU9VhVld8AhoC6hlx3gXuAdb1KWror+3lcH/96vBIZ6mbekKso+2Tl3MzAeuMrLvCVVUfZNzrlx3iY9CSdyJVxNuQE/AE4DVpRYFwlsBDoCdYClQA+gF4EP+5K35iUe9w4QF0rZgfuAW4OPfS/EskcEH9cCeD3Esp8HXA2MBUaGUvbgY0YB04BrQi178HF/B04L0ezV9v+0MrcoQphzboaZJR2zeiCwwTm3CcDM3gJGO+ceAUptxptZOyDTOZflYdyjVEV2M0sF8oOLRd6lPVpVve9B+4EYL3KWpore9+FAAwL/8Q+b2cfOuWIvc0PVve/OuSnAFDP7CHjDu8RHvWZVvO8GPApMc84t8jbx/1Tx33uNFNKFoAxtgO0lllOBQcd5zDjg354lqrjKZp8EPGlmZwIzvAxWAZXKbmaXAj8EGgFPeRvtuCqV3Tn3OwAzGwvsq44iUI7Kvu/DgUsJFN+PPU12fJX9e78DOBeIN7POzrnnvAx3HJV935sCDwP9zOw3wYJRY9TGQlBpzrk/+Z3hRDjncggUsZDjnJtEoJCFLOfcK35nqCzn3NfA1z7HOCHOuSeAJ/zOcSKcc+kE+jZqpJDuLC7DDqBtieXE4LpQoOz+UHZ/KHsNURsLwQKgi5l1MLM6BDr1pvicqaKU3R/K7g9lryn87q0+yd78N4Fd/O/0yXHB9RcC6wj06v/O75zKXnNuyq7s4ZS9ojcNOiciEuZq46EhERGpBBUCEZEwp0IgIhLmVAhERMKcCoGISJhTIRARCXMqBOI5MztUDa8xqoLDR1flaw43syEn8Lh+ZvZS8P5YM/N7rCUAzCzp2KGWS9knwcw+qa5MUj1UCCRkmFlkWducc1Occ4968Jrljcc1HKh0IQB+S+iOmZMG7DKzapvLQLynQiDVysx+ZWYLzGyZmT1QYv1kM1toZivN7JYS6w+Z2d/NbCkw2My2mNkDZrbIzJabWffgfke+WZvZK2b2hJnNNrNNZnZ5cH2EmT1jZmvM7DMz+/i7bcdk/NrM/mlmKcBdZvYjM5tnZovN7HMzaxEclng88HMzW2KBme4SzGxi8PdbUNqHpZnFAb2dc0tL2ZZkZl8G35svgsOjY2adzGxu8Pd9qLQWlgVmTPvIzJaa2Qozuyq4fkDwfVhqZvPNLC74OjOD7+Gi0lo1ZhZpZo+V+Le6tcTmycC1pf4DS2jy+9Jm3Wr/DTgU/Hk+MAEwAl9CpgI/CG5rEvxZD1gBNA0uO+DKEs+1BbgjeP+nwIvB+2OBp4L3XwHeDb5GDwLjxgNcTmDo5QigJYG5EC4vJe/XwDMllhvDkavwfwL8PXj/fuCXJfZ7AzgjeL8dsLqU5z4LmFhiuWTuD4EbgvdvAiYH708FxgTvj//u/TzmeS8DXiixHE9gwpRNwIDguoYERhyuD9QNrusCpATvJxGcfAW4Bfh98H4MkAJ0CC63AZb7/XelW9XdNAy1VKfzg7fFweVYAh9EM4A7zeyS4Pq2wfXpBCbcmXjM83w3fPVCAmPrl2ayC8wTsMrMWgTXnQG8G1y/28y+Kifr2yXuJwJvm1krAh+um8t4zLlAj8D8KQA0NLNY51zJb/CtgLQyHj+4xO/zKvDXEusvDt5/A/hbKY9dDvzdzP4CTHXOzTSzXsAu59wCAOfcQQi0HoCnzKwvgfe3aynPdz7Qu0SLKZ7Av8lmYC/QuozfQUKQCoFUJwMecc49f9TKwGQp5wKDnXM5ZvY1gbmMAXKdc8fOvpYX/FlE2X/DeSXuWxn7lCe7xP0ngcedc1OCWe8v4zERwOnOudxynvcw//vdqoxzbp2ZnUZgILSHzOwL4P0ydv85sAfoQyBzaXmNQMtreinb6hL4PaSWUB+BVKfpwE1mFgtgZm3MrDmBb5v7g0WgO3C6R6//LXBZsK+gBYHO3oqI539jzd9QYn0WEFdi+VMCs2gBEPzGfazVQOcyXmc2geGMIXAMfmbw/lwCh34osf0oZtYayHHOvQY8RmCO3bVAKzMbENwnLtj5HU+gpVAMXE9g/t1jTQduM7Po4GO7BlsSEGhBlHt2kYQWFQKpNs65Twkc2phjZsuB9wh8kH4CRJnZagJz0s71KMJEAsMIrwJeAxYBmRV43P3Au2a2ENhXYv2HwCXfdRYDdwLJwc7VVZQyI5Vzbg2BqRbjjt1GoIjcaGbLCHxA3xVcfzfwi+D6zmVk7gXMN7MlwJ+Ah5xz+cBVBKYzXQp8RuDb/DPADcF13Tm69fOdFwm8T4uCp5Q+z/9aX2cBH5XyGAlRGoZawsp3x+wtMIfsfGCoc253NWf4OZDlnHuxgvvXBw4755yZXU2g43i0pyHLzzODwETt+/3KIFVLfQQSbqaaWSMCnb5/ru4iEPQscEUl9u9PoHPXgAMEzijyhZklEOgvURGoRdQiEBEJc+ojEBEJcyoEIiJhToVARCTMqRCIiIQ5FQIRkTCnQiAiEub+H18pgNpr+mr4AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find()\n", "learner.lr_plot()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using onecycle policy with max lr of 0.0001...\n", "Epoch 1/20\n", "568/568 [==============================] - 206s 363ms/step - loss: 0.1980 - acc: 0.9299 - val_loss: 0.6752 - val_acc: 0.6485\n", "Epoch 2/20\n", "568/568 [==============================] - 206s 363ms/step - loss: 0.1426 - acc: 0.9492 - val_loss: 0.1381 - val_acc: 0.9508\n", "Epoch 3/20\n", "568/568 [==============================] - 206s 362ms/step - loss: 0.1248 - acc: 0.9547 - val_loss: 0.1074 - val_acc: 0.9616\n", "Epoch 4/20\n", "568/568 [==============================] - 206s 363ms/step - loss: 0.1159 - acc: 0.9577 - val_loss: 0.0996 - val_acc: 0.9637\n", "Epoch 5/20\n", "568/568 [==============================] - 207s 364ms/step - loss: 0.1090 - acc: 0.9601 - val_loss: 0.0977 - val_acc: 0.9643\n", "Epoch 6/20\n", "568/568 [==============================] - 206s 363ms/step - loss: 0.1054 - acc: 0.9609 - val_loss: 0.0987 - val_acc: 0.9629\n", "Epoch 7/20\n", "568/568 [==============================] - 206s 362ms/step - loss: 0.1038 - acc: 0.9616 - val_loss: 0.1052 - val_acc: 0.9622\n", "Epoch 8/20\n", "568/568 [==============================] - 206s 363ms/step - loss: 0.1030 - acc: 0.9617 - val_loss: 0.0936 - val_acc: 0.9650\n", "Epoch 9/20\n", "568/568 [==============================] - 206s 363ms/step - loss: 0.1003 - acc: 0.9627 - val_loss: 0.0967 - val_acc: 0.9651\n", "Epoch 10/20\n", "568/568 [==============================] - 207s 365ms/step - loss: 0.0998 - acc: 0.9628 - val_loss: 0.0957 - val_acc: 0.9651\n", "Epoch 11/20\n", "568/568 [==============================] - 207s 364ms/step - loss: 0.0977 - acc: 0.9635 - val_loss: 0.0918 - val_acc: 0.9660\n", "Epoch 12/20\n", "568/568 [==============================] - 207s 364ms/step - loss: 0.0961 - acc: 0.9638 - val_loss: 0.0953 - val_acc: 0.9655\n", "Epoch 13/20\n", "568/568 [==============================] - 207s 365ms/step - loss: 0.0938 - acc: 0.9646 - val_loss: 0.0961 - val_acc: 0.9649\n", "Epoch 14/20\n", "568/568 [==============================] - 207s 365ms/step - loss: 0.0920 - acc: 0.9654 - val_loss: 0.0908 - val_acc: 0.9660\n", "Epoch 15/20\n", "568/568 [==============================] - 207s 365ms/step - loss: 0.0904 - acc: 0.9659 - val_loss: 0.0916 - val_acc: 0.9664\n", "Epoch 16/20\n", "568/568 [==============================] - 207s 364ms/step - loss: 0.0885 - acc: 0.9664 - val_loss: 0.0904 - val_acc: 0.9662\n", "Epoch 17/20\n", "568/568 [==============================] - 207s 364ms/step - loss: 0.0867 - acc: 0.9672 - val_loss: 0.0901 - val_acc: 0.9668\n", "Epoch 18/20\n", "568/568 [==============================] - 207s 364ms/step - loss: 0.0856 - acc: 0.9674 - val_loss: 0.0902 - val_acc: 0.9666\n", "Epoch 19/20\n", "568/568 [==============================] - 207s 364ms/step - loss: 0.0833 - acc: 0.9683 - val_loss: 0.0912 - val_acc: 0.9666\n", "Epoch 20/20\n", "568/568 [==============================] - 207s 364ms/step - loss: 0.0813 - acc: 0.9688 - val_loss: 0.0901 - val_acc: 0.9670\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit_onecycle(1e-4, 20)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import fbeta_score\n", "import numpy as np\n", "import warnings\n", "def f2(preds, targs, start=0.17, end=0.24, step=0.01):\n", " with warnings.catch_warnings():\n", " warnings.simplefilter(\"ignore\")\n", " return max([fbeta_score(targs, (preds>th), 2, average='samples')\n", " for th in np.arange(start,end,step)])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "y_pred = learner.model.predict_generator(val)\n", "y_true = val.labels" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9249264279306654" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f2(y_pred, y_true)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }