{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#default_exp tabular.data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from fastai2.torch_basics import *\n",
"from fastai2.data.all import *\n",
"from fastai2.tabular.core import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from nbdev.showdoc import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tabular data\n",
"\n",
"> Helper functions to get data in a `DataLoaders` in the tabular application and higher class `TabularDataLoaders`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The main class to get your data ready for model training is `TabularDataLoaders` and its factory methods. Checkout the [tabular tutorial](http://dev.fast.ai/tutorial.tabular) for examples of use."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## TabularDataLoaders -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class TabularDataLoaders(DataLoaders):\n",
" \"Basic wrapper around several `DataLoader`s with factory methods for tabular data\"\n",
" @classmethod\n",
" @delegates(Tabular.dataloaders, but=[\"dl_type\", \"dl_kwargs\"])\n",
" def from_df(cls, df, path='.', procs=None, cat_names=None, cont_names=None, y_names=None, y_block=None,\n",
" valid_idx=None, **kwargs):\n",
" \"Create from `df` in `path` using `procs`\"\n",
" if cat_names is None: cat_names = []\n",
" if cont_names is None: cont_names = list(set(df)-set(cat_names)-set(y_names))\n",
" splits = RandomSplitter()(df) if valid_idx is None else IndexSplitter(valid_idx)(df)\n",
" to = TabularPandas(df, procs, cat_names, cont_names, y_names, splits=splits, y_block=y_block)\n",
" return to.dataloaders(path=path, **kwargs)\n",
"\n",
" @classmethod\n",
" def from_csv(cls, csv, **kwargs):\n",
" \"Create from `csv` file in `path` using `procs`\"\n",
" return cls.from_df(pd.read_csv(csv), **kwargs)\n",
"\n",
" @delegates(TabDataLoader.__init__)\n",
" def test_dl(self, test_items, rm_type_tfms=None, process=True, **kwargs):\n",
" to = self.train_ds.new(test_items)\n",
" if process: to.process()\n",
" return self.valid.new(to, **kwargs)\n",
"\n",
"Tabular._dbunch_type = TabularDataLoaders\n",
"TabularDataLoaders.from_csv = delegates(to=TabularDataLoaders.from_df)(TabularDataLoaders.from_csv)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This class should not be used directly, one of the factory methods should be prefered instead. All those factory methods accept as arguments:\n",
"\n",
"- `cat_names`: the names of the categorical variables\n",
"- `cont_names`: the names of the continuous variables\n",
"- `y_names`: the names of the dependent variables\n",
"- `y_block`: the `TransformBlock` to use for the target\n",
"- `valid_idx`: the indices to use for the validation set (defaults to a random split otherwise)\n",
"- `bs`: the batch size\n",
"- `val_bs`: the batch size for the validation `DataLoader` (defaults to `bs`)\n",
"- `shuffle_train`: if we shuffle the training `DataLoader` or not\n",
"- `n`: overrides the numbers of elements in the dataset\n",
"- `device`: the PyTorch device to use (defaults to `default_device()`)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"
\n",
"\n",
"> TabularDataLoaders.from_df(**`df`**, **`path`**=*`'.'`*, **`procs`**=*`None`*, **`cat_names`**=*`None`*, **`cont_names`**=*`None`*, **`y_names`**=*`None`*, **`y_block`**=*`None`*, **`valid_idx`**=*`None`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`n`**=*`None`*, **`device`**=*`None`*)\n",
"\n",
"Create from `df` in `path` using `procs`"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(TabularDataLoaders.from_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's have a look on an example with the adult dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" workclass | \n",
" fnlwgt | \n",
" education | \n",
" education-num | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" sex | \n",
" capital-gain | \n",
" capital-loss | \n",
" hours-per-week | \n",
" native-country | \n",
" salary | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 49 | \n",
" Private | \n",
" 101320 | \n",
" Assoc-acdm | \n",
" 12.0 | \n",
" Married-civ-spouse | \n",
" NaN | \n",
" Wife | \n",
" White | \n",
" Female | \n",
" 0 | \n",
" 1902 | \n",
" 40 | \n",
" United-States | \n",
" >=50k | \n",
"
\n",
" \n",
" | 1 | \n",
" 44 | \n",
" Private | \n",
" 236746 | \n",
" Masters | \n",
" 14.0 | \n",
" Divorced | \n",
" Exec-managerial | \n",
" Not-in-family | \n",
" White | \n",
" Male | \n",
" 10520 | \n",
" 0 | \n",
" 45 | \n",
" United-States | \n",
" >=50k | \n",
"
\n",
" \n",
" | 2 | \n",
" 38 | \n",
" Private | \n",
" 96185 | \n",
" HS-grad | \n",
" NaN | \n",
" Divorced | \n",
" NaN | \n",
" Unmarried | \n",
" Black | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 32 | \n",
" United-States | \n",
" <50k | \n",
"
\n",
" \n",
" | 3 | \n",
" 38 | \n",
" Self-emp-inc | \n",
" 112847 | \n",
" Prof-school | \n",
" 15.0 | \n",
" Married-civ-spouse | \n",
" Prof-specialty | \n",
" Husband | \n",
" Asian-Pac-Islander | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" >=50k | \n",
"
\n",
" \n",
" | 4 | \n",
" 42 | \n",
" Self-emp-not-inc | \n",
" 82297 | \n",
" 7th-8th | \n",
" NaN | \n",
" Married-civ-spouse | \n",
" Other-service | \n",
" Wife | \n",
" Black | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 50 | \n",
" United-States | \n",
" <50k | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age workclass fnlwgt education education-num \\\n",
"0 49 Private 101320 Assoc-acdm 12.0 \n",
"1 44 Private 236746 Masters 14.0 \n",
"2 38 Private 96185 HS-grad NaN \n",
"3 38 Self-emp-inc 112847 Prof-school 15.0 \n",
"4 42 Self-emp-not-inc 82297 7th-8th NaN \n",
"\n",
" marital-status occupation relationship race \\\n",
"0 Married-civ-spouse NaN Wife White \n",
"1 Divorced Exec-managerial Not-in-family White \n",
"2 Divorced NaN Unmarried Black \n",
"3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n",
"4 Married-civ-spouse Other-service Wife Black \n",
"\n",
" sex capital-gain capital-loss hours-per-week native-country salary \n",
"0 Female 0 1902 40 United-States >=50k \n",
"1 Male 10520 0 45 United-States >=50k \n",
"2 Female 0 0 32 United-States <50k \n",
"3 Male 0 0 40 United-States >=50k \n",
"4 Female 0 0 50 United-States <50k "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = untar_data(URLs.ADULT_SAMPLE)\n",
"df = pd.read_csv(path/'adult.csv')\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
"cont_names = ['age', 'fnlwgt', 'education-num']\n",
"procs = [Categorify, FillMissing, Normalize]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names, \n",
" y_names=\"salary\", valid_idx=list(range(800,1000)), bs=64)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | \n",
" workclass | \n",
" education | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" education-num_na | \n",
" age | \n",
" fnlwgt | \n",
" education-num | \n",
" salary | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" Private | \n",
" HS-grad | \n",
" Divorced | \n",
" Sales | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" 40.0 | \n",
" 116632.001407 | \n",
" 9.0 | \n",
" >=50k | \n",
"
\n",
" \n",
" | 1 | \n",
" State-gov | \n",
" Some-college | \n",
" Never-married | \n",
" Protective-serv | \n",
" Own-child | \n",
" Black | \n",
" False | \n",
" 22.0 | \n",
" 293363.998886 | \n",
" 10.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 2 | \n",
" Private | \n",
" HS-grad | \n",
" Divorced | \n",
" Craft-repair | \n",
" Own-child | \n",
" White | \n",
" False | \n",
" 35.0 | \n",
" 126568.998886 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 3 | \n",
" Private | \n",
" Masters | \n",
" Divorced | \n",
" Exec-managerial | \n",
" Unmarried | \n",
" Black | \n",
" False | \n",
" 39.0 | \n",
" 150061.001071 | \n",
" 14.0 | \n",
" >=50k | \n",
"
\n",
" \n",
" | 4 | \n",
" Private | \n",
" Some-college | \n",
" Never-married | \n",
" Sales | \n",
" Own-child | \n",
" White | \n",
" False | \n",
" 21.0 | \n",
" 283756.998474 | \n",
" 10.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 5 | \n",
" Private | \n",
" Masters | \n",
" Married-civ-spouse | \n",
" Sales | \n",
" Husband | \n",
" White | \n",
" False | \n",
" 29.0 | \n",
" 134565.997603 | \n",
" 14.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 6 | \n",
" Self-emp-not-inc | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Farming-fishing | \n",
" Husband | \n",
" White | \n",
" False | \n",
" 39.0 | \n",
" 148442.999504 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 7 | \n",
" Private | \n",
" Some-college | \n",
" Married-civ-spouse | \n",
" Adm-clerical | \n",
" Husband | \n",
" White | \n",
" False | \n",
" 49.0 | \n",
" 280524.999991 | \n",
" 10.0 | \n",
" >=50k | \n",
"
\n",
" \n",
" | 8 | \n",
" Local-gov | \n",
" HS-grad | \n",
" Divorced | \n",
" Handlers-cleaners | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" 39.0 | \n",
" 166497.000063 | \n",
" 9.0 | \n",
" >=50k | \n",
"
\n",
" \n",
" | 9 | \n",
" ? | \n",
" 11th | \n",
" Never-married | \n",
" ? | \n",
" Own-child | \n",
" White | \n",
" False | \n",
" 17.0 | \n",
" 47407.001911 | \n",
" 7.0 | \n",
" <50k | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dls.show_batch()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> TabularDataLoaders.from_csv(**`csv`**, **`path`**=*`'.'`*, **`procs`**=*`None`*, **`cat_names`**=*`None`*, **`cont_names`**=*`None`*, **`y_names`**=*`None`*, **`y_block`**=*`None`*, **`valid_idx`**=*`None`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`n`**=*`None`*, **`device`**=*`None`*)\n",
"\n",
"Create from `csv` file in `path` using `procs`"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(TabularDataLoaders.from_csv)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
"cont_names = ['age', 'fnlwgt', 'education-num']\n",
"procs = [Categorify, FillMissing, Normalize]\n",
"dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, procs=procs, cat_names=cat_names, cont_names=cont_names, \n",
" y_names=\"salary\", valid_idx=list(range(800,1000)), bs=64)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Export -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converted 00_torch_core.ipynb.\n",
"Converted 01_layers.ipynb.\n",
"Converted 02_data.load.ipynb.\n",
"Converted 03_data.core.ipynb.\n",
"Converted 04_data.external.ipynb.\n",
"Converted 05_data.transforms.ipynb.\n",
"Converted 06_data.block.ipynb.\n",
"Converted 07_vision.core.ipynb.\n",
"Converted 08_vision.data.ipynb.\n",
"Converted 09_vision.augment.ipynb.\n",
"Converted 09b_vision.utils.ipynb.\n",
"Converted 09c_vision.widgets.ipynb.\n",
"Converted 10_tutorial.pets.ipynb.\n",
"Converted 11_vision.models.xresnet.ipynb.\n",
"Converted 12_optimizer.ipynb.\n",
"Converted 13_callback.core.ipynb.\n",
"Converted 13a_learner.ipynb.\n",
"Converted 13b_metrics.ipynb.\n",
"Converted 14_callback.schedule.ipynb.\n",
"Converted 14a_callback.data.ipynb.\n",
"Converted 15_callback.hook.ipynb.\n",
"Converted 15a_vision.models.unet.ipynb.\n",
"Converted 16_callback.progress.ipynb.\n",
"Converted 17_callback.tracker.ipynb.\n",
"Converted 18_callback.fp16.ipynb.\n",
"Converted 18a_callback.training.ipynb.\n",
"Converted 19_callback.mixup.ipynb.\n",
"Converted 20_interpret.ipynb.\n",
"Converted 20a_distributed.ipynb.\n",
"Converted 21_vision.learner.ipynb.\n",
"Converted 22_tutorial.imagenette.ipynb.\n",
"Converted 23_tutorial.vision.ipynb.\n",
"Converted 24_tutorial.siamese.ipynb.\n",
"Converted 24_vision.gan.ipynb.\n",
"Converted 30_text.core.ipynb.\n",
"Converted 31_text.data.ipynb.\n",
"Converted 32_text.models.awdlstm.ipynb.\n",
"Converted 33_text.models.core.ipynb.\n",
"Converted 34_callback.rnn.ipynb.\n",
"Converted 35_tutorial.wikitext.ipynb.\n",
"Converted 36_text.models.qrnn.ipynb.\n",
"Converted 37_text.learner.ipynb.\n",
"Converted 38_tutorial.text.ipynb.\n",
"Converted 39_tutorial.transformers.ipynb.\n",
"Converted 40_tabular.core.ipynb.\n",
"Converted 41_tabular.data.ipynb.\n",
"Converted 42_tabular.model.ipynb.\n",
"Converted 43_tabular.learner.ipynb.\n",
"Converted 44_tutorial.tabular.ipynb.\n",
"Converted 45_collab.ipynb.\n",
"Converted 46_tutorial.collab.ipynb.\n",
"Converted 50_tutorial.datablock.ipynb.\n",
"Converted 60_medical.imaging.ipynb.\n",
"Converted 61_tutorial.medical_imaging.ipynb.\n",
"Converted 65_medical.text.ipynb.\n",
"Converted 70_callback.wandb.ipynb.\n",
"Converted 71_callback.tensorboard.ipynb.\n",
"Converted 72_callback.neptune.ipynb.\n",
"Converted 73_callback.captum.ipynb.\n",
"Converted 74_callback.cutmix.ipynb.\n",
"Converted 97_test_utils.ipynb.\n",
"Converted 99_pytorch_doc.ipynb.\n",
"Converted index.ipynb.\n",
"Converted tutorial.ipynb.\n"
]
}
],
"source": [
"#hide\n",
"from nbdev.export import notebook2script\n",
"notebook2script()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}