{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp tabular.data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.torch_basics import *\n", "from fastai2.data.all import *\n", "from fastai2.tabular.core import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tabular data\n", "\n", "> Helper functions to get data in a `DataLoaders` in the tabular application and higher class `TabularDataLoaders`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The main class to get your data ready for model training is `TabularDataLoaders` and its factory methods. Checkout the [tabular tutorial](http://dev.fast.ai/tutorial.tabular) for examples of use." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TabularDataLoaders -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TabularDataLoaders(DataLoaders):\n", " \"Basic wrapper around several `DataLoader`s with factory methods for tabular data\"\n", " @classmethod\n", " @delegates(Tabular.dataloaders, but=[\"dl_type\", \"dl_kwargs\"])\n", " def from_df(cls, df, path='.', procs=None, cat_names=None, cont_names=None, y_names=None, y_block=None,\n", " valid_idx=None, **kwargs):\n", " \"Create from `df` in `path` using `procs`\"\n", " if cat_names is None: cat_names = []\n", " if cont_names is None: cont_names = list(set(df)-set(cat_names)-set(y_names))\n", " splits = RandomSplitter()(df) if valid_idx is None else IndexSplitter(valid_idx)(df)\n", " to = TabularPandas(df, procs, cat_names, cont_names, y_names, splits=splits, y_block=y_block)\n", " return to.dataloaders(path=path, **kwargs)\n", "\n", " @classmethod\n", " def from_csv(cls, csv, **kwargs):\n", " \"Create from `csv` file in `path` using `procs`\"\n", " return cls.from_df(pd.read_csv(csv), **kwargs)\n", "\n", " @delegates(TabDataLoader.__init__)\n", " def test_dl(self, test_items, rm_type_tfms=None, process=True, **kwargs):\n", " to = self.train_ds.new(test_items)\n", " if process: to.process()\n", " return self.valid.new(to, **kwargs)\n", "\n", "Tabular._dbunch_type = TabularDataLoaders\n", "TabularDataLoaders.from_csv = delegates(to=TabularDataLoaders.from_df)(TabularDataLoaders.from_csv)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This class should not be used directly, one of the factory methods should be prefered instead. All those factory methods accept as arguments:\n", "\n", "- `cat_names`: the names of the categorical variables\n", "- `cont_names`: the names of the continuous variables\n", "- `y_names`: the names of the dependent variables\n", "- `y_block`: the `TransformBlock` to use for the target\n", "- `valid_idx`: the indices to use for the validation set (defaults to a random split otherwise)\n", "- `bs`: the batch size\n", "- `val_bs`: the batch size for the validation `DataLoader` (defaults to `bs`)\n", "- `shuffle_train`: if we shuffle the training `DataLoader` or not\n", "- `n`: overrides the numbers of elements in the dataset\n", "- `device`: the PyTorch device to use (defaults to `default_device()`)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

TabularDataLoaders.from_df[source]

\n", "\n", "> TabularDataLoaders.from_df(**`df`**, **`path`**=*`'.'`*, **`procs`**=*`None`*, **`cat_names`**=*`None`*, **`cont_names`**=*`None`*, **`y_names`**=*`None`*, **`y_block`**=*`None`*, **`valid_idx`**=*`None`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`n`**=*`None`*, **`device`**=*`None`*)\n", "\n", "Create from `df` in `path` using `procs`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TabularDataLoaders.from_df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's have a look on an example with the adult dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countrysalary
049Private101320Assoc-acdm12.0Married-civ-spouseNaNWifeWhiteFemale0190240United-States>=50k
144Private236746Masters14.0DivorcedExec-managerialNot-in-familyWhiteMale10520045United-States>=50k
238Private96185HS-gradNaNDivorcedNaNUnmarriedBlackFemale0032United-States<50k
338Self-emp-inc112847Prof-school15.0Married-civ-spouseProf-specialtyHusbandAsian-Pac-IslanderMale0040United-States>=50k
442Self-emp-not-inc822977th-8thNaNMarried-civ-spouseOther-serviceWifeBlackFemale0050United-States<50k
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education-num \\\n", "0 49 Private 101320 Assoc-acdm 12.0 \n", "1 44 Private 236746 Masters 14.0 \n", "2 38 Private 96185 HS-grad NaN \n", "3 38 Self-emp-inc 112847 Prof-school 15.0 \n", "4 42 Self-emp-not-inc 82297 7th-8th NaN \n", "\n", " marital-status occupation relationship race \\\n", "0 Married-civ-spouse NaN Wife White \n", "1 Divorced Exec-managerial Not-in-family White \n", "2 Divorced NaN Unmarried Black \n", "3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n", "4 Married-civ-spouse Other-service Wife Black \n", "\n", " sex capital-gain capital-loss hours-per-week native-country salary \n", "0 Female 0 1902 40 United-States >=50k \n", "1 Male 10520 0 45 United-States >=50k \n", "2 Female 0 0 32 United-States <50k \n", "3 Male 0 0 40 United-States >=50k \n", "4 Female 0 0 50 United-States <50k " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(path/'adult.csv')\n", "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", "cont_names = ['age', 'fnlwgt', 'education-num']\n", "procs = [Categorify, FillMissing, Normalize]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names, \n", " y_names=\"salary\", valid_idx=list(range(800,1000)), bs=64)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
workclasseducationmarital-statusoccupationrelationshipraceeducation-num_naagefnlwgteducation-numsalary
0PrivateHS-gradDivorcedSalesNot-in-familyWhiteFalse40.0116632.0014079.0>=50k
1State-govSome-collegeNever-marriedProtective-servOwn-childBlackFalse22.0293363.99888610.0<50k
2PrivateHS-gradDivorcedCraft-repairOwn-childWhiteFalse35.0126568.9988869.0<50k
3PrivateMastersDivorcedExec-managerialUnmarriedBlackFalse39.0150061.00107114.0>=50k
4PrivateSome-collegeNever-marriedSalesOwn-childWhiteFalse21.0283756.99847410.0<50k
5PrivateMastersMarried-civ-spouseSalesHusbandWhiteFalse29.0134565.99760314.0<50k
6Self-emp-not-incHS-gradMarried-civ-spouseFarming-fishingHusbandWhiteFalse39.0148442.9995049.0<50k
7PrivateSome-collegeMarried-civ-spouseAdm-clericalHusbandWhiteFalse49.0280524.99999110.0>=50k
8Local-govHS-gradDivorcedHandlers-cleanersNot-in-familyWhiteFalse39.0166497.0000639.0>=50k
9?11thNever-married?Own-childWhiteFalse17.047407.0019117.0<50k
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dls.show_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

TabularDataLoaders.from_csv[source]

\n", "\n", "> TabularDataLoaders.from_csv(**`csv`**, **`path`**=*`'.'`*, **`procs`**=*`None`*, **`cat_names`**=*`None`*, **`cont_names`**=*`None`*, **`y_names`**=*`None`*, **`y_block`**=*`None`*, **`valid_idx`**=*`None`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`n`**=*`None`*, **`device`**=*`None`*)\n", "\n", "Create from `csv` file in `path` using `procs`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TabularDataLoaders.from_csv)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", "cont_names = ['age', 'fnlwgt', 'education-num']\n", "procs = [Categorify, FillMissing, Normalize]\n", "dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, procs=procs, cat_names=cat_names, cont_names=cont_names, \n", " y_names=\"salary\", valid_idx=list(range(800,1000)), bs=64)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_torch_core.ipynb.\n", "Converted 01_layers.ipynb.\n", "Converted 02_data.load.ipynb.\n", "Converted 03_data.core.ipynb.\n", "Converted 04_data.external.ipynb.\n", "Converted 05_data.transforms.ipynb.\n", "Converted 06_data.block.ipynb.\n", "Converted 07_vision.core.ipynb.\n", "Converted 08_vision.data.ipynb.\n", "Converted 09_vision.augment.ipynb.\n", "Converted 09b_vision.utils.ipynb.\n", "Converted 09c_vision.widgets.ipynb.\n", "Converted 10_tutorial.pets.ipynb.\n", "Converted 11_vision.models.xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_callback.core.ipynb.\n", "Converted 13a_learner.ipynb.\n", "Converted 13b_metrics.ipynb.\n", "Converted 14_callback.schedule.ipynb.\n", "Converted 14a_callback.data.ipynb.\n", "Converted 15_callback.hook.ipynb.\n", "Converted 15a_vision.models.unet.ipynb.\n", "Converted 16_callback.progress.ipynb.\n", "Converted 17_callback.tracker.ipynb.\n", "Converted 18_callback.fp16.ipynb.\n", "Converted 18a_callback.training.ipynb.\n", "Converted 19_callback.mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision.learner.ipynb.\n", "Converted 22_tutorial.imagenette.ipynb.\n", "Converted 23_tutorial.vision.ipynb.\n", "Converted 24_tutorial.siamese.ipynb.\n", "Converted 24_vision.gan.ipynb.\n", "Converted 30_text.core.ipynb.\n", "Converted 31_text.data.ipynb.\n", "Converted 32_text.models.awdlstm.ipynb.\n", "Converted 33_text.models.core.ipynb.\n", "Converted 34_callback.rnn.ipynb.\n", "Converted 35_tutorial.wikitext.ipynb.\n", "Converted 36_text.models.qrnn.ipynb.\n", "Converted 37_text.learner.ipynb.\n", "Converted 38_tutorial.text.ipynb.\n", "Converted 39_tutorial.transformers.ipynb.\n", "Converted 40_tabular.core.ipynb.\n", "Converted 41_tabular.data.ipynb.\n", "Converted 42_tabular.model.ipynb.\n", "Converted 43_tabular.learner.ipynb.\n", "Converted 44_tutorial.tabular.ipynb.\n", "Converted 45_collab.ipynb.\n", "Converted 46_tutorial.collab.ipynb.\n", "Converted 50_tutorial.datablock.ipynb.\n", "Converted 60_medical.imaging.ipynb.\n", "Converted 61_tutorial.medical_imaging.ipynb.\n", "Converted 65_medical.text.ipynb.\n", "Converted 70_callback.wandb.ipynb.\n", "Converted 71_callback.tensorboard.ipynb.\n", "Converted 72_callback.neptune.ipynb.\n", "Converted 73_callback.captum.ipynb.\n", "Converted 74_callback.cutmix.ipynb.\n", "Converted 97_test_utils.ipynb.\n", "Converted 99_pytorch_doc.ipynb.\n", "Converted index.ipynb.\n", "Converted tutorial.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }