{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#default_exp tabular.core"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from fastai2.torch_basics import *\n",
"from fastai2.test import *\n",
"from fastai2.core import *\n",
"from fastai2.data.all import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from nbdev.showdoc import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"pd.set_option('mode.chained_assignment','raise')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tabular core\n",
"\n",
"> Basic function to preprocess tabular data before assembling it in a `DataBunch`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tabular -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class _TabIloc:\n",
" \"Get/set rows by iloc and cols by name\"\n",
" def __init__(self,to): self.to = to\n",
" def __getitem__(self, idxs):\n",
" df = self.to.items\n",
" if isinstance(idxs,tuple):\n",
" rows,cols = idxs\n",
" cols = df.columns.isin(cols) if is_listy(cols) else df.columns.get_loc(cols)\n",
" else: rows,cols = idxs,slice(None)\n",
" return self.to.new(df.iloc[rows, cols])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class Tabular(CollBase, GetAttr, FilteredBase):\n",
" \"A `DataFrame` wrapper that knows which cols are cont/cat/y, and returns rows in `__getitem__`\"\n",
" _default='items'\n",
" def __init__(self, df, procs=None, cat_names=None, cont_names=None, y_names=None, block_y=CategoryBlock, splits=None, do_setup=True):\n",
" if splits is None: splits=[range_of(df)]\n",
" df = df.iloc[sum(splits, [])].copy()\n",
" super().__init__(df)\n",
" \n",
" self.y_names = L(y_names)\n",
" if block_y is not None: \n",
" if callable(block_y): block_y = block_y()\n",
" procs = L(procs) + block_y.type_tfms\n",
" self.cat_names,self.cont_names,self.procs = L(cat_names),L(cont_names),Pipeline(procs, as_item=True)\n",
" self.split = len(splits[0])\n",
" if do_setup: self.setup()\n",
"\n",
" def subset(self, i): return self.new(self.items[slice(0,self.split) if i==0 else slice(self.split,len(self))])\n",
" def copy(self): self.items = self.items.copy(); return self\n",
" def new(self, df): return type(self)(df, do_setup=False, block_y=None, **attrdict(self, 'procs','cat_names','cont_names','y_names'))\n",
" def show(self, max_n=10, **kwargs): display_df(self.all_cols[:max_n])\n",
" def setup(self): self.procs.setup(self)\n",
" def process(self): self.procs(self)\n",
" def iloc(self): return _TabIloc(self)\n",
" def targ(self): return self.items[self.y_names]\n",
" def all_col_names (self): return self.cat_names + self.cont_names + self.y_names\n",
" def n_subsets(self): return 2\n",
"\n",
"properties(Tabular,'iloc','targ','all_col_names','n_subsets')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class TabularPandas(Tabular):\n",
" def transform(self, cols, f): self[cols] = self[cols].transform(f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"def _add_prop(cls, nm):\n",
" @property\n",
" def f(o): return o[list(getattr(o,nm+'_names'))]\n",
" @f.setter\n",
" def fset(o, v): o[getattr(o,nm+'_names')] = v\n",
" setattr(cls, nm+'s', f)\n",
" setattr(cls, nm+'s', fset)\n",
"\n",
"_add_prop(Tabular, 'cat')\n",
"_add_prop(Tabular, 'cont')\n",
"_add_prop(Tabular, 'y')\n",
"_add_prop(Tabular, 'all_col')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
" \n",
" | \n",
" a | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1 | \n",
"
\n",
" \n",
" | 2 | \n",
" 2 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 2 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df = pd.DataFrame({'a':[0,1,2,0,2], 'b':[0,0,0,0,1]})\n",
"to = TabularPandas(df, cat_names='a')\n",
"t = pickle.loads(pickle.dumps(to))\n",
"test_eq(t.items,to.items)\n",
"test_eq(to.all_cols,to[['a']])\n",
"to.show() # only shows 'a' since that's the only col in `TabularPandas`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class TabularProc(InplaceTransform):\n",
" \"Base class to write a non-lazy tabular processor for dataframes\"\n",
" def setup(self, items=None):\n",
" super().setup(getattr(items,'train',items))\n",
" # Procs are called as soon as data is available\n",
" return self(items.items if isinstance(items,DataSource) else items)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"def _apply_cats (voc, add, c): return c.cat.codes+add if is_categorical_dtype(c) else c.map(voc[c.name].o2i)\n",
"def _decode_cats(voc, c): return c.map(dict(enumerate(voc[c.name].items)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class Categorify(TabularProc):\n",
" \"Transform the categorical variables to that type.\"\n",
" order = 1\n",
" def setups(self, to):\n",
" self.classes = {n:CategoryMap(to.iloc[:,n].items, add_na=(n in to.cat_names)) for n in to.cat_names}\n",
" def encodes(self, to): to.transform(to.cat_names, partial(_apply_cats, self.classes, 1))\n",
" def decodes(self, to): to.transform(to.cat_names, partial(_decode_cats, self.classes))\n",
" def __getitem__(self,k): return self.classes[k]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"@Categorize\n",
"def setups(self, to:Tabular): \n",
" if len(to.y_names) > 0: self.vocab = CategoryMap(to.iloc[:,to.y_names[0]].items)\n",
" return self(to)\n",
"\n",
"@Categorize\n",
"def encodes(self, to:Tabular): \n",
" to.transform(to.y_names, partial(_apply_cats, {n: self.vocab for n in to.y_names}, 0))\n",
" return to\n",
" \n",
"@Categorize\n",
"def decodes(self, to:Tabular): \n",
" to.transform(to.y_names, partial(_decode_cats, {n: self.vocab for n in to.y_names}))\n",
" return to"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> Categorify(**`enc`**=*`None`*, **`dec`**=*`None`*, **`split_idx`**=*`None`*, **`as_item`**=*`False`*, **`order`**=*`None`*) :: [`TabularProc`](/tabular.core.html#TabularProc)\n",
"\n",
"Transform the categorical variables to that type."
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(Categorify, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame({'a':[0,1,2,0,2]})\n",
"to = TabularPandas(df, Categorify, 'a')\n",
"cat = to.procs.categorify\n",
"test_eq(cat['a'], ['#na#',0,1,2])\n",
"test_eq(to.a, [1,2,3,1,3])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df1 = pd.DataFrame({'a':[1,0,3,-1,2]})\n",
"to1 = to.new(df1)\n",
"to1.process()\n",
"#Values that weren't in the training df are sent to 0 (na)\n",
"test_eq(to1.a, [2,1,0,0,3])\n",
"to2 = cat.decode(to1)\n",
"test_eq(to2.a, [1,0,'#na#','#na#',2])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#test with splits\n",
"cat = Categorify()\n",
"df = pd.DataFrame({'a':[0,1,2,3,2]})\n",
"to = TabularPandas(df, cat, 'a', splits=[[0,1,2],[3,4]])\n",
"test_eq(cat['a'], ['#na#',0,1,2])\n",
"test_eq(to['a'], [1,2,3,0,3])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame({'a':pd.Categorical(['M','H','L','M'], categories=['H','M','L'], ordered=True)})\n",
"to = TabularPandas(df, Categorify, 'a')\n",
"cat = to.procs.categorify\n",
"test_eq(cat['a'], ['#na#','H','M','L'])\n",
"test_eq(to.a, [2,1,3,2])\n",
"to2 = cat.decode(to)\n",
"test_eq(to2.a, ['M','H','L','M'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#test with targets\n",
"cat = Categorify()\n",
"df = pd.DataFrame({'a':[0,1,2,3,2], 'b': ['a', 'b', 'a', 'b', 'b']})\n",
"to = TabularPandas(df, cat, 'a', splits=[[0,1,2],[3,4]], y_names='b')\n",
"test_eq(to.procs.vocab, ['a', 'b'])\n",
"test_eq(to.b, [0,1,0,1,1])\n",
"to2 = to.procs.decode(to)\n",
"test_eq(to2.b, ['a', 'b', 'a', 'b', 'b'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class Normalize(TabularProc):\n",
" \"Normalize the continuous variables.\"\n",
" order = 2\n",
" def setups(self, dsrc): self.means,self.stds = dsrc.conts.mean(),dsrc.conts.std(ddof=0)+1e-7\n",
" def encodes(self, to): to.conts = (to.conts-self.means) / self.stds\n",
" def decodes(self, to): to.conts = (to.conts*self.stds ) + self.means"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> Normalize(**`enc`**=*`None`*, **`dec`**=*`None`*, **`split_idx`**=*`None`*, **`as_item`**=*`False`*, **`order`**=*`None`*) :: [`TabularProc`](/tabular.core.html#TabularProc)\n",
"\n",
"Normalize the continuous variables."
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(Normalize, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"norm = Normalize()\n",
"df = pd.DataFrame({'a':[0,1,2,3,4]})\n",
"to = TabularPandas(df, norm, cont_names='a')\n",
"x = np.array([0,1,2,3,4])\n",
"m,s = x.mean(),x.std()\n",
"test_eq(norm.means['a'], m)\n",
"test_close(norm.stds['a'], s)\n",
"test_close(to.a.values, (x-m)/s)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df1 = pd.DataFrame({'a':[5,6,7]})\n",
"to1 = to.new(df1)\n",
"to1.process()\n",
"test_close(to1['a'].values, (np.array([5,6,7])-m)/s)\n",
"to2 = norm.decode(to1)\n",
"test_close(to2.a.values, [5,6,7])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"norm = Normalize()\n",
"df = pd.DataFrame({'a':[0,1,2,3,4]})\n",
"to = TabularPandas(df, norm, cont_names='a', splits=[[0,1,2],[3,4]])\n",
"x = np.array([0,1,2])\n",
"m,s = x.mean(),x.std()\n",
"test_eq(norm.means['a'], m)\n",
"test_close(norm.stds['a'], s)\n",
"test_close(to['a'].values, (np.array([0,1,2,3,4])-m)/s)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class FillStrategy:\n",
" \"Namespace containing the various filling strategies.\"\n",
" def median (c,fill): return c.median()\n",
" def constant(c,fill): return fill\n",
" def mode (c,fill): return c.dropna().value_counts().idxmax()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class FillMissing(TabularProc):\n",
" \"Fill the missing values in continuous columns.\"\n",
" def __init__(self, fill_strategy=FillStrategy.median, add_col=True, fill_vals=None):\n",
" if fill_vals is None: fill_vals = defaultdict(int)\n",
" store_attr(self, 'fill_strategy,add_col,fill_vals')\n",
"\n",
" def setups(self, dsrc):\n",
" self.na_dict = {n:self.fill_strategy(dsrc[n], self.fill_vals[n])\n",
" for n in pd.isnull(dsrc.conts).any().keys()}\n",
"\n",
" def encodes(self, to):\n",
" missing = pd.isnull(to.conts)\n",
" for n in missing.any().keys():\n",
" assert n in self.na_dict, f\"nan values in `{n}` but not in setup training set\"\n",
" to[n].fillna(self.na_dict[n], inplace=True)\n",
" if self.add_col:\n",
" to.loc[:,n+'_na'] = missing[n]\n",
" if n+'_na' not in to.cat_names: to.cat_names.append(n+'_na')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> FillMissing(**`fill_strategy`**=*`'median'`*, **`add_col`**=*`True`*, **`fill_vals`**=*`None`*) :: [`TabularProc`](/tabular.core.html#TabularProc)\n",
"\n",
"Fill the missing values in continuous columns."
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(FillMissing, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fill1,fill2,fill3 = (FillMissing(fill_strategy=s) \n",
" for s in [FillStrategy.median, FillStrategy.constant, FillStrategy.mode])\n",
"df = pd.DataFrame({'a':[0,1,np.nan,1,2,3,4]})\n",
"df1 = df.copy(); df2 = df.copy()\n",
"tos = TabularPandas(df, fill1, cont_names='a'),TabularPandas(df1, fill2, cont_names='a'),TabularPandas(df2, fill3, cont_names='a')\n",
"test_eq(fill1.na_dict, {'a': 1.5})\n",
"test_eq(fill2.na_dict, {'a': 0})\n",
"test_eq(fill3.na_dict, {'a': 1.0})\n",
"\n",
"for t in tos: test_eq(t.cat_names, ['a_na'])\n",
"\n",
"for to_,v in zip(tos, [1.5, 0., 1.]):\n",
" test_eq(to_.a.values, np.array([0, 1, v, 1, 2, 3, 4]))\n",
" test_eq(to_.a_na.values, np.array([0, 0, 1, 0, 0, 0, 0]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dfa = pd.DataFrame({'a':[np.nan,0,np.nan]})\n",
"tos = [t.new(o) for t,o in zip(tos,(dfa,dfa.copy(),dfa.copy()))]\n",
"for t in tos: t.process()\n",
"for to_,v in zip(tos, [1.5, 0., 1.]):\n",
" test_eq(to_.a.values, np.array([v, 0, v]))\n",
" test_eq(to_.a_na.values, np.array([1, 0, 1]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## TabularPandas Pipelines -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"procs = [Normalize, Categorify, FillMissing, noop]\n",
"df = pd.DataFrame({'a':[0,1,2,1,1,2,0], 'b':[0,1,np.nan,1,2,3,4]})\n",
"to = TabularPandas(df, procs, cat_names='a', cont_names='b')\n",
"\n",
"#Test setup and apply on df_main\n",
"test_eq(to.cat_names, ['a', 'b_na'])\n",
"test_eq(to.a, [1,2,3,2,2,3,1])\n",
"test_eq(to.b_na, [1,1,2,1,1,1,1])\n",
"x = np.array([0,1,1.5,1,2,3,4])\n",
"m,s = x.mean(),x.std()\n",
"test_close(to.b.values, (x-m)/s)\n",
"test_eq(to.procs.classes, {'a': ['#na#',0,1,2], 'b_na': ['#na#',False,True]})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Test apply on y_names\n",
"df = pd.DataFrame({'a':[0,1,2,1,1,2,0], 'b':[0,1,np.nan,1,2,3,4], 'c': ['b','a','b','a','a','b','a']})\n",
"to = TabularPandas(df, procs, 'a', 'b', y_names='c')\n",
"\n",
"test_eq(to.cat_names, ['a', 'b_na'])\n",
"test_eq(to.a, [1,2,3,2,2,3,1])\n",
"test_eq(to.b_na, [1,1,2,1,1,1,1])\n",
"test_eq(to.c, [1,0,1,0,0,1,0])\n",
"x = np.array([0,1,1.5,1,2,3,4])\n",
"m,s = x.mean(),x.std()\n",
"test_close(to.b.values, (x-m)/s)\n",
"test_eq(to.procs.classes, {'a': ['#na#',0,1,2], 'b_na': ['#na#',False,True]})\n",
"test_eq(to.procs.vocab, ['a','b'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame({'a':[0,1,2,1,1,2,0], 'b':[0,1,np.nan,1,2,3,4], 'c': ['b','a','b','a','a','b','a']})\n",
"to = TabularPandas(df, procs, 'a', 'b', y_names='c')\n",
"\n",
"test_eq(to.cat_names, ['a', 'b_na'])\n",
"test_eq(to.a, [1,2,3,2,2,3,1])\n",
"test_eq(df.a.dtype,int)\n",
"test_eq(to.b_na, [1,1,2,1,1,1,1])\n",
"test_eq(to.c, [1,0,1,0,0,1,0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame({'a':[0,1,2,1,1,2,0], 'b':[0,np.nan,1,1,2,3,4], 'c': ['b','a','b','a','a','b','a']})\n",
"to = TabularPandas(df, procs, cat_names='a', cont_names='b', y_names='c', splits=[[0,1,4,6], [2,3,5]])\n",
"\n",
"test_eq(to.cat_names, ['a', 'b_na'])\n",
"test_eq(to.a, [1,2,2,1,0,2,0])\n",
"test_eq(df.a.dtype,int)\n",
"test_eq(to.b_na, [1,2,1,1,1,1,1])\n",
"test_eq(to.c, [1,0,0,0,1,0,1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class ReadTabBatch(ItemTransform):\n",
" order = -1 #run before cuda\n",
" def __init__(self, to): self.to = to\n",
" # TODO: use float for cont targ\n",
" def encodes(self, to): return tensor(to.cats).long(),tensor(to.conts).float(), tensor(to.targ)\n",
"\n",
" def decodes(self, o):\n",
" cats,conts,targs = to_np(o)\n",
" vals = np.concatenate([cats,conts,targs], axis=1)\n",
" df = pd.DataFrame(vals, columns=self.to.all_col_names)\n",
" to = self.to.new(df)\n",
" to = self.to.procs.decode(to)\n",
" return to"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"@typedispatch\n",
"def show_batch(x: Tabular, y, its, max_n=10, ctxs=None):\n",
" x.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"@delegates()\n",
"class TabDataLoader(TfmdDL):\n",
" do_item = noops\n",
" def __init__(self, dataset, bs=16, shuffle=False, after_batch=None, num_workers=0, **kwargs):\n",
" after_batch = L(after_batch)+ReadTabBatch(dataset)\n",
" super().__init__(dataset, bs=bs, shuffle=shuffle, after_batch=after_batch, num_workers=num_workers, **kwargs)\n",
"\n",
" def create_batch(self, b): return self.dataset.iloc[b]\n",
"\n",
"TabularPandas._dl_type = TabDataLoader"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Integration example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" workclass | \n",
" fnlwgt | \n",
" education | \n",
" education-num | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" sex | \n",
" capital-gain | \n",
" capital-loss | \n",
" hours-per-week | \n",
" native-country | \n",
" salary | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 49 | \n",
" Private | \n",
" 101320 | \n",
" Assoc-acdm | \n",
" 12.0 | \n",
" Married-civ-spouse | \n",
" NaN | \n",
" Wife | \n",
" White | \n",
" Female | \n",
" 0 | \n",
" 1902 | \n",
" 40 | \n",
" United-States | \n",
" >=50k | \n",
"
\n",
" \n",
" | 1 | \n",
" 44 | \n",
" Private | \n",
" 236746 | \n",
" Masters | \n",
" 14.0 | \n",
" Divorced | \n",
" Exec-managerial | \n",
" Not-in-family | \n",
" White | \n",
" Male | \n",
" 10520 | \n",
" 0 | \n",
" 45 | \n",
" United-States | \n",
" >=50k | \n",
"
\n",
" \n",
" | 2 | \n",
" 38 | \n",
" Private | \n",
" 96185 | \n",
" HS-grad | \n",
" NaN | \n",
" Divorced | \n",
" NaN | \n",
" Unmarried | \n",
" Black | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 32 | \n",
" United-States | \n",
" <50k | \n",
"
\n",
" \n",
" | 3 | \n",
" 38 | \n",
" Self-emp-inc | \n",
" 112847 | \n",
" Prof-school | \n",
" 15.0 | \n",
" Married-civ-spouse | \n",
" Prof-specialty | \n",
" Husband | \n",
" Asian-Pac-Islander | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" >=50k | \n",
"
\n",
" \n",
" | 4 | \n",
" 42 | \n",
" Self-emp-not-inc | \n",
" 82297 | \n",
" 7th-8th | \n",
" NaN | \n",
" Married-civ-spouse | \n",
" Other-service | \n",
" Wife | \n",
" Black | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 50 | \n",
" United-States | \n",
" <50k | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age workclass fnlwgt education education-num \\\n",
"0 49 Private 101320 Assoc-acdm 12.0 \n",
"1 44 Private 236746 Masters 14.0 \n",
"2 38 Private 96185 HS-grad NaN \n",
"3 38 Self-emp-inc 112847 Prof-school 15.0 \n",
"4 42 Self-emp-not-inc 82297 7th-8th NaN \n",
"\n",
" marital-status occupation relationship race \\\n",
"0 Married-civ-spouse NaN Wife White \n",
"1 Divorced Exec-managerial Not-in-family White \n",
"2 Divorced NaN Unmarried Black \n",
"3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n",
"4 Married-civ-spouse Other-service Wife Black \n",
"\n",
" sex capital-gain capital-loss hours-per-week native-country salary \n",
"0 Female 0 1902 40 United-States >=50k \n",
"1 Male 10520 0 45 United-States >=50k \n",
"2 Female 0 0 32 United-States <50k \n",
"3 Male 0 0 40 United-States >=50k \n",
"4 Female 0 0 50 United-States <50k "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = untar_data(URLs.ADULT_SAMPLE)\n",
"df = pd.read_csv(path/'adult.csv')\n",
"df_main,df_test = df.iloc[:10000].copy(),df.iloc[10000:].copy()\n",
"df_main.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
"cont_names = ['age', 'fnlwgt', 'education-num']\n",
"procs = [Categorify, FillMissing, Normalize]\n",
"splits = RandomSplitter()(range_of(df_main))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 192 ms, sys: 276 µs, total: 192 ms\n",
"Wall time: 191 ms\n"
]
}
],
"source": [
"%time to = TabularPandas(df_main, procs, cat_names, cont_names, y_names=\"salary\", splits=splits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | \n",
" workclass | \n",
" education | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" age_na | \n",
" fnlwgt_na | \n",
" education-num_na | \n",
" age | \n",
" fnlwgt | \n",
" education-num | \n",
" salary | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" State-gov | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Craft-repair | \n",
" Husband | \n",
" Black | \n",
" False | \n",
" False | \n",
" False | \n",
" 50.0 | \n",
" 229271.999663 | \n",
" 9.0 | \n",
" >=50k | \n",
"
\n",
" \n",
" | 1 | \n",
" Private | \n",
" Assoc-voc | \n",
" Never-married | \n",
" Craft-repair | \n",
" Own-child | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 30.0 | \n",
" 160633.999723 | \n",
" 11.0 | \n",
" >=50k | \n",
"
\n",
" \n",
" | 2 | \n",
" Private | \n",
" HS-grad | \n",
" Never-married | \n",
" Handlers-cleaners | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 32.0 | \n",
" 164507.001425 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 3 | \n",
" Private | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Sales | \n",
" Husband | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 48.0 | \n",
" 320421.005237 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 4 | \n",
" Private | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Adm-clerical | \n",
" Wife | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 46.0 | \n",
" 243189.999445 | \n",
" 9.0 | \n",
" >=50k | \n",
"
\n",
" \n",
" | 5 | \n",
" Private | \n",
" HS-grad | \n",
" Divorced | \n",
" Sales | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 31.0 | \n",
" 217802.999944 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 6 | \n",
" Private | \n",
" HS-grad | \n",
" Divorced | \n",
" Sales | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 34.0 | \n",
" 245172.999308 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 7 | \n",
" Private | \n",
" HS-grad | \n",
" Never-married | \n",
" Other-service | \n",
" Unmarried | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 52.0 | \n",
" 195638.000066 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 8 | \n",
" Private | \n",
" Masters | \n",
" Never-married | \n",
" Prof-specialty | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 28.0 | \n",
" 274679.000327 | \n",
" 14.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 9 | \n",
" Private | \n",
" Some-college | \n",
" Never-married | \n",
" Sales | \n",
" Unmarried | \n",
" Black | \n",
" False | \n",
" False | \n",
" False | \n",
" 38.0 | \n",
" 363394.997929 | \n",
" 10.0 | \n",
" <50k | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dbch = to.databunch()\n",
"dbch.valid_dl.show_batch()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" workclass | \n",
" education | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" age_na | \n",
" fnlwgt_na | \n",
" education-num_na | \n",
" age | \n",
" fnlwgt | \n",
" education-num | \n",
" salary | \n",
"
\n",
" \n",
" \n",
" \n",
" | 10000 | \n",
" 5 | \n",
" 10 | \n",
" 3 | \n",
" 2 | \n",
" 1 | \n",
" 2 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" 0.459514 | \n",
" 1.345251 | \n",
" 1.183611 | \n",
" 0 | \n",
"
\n",
" \n",
" | 10001 | \n",
" 5 | \n",
" 12 | \n",
" 3 | \n",
" 15 | \n",
" 1 | \n",
" 4 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" -0.935650 | \n",
" 1.257914 | \n",
" -0.427498 | \n",
" 0 | \n",
"
\n",
" \n",
" | 10002 | \n",
" 5 | \n",
" 2 | \n",
" 1 | \n",
" 9 | \n",
" 2 | \n",
" 5 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" 1.046951 | \n",
" 0.151258 | \n",
" -1.233052 | \n",
" 0 | \n",
"
\n",
" \n",
" | 10003 | \n",
" 5 | \n",
" 12 | \n",
" 7 | \n",
" 2 | \n",
" 5 | \n",
" 5 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" 0.532943 | \n",
" -0.283410 | \n",
" -0.427498 | \n",
" 0 | \n",
"
\n",
" \n",
" | 10004 | \n",
" 6 | \n",
" 9 | \n",
" 3 | \n",
" 5 | \n",
" 1 | \n",
" 5 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" 0.753232 | \n",
" 1.448155 | \n",
" 0.378057 | \n",
" 1 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" workclass education marital-status occupation relationship race \\\n",
"10000 5 10 3 2 1 2 \n",
"10001 5 12 3 15 1 4 \n",
"10002 5 2 1 9 2 5 \n",
"10003 5 12 7 2 5 5 \n",
"10004 6 9 3 5 1 5 \n",
"\n",
" age_na fnlwgt_na education-num_na age fnlwgt education-num \\\n",
"10000 1 1 1 0.459514 1.345251 1.183611 \n",
"10001 1 1 1 -0.935650 1.257914 -0.427498 \n",
"10002 1 1 1 1.046951 0.151258 -1.233052 \n",
"10003 1 1 1 0.532943 -0.283410 -0.427498 \n",
"10004 1 1 1 0.753232 1.448155 0.378057 \n",
"\n",
" salary \n",
"10000 0 \n",
"10001 0 \n",
"10002 0 \n",
"10003 0 \n",
"10004 1 "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"to_tst = to.new(df_test)\n",
"to_tst.process()\n",
"to_tst.all_cols.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Other target types"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multi-label categories"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### one-hot encoded label"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _mock_multi_label(df):\n",
" sal,sex,white = [],[],[]\n",
" for row in df.itertuples():\n",
" sal.append(row.salary == '>=50k')\n",
" sex.append(row.sex == ' Male')\n",
" white.append(row.race == ' White')\n",
" df['salary'] = np.array(sal)\n",
" df['male'] = np.array(sex)\n",
" df['white'] = np.array(white)\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.ADULT_SAMPLE)\n",
"df = pd.read_csv(path/'adult.csv')\n",
"df_main,df_test = df.iloc[:10000].copy(),df.iloc[10000:].copy()\n",
"df_main = _mock_multi_label(df_main)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" workclass | \n",
" fnlwgt | \n",
" education | \n",
" education-num | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" sex | \n",
" capital-gain | \n",
" capital-loss | \n",
" hours-per-week | \n",
" native-country | \n",
" salary | \n",
" male | \n",
" white | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 49 | \n",
" Private | \n",
" 101320 | \n",
" Assoc-acdm | \n",
" 12.0 | \n",
" Married-civ-spouse | \n",
" NaN | \n",
" Wife | \n",
" White | \n",
" Female | \n",
" 0 | \n",
" 1902 | \n",
" 40 | \n",
" United-States | \n",
" True | \n",
" False | \n",
" True | \n",
"
\n",
" \n",
" | 1 | \n",
" 44 | \n",
" Private | \n",
" 236746 | \n",
" Masters | \n",
" 14.0 | \n",
" Divorced | \n",
" Exec-managerial | \n",
" Not-in-family | \n",
" White | \n",
" Male | \n",
" 10520 | \n",
" 0 | \n",
" 45 | \n",
" United-States | \n",
" True | \n",
" True | \n",
" True | \n",
"
\n",
" \n",
" | 2 | \n",
" 38 | \n",
" Private | \n",
" 96185 | \n",
" HS-grad | \n",
" NaN | \n",
" Divorced | \n",
" NaN | \n",
" Unmarried | \n",
" Black | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 32 | \n",
" United-States | \n",
" False | \n",
" False | \n",
" False | \n",
"
\n",
" \n",
" | 3 | \n",
" 38 | \n",
" Self-emp-inc | \n",
" 112847 | \n",
" Prof-school | \n",
" 15.0 | \n",
" Married-civ-spouse | \n",
" Prof-specialty | \n",
" Husband | \n",
" Asian-Pac-Islander | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" True | \n",
" True | \n",
" False | \n",
"
\n",
" \n",
" | 4 | \n",
" 42 | \n",
" Self-emp-not-inc | \n",
" 82297 | \n",
" 7th-8th | \n",
" NaN | \n",
" Married-civ-spouse | \n",
" Other-service | \n",
" Wife | \n",
" Black | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 50 | \n",
" United-States | \n",
" False | \n",
" False | \n",
" False | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age workclass fnlwgt education education-num \\\n",
"0 49 Private 101320 Assoc-acdm 12.0 \n",
"1 44 Private 236746 Masters 14.0 \n",
"2 38 Private 96185 HS-grad NaN \n",
"3 38 Self-emp-inc 112847 Prof-school 15.0 \n",
"4 42 Self-emp-not-inc 82297 7th-8th NaN \n",
"\n",
" marital-status occupation relationship race \\\n",
"0 Married-civ-spouse NaN Wife White \n",
"1 Divorced Exec-managerial Not-in-family White \n",
"2 Divorced NaN Unmarried Black \n",
"3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n",
"4 Married-civ-spouse Other-service Wife Black \n",
"\n",
" sex capital-gain capital-loss hours-per-week native-country \\\n",
"0 Female 0 1902 40 United-States \n",
"1 Male 10520 0 45 United-States \n",
"2 Female 0 0 32 United-States \n",
"3 Male 0 0 40 United-States \n",
"4 Female 0 0 50 United-States \n",
"\n",
" salary male white \n",
"0 True False True \n",
"1 True True True \n",
"2 False False False \n",
"3 True True False \n",
"4 False False False "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_main.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"@EncodedMultiCategorize\n",
"def encodes(self, to:Tabular): return to\n",
" \n",
"@EncodedMultiCategorize\n",
"def decodes(self, to:Tabular): \n",
" to.transform(to.y_names, lambda c: c==1)\n",
" return to"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
"cont_names = ['age', 'fnlwgt', 'education-num']\n",
"procs = [Categorify, FillMissing, Normalize]\n",
"splits = RandomSplitter()(range_of(df_main))\n",
"y_names=[\"salary\", \"male\", \"white\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 162 ms, sys: 0 ns, total: 162 ms\n",
"Wall time: 160 ms\n"
]
}
],
"source": [
"%time to = TabularPandas(df_main, procs, cat_names, cont_names, y_names=y_names, block_y=MultiCategoryBlock(encoded=True, vocab=y_names), splits=splits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | \n",
" workclass | \n",
" education | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" age_na | \n",
" fnlwgt_na | \n",
" education-num_na | \n",
" age | \n",
" fnlwgt | \n",
" education-num | \n",
" salary | \n",
" male | \n",
" white | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" Private | \n",
" Some-college | \n",
" Never-married | \n",
" #na# | \n",
" Own-child | \n",
" White | \n",
" False | \n",
" False | \n",
" True | \n",
" 19.000000 | \n",
" 226912.999574 | \n",
" 10.0 | \n",
" False | \n",
" False | \n",
" True | \n",
"
\n",
" \n",
" | 1 | \n",
" ? | \n",
" Some-college | \n",
" Divorced | \n",
" ? | \n",
" Unmarried | \n",
" White | \n",
" False | \n",
" False | \n",
" True | \n",
" 51.000000 | \n",
" 76437.000728 | \n",
" 10.0 | \n",
" False | \n",
" False | \n",
" True | \n",
"
\n",
" \n",
" | 2 | \n",
" Private | \n",
" Some-college | \n",
" Never-married | \n",
" Handlers-cleaners | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 24.000000 | \n",
" 165064.000614 | \n",
" 10.0 | \n",
" False | \n",
" True | \n",
" True | \n",
"
\n",
" \n",
" | 3 | \n",
" Private | \n",
" HS-grad | \n",
" Never-married | \n",
" Machine-op-inspct | \n",
" Unmarried | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 26.000000 | \n",
" 211435.000638 | \n",
" 9.0 | \n",
" False | \n",
" False | \n",
" True | \n",
"
\n",
" \n",
" | 4 | \n",
" Federal-gov | \n",
" 11th | \n",
" Never-married | \n",
" Other-service | \n",
" Own-child | \n",
" Asian-Pac-Islander | \n",
" False | \n",
" False | \n",
" False | \n",
" 17.999999 | \n",
" 101709.002120 | \n",
" 7.0 | \n",
" False | \n",
" True | \n",
" False | \n",
"
\n",
" \n",
" | 5 | \n",
" Federal-gov | \n",
" Some-college | \n",
" Married-civ-spouse | \n",
" Adm-clerical | \n",
" Husband | \n",
" Black | \n",
" False | \n",
" False | \n",
" False | \n",
" 39.000000 | \n",
" 314822.002568 | \n",
" 10.0 | \n",
" False | \n",
" True | \n",
" False | \n",
"
\n",
" \n",
" | 6 | \n",
" Private | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Machine-op-inspct | \n",
" Husband | \n",
" Black | \n",
" False | \n",
" False | \n",
" False | \n",
" 45.000000 | \n",
" 256649.002323 | \n",
" 9.0 | \n",
" False | \n",
" True | \n",
" False | \n",
"
\n",
" \n",
" | 7 | \n",
" Private | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Exec-managerial | \n",
" Husband | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 50.000000 | \n",
" 161437.998631 | \n",
" 9.0 | \n",
" True | \n",
" True | \n",
" True | \n",
"
\n",
" \n",
" | 8 | \n",
" Private | \n",
" Assoc-acdm | \n",
" Never-married | \n",
" #na# | \n",
" Unmarried | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 26.000000 | \n",
" 159602.998851 | \n",
" 12.0 | \n",
" False | \n",
" False | \n",
" True | \n",
"
\n",
" \n",
" | 9 | \n",
" Private | \n",
" Bachelors | \n",
" Never-married | \n",
" Prof-specialty | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 27.000000 | \n",
" 660870.007785 | \n",
" 13.0 | \n",
" False | \n",
" False | \n",
" True | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dbch = to.databunch()\n",
"dbch.valid_dl.show_batch()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Not one-hot encoded"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _mock_multi_label(df):\n",
" targ = []\n",
" for row in df.itertuples():\n",
" labels = []\n",
" if row.salary == '>=50k': labels.append('>50k')\n",
" if row.sex == ' Male': labels.append('male')\n",
" if row.race == ' White': labels.append('white')\n",
" targ.append(' '.join(labels))\n",
" df['target'] = np.array(targ)\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.ADULT_SAMPLE)\n",
"df = pd.read_csv(path/'adult.csv')\n",
"df_main,df_test = df.iloc[:10000].copy(),df.iloc[10000:].copy()\n",
"df_main = _mock_multi_label(df_main)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" workclass | \n",
" fnlwgt | \n",
" education | \n",
" education-num | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" sex | \n",
" capital-gain | \n",
" capital-loss | \n",
" hours-per-week | \n",
" native-country | \n",
" salary | \n",
" target | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 49 | \n",
" Private | \n",
" 101320 | \n",
" Assoc-acdm | \n",
" 12.0 | \n",
" Married-civ-spouse | \n",
" NaN | \n",
" Wife | \n",
" White | \n",
" Female | \n",
" 0 | \n",
" 1902 | \n",
" 40 | \n",
" United-States | \n",
" >=50k | \n",
" >50k white | \n",
"
\n",
" \n",
" | 1 | \n",
" 44 | \n",
" Private | \n",
" 236746 | \n",
" Masters | \n",
" 14.0 | \n",
" Divorced | \n",
" Exec-managerial | \n",
" Not-in-family | \n",
" White | \n",
" Male | \n",
" 10520 | \n",
" 0 | \n",
" 45 | \n",
" United-States | \n",
" >=50k | \n",
" >50k male white | \n",
"
\n",
" \n",
" | 2 | \n",
" 38 | \n",
" Private | \n",
" 96185 | \n",
" HS-grad | \n",
" NaN | \n",
" Divorced | \n",
" NaN | \n",
" Unmarried | \n",
" Black | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 32 | \n",
" United-States | \n",
" <50k | \n",
" | \n",
"
\n",
" \n",
" | 3 | \n",
" 38 | \n",
" Self-emp-inc | \n",
" 112847 | \n",
" Prof-school | \n",
" 15.0 | \n",
" Married-civ-spouse | \n",
" Prof-specialty | \n",
" Husband | \n",
" Asian-Pac-Islander | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" >=50k | \n",
" >50k male | \n",
"
\n",
" \n",
" | 4 | \n",
" 42 | \n",
" Self-emp-not-inc | \n",
" 82297 | \n",
" 7th-8th | \n",
" NaN | \n",
" Married-civ-spouse | \n",
" Other-service | \n",
" Wife | \n",
" Black | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 50 | \n",
" United-States | \n",
" <50k | \n",
" | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age workclass fnlwgt education education-num \\\n",
"0 49 Private 101320 Assoc-acdm 12.0 \n",
"1 44 Private 236746 Masters 14.0 \n",
"2 38 Private 96185 HS-grad NaN \n",
"3 38 Self-emp-inc 112847 Prof-school 15.0 \n",
"4 42 Self-emp-not-inc 82297 7th-8th NaN \n",
"\n",
" marital-status occupation relationship race \\\n",
"0 Married-civ-spouse NaN Wife White \n",
"1 Divorced Exec-managerial Not-in-family White \n",
"2 Divorced NaN Unmarried Black \n",
"3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n",
"4 Married-civ-spouse Other-service Wife Black \n",
"\n",
" sex capital-gain capital-loss hours-per-week native-country salary \\\n",
"0 Female 0 1902 40 United-States >=50k \n",
"1 Male 10520 0 45 United-States >=50k \n",
"2 Female 0 0 32 United-States <50k \n",
"3 Male 0 0 40 United-States >=50k \n",
"4 Female 0 0 50 United-States <50k \n",
"\n",
" target \n",
"0 >50k white \n",
"1 >50k male white \n",
"2 \n",
"3 >50k male \n",
"4 "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_main.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@MultiCategorize\n",
"def encodes(self, to:Tabular): \n",
" #to.transform(to.y_names, partial(_apply_cats, {n: self.vocab for n in to.y_names}, 0))\n",
" return to\n",
" \n",
"@MultiCategorize\n",
"def decodes(self, to:Tabular): \n",
" #to.transform(to.y_names, partial(_decode_cats, {n: self.vocab for n in to.y_names}))\n",
" return to"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
"cont_names = ['age', 'fnlwgt', 'education-num']\n",
"procs = [Categorify, FillMissing, Normalize]\n",
"splits = RandomSplitter()(range_of(df_main))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 214 ms, sys: 3.95 ms, total: 218 ms\n",
"Wall time: 217 ms\n"
]
}
],
"source": [
"%time to = TabularPandas(df_main, procs, cat_names, cont_names, y_names=\"target\", type_y=MultiCategory, splits=splits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#24) [-,_,a,c,d,e,f,g,h,i...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"to.procs[2].vocab"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regression"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.ADULT_SAMPLE)\n",
"df = pd.read_csv(path/'adult.csv')\n",
"df_main,df_test = df.iloc[:10000].copy(),df.iloc[10000:].copy()\n",
"df_main = _mock_multi_label(df_main)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
"cont_names = ['fnlwgt', 'education-num']\n",
"procs = [Categorify, FillMissing, Normalize]\n",
"splits = RandomSplitter()(range_of(df_main))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 148 ms, sys: 7.58 ms, total: 155 ms\n",
"Wall time: 154 ms\n"
]
}
],
"source": [
"%time to = TabularPandas(df_main, procs, cat_names, cont_names, y_names='age', type_y=Float, splits=splits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"fnlwgt 192859.69200\n",
"education-num 10.08125\n",
"dtype: float64"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"to.procs[-1].means"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | \n",
" workclass | \n",
" education | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" fnlwgt_na | \n",
" education-num_na | \n",
" fnlwgt | \n",
" education-num | \n",
" age | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" Private | \n",
" Some-college | \n",
" Married-civ-spouse | \n",
" Craft-repair | \n",
" Husband | \n",
" White | \n",
" False | \n",
" False | \n",
" 213840.999630 | \n",
" 10.0 | \n",
" 37.0 | \n",
"
\n",
" \n",
" | 1 | \n",
" Federal-gov | \n",
" Assoc-acdm | \n",
" Never-married | \n",
" Adm-clerical | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" False | \n",
" 235890.999234 | \n",
" 12.0 | \n",
" 45.0 | \n",
"
\n",
" \n",
" | 2 | \n",
" Private | \n",
" HS-grad | \n",
" Never-married | \n",
" Sales | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" False | \n",
" 234108.000241 | \n",
" 9.0 | \n",
" 21.0 | \n",
"
\n",
" \n",
" | 3 | \n",
" Self-emp-not-inc | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Prof-specialty | \n",
" Husband | \n",
" White | \n",
" False | \n",
" False | \n",
" 287037.000237 | \n",
" 9.0 | \n",
" 41.0 | \n",
"
\n",
" \n",
" | 4 | \n",
" ? | \n",
" HS-grad | \n",
" Separated | \n",
" ? | \n",
" Unmarried | \n",
" Black | \n",
" False | \n",
" False | \n",
" 427965.002375 | \n",
" 9.0 | \n",
" 29.0 | \n",
"
\n",
" \n",
" | 5 | \n",
" Private | \n",
" Some-college | \n",
" Married-civ-spouse | \n",
" Handlers-cleaners | \n",
" Husband | \n",
" White | \n",
" False | \n",
" False | \n",
" 368949.005576 | \n",
" 10.0 | \n",
" 29.0 | \n",
"
\n",
" \n",
" | 6 | \n",
" Private | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Craft-repair | \n",
" Husband | \n",
" White | \n",
" False | \n",
" False | \n",
" 85341.001781 | \n",
" 9.0 | \n",
" 48.0 | \n",
"
\n",
" \n",
" | 7 | \n",
" Private | \n",
" Bachelors | \n",
" Married-civ-spouse | \n",
" Exec-managerial | \n",
" Husband | \n",
" White | \n",
" False | \n",
" False | \n",
" 110243.002830 | \n",
" 13.0 | \n",
" 49.0 | \n",
"
\n",
" \n",
" | 8 | \n",
" ? | \n",
" 10th | \n",
" Divorced | \n",
" ? | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" False | \n",
" 124647.999756 | \n",
" 6.0 | \n",
" 61.0 | \n",
"
\n",
" \n",
" | 9 | \n",
" ? | \n",
" Masters | \n",
" Never-married | \n",
" ? | \n",
" Unmarried | \n",
" Asian-Pac-Islander | \n",
" False | \n",
" False | \n",
" 173799.999432 | \n",
" 14.0 | \n",
" 27.0 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dbch = to.databunch()\n",
"dbch.valid_dl.show_batch()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Not being used now - for multi-modal"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class TensorTabular(Tuple):\n",
" def get_ctxs(self, max_n=10, **kwargs):\n",
" n_samples = min(self[0].shape[0], max_n)\n",
" df = pd.DataFrame(index = range(n_samples))\n",
" return [df.iloc[i] for i in range(n_samples)]\n",
"\n",
" def display(self, ctxs): display_df(pd.DataFrame(ctxs))\n",
"\n",
"class TabularLine(pd.Series):\n",
" \"A line of a dataframe that knows how to show itself\"\n",
" def show(self, ctx=None, **kwargs): return self if ctx is None else ctx.append(self)\n",
"\n",
"class ReadTabLine(ItemTransform):\n",
" def __init__(self, proc): self.proc = proc\n",
"\n",
" def encodes(self, row):\n",
" cats,conts = (o.map(row.__getitem__) for o in (self.proc.cat_names,self.proc.cont_names))\n",
" return TensorTabular(tensor(cats).long(),tensor(conts).float())\n",
"\n",
" def decodes(self, o):\n",
" to = TabularPandas(o, self.proc.cat_names, self.proc.cont_names, self.proc.y_names)\n",
" to = self.proc.decode(to)\n",
" return TabularLine(pd.Series({c: v for v,c in zip(to.items[0]+to.items[1], self.proc.cat_names+self.proc.cont_names)}))\n",
"\n",
"class ReadTabTarget(ItemTransform):\n",
" def __init__(self, proc): self.proc = proc\n",
" def encodes(self, row): return row[self.proc.y_names].astype(np.int64)\n",
" def decodes(self, o): return Category(self.proc.classes[self.proc.y_names][o])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# tds = TfmdDS(to.items, tfms=[[ReadTabLine(proc)], ReadTabTarget(proc)])\n",
"# enc = tds[1]\n",
"# test_eq(enc[0][0], tensor([2,1]))\n",
"# test_close(enc[0][1], tensor([-0.628828]))\n",
"# test_eq(enc[1], 1)\n",
"\n",
"# dec = tds.decode(enc)\n",
"# assert isinstance(dec[0], TabularLine)\n",
"# test_close(dec[0], pd.Series({'a': 1, 'b_na': False, 'b': 1}))\n",
"# test_eq(dec[1], 'a')\n",
"\n",
"# test_stdout(lambda: print(show_at(tds, 1)), \"\"\"a 1\n",
"# b_na False\n",
"# b 1\n",
"# category a\n",
"# dtype: object\"\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Export -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converted 00_test.ipynb.\n",
"Converted 01_core_foundation.ipynb.\n",
"Converted 01a_core_utils.ipynb.\n",
"Converted 01b_core_dispatch.ipynb.\n",
"Converted 01c_core_transform.ipynb.\n",
"Converted 02_core_script.ipynb.\n",
"Converted 03_torchcore.ipynb.\n",
"Converted 03a_layers.ipynb.\n",
"Converted 04_data_load.ipynb.\n",
"Converted 05_data_core.ipynb.\n",
"Converted 06_data_transforms.ipynb.\n",
"Converted 07_data_block.ipynb.\n",
"Converted 08_vision_core.ipynb.\n",
"Converted 09_vision_augment.ipynb.\n",
"Converted 09a_vision_data.ipynb.\n",
"Converted 09b_vision_utils.ipynb.\n",
"Converted 10_pets_tutorial.ipynb.\n",
"Converted 11_vision_models_xresnet.ipynb.\n",
"Converted 12_optimizer.ipynb.\n",
"Converted 13_learner.ipynb.\n",
"Converted 13a_metrics.ipynb.\n",
"Converted 14_callback_schedule.ipynb.\n",
"Converted 14a_callback_data.ipynb.\n",
"Converted 15_callback_hook.ipynb.\n",
"Converted 15a_vision_models_unet.ipynb.\n",
"Converted 16_callback_progress.ipynb.\n",
"Converted 17_callback_tracker.ipynb.\n",
"Converted 18_callback_fp16.ipynb.\n",
"Converted 19_callback_mixup.ipynb.\n",
"Converted 20_interpret.ipynb.\n",
"Converted 20a_distributed.ipynb.\n",
"Converted 21_vision_learner.ipynb.\n",
"Converted 22_tutorial_imagenette.ipynb.\n",
"Converted 23_tutorial_transfer_learning.ipynb.\n",
"Converted 30_text_core.ipynb.\n",
"Converted 31_text_data.ipynb.\n",
"Converted 32_text_models_awdlstm.ipynb.\n",
"Converted 33_text_models_core.ipynb.\n",
"Converted 34_callback_rnn.ipynb.\n",
"Converted 35_tutorial_wikitext.ipynb.\n",
"Converted 36_text_models_qrnn.ipynb.\n",
"Converted 37_text_learner.ipynb.\n",
"Converted 38_tutorial_ulmfit.ipynb.\n",
"Converted 40_tabular_core.ipynb.\n",
"Converted 41_tabular_model.ipynb.\n",
"Converted 42_tabular_rapids.ipynb.\n",
"Converted 50_data_block_examples.ipynb.\n",
"Converted 60_medical_imaging.ipynb.\n",
"Converted 65_medical_text.ipynb.\n",
"Converted 70_callback_wandb.ipynb.\n",
"Converted 71_callback_tensorboard.ipynb.\n",
"Converted 90_notebook_core.ipynb.\n",
"Converted 91_notebook_export.ipynb.\n",
"Converted 92_notebook_showdoc.ipynb.\n",
"Converted 93_notebook_export2html.ipynb.\n",
"Converted 94_notebook_test.ipynb.\n",
"Converted 95_index.ipynb.\n",
"Converted 96_data_external.ipynb.\n",
"Converted 97_utils_test.ipynb.\n",
"Converted notebook2jekyll.ipynb.\n",
"Converted xse_resnext.ipynb.\n"
]
}
],
"source": [
"#hide\n",
"from nbdev.export import notebook2script\n",
"notebook2script()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}