{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Create a Learner for inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai import *\n", "from fastai.gen_doc.nbdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we'll see how the same API allows you to create an empty [`DataBunch`](/basic_data.html#DataBunch) for a [`Learner`](/basic_train.html#Learner) at inference time (once you have trained your model) and how to call the `predict` method to get the predictions on a single item." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
Note: As usual, this page is generated from a notebook that you can find in the docs_srs folder of the\n", "[fastai repo](https://github.com/fastai/fastai). We use the saved models from [this tutorial](/tutorial.data.html) to\n", "have this notebook run fast.\n", "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "jekyll_note(\"\"\"As usual, this page is generated from a notebook that you can find in the docs_srs folder of the\n", "[fastai repo](https://github.com/fastai/fastai). We use the saved models from [this tutorial](/tutorial.data.html) to\n", "have this notebook run fast.\n", "\"\"\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vision" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To quickly get acces to all the vision functions inside fastai, we use the usual import statements." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A classification problem" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's begin with our sample of the MNIST dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mnist = untar_data(URLs.MNIST_TINY)\n", "tfms = get_transforms(do_flip=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's set up with an imagenet structure so we use it to split our training and validation set, then labelling." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = (ImageItemList.from_folder(mnist)\n", " .split_by_folder() \n", " .label_from_folder()\n", " .transform(tfms, size=32)\n", " .databunch()\n", " .normalize(imagenet_stats)) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that our data has been properly set up, we can train a model. Once the time comes to deploy it for inference, we'll need to save the information this [`DataBunch`](/basic_data.html#DataBunch) contains (classes for instance), to do this, we call `data.export()`. This will create an 'export.pkl' file that you'll need to copy with your model file if you want do deploy pn another device." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.export()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To create the [`DataBunch`](/basic_data.html#DataBunch) for inference, you'll need to use the `load_empty` method. Note that for now, transforms and normalization aren't saved inside the export file. This is going to be integrated in a future version of the library. For now, we pass the transforms we applied on the validation set, along with all relevant kwargs, and we normalize with the same statistics as during training.\n", "\n", "Then, we use it to create a [`Learner`](/basic_train.html#Learner) and load the model we trained before." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "empty_data = ImageDataBunch.load_empty(mnist, tfms=tfms[1],size=32).normalize(imagenet_stats)\n", "learn = create_cnn(empty_data, models.resnet18)\n", "learn.load('mini_train');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can now get the predictions on any image via `learn.predict`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Category 7, tensor(0), tensor([0.6870, 0.3130]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = data.train_ds[0][0]\n", "learn.predict(img)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It returns a tuple of three things: the object predicted (with the class in this instance), the underlying data (here the corresponding index) and the raw probabilities." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A multilabel problem" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's try these on the planet dataset, which is a little bit different in the sense that each image can have multiple tags (and not jsut one label)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "planet = untar_data(URLs.PLANET_TINY)\n", "planet_tfms = get_transforms(flip_vert=True, max_lighting=0.1, max_zoom=1.05, max_warp=0.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here each images is labelled in a file named 'labels.csv'. We have to add 'train' as a prefix to the filenames, '.jpg' as a suffix and he labels are separated by spaces." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = (ImageItemList.from_csv(planet, 'labels.csv', folder='train', suffix='.jpg')\n", " .random_split_by_pct()\n", " .label_from_df(sep=' ')\n", " .transform(planet_tfms, size=128)\n", " .databunch()\n", " .normalize(imagenet_stats))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again, we call `data.export()` to export our data object properties." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.export()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can then create the [`DataBunch`](/basic_data.html#DataBunch) for inference, by using the `load_empty` method as before." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "empty_data = ImageDataBunch.load_empty(planet, tfms=tfms[1],size=32).normalize(imagenet_stats)\n", "learn = create_cnn(empty_data, models.resnet18)\n", "learn.load('mini_train');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we get the predictions on any image via `learn.predict`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(MultiCategory selective_logging;cultivation,\n", " tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]),\n", " tensor([0.1923, 0.6070, 0.1862, 0.3969, 0.0471, 0.1786, 0.0220, 0.2634, 0.3177,\n", " 0.2928, 0.1829, 0.5932, 0.3870, 0.2521]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = data.train_ds[0][0]\n", "learn.predict(img)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we can specify a particular theshold to consider the predictions are a hit or not. The default is 0.5 but we can change it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(MultiCategory selective_logging;primary;partly_cloudy;cultivation;clear,\n", " tensor([0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0.]),\n", " tensor([0.1923, 0.6070, 0.1862, 0.3969, 0.0471, 0.1786, 0.0220, 0.2634, 0.3177,\n", " 0.2928, 0.1829, 0.5932, 0.3870, 0.2521]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.predict(img, thresh=0.3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A regression example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the next example, we are going to use the [BIWI head pose](https://data.vision.ee.ethz.ch/cvl/gfanelli/head_pose/head_forest.html#db) dataset. On pictures of persons, we have to find the center of their face. For the fastai docs, we have built a small subsample of the dataset (200 images) and prepared a dictionary for the correspondance fielname to center." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "biwi = untar_data(URLs.BIWI_SAMPLE)\n", "fn2ctr = pickle.load(open(biwi/'centers.pkl', 'rb'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To grab our data, we use this dictionary to label our items. We also use the [`PointsItemList`](/vision.data.html#PointsItemList) class to have the targets be of type [`ImagePoints`](/vision.image.html#ImagePoints) (which will make sure the data augmentation is properly applied to them). When calling [`transform`](/tabular.transform.html#tabular.transform) we make sure to set `tfm_y=True`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = (ImageItemList.from_folder(biwi)\n", " .random_split_by_pct()\n", " .label_from_func(lambda o:fn2ctr[o.name], label_cls=PointsItemList)\n", " .transform(get_transforms(), tfm_y=True, size=(120,160))\n", " .databunch()\n", " .normalize(imagenet_stats))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As before, the road to inference is pretty straightforward: export the data, then load an empty [`DataBunch`](/basic_data.html#DataBunch)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.export()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "empty_data = ImageDataBunch.load_empty(biwi, tfms=get_transforms()[1], tfm_y=True, size=(120,60)).normalize(imagenet_stats)\n", "learn = create_cnn(empty_data, models.resnet18)\n", "learn.load('mini_train');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And now we can a prediction on an image." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(ImagePoints (120, 60),\n", " tensor([[ 0.7982, -0.5515]]),\n", " tensor([ 0.7982, -0.5515]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = data.train_ds[0][0]\n", "learn.predict(img)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To visualize the predictions, we can use the [`Image.show`](/vision.image.html#Image.show) method." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "img.show(y=learn.predict(img)[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A segmentation example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we are going to look at the [camvid dataset](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) (at least a small sample of it), where we have to predict the class of each pixel in an image. Each image in the 'images' subfolder as an equivalent in 'labels' that is its segmentations mask." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "camvid = untar_data(URLs.CAMVID_TINY)\n", "path_lbl = camvid/'labels'\n", "path_img = camvid/'images'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We read the classes in 'codes.txt' and the function maps each image filename with its corresponding mask filename." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "codes = np.loadtxt(camvid/'codes.txt', dtype=str)\n", "get_y_fn = lambda x: path_lbl/f'{x.stem}_P{x.suffix}'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The data block API allows us to uickly get everything in a [`DataBunch`](/basic_data.html#DataBunch) and then we can have a look with `show_batch`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = (SegmentationItemList.from_folder(path_img)\n", " .random_split_by_pct()\n", " .label_from_func(get_y_fn, classes=codes)\n", " .transform(get_transforms(), tfm_y=True, size=128)\n", " .databunch(bs=16, path=camvid)\n", " .normalize(imagenet_stats))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As before, we export the data then create an empty [`DataBunch`](/basic_data.html#DataBunch) that we pass to a [`Learner`](/basic_train.html#Learner)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.export()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "empty_data = ImageDataBunch.load_empty(camvid, tfms=get_transforms()[1], tfm_y=True, size=128).normalize(imagenet_stats)\n", "learn = Learner.create_unet(empty_data, models.resnet18)\n", "learn.load('mini_train');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And now we can a prediction on an image." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(ImageSegment (1, 128, 128), tensor([[[21, 21, 21, ..., 4, 4, 4],\n", " [21, 21, 21, ..., 4, 4, 4],\n", " [21, 21, 21, ..., 4, 4, 4],\n", " ...,\n", " [17, 17, 17, ..., 17, 17, 17],\n", " [17, 17, 17, ..., 17, 17, 17],\n", " [17, 17, 17, ..., 17, 17, 17]]]), tensor([[[2.3323e-02, 2.2946e-02, 4.6494e-03, ..., 3.7161e-04,\n", " 5.0104e-03, 6.4412e-03],\n", " [2.2634e-02, 2.0179e-02, 5.6307e-03, ..., 2.9151e-04,\n", " 5.1298e-03, 5.2150e-03],\n", " [6.2241e-03, 7.0151e-03, 3.7897e-04, ..., 4.0918e-06,\n", " 3.2178e-04, 4.4637e-04],\n", " ...,\n", " [9.0529e-04, 7.9842e-04, 7.6028e-05, ..., 8.6253e-04,\n", " 3.6710e-03, 3.6289e-03],\n", " [3.6277e-03, 4.0498e-03, 5.4831e-04, ..., 3.7190e-03,\n", " 8.8866e-03, 9.6405e-03],\n", " [3.8535e-03, 3.4831e-03, 6.7555e-04, ..., 2.9442e-03,\n", " 8.2562e-03, 7.8639e-03]],\n", " \n", " [[2.7580e-02, 2.4669e-02, 7.3161e-03, ..., 1.1249e-03,\n", " 9.4848e-03, 1.1947e-02],\n", " [2.5641e-02, 2.3801e-02, 8.7720e-03, ..., 7.5269e-04,\n", " 9.4482e-03, 8.9084e-03],\n", " [9.8628e-03, 9.7588e-03, 9.2763e-04, ..., 2.3099e-05,\n", " 8.9190e-04, 1.2924e-03],\n", " ...,\n", " [3.0092e-04, 3.2446e-04, 1.6413e-05, ..., 5.1571e-04,\n", " 1.9530e-03, 2.2843e-03],\n", " [1.5598e-03, 1.6954e-03, 1.7152e-04, ..., 2.1259e-03,\n", " 5.2503e-03, 5.7921e-03],\n", " [1.5411e-03, 1.6587e-03, 1.9799e-04, ..., 1.7436e-03,\n", " 4.3894e-03, 4.9487e-03]],\n", " \n", " [[3.9882e-03, 3.1505e-03, 5.0457e-04, ..., 1.4255e-03,\n", " 1.0569e-02, 1.2237e-02],\n", " [4.3692e-03, 3.8382e-03, 7.0057e-04, ..., 1.4145e-03,\n", " 1.2426e-02, 1.2313e-02],\n", " [8.4888e-04, 7.3702e-04, 2.9357e-05, ..., 8.8902e-05,\n", " 2.1612e-03, 2.8400e-03],\n", " ...,\n", " [4.2905e-04, 4.1327e-04, 1.9656e-05, ..., 3.4418e-03,\n", " 6.5353e-03, 6.9649e-03],\n", " [1.5882e-03, 1.9241e-03, 1.7470e-04, ..., 6.1704e-03,\n", " 8.6600e-03, 9.6681e-03],\n", " [1.6485e-03, 1.6144e-03, 1.9710e-04, ..., 4.7098e-03,\n", " 7.9967e-03, 8.2181e-03]],\n", " \n", " ...,\n", " \n", " [[9.3113e-03, 6.5334e-03, 1.8013e-03, ..., 3.2796e-03,\n", " 2.0767e-02, 2.0632e-02],\n", " [8.6677e-03, 8.9047e-03, 2.1877e-03, ..., 3.3119e-03,\n", " 2.2205e-02, 2.2050e-02],\n", " [2.8595e-03, 2.1089e-03, 1.8231e-04, ..., 2.3835e-04,\n", " 4.7750e-03, 5.3993e-03],\n", " ...,\n", " [1.3388e-04, 1.5440e-04, 3.7817e-06, ..., 1.1344e-03,\n", " 2.8678e-03, 3.3601e-03],\n", " [7.2934e-04, 9.0662e-04, 5.3639e-05, ..., 2.7882e-03,\n", " 4.7959e-03, 5.4441e-03],\n", " [6.9227e-04, 8.4727e-04, 5.7668e-05, ..., 2.1941e-03,\n", " 4.1560e-03, 4.9784e-03]],\n", " \n", " [[1.2492e-02, 1.1063e-02, 4.9451e-03, ..., 3.0513e-02,\n", " 6.4178e-02, 7.0040e-02],\n", " [1.2166e-02, 1.2005e-02, 6.0362e-03, ..., 3.1947e-02,\n", " 6.7927e-02, 6.9785e-02],\n", " [7.7364e-03, 7.5642e-03, 1.4550e-03, ..., 6.8208e-03,\n", " 3.3262e-02, 4.0198e-02],\n", " ...,\n", " [1.5871e-05, 1.4742e-05, 1.9349e-07, ..., 3.2095e-04,\n", " 8.5799e-04, 9.8868e-04],\n", " [1.1769e-04, 1.4500e-04, 4.5852e-06, ..., 7.2754e-04,\n", " 1.4678e-03, 1.6742e-03],\n", " [1.1819e-04, 1.1339e-04, 5.3238e-06, ..., 5.4077e-04,\n", " 1.2279e-03, 1.3526e-03]],\n", " \n", " [[2.5760e-03, 2.1818e-03, 1.8901e-04, ..., 1.0640e-03,\n", " 1.0117e-02, 1.1868e-02],\n", " [2.7598e-03, 2.6283e-03, 2.9429e-04, ..., 9.7199e-04,\n", " 1.1117e-02, 1.1070e-02],\n", " [3.4934e-04, 3.4733e-04, 5.4813e-06, ..., 6.2566e-05,\n", " 1.8896e-03, 2.6472e-03],\n", " ...,\n", " [6.8025e-04, 5.9770e-04, 3.0153e-05, ..., 6.9031e-03,\n", " 1.1587e-02, 1.2512e-02],\n", " [2.8487e-03, 2.9917e-03, 3.1098e-04, ..., 1.0829e-02,\n", " 1.6660e-02, 1.6246e-02],\n", " [2.4998e-03, 2.2843e-03, 2.9663e-04, ..., 8.1416e-03,\n", " 1.2735e-02, 1.3245e-02]]]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = data.train_ds[0][0]\n", "learn.predict(img)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To visualize the predictions, we can use the [`Image.show`](/vision.image.html#Image.show) method." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "img.show(y=learn.predict(img)[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next application is text, so let's start by importing everything we'll need." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import *\n", "from fastai.text import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Language modelling" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First let's look a how to get a language model ready for inference. Since we'll load the model trained in the [visualize data tutorial](/tutorial.data.html), we load the vocabulary used there." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "imdb = untar_data(URLs.IMDB_SAMPLE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vocab = Vocab(pickle.load(open(imdb/'tmp'/'itos.pkl', 'rb')))\n", "data_lm = (TextList.from_csv(imdb, 'texts.csv', cols='text', vocab=vocab)\n", " .random_split_by_pct()\n", " .label_for_lm()\n", " .databunch())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Like in vision, we just have to type `data_lm.export()` to save all the information inside the [`DataBunch`](/basic_data.html#DataBunch) we'll need. In this case, this includes all the vocabulary we created." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_lm.export()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's define a language model learner from an empty data object." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "empty_data = TextLMDataBunch.load_empty(imdb)\n", "learn = language_model_learner(empty_data)\n", "learn.load('mini_train_lm');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we can predict with the usual method, here we can specify how many words we want the model to predict." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:01\n", "\n" ] }, { "data": { "text/plain": [ "'This is a simple test of these men from the \" popularity of some scenes of a xxmaj gordon - xxmaj hudson and thick parts ,'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.predict('This is a simple test of', n_words=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's see a classification example. We have to use the same vocabulary as for the language model if we want to be able to use the encoder we saved." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_clas = (TextList.from_csv(imdb, 'texts.csv', cols='text', vocab=vocab)\n", " .split_from_df(col='is_valid')\n", " .label_from_df(cols='label')\n", " .databunch(bs=42))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again we export the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_clas.export()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's define a text classifier from an empty data object." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "empty_data = TextClasDataBunch.load_empty(imdb)\n", "learn = text_classifier_learner(empty_data)\n", "learn.load('mini_train_clas');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we can predict with the usual method." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Category negative, tensor(0), tensor([0.8194, 0.1806]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.predict('I really loved that movie!')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tabular" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Last application brings us to tabular data. First let's import everything we'll need." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import *\n", "from fastai.tabular import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll use a sample of the [adult dataset](https://archive.ics.uci.edu/ml/datasets/adult) here. Once we read the csv file, we'll need to specify the dependant variable, the categorical variables, the continuous variables and the processors we want to use." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "adult = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(adult/'adult.csv')\n", "dep_var = '>=50k'\n", "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']\n", "cont_names = ['education-num', 'hours-per-week', 'age', 'capital-loss', 'fnlwgt', 'capital-gain']\n", "procs = [FillMissing, Categorify, Normalize]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we can use the data block API to grab everything together before using `data.show_batch()`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = (TabularList.from_df(df, path=adult, cat_names=cat_names, cont_names=cont_names, procs=procs)\n", " .split_by_idx(valid_idx=range(800,1000))\n", " .label_from_df(cols=dep_var)\n", " .databunch())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define a [`Learner`](/basic_train.html#Learner) object that we fit and then save the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:04\n", "epoch train_loss valid_loss accuracy\n", "1 0.328005 0.354749 0.820000 (00:04)\n", "\n" ] } ], "source": [ "learn = tabular_learner(data, layers=[200,100], metrics=accuracy)\n", "learn.fit(1, 1e-2)\n", "learn.save('mini_train')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As in the other applications, we just have to type `data.export()` to save everything we'll need for inference (here the inner state of each processor)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.export()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we create an empty data object and a learner from it like before." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = TabularDataBunch.load_empty(adult)\n", "learn = tabular_learner(data, layers=[200,100])\n", "learn.load('mini_train');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we can predict on a row of dataframe that has the right `cat_names` and `cont_names`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Category 1, tensor(0), tensor([0.8100, 0.1900]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.predict(df.iloc[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Intermediate tutorial, explains how to create a Learner for inference", "title": "tutorial.inference" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }