{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.basics import *\n", "from fastai2.tabular.core import *\n", "from fastai2.tabular.model import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nbdev.showdoc import *\n", "from fastai2.tabular.data import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp tabular.learner" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tabular learner\n", "\n", "> The function to immediately get a `Learner` ready to train for tabular data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The main function you probably want to use in this module is `tabular_learner`. It will automatically create a `TabulaModel` suitable for your data and infer the irght loss function. See the [tabular tutorial](http://dev.fast.ai/tutorial.tabular) for an example of use in context." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Main functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@log_args(but_as=Learner.__init__)\n", "class TabularLearner(Learner):\n", " \"`Learner` for tabular data\"\n", " def predict(self, row):\n", " tst_to = self.dls.valid_ds.new(pd.DataFrame(row).T)\n", " tst_to.process()\n", " tst_to.conts = tst_to.conts.astype(np.float32)\n", " dl = self.dls.valid.new(tst_to)\n", " inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)\n", " i = getattr(self.dls, 'n_inp', -1)\n", " b = (*tuplify(inp),*tuplify(dec_preds))\n", " full_dec = self.dls.decode((*tuplify(inp),*tuplify(dec_preds)))\n", " return full_dec,dec_preds[0],preds[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class TabularLearner[source]

\n", "\n", "> TabularLearner(**`dls`**, **`model`**, **`loss_func`**=*`None`*, **`opt_func`**=*`'Adam'`*, **`lr`**=*`0.001`*, **`splitter`**=*`'trainable_params'`*, **`cbs`**=*`None`*, **`metrics`**=*`None`*, **`path`**=*`None`*, **`model_dir`**=*`'models'`*, **`wd`**=*`None`*, **`wd_bn_bias`**=*`False`*, **`train_bn`**=*`True`*, **`moms`**=*`(0.95, 0.85, 0.95)`*) :: [`Learner`](/13a_learner#Learner)\n", "\n", "[`Learner`](/13a_learner#Learner) for tabular data" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TabularLearner, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It works exactly as a normal `Learner`, the only difference is that it implements a `predict` method specific to work on a row of data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@log_args(to_return=True, but_as=Learner.__init__)\n", "@delegates(Learner.__init__)\n", "def tabular_learner(dls, layers=None, emb_szs=None, config=None, n_out=None, y_range=None, **kwargs):\n", " \"Get a `Learner` using `dls`, with `metrics`, including a `TabularModel` created using the remaining params.\"\n", " if config is None: config = tabular_config()\n", " if layers is None: layers = [200,100]\n", " to = dls.train_ds\n", " emb_szs = get_emb_sz(dls.train_ds, {} if emb_szs is None else emb_szs)\n", " if n_out is None: n_out = get_c(dls)\n", " assert n_out, \"`n_out` is not defined, and could not be infered from data, set `dls.c` or pass `n_out`\"\n", " if y_range is None and 'y_range' in config: y_range = config.pop('y_range')\n", " model = TabularModel(emb_szs, len(dls.cont_names), n_out, layers, y_range=y_range, **config)\n", " return TabularLearner(dls, model, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If your data was built with fastai, you probably won't need to pass anything to `emb_szs` unless you want to change the default of the library (produced by `get_emb_sz`), same for `n_out` which should be automatically inferred. `layers` will default to `[200,100]` and is passed to `TabularModel` along with the `config`.\n", "\n", "Use `tabular_config` to create a `config` and cusotmize the model used. There is just easy access to `y_range` because this argument is often used.\n", "\n", "All the other arguments are passed to `Learner`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(path/'adult.csv')\n", "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", "cont_names = ['age', 'fnlwgt', 'education-num']\n", "procs = [Categorify, FillMissing, Normalize]\n", "dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names, \n", " y_names=\"salary\", valid_idx=list(range(800,1000)), bs=64)\n", "learn = tabular_learner(dls)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "tst = learn.predict(df.iloc[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#test y_range is passed\n", "learn = tabular_learner(dls, y_range=(0,32))\n", "assert isinstance(learn.model.layers[-1], SigmoidRange)\n", "test_eq(learn.model.layers[-1].low, 0)\n", "test_eq(learn.model.layers[-1].high, 32)\n", "\n", "learn = tabular_learner(dls, config = tabular_config(y_range=(0,32)))\n", "assert isinstance(learn.model.layers[-1], SigmoidRange)\n", "test_eq(learn.model.layers[-1].low, 0)\n", "test_eq(learn.model.layers[-1].high, 32)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@typedispatch\n", "def show_results(x:Tabular, y:Tabular, samples, outs, ctxs=None, max_n=10, **kwargs):\n", " df = x.all_cols[:max_n]\n", " for n in x.y_names: df[n+'_pred'] = y[n][:max_n].values\n", " display_df(df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_torch_core.ipynb.\n", "Converted 01_layers.ipynb.\n", "Converted 02_data.load.ipynb.\n", "Converted 03_data.core.ipynb.\n", "Converted 04_data.external.ipynb.\n", "Converted 05_data.transforms.ipynb.\n", "Converted 06_data.block.ipynb.\n", "Converted 07_vision.core.ipynb.\n", "Converted 08_vision.data.ipynb.\n", "Converted 09_vision.augment.ipynb.\n", "Converted 09b_vision.utils.ipynb.\n", "Converted 09c_vision.widgets.ipynb.\n", "Converted 10_tutorial.pets.ipynb.\n", "Converted 11_vision.models.xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_callback.core.ipynb.\n", "Converted 13a_learner.ipynb.\n", "Converted 13b_metrics.ipynb.\n", "Converted 14_callback.schedule.ipynb.\n", "Converted 14a_callback.data.ipynb.\n", "Converted 15_callback.hook.ipynb.\n", "Converted 15a_vision.models.unet.ipynb.\n", "Converted 16_callback.progress.ipynb.\n", "Converted 17_callback.tracker.ipynb.\n", "Converted 18_callback.fp16.ipynb.\n", "Converted 18a_callback.training.ipynb.\n", "Converted 19_callback.mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision.learner.ipynb.\n", "Converted 22_tutorial.imagenette.ipynb.\n", "Converted 23_tutorial.vision.ipynb.\n", "Converted 24_tutorial.siamese.ipynb.\n", "Converted 24_vision.gan.ipynb.\n", "Converted 30_text.core.ipynb.\n", "Converted 31_text.data.ipynb.\n", "Converted 32_text.models.awdlstm.ipynb.\n", "Converted 33_text.models.core.ipynb.\n", "Converted 34_callback.rnn.ipynb.\n", "Converted 35_tutorial.wikitext.ipynb.\n", "Converted 36_text.models.qrnn.ipynb.\n", "Converted 37_text.learner.ipynb.\n", "Converted 38_tutorial.text.ipynb.\n", "Converted 40_tabular.core.ipynb.\n", "Converted 41_tabular.data.ipynb.\n", "Converted 42_tabular.model.ipynb.\n", "Converted 43_tabular.learner.ipynb.\n", "Converted 44_tutorial.tabular.ipynb.\n", "Converted 45_collab.ipynb.\n", "Converted 46_tutorial.collab.ipynb.\n", "Converted 50_tutorial.datablock.ipynb.\n", "Converted 60_medical.imaging.ipynb.\n", "Converted 61_tutorial.medical_imaging.ipynb.\n", "Converted 65_medical.text.ipynb.\n", "Converted 70_callback.wandb.ipynb.\n", "Converted 71_callback.tensorboard.ipynb.\n", "Converted 72_callback.neptune.ipynb.\n", "Converted 73_callback.captum.ipynb.\n", "Converted 74_callback.cutmix.ipynb.\n", "Converted 97_test_utils.ipynb.\n", "Converted 99_pytorch_doc.ipynb.\n", "Converted index.ipynb.\n", "Converted tutorial.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }