{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from nb_006a import *\n", "from concurrent.futures import ProcessPoolExecutor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Camvid" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = Path('data/camvid')\n", "PATH_X = PATH/'701_StillsRaw_full'\n", "PATH_Y = PATH/'LabeledApproved_full'\n", "PATH_Y_PROCESSED = PATH/'LabelProcessed'\n", "label_csv = PATH/'label_colors.txt'\n", "\n", "PATH_Y_PROCESSED.mkdir(exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('data/camvid/LabeledApproved_full/0016E5_08109_L.png')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(PATH_Y.iterdir())[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_y_fn(x_fn): return PATH_Y/f'{x_fn.name[:-4]}_L.png'\n", "def get_y_proc_fn(y_fn): return PATH_Y_PROCESSED/f'{y_fn.name[:-6]}_P.png'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "([PosixPath('data/camvid/701_StillsRaw_full/Seq05VD_f02220.png'),\n", " PosixPath('data/camvid/701_StillsRaw_full/0016E5_08550.png'),\n", " PosixPath('data/camvid/701_StillsRaw_full/0001TP_007260.png')],\n", " [PosixPath('data/camvid/LabeledApproved_full/Seq05VD_f02220_L.png'),\n", " PosixPath('data/camvid/LabeledApproved_full/0016E5_08550_L.png'),\n", " PosixPath('data/camvid/LabeledApproved_full/0001TP_007260_L.png')],\n", " [PosixPath('data/camvid/LabelProcessed/Seq05VD_f02220_P.png'),\n", " PosixPath('data/camvid/LabelProcessed/0016E5_08550_P.png'),\n", " PosixPath('data/camvid/LabelProcessed/0001TP_007260_P.png')])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_fns = get_image_files(PATH_X)\n", "y_fns = [get_y_fn(o) for o in x_fns]\n", "y_proc_fns = [get_y_proc_fn(o) for o in y_fns]\n", "x_fns[:3],y_fns[:3],y_proc_fns[:3]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def parse_code(l):\n", " a,b = [c for c in l.strip().split(\"\\t\") if c]\n", " return tuple(int(o) for o in a.split(' ')), b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(((64, 128, 64), (192, 0, 128), (0, 128, 192), (0, 128, 64), (128, 0, 0)),\n", " ('Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building'),\n", " 32)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "label_codes,label_names = zip(*[parse_code(l) for l in open(PATH/\"label_colors.txt\")])\n", "label_t = tensor(label_codes)\n", "n_labels = len(label_codes)\n", "label_codes[:5],label_names[:5], n_labels" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "name2code = dict(zip(label_names, label_codes))\n", "name2id = {v:k for k,v in enumerate(label_names)}\n", "void_code = name2id['Void']\n", "\n", "code2id = ByteTensor(255,255,255).zero_()+void_code\n", "for i,code in enumerate(label_codes):\n", " if not code == void_code: code2id[code]=i" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def colors_to_codes(color_data):\n", " n = len(color_data)\n", " idxs = tuple(color_data.reshape(n,-1).long())\n", " return code2id[idxs].view(color_data.shape[1:])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "i = 0\n", "x_img = open_image(x_fns[i])\n", "y_img_mask = open_mask(y_fns[i])\n", "y_img = Image(y_img_mask.data.int())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_code = colors_to_codes(y_img.data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def codes_to_colors(label_data):\n", " h,w = label_data.shape\n", " idxs = label_data.flatten().long()\n", " return Image(label_t.index_select(0, idxs).reshape(h,w,3).permute(2,0,1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_img2 = codes_to_colors(y_code)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(None, None)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAANMAAACgCAYAAACWsVm2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADlVJREFUeJztnWuWqzoOhVW9el7tmZXuyNo9suofxMEY+S2DMftbJycJ+BXwRrJsqJ+/vz8CAPTzr7sbAMAqQEwAKAExAaAExASAEhATAEpATAAoATEBoATEBIASEBMASvz77gYQEf3z86OzDINZpRhtmPnwInKvQyrhc/gefk5texj+udM4j6VlFKT7/f39KSkKlmkw7J0sPp04jnyuqqExH9AGYhrIZolCEcU+H/PVpAdzMIWbtx6SiHJp9rSyBZPKSpW/MNJxncDFh2XqxHX80JpsX5lSIoiLLba9Nk1NuocQO2bblWl/3QDE1MlRTOHnfZuXQ9gW7ovtT7dBo6xbCNufEszEQEyq8PZ/9qRz8B7bn8tPh/rOY61UObk6JmNyIRFBTDfB3juf9hpz3pYuY3svjwYXJryLUkulVZeSe4gAxAWc55fcq6tUoZyysn2XdGruEL1U5+9vUVZYpos4R+2Y0h0/tj2WNiy3JM+k3BhE6AGWqQO2Nr3fmO09GZCg0zZrpTRBjq+1k8sA17OOZZrySsaUtxglaYRcybB6qoyy8kE964hpWjh41yiHhe25fdJ+oAncvEZyLp6Qw3vnaKoUe5SPKV49Rz6D0cAyDYcz3/M4Ebmx1C4kv6xcudxUdxFTutjXAzHdAlOJEDYR8SEg0Sak0jQdpASVmsdZSIhw826FKSWkkqheWR3++0BKhOHSLCQiByzTJXDhto16IaXScmb/TTx0LinFOmKa8MSwNf634nx5IXHwmYVtdXWCfuDmDcJ+JmyJzHH74RtXlspiHncTYizteYIXjGAdy1TC410LlrdyuJ8PaSGka1hLTLmI0lSwQr7wMyf2B6VMdzyezzpiylkdf9+NHam/E4f5mYicG8cki+m8vasZfqgbovyyjphyTHfSWSG/e7mft38/Cync11s/CHmPmETsje6OXG/ZjYEuf/gKy2UhLRjFy8VExGwojLFFEhIpRcVSZbRN1LLw2X9nUhPSdBZ+HtYKjXedaEvMeUvVK6Zjdqawk5dP2DLJItm/qw5pIKIsa4nJg41Jr+xmJmITbDIXdhq5nnLLxJF3LwUEcCnruHkXdpyeTuoCBbEiysdM3xIJY6E5WEdMF6I3buovZyjud8LCFbGkmL7PXjDm+5IT2quadKz2EMjgRMpsSadyWzAmkRdCKmY5MUWFE8EYQ6YyTw/yAyObSiItMVmr2a73spSYqoX0Gexbw5uVKrBUOi7e4dthX/2tF3t6LffTmM1agTqWEtMV6Fy5Ofi8f6+btD2WpWVVbHUQBBAtHBr3qX/4SUGZblz2Kdt+wuymehzGh2/lc0wD8MQIQdUDyzQCa8kIVuK4jYPXJ41CJ8b9S/fwCss0knCcxt6Aw9LZ9do7+nG7o+1prnKaKkaLz02Q3xRBvQJYpg5KOqzr+L4AjmHxfBleaad6pSjc1FYpWHWyEkuLia1Njpd6xgXGbFbE1LRHFJP/nnPzjvt8ozi1gF7CUmIKhZOcsKU9NF5XCe8ukfW2VRXB3rt71TZje6qrtQqhbAhRhVePmazhNkGl8DtmOF7ytvuC8q1K6ZOJ3HxQd6ASQlJjKcsUo3YyV69iPr787Ydk+8LXnJvni2+kkNQvMt867ZhyJ2A5MTlXT3L5Qro6jKXwKV5D2fTIJ0vWVWAC9XmmwhUmT2Y5Mfn4ghoxcZtD++p+lZCIBlqmhVlyzMTW3ufaBYQrJb7bt51BaksxXP9H1O4asjeXCixrmYZbInPeVOUa1c63MCFYcAG5CHCKdcU02DI54djWDs6WOBhDFBW1sKBm8SZaWVdMXiBCy0oZ5u/6OmlMUTvOYOEZFNZa2syety8sNiOoaJ8Mo4pKZG/CLCzDMct4rfb3LDlmctwRdJCItSO0TNu2/vrE6koKZnOOuCXW1LWMK0qJjTVzuItdq8dg2XxX/mPMdCHhGEn6njohHN2TTnfOx+Ie3qMWhRVZeVsipD2Tayat1L8SiGkgLe5KrD+w/xISzRLl813hUjQEqSUk/3602na9Rkyz+OEh3JrvcCsHiyKcRWA5NMe1RB1BIdpv8iSCm9fM4Y/8xdIwH06UE6imUGVRmCBYYcMUXtrg80BBaQmgN3hxQKkYIlimJpyQSgR1zMfKLUm4eZMuxbEHkeuhba1KObh5sEz1GHt3C3ZCMW3GxQrpTDS9WxDbw5XPgNASTY97t5dhmvMuHRr3ueMBIS11lnpmOUv1HU41YiwXt98Xg20URq+gNITkaG3LKyxTzZim1tXb8pSXr8EuJPv5rj80euvTiXp+92ssU4oWAfkwfR6eQl3GYCuL5c9iWrPVyp8fcFpR0dkWUMcrLFMtPeJirUaU1mf5IypDm6StSrmzTiXMDMQkcGdAosVds5bIEFMqLlwrDt/diV1ciqYTCtKsAty8B+CLO+yc5mQLgwTffMd0VrChYVnW9F1Y3iQkIlimW2HKu4W5zuyeTGQMERsm9kTjl11imXyBhal7hGFNXf67gh+9ru3rLZPGkhuOfB6Dpdw0Pwedka35ZLFERGQKxlVMWqOv9/B6MTXhOqt0BY1cVbnwqhdq+3hFt4d3tp+IXsYfY+PSm2iaWmZz4WrmxWL05oeYFAktgo/U112HdPuYmIwJF60KGcNCCgc2TlQaCw6chkvTzbTKZBQYM9UiCEbLxXfPzNuXENlkejZWvZdypM7eamazZBIYM2Vg2idUQyx9wsom2GGY6OJ5Fmvco47Dxshwbe+usWBkyT1FncnS55+339TVXd+EsvK8c2T0im1meTER6fzNI6KPX95YlE30oq+7Z9rKLsLzt6R6nOvHxhJb8xWRhLNexhNcicCe4uodRFrRd5YRU8w9qb6C0xYilg5iiZA0B/nROtxvqr3UJwZuxtvXErAIj78krtXHTuuIaeWz5HH4nRq/uSMKuK0Ql/P77qGpcMI0onI91P2B7iPLiKkI4dLoXB43+y+tDKhhRVGnRFUyxrPWkjHbY8wstY+5ZuddYsp0dGvse2cqC3ww/0JR68769zn5LuFKwnqXmBJ8AwTG6EzEPI3KCaGYBa4VWTjWsmS/254mtPXFtPqotxXpuLQGNjxK3dyY6Hy3MRZU+qQUpXbneGt9MZWc3IkepHgZqeNywbKFbtEJgYK0+PJYIT+Tpd/C/MuIia35zpEQ0elzmMY/mZa3Bye6d6ItEBF7nsGKQYYTk6wDCo+1O6c2vAAquebWGDKNZf38/f2pNKKHf35+mhrRfSX6RJl68zrR7WUZ8iMZTHR7p3wD4cXTbfMvrLH9fn6Xxn3/+/vPT0n9jxJTr3h8zgK4riwmOo5P/O9gOn7/+1skpundPE0BzQITHVZ+fm+jIDoux4bAHsW0q8bZC5GuzrYWjnbx+CKaYbl17a2yL2UKy/QW0VQhCYsoHRQYFTDoKXOCIMZVTGuZQIRcSHs2ZmzTIKawTADchu++hkGhSiAm8E5Oz0yz0Qni103a3olGeP0VE8EzES7aVQiwQEwT8FghPTyEr30jJwIQT2HG0LSxekKa8fdVAst0IyeLFE7cat9VOzMX/74RjxeAZZqRBa7SbwRiuomkVfLf3T7/Bepwxyyy2FWLKRa6/vz87/5GXIh4In23rmXVQExkq7uHpbhjGhzbEnevdNU4xkwz4J/QVstTIhphbuU1+ELyBBXentEDLNPFJK3SDKHmtwiuwkKVWiaMme5mtjGQC3dLYW9h7CHufxL+ndmdFw5YpgvJBh3uZJQQZvhtKYS1eaGVetSdtm8QE4fRuZmQ7vrNBTRyNzNKC0hnJdLWR962vrqYouMkIpXHa03JTOPAUiIRv9Lb1jFmGkyRHz6bpdIg7JRP+I3+HF9DexEaH0hUSG7QS4bYPVZK48pd4qaF+8OQcdgWzfmrJ1nfhnbCzRtITkyPcYN6XNJQqH45DwFu3s0UCWlGpCVLYZi8xg2SFu0+weVrAJZpAMVCmkFYV7fhgZPCyzw3b3lGdabSTntlZ5ZcvoWAm6dM1irFvmsilV1aX027an+DJKSFhAXLpEhRGPxqtyZXnxRMkDq4VI40HmrhSVG+BLBMVxAbK42iZrAfduLaW9FL3Tb/GKQmsR8MxKREsXvnc9V4qZeSeatcWyTRhtsffvMj3DwFqty7OzpLb50xNyznDubcN2nCuCTfpEBMVxAbbzydlDUudRdjS3geKCiIqZOkVXpa5CoViKhpe8st9058sx+jBBBTB0VCGu3ela7H89tSU65UtrbFiK2umGFSuwIEIBrpuitTs3PUXNF7BvhXLAdy7ZMs5AMsFizTCFJXVM3V4T35NeaGHFoXB2k1+4OAmBrofVZAFyUTqqnbJmIuVE4cqY5dOslbQmpB7eRBCbh5lWSFlFrKM8K9qy0zJrzPdiZzztNiIZxr1uqixY7ZxNYKYhrFiMCDhntHdBah1+GZLJ3QWCbVs47vIYKCmAphU/AoqJHRp9IxikYIu6QMqYP71lKymhoLY1vLugCIqYDmMdLdd5eOXnhbukC2JxKYWkw7maAgJi1Gz4nULj71Ph/GQakOH3tJpFxGbfxVEhO7fBBThiKrpDlZ2ouwgJQNj6tnVPqrylIEYkpQ5d5dcYVuxWuLGK1rLTMWrcv9do1jM+Fqc4gpQrGQUqFwoumuomK07ipGLkMiul1QEFMPE1wNL6N1zHY1N54TrIDQINZ5JrNKQ+iJdGocHyl8ftNKCVgmga6gw4TWSm2cNCuTRPkgpoBb192NQLNTacz1jAzS3BwEmuIhlACsACwTAEpATAAoATEBoATEBIASEBMASkBMACgBMQGgBMQEgBIQEwBKQEwAKAExAaAExASAEhATAEpATAAoATEBoATEBIASEBMASkBMACgBMQGgBMQEgBIQEwBKQEwAKAExAaDE/wEkW6pLebya2QAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAANMAAACgCAYAAACWsVm2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADlVJREFUeJztnWuWqzoOhVW9el7tmZXuyNo9suofxMEY+S2DMftbJycJ+BXwRrJsqJ+/vz8CAPTzr7sbAMAqQEwAKAExAaAExASAEhATAEpATAAoATEBoATEBIASEBMASvz77gYQEf3z86OzDINZpRhtmPnwInKvQyrhc/gefk5texj+udM4j6VlFKT7/f39KSkKlmkw7J0sPp04jnyuqqExH9AGYhrIZolCEcU+H/PVpAdzMIWbtx6SiHJp9rSyBZPKSpW/MNJxncDFh2XqxHX80JpsX5lSIoiLLba9Nk1NuocQO2bblWl/3QDE1MlRTOHnfZuXQ9gW7ovtT7dBo6xbCNufEszEQEyq8PZ/9qRz8B7bn8tPh/rOY61UObk6JmNyIRFBTDfB3juf9hpz3pYuY3svjwYXJryLUkulVZeSe4gAxAWc55fcq6tUoZyysn2XdGruEL1U5+9vUVZYpos4R+2Y0h0/tj2WNiy3JM+k3BhE6AGWqQO2Nr3fmO09GZCg0zZrpTRBjq+1k8sA17OOZZrySsaUtxglaYRcybB6qoyy8kE964hpWjh41yiHhe25fdJ+oAncvEZyLp6Qw3vnaKoUe5SPKV49Rz6D0cAyDYcz3/M4Ebmx1C4kv6xcudxUdxFTutjXAzHdAlOJEDYR8SEg0Sak0jQdpASVmsdZSIhw826FKSWkkqheWR3++0BKhOHSLCQiByzTJXDhto16IaXScmb/TTx0LinFOmKa8MSwNf634nx5IXHwmYVtdXWCfuDmDcJ+JmyJzHH74RtXlspiHncTYizteYIXjGAdy1TC410LlrdyuJ8PaSGka1hLTLmI0lSwQr7wMyf2B6VMdzyezzpiylkdf9+NHam/E4f5mYicG8cki+m8vasZfqgbovyyjphyTHfSWSG/e7mft38/Cync11s/CHmPmETsje6OXG/ZjYEuf/gKy2UhLRjFy8VExGwojLFFEhIpRcVSZbRN1LLw2X9nUhPSdBZ+HtYKjXedaEvMeUvVK6Zjdqawk5dP2DLJItm/qw5pIKIsa4nJg41Jr+xmJmITbDIXdhq5nnLLxJF3LwUEcCnruHkXdpyeTuoCBbEiysdM3xIJY6E5WEdMF6I3buovZyjud8LCFbGkmL7PXjDm+5IT2quadKz2EMjgRMpsSadyWzAmkRdCKmY5MUWFE8EYQ6YyTw/yAyObSiItMVmr2a73spSYqoX0Gexbw5uVKrBUOi7e4dthX/2tF3t6LffTmM1agTqWEtMV6Fy5Ofi8f6+btD2WpWVVbHUQBBAtHBr3qX/4SUGZblz2Kdt+wuymehzGh2/lc0wD8MQIQdUDyzQCa8kIVuK4jYPXJ41CJ8b9S/fwCss0knCcxt6Aw9LZ9do7+nG7o+1prnKaKkaLz02Q3xRBvQJYpg5KOqzr+L4AjmHxfBleaad6pSjc1FYpWHWyEkuLia1Njpd6xgXGbFbE1LRHFJP/nnPzjvt8ozi1gF7CUmIKhZOcsKU9NF5XCe8ukfW2VRXB3rt71TZje6qrtQqhbAhRhVePmazhNkGl8DtmOF7ytvuC8q1K6ZOJ3HxQd6ASQlJjKcsUo3YyV69iPr787Ydk+8LXnJvni2+kkNQvMt867ZhyJ2A5MTlXT3L5Qro6jKXwKV5D2fTIJ0vWVWAC9XmmwhUmT2Y5Mfn4ghoxcZtD++p+lZCIBlqmhVlyzMTW3ufaBYQrJb7bt51BaksxXP9H1O4asjeXCixrmYZbInPeVOUa1c63MCFYcAG5CHCKdcU02DI54djWDs6WOBhDFBW1sKBm8SZaWVdMXiBCy0oZ5u/6OmlMUTvOYOEZFNZa2syety8sNiOoaJ8Mo4pKZG/CLCzDMct4rfb3LDlmctwRdJCItSO0TNu2/vrE6koKZnOOuCXW1LWMK0qJjTVzuItdq8dg2XxX/mPMdCHhGEn6njohHN2TTnfOx+Ie3qMWhRVZeVsipD2Tayat1L8SiGkgLe5KrD+w/xISzRLl813hUjQEqSUk/3602na9Rkyz+OEh3JrvcCsHiyKcRWA5NMe1RB1BIdpv8iSCm9fM4Y/8xdIwH06UE6imUGVRmCBYYcMUXtrg80BBaQmgN3hxQKkYIlimJpyQSgR1zMfKLUm4eZMuxbEHkeuhba1KObh5sEz1GHt3C3ZCMW3GxQrpTDS9WxDbw5XPgNASTY97t5dhmvMuHRr3ueMBIS11lnpmOUv1HU41YiwXt98Xg20URq+gNITkaG3LKyxTzZim1tXb8pSXr8EuJPv5rj80euvTiXp+92ssU4oWAfkwfR6eQl3GYCuL5c9iWrPVyp8fcFpR0dkWUMcrLFMtPeJirUaU1mf5IypDm6StSrmzTiXMDMQkcGdAosVds5bIEFMqLlwrDt/diV1ciqYTCtKsAty8B+CLO+yc5mQLgwTffMd0VrChYVnW9F1Y3iQkIlimW2HKu4W5zuyeTGQMERsm9kTjl11imXyBhal7hGFNXf67gh+9ru3rLZPGkhuOfB6Dpdw0Pwedka35ZLFERGQKxlVMWqOv9/B6MTXhOqt0BY1cVbnwqhdq+3hFt4d3tp+IXsYfY+PSm2iaWmZz4WrmxWL05oeYFAktgo/U112HdPuYmIwJF60KGcNCCgc2TlQaCw6chkvTzbTKZBQYM9UiCEbLxXfPzNuXENlkejZWvZdypM7eamazZBIYM2Vg2idUQyx9wsom2GGY6OJ5Fmvco47Dxshwbe+usWBkyT1FncnS55+339TVXd+EsvK8c2T0im1meTER6fzNI6KPX95YlE30oq+7Z9rKLsLzt6R6nOvHxhJb8xWRhLNexhNcicCe4uodRFrRd5YRU8w9qb6C0xYilg5iiZA0B/nROtxvqr3UJwZuxtvXErAIj78krtXHTuuIaeWz5HH4nRq/uSMKuK0Ql/P77qGpcMI0onI91P2B7iPLiKkI4dLoXB43+y+tDKhhRVGnRFUyxrPWkjHbY8wstY+5ZuddYsp0dGvse2cqC3ww/0JR68769zn5LuFKwnqXmBJ8AwTG6EzEPI3KCaGYBa4VWTjWsmS/254mtPXFtPqotxXpuLQGNjxK3dyY6Hy3MRZU+qQUpXbneGt9MZWc3IkepHgZqeNywbKFbtEJgYK0+PJYIT+Tpd/C/MuIia35zpEQ0elzmMY/mZa3Bye6d6ItEBF7nsGKQYYTk6wDCo+1O6c2vAAquebWGDKNZf38/f2pNKKHf35+mhrRfSX6RJl68zrR7WUZ8iMZTHR7p3wD4cXTbfMvrLH9fn6Xxn3/+/vPT0n9jxJTr3h8zgK4riwmOo5P/O9gOn7/+1skpundPE0BzQITHVZ+fm+jIDoux4bAHsW0q8bZC5GuzrYWjnbx+CKaYbl17a2yL2UKy/QW0VQhCYsoHRQYFTDoKXOCIMZVTGuZQIRcSHs2ZmzTIKawTADchu++hkGhSiAm8E5Oz0yz0Qni103a3olGeP0VE8EzES7aVQiwQEwT8FghPTyEr30jJwIQT2HG0LSxekKa8fdVAst0IyeLFE7cat9VOzMX/74RjxeAZZqRBa7SbwRiuomkVfLf3T7/Bepwxyyy2FWLKRa6/vz87/5GXIh4In23rmXVQExkq7uHpbhjGhzbEnevdNU4xkwz4J/QVstTIhphbuU1+ELyBBXentEDLNPFJK3SDKHmtwiuwkKVWiaMme5mtjGQC3dLYW9h7CHufxL+ndmdFw5YpgvJBh3uZJQQZvhtKYS1eaGVetSdtm8QE4fRuZmQ7vrNBTRyNzNKC0hnJdLWR962vrqYouMkIpXHa03JTOPAUiIRv9Lb1jFmGkyRHz6bpdIg7JRP+I3+HF9DexEaH0hUSG7QS4bYPVZK48pd4qaF+8OQcdgWzfmrJ1nfhnbCzRtITkyPcYN6XNJQqH45DwFu3s0UCWlGpCVLYZi8xg2SFu0+weVrAJZpAMVCmkFYV7fhgZPCyzw3b3lGdabSTntlZ5ZcvoWAm6dM1irFvmsilV1aX027an+DJKSFhAXLpEhRGPxqtyZXnxRMkDq4VI40HmrhSVG+BLBMVxAbK42iZrAfduLaW9FL3Tb/GKQmsR8MxKREsXvnc9V4qZeSeatcWyTRhtsffvMj3DwFqty7OzpLb50xNyznDubcN2nCuCTfpEBMVxAbbzydlDUudRdjS3geKCiIqZOkVXpa5CoViKhpe8st9058sx+jBBBTB0VCGu3ela7H89tSU65UtrbFiK2umGFSuwIEIBrpuitTs3PUXNF7BvhXLAdy7ZMs5AMsFizTCFJXVM3V4T35NeaGHFoXB2k1+4OAmBrofVZAFyUTqqnbJmIuVE4cqY5dOslbQmpB7eRBCbh5lWSFlFrKM8K9qy0zJrzPdiZzztNiIZxr1uqixY7ZxNYKYhrFiMCDhntHdBah1+GZLJ3QWCbVs47vIYKCmAphU/AoqJHRp9IxikYIu6QMqYP71lKymhoLY1vLugCIqYDmMdLdd5eOXnhbukC2JxKYWkw7maAgJi1Gz4nULj71Ph/GQakOH3tJpFxGbfxVEhO7fBBThiKrpDlZ2ouwgJQNj6tnVPqrylIEYkpQ5d5dcYVuxWuLGK1rLTMWrcv9do1jM+Fqc4gpQrGQUqFwoumuomK07ipGLkMiul1QEFMPE1wNL6N1zHY1N54TrIDQINZ5JrNKQ+iJdGocHyl8ftNKCVgmga6gw4TWSm2cNCuTRPkgpoBb192NQLNTacz1jAzS3BwEmuIhlACsACwTAEpATAAoATEBoATEBIASEBMASkBMACgBMQGgBMQEgBIQEwBKQEwAKAExAaAExASAEhATAEpATAAoATEBoATEBIASEBMASkBMACgBMQGgBMQEgBIQEwBKQEwAKAExAaDE/wEkW6pLebya2QAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "y_img.show(), y_img2.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def process_file(fns):\n", " yfn, pfn = fns\n", " if not pfn.exists():\n", " y_data = open_mask(yfn).px.long()\n", " proc_data = colors_to_codes(y_data)\n", " img = PIL.Image.fromarray(proc_data.numpy())\n", " img.save(pfn)\n", " return pfn\n", "\n", "def process_label_files(y_fns, y_proc_fns):\n", " ex = ProcessPoolExecutor(16)\n", " for pfn in ex.map(process_file, zip(y_fns, y_proc_fns)):\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 396 ms, sys: 260 ms, total: 656 ms\n", "Wall time: 39.1 s\n" ] } ], "source": [ "%time process_label_files(y_fns, y_proc_fns)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_datasets(path, valid_pct=0.2):\n", " x_fns = get_image_files(path)\n", " y_fns = [get_y_fn(o) for o in x_fns]\n", " y_proc_fns = [get_y_proc_fn(o) for o in y_fns]\n", " total = len(x_fns)\n", " \n", " train, valid = random_split(valid_pct, x_fns, y_proc_fns)\n", " return (MatchedImageDataset(*train), MatchedImageDataset(*valid))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_tfm_datasets(size):\n", " datasets = get_datasets(PATH_X)\n", " tfms = get_transforms(do_flip=True, max_rotate=4, max_lighting=0.2)\n", " return transform_datasets(*datasets, tfms=tfms, tfm_y=True, size=size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "default_norm,default_denorm = normalize_funcs(*imagenet_stats)\n", "bs = 8\n", "size = 512" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = get_transforms(do_flip=True, max_rotate=4, max_lighting=0.2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data(size, bs):\n", " return DataBunch.create(*get_tfm_datasets(size), bs=bs, tfms=default_norm)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = get_data(size, bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([3, 512, 512]), torch.Size([1, 512, 512]), torch.int64)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x, y = data.train_ds[0]\n", "x.shape, y.shape, y.data.dtype" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Unet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def accuracy_no_void(input, target):\n", " target = target.squeeze()\n", " mask = target != void_code\n", " return (input.argmax(dim=1)[mask]==target[mask]).float().mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.7114, device='cuda:0')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_no_void(p,y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "metrics=[accuracy_no_void]\n", "lr = 1e-3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "body = create_body(tvm.resnet34(True), 2)\n", "model = DynamicUnet(body, n_classes=len(label_codes)).cuda()\n", "learn = Learner(data, model, metrics=metrics, loss_fn=CrossEntropyFlat())\n", "learn.split([model[0][6], model[1]])\n", "learn.freeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c159584c71a441e583c8c3b2cb64c482", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=2), HTML(value=''))), HTML(value='epoch train loss va…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEACAYAAABI5zaHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl8lNW9x/HPb7LvCwkBEkJEkFUWiYK1WqVWqAtatbV6a9Vr67VVW6v2avfWbrZWbatVa7WLba11r7uXqqi4IAHZwcgWAoEQkpB9z7l/JKEYgUzITJ6Zyff9es2LWU7m+fGI3zk5z5lzzDmHiIhEFp/XBYiISOAp3EVEIpDCXUQkAincRUQikMJdRCQCKdxFRCKQwl1EJAIp3EVEIpDCXUQkAincRUQiULRXB87KynIFBQVeHV5EJCwtW7Zsj3Muu692noV7QUEBRUVFXh1eRCQsmVmJP+00LCMiEoEU7iIiEUjhLiISgRTuIiIRSOEuIhKBFO4iIhHIs6mQh2tXTTOrd9QQF+0jNtpHXLSPuOgoEmOjSE2IISU+mpgofWaJyNAWduFeVFLF1Q+9d8g2CTFRpMRHkxwXTVJcNImxUSTFdT1OT4whPSGG1IQY0hNjSUuIITU+mtTu51K7f87MBulvJCISeGEX7ieOz+bZaz5OS3sHLW2dtHR00tLWQWNrB3XN7dQ2tVHb3EZtUzv1re00trTT0NJBeW0zm1raqWlqo6apjUPtCx7tM9ITY8lIjCEjMZaMpBgyk2JJT4wlMzGWjKTYfR8QaQkxpCZEk5YQQ2Js2J1OEYlQYZdGaQkxpOWmDeg9Ojsddc3t7G1qpbapvfvDoOtDoaapjb2NbVQ3tlLd0PXnlj0NLN+2l+qGVto7D/6pEB/jIzMxlszkWDKT4shKiiUrJY6s5FiykuPITum6DU+JJyMxRr8diEjQhF24B4LPZ6QlxpCWGNOvn3POUdfSTnVDKzVNXb8d1HR/KOxtbKOqoYXKhlaqum+bdtdTUd9Ca3vnR94rJsrITo4jOzW+68+U/9xGpcWTl5FIbkYCyXFD8j+RiAyQkqMfzIzU+BhS4/3/UOj5QNhT18Ke+lZ21zWzu7aF3XUt7K5rpqKuhR17m1hRupfKhpaPDBdlJMaQl5FIXkYCozO7/8xIJH9YIqMzEomN1sVjEfkohXuQ7f+BMLaPddzaOzqpamilrKaZ7dWNlFY1df1Z3cT75XW8vGH3h34L8BnkZSRSkJXE2KwkCoZ13T8iK4nc9ASiNWtIZMhSuIeQ6Cgfw1PjGZ4az4zR6R95vbPTUVHfQmlVIyWVjWytbGDLnga2VjawbGsVDa0d/3kvnzFmWCITR6Yyufs2aWQqOalxGusXGQIU7mHE5zNyUuPJSY2nsCDzQ6851xX8W/c0srU78D/YXc/K0r08t2rnvnYZiTFMHpXKpBFdYT95VCrjhyerly8SYRTuEcLMGJ4Sz/CUeI474sPBX9vcxoaddazfWbvv9td3SmjpHuKJj/ExdVQa0/LSmT46jWPyM8jLSFAPXySMmTvUhG/AzOKB14E4uj4MHnPO/eAgbc8HHgWOdc4dcieOwsJCp806vNPe0cmWPQ2sLatl5fa9rNpew9qyGprbugI/Nz2B2UdkMntsJnPGDiM/M1FhLxICzGyZc66wr3b+9NxbgLnOuXoziwEWm9kLzrl3eh0wBfgasOSwKpZBFR3lY3xOCuNzUjhnZi7QFfjvl9exrKSaJZureK24gife2wFAXkYCJ47P4sTx2ZxwZFa/p5GKyODqM9xdV9e+vvthTPftQN39HwO/BG4IWHUyqKKjfEwZlcaUUWl88fgCnHNs3F3P25srWfzBHp5duZN/vFuKz2BaXjqnTBjO3InDmTIqFZ9PvXqRUOLXmLuZRQHLgHHA75xzS3q9PhMY7Zx71swU7hHCzPb17r94fAFtHZ2sLN3L6x/s4bXiCn79cjF3/LuYrOQ4Tp6QzamThnPi+GyS9MUrEc/1Oeb+ocZm6cCTwDXOuTXdz/mAV4BLnXNbzWwRcMOBxtzN7ArgCoD8/PxZJSV+7fMqIWpPfQuvF1fwyobdvF5cQW1zO7HRPj4+LotPTc7hk5OGMzwl3usyRSKKv2Pu/Qr37jf+AdDgnPtV9+M0YBP/GboZAVQBCw51UVUXVCNLW0cnS7dWsXBdOQvXlbO9ugkzOHZMJqcfPYJPHz2SnFQFvchABSzczSwbaHPO7TWzBOD/gF845549SPtFHKTnvj+Fe+RyzrFhVx0vrd3F86t3Ulxevy/oz5o+kgUzcklL0AVZkcMRyHCfBvwFiKJr56ZHnHM3m9nNQJFz7ule7RehcJf9bNxdx3OrdvHc6jKKy+uJj/FxxtGjuGj2aI7Jz9AUS5F+CNqwTKAo3Iem1dtreOjdbTy9YgcNrR0clZPMxXPGcN6sPK2HL+IHhbuEtIaWdp5ZWcbfl2xj9Y4a0hJiuPC4fC752BhGpiV4XZ5IyFK4S1hwzrGspJoHFm/hpbW7MDPOnDaS6z81gfxhiV6XJxJyAvkNVZGgMTMKCzIpLMiktKqRP7+1lYeWbOOF1bu49IQCrjplnC6+ihwG9dwl5JTXNvOrl97nseXbSU+I4RufOooLj8snRitXivjdc9f/LRJyclLjufWz03nm6o8zcUQq3//XWk7/zRu8vanS69JEwobCXULW1Nw0HvrybO67eBZNbR1c+Id3uPbh99hd2+x1aSIhT+EuIc3MOG3KCBZ+4xN8be44nl+9i7m3vcYDi7fQ3vHRjcdFpIvCXcJCQmwU1502gf/7xkkUFmTw42fX8Zm732LDrlqvSxMJSQp3CSsFWUn86dJj+d1Fx7Czpomz7lzMHQuLP7RxuIgo3CUMmRlnTBvJwm98gjOnjeI3L3/AgrsWs2r7Xq9LEwkZCncJWxlJsdxxwQz+eGkhexvbOPfut/jrOyV4Nb1XJJQo3CXszZ2Yw0vXnsSJ47P43lNruOnx1bS0d3hdloinFO4SEdISY7j/kmO5+pRx/LOolAt+/w67ajRlUoYuhbtEjCifccO8Cdz7hWMoLq/jzDsXs3xbtddliXhC4S4RZ/7UkTx11QkkxUVx0R/e4ZUN5V6XJDLoFO4SkY7KSeHxr3yM8cNT+PKDy3hs2XavSxIZVAp3iVhZyXH844o5HD92GDc8upJ7X9ukmTQyZCjcJaIlx0Xzx0uP5azpo7jlhQ385Ln1dHYq4CXyaT13iXix0T5+c8EMspJjeWDxFqobW/nledOI1hLCEsEU7jIk+HzG98+cTGZiLLctLKa2qZ27LppJfEyU16WJBIW6LjJkmBnXfHI8Pz57Ci9vKOfSP71LXXOb12WJBIXCXYaci48v4NcXzKBoazUX/WEJlfUtXpckEnAKdxmSzp6Ry31fnEVxeR1feOBdGlravS5JJKAU7jJkzZ2Yw+8vnsX7u2q57pEVmkUjEUXhLkPayROG8+3TJ/HS2nJ+/e9ir8sRCRjNlpEh7/KPH0FxeR2/fWUj43NSOGv6KK9LEhmwPnvuZhZvZu+a2UozW2tmPzpAm+vMbJ2ZrTKzl81sTHDKFQk8M+PH50zl2IIMbnh0pTb9kIjgz7BMCzDXOTcdmAHMN7M5vdq8BxQ656YBjwG/DGyZIsEVFx3FPV+YRVZyHF9+sIjdtVouWMJbn+HuutR3P4zpvrlebV51zjV2P3wHyAtolSKDICs5jvsvKaSmqY3rH12pC6wS1vy6oGpmUWa2AtgNLHTOLTlE88uBFwJRnMhgmzQyle+eMZk3PtjDn9/a6nU5IofNr3B3znU452bQ1SM/zsymHqidmX0BKARuPcjrV5hZkZkVVVRUHG7NIkH1X7PzOXXScG55cQMbdtV6XY7IYenXVEjn3F5gETC/92tmdirwHWCBc+6AX/lzzt3nnCt0zhVmZ2cfRrkiwWdm3HLeNFLjo7n24RU0t2k/Vgk//syWyTaz9O77CcCpwIZebWYCv6cr2HcHo1CRwZSVHMet509nw646fvXS+16XI9Jv/vTcRwKvmtkqYCldY+7PmtnNZragu82tQDLwqJmtMLOng1SvyKA5ZeJwLp4zhvsXb2HxB3u8LkekX8yrnWkKCwtdUVGRJ8cW8VdTawdn3vkG9S3tLLzuE6TGx3hdkgxxZrbMOVfYVzstPyByCAmxUdz+uRlU1LVw64sanpHwoXAX6cP00elc8rEC/rakhGUl1V6XI+IXhbuIH64/bQIjU+P59hOraW3v9LockT4p3EX8kBwXzc1nT+X98jr+8MZmr8sR6ZPCXcRPp07O4fSjR/Cblz9gy54Gr8sROSSFu0g//PCsKcRF+/jOk6vxaqaZiD8U7iL9MDw1nhvnT+StTZU8vnyH1+WIHJTCXaSfLjoun5n56fzixQ3ae1VClsJdpJ98PuN7Z06moq6F37+2yetyRA5I4S5yGI7Jz+Cs6aO4743NlO1t8rocCSNXP7ScJ5ZvD/pxFO4ih+l/502g06GFxcRvTa0dPLtq56B0CBTuIodpdGYil3/8CJ54b4f2XRW/lFZ3bVg3OjMx6MdSuIsMwFdPPpKs5Fh+8ux6TY2UPm2r7Ar3fIW7SGhLiY/huk9N4N2tVby0dpfX5UiI21alcBcJG58rzGNCTgo/e34DLe3atUkObltVI0mxUWQmxQb9WAp3kQGKjvLx7TMmsa2qkUeKgj8LQsJXaVUjozMTMbOgH0vhLhIAJ43PonBMBne/ulG9dzmobVWNgzIkAwp3kYAwM6499Sh21jSr9y4H5JxTuIuEoxPGDVPvXQ6qoq6FlvZO8ocp3EXCinrvcig9M2UGY447KNxFAkq9dzmYwZwGCQp3kYD6UO99aanX5UgI2VbViBnkpicMyvEU7iIBdsK4YRxbkMHvXt2k3rvss62qkRGp8cTHRA3K8RTuIgHW03vfVaveu/xHzxz3waJwFwmCjx3Z1Xu/e5F679JlMKdBgsJdJCjMjGvmjmdnTTOPL9N2fENdc1sH5bUtoRXuZhZvZu+a2UozW2tmPzpAmzgz+6eZbTSzJWZWEIxiRcLJieOzmD46nbsXbaSto9PrcsRD26sHd6YM+NdzbwHmOuemAzOA+WY2p1eby4Fq59w44A7gF4EtUyT8mBlfmzuO7dVNPPWeeu9D2WDPcQc/wt11qe9+GNN9671w9dnAX7rvPwZ80gZjZRyREDd34nCmjErl7kWb6OjUeu9D1WCu497DrzF3M4sysxXAbmChc25Jrya5QCmAc64dqAGGHeB9rjCzIjMrqqioGFjlImGga+x9HFv2NPDsqjKvyxGPbKtqIiEmiqzk4C/128OvcHfOdTjnZgB5wHFmNrVXkwP10j/STXHO3eecK3TOFWZnZ/e/WpEwdNrkERyVk8zvXt1Ip3rvQ1LPTJnBHNDo12wZ59xeYBEwv9dL24HRAGYWDaQBVQGoTyTs+XzGVaeMo7i8Xrs1DVGDPccd/Jstk21m6d33E4BTgQ29mj0NXNJ9/3zgFacNJUX2OXPaKMZmJXHnKxu11+oQM9hL/fbwp+c+EnjVzFYBS+kac3/WzG42swXdbR4AhpnZRuA64KbglCsSnqJ8xldPGce6nbU8s2qn1+XIINpT30pTWwf5mYOzpkyP6L4aOOdWATMP8Pz397vfDHw2sKWJRJazZ4zir++U8K3HVzEhJ4UJI1K8LkkGwb7VIAdpHfce+oaqyCCJifJx38WzSIqL5ksPLqW6odXrkmQQlO5b6jdpUI+rcBcZRDmp8dx78SzKa1q46qHl+ubqENDTc8/LGNxhGYW7yCA7Jj+Dn517NG9tquSnz633uhwJssFe6rdHn2PuIhJ458/KY/3OWh5YvIWJI1L4/HH5XpckQeLFTBlQz13EM9/69EROHJ/Ft55czbefXK0x+AjlxRx3ULiLeCY6yse9X5jFZR87gn8uLWXubYt4+N1t+hZrBGlu62BXbbN67iJDTVJcNN8/azLPXvNxxg1P5qYnVnPuPW+xcXed16VJAOzY24RzkD9scC+mgsJdJCRMGpnKI/9zPLd/bjollQ1c/dB76sFHgH1z3NVzFxm6zIxzj8njhwumsGFXHc9oFcmwV+rBOu49FO4iIeasaaOYOCKF2xcWax58mNu6p5H4GB/ZyXGDfmyFu0iI8fmMb86bQEllI48Wbfe6HBmA5duqmToqbVCX+u2hcBcJQXMnDueY/HR+83IxzW0dXpcjh6GhpZ3VO2qYPTbTk+Mr3EVCkJnxzXkTKa9t4a9vl3hdjhyGZSXVdHQ6Zh/xkU3pBoXCXSREHX/kME4cn8XdizZS19zmdTnST0u2VBLlM2aNyfDk+Ap3kRD2zXkTqG5s4/43tnhdivTTks1VHJ2bRlKcN6u8KNxFQti0vHTmTxnB/W9spkrLE4SNptYOVm7f69l4OyjcRULe9acdRWNbB79/fZPXpYif3ttWTVuHY45H4+2gcBcJeeNzUlgwfRQPvlXCnvoWr8sRP7yzpQqfQWGBN+PtoHAXCQtf++R4Wto7+P1r6r2HgyWbK5kyKo2U+BjPalC4i4SBI7OTOWdGLg++XcLu2mavy5FDaG7r4L3Svcw+wrvxdlC4i4SNaz45nvZOxz3qvYe0laV7aW3vZPZY78bbQeEuEjaOyEri3Jm5/H3JNsrVew9Z726pwgyOK1DPXUT8dM3c8XR2Ou5+daPXpchBLNlSxcQRqaQlejfeDgp3kbCSPyyR82fl8Y93S9lZ0+R1OdJLW0cny0qqPR9vB4W7SNi56pRxOBx3vaLee6hZtb2GprYO5nj45aUeCneRMDM6M5ELjh3Nw0tLeX+XtuMLJUu2VAJwrMfj7eBHuJvZaDN71czWm9laM/v6AdqkmdkzZrayu81lwSlXRACu+9QEUuKj+d6/1uCctuMLFUs2VzF+eDLDPNicozd/eu7twPXOuUnAHOAqM5vcq81VwDrn3HTgZOA2M4sNaKUisk9mUizfnDeBd7dU8fRKbccXCto7OinaWuXpejL76zPcnXM7nXPLu+/XAeuB3N7NgBTr2m4kGaii60NBRILk88fmMy0vjZ8+t15LAoeAlzfspqG1g+PHZnldCtDPMXczKwBmAkt6vXQXMAkoA1YDX3fOfWTzRzO7wsyKzKyooqLisAoWkS5RPuPms6dSUd/Cb1/+wOtyhrSGlnZ+9PRaJuSkcNqUHK/LAfoR7maWDDwOXOucq+318jxgBTAKmAHcZWapvd/DOXefc67QOVeYnZ09gLJFBGDG6HQ+f+xo/vjmVorLdXHVK7cvLKasppmfnTuVmKjQmKfiVxVmFkNXsP/dOffEAZpcBjzhumwEtgATA1emiBzMN+dNJCU+mu/r4qon1uyo4U9vbuG/Zucza0xojLeDf7NlDHgAWO+cu/0gzbYBn+xunwNMADYHqkgRObjMpFhuOG0C72zWxdXB1tHp+NYTq8lMiuN/54dWf9afnvsJwMXAXDNb0X073cyuNLMru9v8GPiYma0GXgZudM7tCVLNItLLhcflM310Ojc/s45q7dg0aB58eyurd9Twg7Mmk5bg7XIDvfW5uZ9zbjFgfbQpA04LVFEi0j9RPuOWc4/mrDsX89Pn1/Orz073uqSIV7a3iV+99D4nT8jmzGkjvS7nI0Jj5F9EBmzSyFSu/MSRPLZsO4s/0C/OwXbLCxvocI4fnz2VrtHr0KJwF4kgV88dx9isJL715CqaWju8LieiLdlSyelTRzI6M9HrUg5I4S4SQeJjovj5uUdTWtXEHf8u9rqciNXc1kF5bQsFWUlel3JQCneRCDN77DAuPC6f+9/YzOrtNV6XE5FKqxoBGDMsNHvtoHAXiUg3fXoiWclx3Pj4Kto7PvJlcRmgksqucM8P0SEZULiLRKS0hBhuPnsK63bW8uDbJV6XE3G2VSncRcQj86aM4KSjsrljYTG767TnaiBtq2okOS6azKTQXfxW4S4SocyMH541meb2Dm55YYPX5USUbVWNjM5MDMkpkD0U7iIRbGx2MlecNJYnlu9g6dYqr8uJGCWVDYwJ4SEZULiLRLyrThnHqLR4vvfUGl1cDYDOTkdpdVNIz5QBhbtIxEuMjeZ7Z05mw646/vaOLq4OVHldM63tnSH75aUeCneRIWD+1BGcOD6L2/6vmIq6Fq/LCWs90yDVcxcRz5kZP1wwheb2Dn7+wnqvywlr4TANEhTuIkPGkdnJfOlEXVwdqG2VjUT5jFHpCV6XckgKd5Eh5Jq548hNT9DF1QEoqWokNz0hZLbTO5jQrk5EAmr/i6t/fmur1+WEpW1VjSE/JAMKd5EhZ96UHE6ekM2v//0B5bX65mp/batsID/EL6aCwl1kyDEzfrRgCq0dnfzkOV1c7Y/a5jaqG9vUcxeR0DRmWBJf+cSRPLOyjDc3atcmf23rmQapcBeRUPWVk48kPzOR7/9rDa3turjqj33TIDUsIyKhKj4mih8tmMKmigbuenWj1+WEhXCZ4w4Kd5Eh7ZSJw/nMzFzufnUja3Zo16a+lFQ2kpkUS0p8jNel9EnhLjLE/fCsKWQmxXL9Iytpadem2oeyraohLHrtoHAXGfLSEmO45byjeb+8jt++/IHX5YS0cJnjDgp3EQHmTszhs7PyuGfRJlaW7vW6nJDU1tFJ2d7mkF8wrEef4W5mo83sVTNbb2ZrzezrB2l3spmt6G7zWuBLFZFg+t5Zk8lJjef6R1fS3Kbhmd7K9jbR0elCfqnfHv703NuB651zk4A5wFVmNnn/BmaWDtwNLHDOTQE+G/BKRSSoUuNj+MV509i4u547FhZ7XU7IKQmjOe7gR7g753Y655Z3368D1gO5vZpdBDzhnNvW3W53oAsVkeA76ahsLjwunz+8sZnV2zV7Zn8lVT3ruCd5XIl/+jXmbmYFwExgSa+XjgIyzGyRmS0zsy8GpjwRGWw3fXoiWclx3PTEKq0cuZ/SqkZio30MT4nzuhS/+B3uZpYMPA5c65yr7fVyNDALOAOYB3zPzI46wHtcYWZFZlZUUVExgLJFJFjSEmL40YIprC2r5U9vbvW6nJBRUtk1DdLnM69L8Ytf4W5mMXQF+9+dc08coMl24EXnXINzbg/wOjC9dyPn3H3OuULnXGF2dvZA6haRIJo/dQSnTsrh9oXFlHYPRwx1JZXhMw0S/JstY8ADwHrn3O0HafYv4EQzizazRGA2XWPzIhKGzIybz56Cz+C7T63BOed1SZ5yzlEaRnPcwb+e+wnAxcDc7qmOK8zsdDO70syuBHDOrQdeBFYB7wL3O+fWBK1qEQm6UekJfHPeBF4rruDplWVel+OpyoZWGlo7wmaOO3SNlR+Sc24x0Ocgk3PuVuDWQBQlIqHh4uMLeHJFGTc/s45PHJVNemKs1yV5IpwWDOuhb6iKyEFF+Yxbzj2amqY2fvTMOq/L8cy+ddzDqOeucBeRQ5o0MpWrThnHk+/t4PnVO70uxxPvl9cR7bOw+XYqKNxFxA9Xzx3HtLw0vvPkanYPwX1X1+yo4aicFOKio7wuxW8KdxHpU0yUj9s/N4PG1g5ufHzVkJo945xjbVktU3NTvS6lXxTuIuKXccOTuenTE3n1/Qr+8W6p1+UMmp01zVQ1tHJ0bprXpfSLwl1E/HbJ8QWcMG4YP3luHSWVDV6XMyh6dqiaonAXkUjl8xm3nj+dKJ9x3SMrh8TaM2vKavEZTBqhYRkRiWCj0hP4yTlTWVZSzZ2vRP7G2mt31DBueDIJseFzMRUU7iJyGM6ekcu5x+Ry5ysf8PamSq/LCao1ZTVMHRVeQzKgcBeRw/Tjs6dSMCyJa//5HpX1LV6XExS765opr20Ju/F2ULiLyGFKiovmzotmUt3Qxg2ProzI6ZFrd3Stbj51VHiNt4PCXUQGYMqoNL5zxiRefb+CBxZv8bqcgOuZKTNZ4S4iQ80Xjx/DaZNz+MWLG1hZutfrcgJqTVkNR2QlkRIf43Up/aZwF5EBMTN+ef40spPj+Orfl7O7LnKWJ1izo5apYTjeDgp3EQmA9MRYfn9xIVUNrVzx4DKa2zq8LmnAqhta2bG3KSzH20HhLiIBcnReGndcMIOV2/dy/aMr6ewM7wusa8u6L6aq5y4iQ938qSO4cf5Enlu1kzv+Xex1OQOypqx72YEw7bn3uROTiEh//M9JY9lcUc+dr2zkiKwkzj0mz+uSDsuaHTXkZSSE7e5TCncRCSgz4yfnHE1pVRM3Pb6a2GgfZ04b5XVZ/bZmR3h+M7WHhmVEJOBio33c84VjmDwqlasfeo9vPrqShpZ2r8vyW21zG1srG8NuDff9KdxFJCjSE2N59MrjufqUcTy2fDtn/PaNsJkHv677Ymo4LjvQQ+EuIkETE+XjhnkTePjLc2ht7+S8e97i3tc2hfxSBT3fTNWwjIjIIcweO4wXvn4S86aM4JYXNnD9oytpbQ/dteDXltWSkxpHdkqc16UcNoW7iAyKtMQY7rpoJtd96iieWL6DS/74LjWNbV6XdUBrdtSE3bZ6vSncRWTQmBlf++R47rhgOkUlVZx7z5uUVjV6XdaHPFJUysaKeqblpXtdyoAo3EVk0H1mZh5/vXw2e+pb+czdb/LXt7eydGuVpz35zk7Hz19Yz/8+toqPj8vishMKPKslEKyvCxtmNhp4EBgBdAL3Oed+c5C2xwLvABc45x471PsWFha6oqKiwypaRCLDxt31XPFgEZv3/Gez7RGp8UzNTeWGeROYOEj7lja2tvONf67gpbXl/NfsfH64YAoxUaHZ9zWzZc65wj7b+RHuI4GRzrnlZpYCLAPOcc6t69UuClgINAN/VLiLiD+cc5TVNFO8q44Nu+ooLq/jteIK6pvbuf60o/jSiWOJ8lnQjl9e28zlf1nKurJavnvGZC47oQCz4B1voPwN9z6/oeqc2wns7L5fZ2brgVxgXa+m1wCPA8f2v1wRGarMjNz0BHLTEzhl4nAAKutb+PaTq/n5Cxt4ef1ubvvcdEZnJgbl+N95cjWbKxq4/5JC5k7MCcoxvNCv3zvMrACYCSzp9Xwu8Bng3kAVJiJD17DkOO79wixu++x01u+sZf6vX+dv75TQEeCVJnfVNPPKht1cdkJBRAXw89t0AAAGGklEQVQ79CPczSyZrp75tc652l4v/xq40Tl3yEWczewKMysys6KKior+VysiQ4aZcd6sPF78xklMH53Od59aw2fufpP3tlUH7BiPL99Op4PPFY4O2HuGij7H3AHMLAZ4FnjJOXf7AV7fAvQMUmUBjcAVzrmnDvaeGnMXEX8553h6ZRk/e3495bUtnD8rjxvnTxzQl4w6Ox2n3LaIkWnxPHzF8QGsNrj8HXPvs+duXVcWHgDWHyjYAZxzRzjnCpxzBcBjwFcPFewiIv1hZpw9I5eXrz+ZKz9xJP9asYO5v1rE0yvLDvs9l2ypoqSykQuOjbxeO/g3LHMCcDEw18xWdN9ON7MrzezKINcnIrJPclw0N316Ii9eexITRqTwtX+8x69eev+wdn16pKiUlPhoPj11ZBAq9Z4/s2UW858hlz455y4dSEEiIn05MjuZh748h+//aw13vbqR4vI67rhgBklx/m1RUdPUxvOrd/LZwjziY6KCXK03QnOWvohIH2Kjffz83KP5wVmT+ff6cs675y2/lzJ4esUOWto7uaAwP8hVekfhLiJhy8y47IQj+Mt/H0fZ3ibO+d2b+5brPZR/FpUyaWRqWG/G0ReFu4iEvRPHZ/PkVScQF+3jwj+8w7KSg0+XXFtWw5odtVxQmBfS30QdKIW7iESEI7OTefQrHyMrOY6LH1jCmxv3HLDdI0tLiY32cc7M3EGucHAp3EUkYuSmJ/DP/5lDfmYil/1pKQvXle97ra65jRWle3lqRRnzpowgPTHWw0qDz79LyyIiYWJ4SjwPXzGHS/60lCv/tozCMRlsrWygvLYFgCifcfGcMR5XGXwKdxGJOOmJsfz9S7O56fFVbK9u4sTx2RyZncyR2UlMGpkatEXIQonCXUQiUnJcNHdddIzXZXhGY+4iIhFI4S4iEoEU7iIiEUjhLiISgRTuIiIRSOEuIhKBFO4iIhFI4S4iEoH82kM1KAc2qwBKDvBSGtD3mp2Hbnew1w70vD/P7f84CzjwikQD5+/f/XB+rq82Omf9bzOQc9bXY52z8D9nwcqyMc657D7f1TkXUjfgvoG2O9hrB3ren+f2fwwUef13P5yf66uNztngnjM/Huuchfk5C2aW+XMLxWGZZwLQ7mCvHeh5f57zt6aBOtzj+PNzfbXROet/m4GcM6/O10COpXMWnJ85nCzrk2fDMuHKzIqcc4Ve1xFOdM76T+es/3TOPiwUe+6h7j6vCwhDOmf9p3PWfzpn+1HPXUQkAqnnLiISgRTuIiIRSOEuIhKBFO4BZmZJZrbMzM70upZwYGaTzOxeM3vMzL7idT3hwMzOMbM/mNm/zOw0r+sJdWY21sweMLPHvK5lMCncu5nZH81st5mt6fX8fDN738w2mtlNfrzVjcAjwakytATinDnn1jvnrgQ+B0T8NLYAnbOnnHNfBi4FLghiuZ4L0Pna7Jy7PLiVhh7NlulmZicB9cCDzrmp3c9FAcXAp4DtwFLgQiAK+Hmvt/hvYBpdX4GOB/Y4554dnOq9EYhz5pzbbWYLgJuAu5xzDw1W/V4I1Dnr/rnbgL8755YPUvmDLsDn6zHn3PmDVbvXtEF2N+fc62ZW0Ovp44CNzrnNAGb2MHC2c+7nwEeGXczsFCAJmAw0mdnzzrnOoBbuoUCcs+73eRp42syeAyI63AP078yAW4AXIjnYIXD/xoYihfuh5QKl+z3eDsw+WGPn3HcAzOxSunruERvsh9Cvc2ZmJwPnAnHA80GtLHT165wB1wCnAmlmNs45d28wiwtB/f03Ngz4KTDTzL7V/SEQ8RTuh2YHeK7PcSzn3J8DX0rY6Nc5c84tAhYFq5gw0d9z9lvgt8ErJ+T193xVAlcGr5zQpAuqh7YdGL3f4zygzKNawoXOWf/pnPWPzpcfFO6HthQYb2ZHmFks8HngaY9rCnU6Z/2nc9Y/Ol9+ULh3M7N/AG8DE8xsu5ld7pxrB64GXgLWA48459Z6WWco0TnrP52z/tH5OnyaCikiEoHUcxcRiUAKdxGRCKRwFxGJQAp3EZEIpHAXEYlACncRkQikcBcRiUAKdxGRCKRwFxGJQP8Pvn+X4EDkyyIAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lr_find(learn)\n", "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 1e-2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=6), HTML(value=''))), HTML(value='epoch train loss va…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 02:55\n", "epoch train loss valid loss\n", "0 1.144465 1.669415 (00:28)\n", "1 0.827026 0.767053 (00:29)\n", "2 0.702947 0.656885 (00:29)\n", "3 0.644283 0.570501 (00:29)\n", "4 0.598689 0.517745 (00:28)\n", "5 0.575997 0.510366 (00:29)\n", "\n" ] } ], "source": [ "learn.fit_one_cycle(6, slice(lr), pct_start=0.05)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('u0')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.load('u0')\n", "x,y = next(iter(learn.data.valid_dl))\n", "py = learn.model(x).detach()\n", "py = py.softmax(dim=1).max(dim=1, keepdim=True)[1]\n", "x,y,py = x.cpu(),y.cpu(),py.cpu()\n", "x = default_denorm(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n = 4\n", "fig, axs = plt.subplots(n,3,figsize=(10,10), sharey=True)\n", "for i in range(n):\n", " Image(x[i]).show(ax=axs[i][0])\n", " codes_to_image(y[i].numpy()).show(ax=axs[i][1])\n", " codes_to_image(py[i].numpy()).show(ax=axs[i][2])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.unfreeze()\n", "lr=1e-2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(6, slice(lr/100,lr), pct_start=0.05)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "size=640\n", "bs = 4\n", "learn.data = get_data(size, bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#learn.freeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(6, slice(lr), pct_start=0.05)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "id2code" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fin" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }