{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.basics import *\n", "from fastai2.tabular.core import *\n", "from fastai2.tabular.model import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nbdev.showdoc import *\n", "from fastai2.tabular.data import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp tabular.learner" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tabular learner\n", "\n", "> The function to immediately get a `Learner` ready to train for tabular data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The main function you probably want to use in this module is `tabular_learner`. It will automatically create a `TabulaModel` suitable for your data and infer the irght loss function. See the [tabular tutorial](http://dev.fast.ai/tutorial.tabular) for an example of use in context." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Main functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@log_args(but_as=Learner.__init__)\n", "class TabularLearner(Learner):\n", " \"`Learner` for tabular data\"\n", " def predict(self, row):\n", " tst_to = self.dls.valid_ds.new(pd.DataFrame(row).T)\n", " tst_to.process()\n", " tst_to.conts = tst_to.conts.astype(np.float32)\n", " dl = self.dls.valid.new(tst_to)\n", " inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)\n", " i = getattr(self.dls, 'n_inp', -1)\n", " b = (*tuplify(inp),*tuplify(dec_preds))\n", " full_dec = self.dls.decode((*tuplify(inp),*tuplify(dec_preds)))\n", " return full_dec,dec_preds[0],preds[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "
class TabularLearner[source]TabularLearner(**`dls`**, **`model`**, **`loss_func`**=*`None`*, **`opt_func`**=*`'Adam'`*, **`lr`**=*`0.001`*, **`splitter`**=*`'trainable_params'`*, **`cbs`**=*`None`*, **`metrics`**=*`None`*, **`path`**=*`None`*, **`model_dir`**=*`'models'`*, **`wd`**=*`None`*, **`wd_bn_bias`**=*`False`*, **`train_bn`**=*`True`*, **`moms`**=*`(0.95, 0.85, 0.95)`*) :: [`Learner`](/13a_learner#Learner)\n",
"\n",
"[`Learner`](/13a_learner#Learner) for tabular data"
],
"text/plain": [
"