{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Simple model for tabular data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.tabular.models import TabularModel" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class TabularModel[source][test]

\n", "\n", "> TabularModel(**`emb_szs`**:`ListSizes`, **`n_cont`**:`int`, **`out_sz`**:`int`, **`layers`**:`Collection`\\[`int`\\], **`ps`**:`Collection`\\[`float`\\]=***`None`***, **`emb_drop`**:`float`=***`0.0`***, **`y_range`**:`OptRange`=***`None`***, **`use_bn`**:`bool`=***`True`***, **`bn_final`**:`bool`=***`False`***) :: [`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\n", "\n", "
×

No tests found for TabularModel. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Basic model for tabular data. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TabularModel)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`emb_szs` match each categorical variable size with an embedding size, `n_cont` is the number of continuous variables. The model consists of `Embedding` layers for the categorical variables, followed by a `Dropout` of `emb_drop`, and a `BatchNorm` for the continuous variables. The results are concatenated and followed by blocks of `BatchNorm`, `Dropout`, `Linear` and `ReLU` (the first block skips `BatchNorm` and `Dropout`, the last block skips the `ReLU`).\n", "\n", "The sizes of the blocks are given in [`layers`](/layers.html#layers) and the probabilities of the `Dropout` in `ps`. The last size is `out_sz`, and we add a last activation that is a sigmoid rescaled to cover `y_range` (if it's not `None`). Lastly, if `use_bn` is set to False, all `BatchNorm` layers are skipped except the one applied to the continuous variables.\n", "\n", "Generally it's easiest to just create a learner with [`tabular_learner`](/tabular.data.html#tabular_learner), which will automatically create a [`TabularModel`](/tabular.models.html#TabularModel) for you." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

forward[source][test]

\n", "\n", "> forward(**`x_cat`**:`Tensor`, **`x_cont`**:`Tensor`) → `Tensor`\n", "\n", "
×

No tests found for forward. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Defines the computation performed at every call. Should be overridden by all subclasses.\n", "\n", ".. note::\n", " Although the recipe for forward pass needs to be defined within\n", " this function, one should call the :class:`Module` instance afterwards\n", " instead of this since the former takes care of running the\n", " registered hooks while the latter silently ignores them. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TabularModel.forward)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

get_sizes[source][test]

\n", "\n", "> get_sizes(**`layers`**, **`out_sz`**)\n", "\n", "
×

No tests found for get_sizes. To contribute a test please refer to this guide and this discussion.

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TabularModel.get_sizes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Methods - Please document or move to the undocumented section" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Model for training tabular/structured data", "title": "tabular.models" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }