{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp tabular.core" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.torch_basics import *\n", "from fastai2.data.all import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "pd.set_option('mode.chained_assignment','raise')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tabular core\n", "\n", "> Basic function to preprocess tabular data before assembling it in a `DataLoaders`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initial preprocessing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def make_date(df, date_field):\n", " \"Make sure `df[date_field]` is of the right date type.\"\n", " field_dtype = df[date_field].dtype\n", " if isinstance(field_dtype, pd.core.dtypes.dtypes.DatetimeTZDtype):\n", " field_dtype = np.datetime64\n", " if not np.issubdtype(field_dtype, np.datetime64):\n", " df[date_field] = pd.to_datetime(df[date_field], infer_datetime_format=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame({'date': ['2019-12-04', '2019-11-29', '2019-11-15', '2019-10-24']})\n", "make_date(df, 'date')\n", "test_eq(df['date'].dtype, np.dtype('datetime64[ns]'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def add_datepart(df, field_name, prefix=None, drop=True, time=False):\n", " \"Helper function that adds columns relevant to a date in the column `field_name` of `df`.\"\n", " make_date(df, field_name)\n", " field = df[field_name]\n", " prefix = ifnone(prefix, re.sub('[Dd]ate$', '', field_name))\n", " attr = ['Year', 'Month', 'Week', 'Day', 'Dayofweek', 'Dayofyear', 'Is_month_end', 'Is_month_start',\n", " 'Is_quarter_end', 'Is_quarter_start', 'Is_year_end', 'Is_year_start']\n", " if time: attr = attr + ['Hour', 'Minute', 'Second']\n", " for n in attr: df[prefix + n] = getattr(field.dt, n.lower())\n", " df[prefix + 'Elapsed'] = field.astype(np.int64) // 10 ** 9\n", " if drop: df.drop(field_name, axis=1, inplace=True)\n", " return df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
YearMonthWeekDayDayofweekDayofyearIs_month_endIs_month_startIs_quarter_endIs_quarter_startIs_year_endIs_year_startElapsed
02019124942338FalseFalseFalseFalseFalseFalse1575417600
120191148294333FalseFalseFalseFalseFalseFalse1574985600
220191146154319FalseFalseFalseFalseFalseFalse1573776000
320191043243297FalseFalseFalseFalseFalseFalse1571875200
\n", "
" ], "text/plain": [ " Year Month Week Day Dayofweek Dayofyear Is_month_end Is_month_start \\\n", "0 2019 12 49 4 2 338 False False \n", "1 2019 11 48 29 4 333 False False \n", "2 2019 11 46 15 4 319 False False \n", "3 2019 10 43 24 3 297 False False \n", "\n", " Is_quarter_end Is_quarter_start Is_year_end Is_year_start Elapsed \n", "0 False False False False 1575417600 \n", "1 False False False False 1574985600 \n", "2 False False False False 1573776000 \n", "3 False False False False 1571875200 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.DataFrame({'date': ['2019-12-04', '2019-11-29', '2019-11-15', '2019-10-24']})\n", "df = add_datepart(df, 'date')\n", "test_eq(df.columns, ['Year', 'Month', 'Week', 'Day', 'Dayofweek', 'Dayofyear', 'Is_month_end', 'Is_month_start', \n", " 'Is_quarter_end', 'Is_quarter_start', 'Is_year_end', 'Is_year_start', 'Elapsed'])\n", "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _get_elapsed(df,field_names, date_field, base_field, prefix):\n", " for f in field_names:\n", " day1 = np.timedelta64(1, 'D')\n", " last_date,last_base,res = np.datetime64(),None,[]\n", " for b,v,d in zip(df[base_field].values, df[f].values, df[date_field].values):\n", " if last_base is None or b != last_base:\n", " last_date,last_base = np.datetime64(),b\n", " if v: last_date = d\n", " res.append(((d-last_date).astype('timedelta64[D]') / day1))\n", " df[prefix + f] = res\n", " return df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def add_elapsed_times(df, field_names, date_field, base_field):\n", " \"Add in `df` for each event in `field_names` the elapsed time according to `date_field` grouped by `base_field`\"\n", " field_names = list(L(field_names))\n", " #Make sure date_field is a date and base_field a bool\n", " df[field_names] = df[field_names].astype('bool')\n", " make_date(df, date_field)\n", "\n", " work_df = df[field_names + [date_field, base_field]]\n", " work_df = work_df.sort_values([base_field, date_field])\n", " work_df = _get_elapsed(work_df, field_names, date_field, base_field, 'After')\n", " work_df = work_df.sort_values([base_field, date_field], ascending=[True, False])\n", " work_df = _get_elapsed(work_df, field_names, date_field, base_field, 'Before')\n", "\n", " for a in ['After' + f for f in field_names] + ['Before' + f for f in field_names]:\n", " work_df[a] = work_df[a].fillna(0).astype(int)\n", "\n", " for a,s in zip([True, False], ['_bw', '_fw']):\n", " work_df = work_df.set_index(date_field)\n", " tmp = (work_df[[base_field] + field_names].sort_index(ascending=a)\n", " .groupby(base_field).rolling(7, min_periods=1).sum())\n", " tmp.drop(base_field,1,inplace=True)\n", " tmp.reset_index(inplace=True)\n", " work_df.reset_index(inplace=True)\n", " work_df = work_df.merge(tmp, 'left', [date_field, base_field], suffixes=['', s])\n", " work_df.drop(field_names,1,inplace=True)\n", " return df.merge(work_df, 'left', [date_field, base_field])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dateeventbaseAftereventBeforeeventevent_bwevent_fw
02019-12-04False1501.00.0
12019-11-29True1001.01.0
22019-11-15False22201.00.0
32019-10-24True2001.01.0
\n", "
" ], "text/plain": [ " date event base Afterevent Beforeevent event_bw event_fw\n", "0 2019-12-04 False 1 5 0 1.0 0.0\n", "1 2019-11-29 True 1 0 0 1.0 1.0\n", "2 2019-11-15 False 2 22 0 1.0 0.0\n", "3 2019-10-24 True 2 0 0 1.0 1.0" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.DataFrame({'date': ['2019-12-04', '2019-11-29', '2019-11-15', '2019-10-24'],\n", " 'event': [False, True, False, True], 'base': [1,1,2,2]})\n", "df = add_elapsed_times(df, ['event'], 'date', 'base')\n", "df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def cont_cat_split(df, max_card=20, dep_var=None):\n", " \"Helper function that returns column names of cont and cat variables from given `df`.\"\n", " cont_names, cat_names = [], []\n", " for label in df:\n", " if label == dep_var: continue\n", " if df[label].dtype == int and df[label].unique().shape[0] > max_card or df[label].dtype == float:\n", " cont_names.append(label)\n", " else: cat_names.append(label)\n", " return cont_names, cat_names" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def df_shrink_dtypes(df, skip=[], obj2cat=True, int2uint=False):\n", " \"Return any possible smaller data types for DataFrame columns. Allows `object`->`category`, `int`->`uint`, and exclusion.\"\n", "\n", " # 1: Build column filter and typemap\n", " excl_types, skip = {'category','datetime64[ns]','bool'}, set(skip)\n", "\n", " typemap = {'int' : [(np.dtype(x), np.iinfo(x).min, np.iinfo(x).max) for x in (np.int8, np.int16, np.int32, np.int64)],\n", " 'uint' : [(np.dtype(x), np.iinfo(x).min, np.iinfo(x).max) for x in (np.uint8, np.uint16, np.uint32, np.uint64)],\n", " 'float' : [(np.dtype(x), np.finfo(x).min, np.finfo(x).max) for x in (np.float32, np.float64, np.longdouble)]\n", " }\n", " if obj2cat: typemap['object'] = 'category' # User wants to categorify dtype('Object'), which may not always save space\n", " else: excl_types.add('object')\n", "\n", " new_dtypes = {}\n", " exclude = lambda dt: dt[1].name not in excl_types and dt[0] not in skip\n", "\n", " for c, old_t in filter(exclude, df.dtypes.items()):\n", " t = next((v for k,v in typemap.items() if old_t.name.startswith(k)), None)\n", "\n", " if isinstance(t, list): # Find the smallest type that fits\n", " if int2uint and t==typemap['int'] and df[c].min() >= 0: t=typemap['uint']\n", " new_t = next((r[0] for r in t if r[1]<=df[c].min() and r[2]>=df[c].max()), None)\n", " if new_t and new_t == old_t: new_t = None\n", " else: new_t = t if isinstance(t, str) else None\n", "\n", " if new_t: new_dtypes[c] = new_t\n", " return new_dtypes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

df_shrink_dtypes[source]

\n", "\n", "> df_shrink_dtypes(**`df`**, **`skip`**=*`[]`*, **`obj2cat`**=*`True`*, **`int2uint`**=*`False`*)\n", "\n", "Return any possible smaller data types for DataFrame columns. Allows `object`->`category`, `int`->`uint`, and exclusion." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(df_shrink_dtypes, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame({'i': [-100, 0, 100], 'f': [-100.0, 0.0, 100.0], 'e': [True, False, True],\n", " 'date':['2019-12-04','2019-11-29','2019-11-15',]})\n", "dt = df_shrink_dtypes(df)\n", "test_eq(df['i'].dtype, 'int64')\n", "test_eq(dt['i'], 'int8')\n", "\n", "test_eq(df['f'].dtype, 'float64')\n", "test_eq(dt['f'], 'float32')\n", "\n", "# Default ignore 'object' and 'boolean' columns\n", "test_eq(df['date'].dtype, 'object')\n", "test_eq(dt['date'], 'category')\n", "\n", "# Test categorifying 'object' type\n", "dt2 = df_shrink_dtypes(df, obj2cat=False)\n", "test_eq('date' not in dt2, True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def df_shrink(df, skip=[], obj2cat=True, int2uint=False):\n", " \"Reduce DataFrame memory usage, by casting to smaller types returned by `df_shrink_dtypes()`.\"\n", " dt = df_shrink_dtypes(df, skip, obj2cat=obj2cat, int2uint=int2uint)\n", " return df.astype(dt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

df_shrink[source]

\n", "\n", "> df_shrink(**`df`**, **`skip`**=*`[]`*, **`obj2cat`**=*`True`*, **`int2uint`**=*`False`*)\n", "\n", "Reduce DataFrame memory usage, by casting to smaller types returned by `df_shrink_dtypes()`." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(df_shrink, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`df_shrink(df)` attempts to make a DataFrame uses less memory, by fit numeric columns into smallest datatypes. In addition:\n", "\n", " * `boolean`, `category`, `datetime64[ns]` dtype columns are ignored.\n", " * 'object' type columns are categorified, which can save a lot of memory in large dataset. It can be turned off by `obj2cat=False`.\n", " * `int2uint=True`, to fit `int` types to `uint` types, if all data in the column is >= 0.\n", " * columns can be excluded by name using `excl_cols=['col1','col2']`.\n", "\n", "To get only new column data types without actually casting a DataFrame,\n", "use `df_shrink_dtypes()` with all the same parameters for `df_shrink()`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame({'i': [-100, 0, 100], 'f': [-100.0, 0.0, 100.0], 'u':[0, 10,254],\n", " 'date':['2019-12-04','2019-11-29','2019-11-15']})\n", "df2 = df_shrink(df, skip=['date'])\n", "\n", "test_eq(df['i'].dtype=='int64' and df2['i'].dtype=='int8', True)\n", "test_eq(df['f'].dtype=='float64' and df2['f'].dtype=='float32', True)\n", "test_eq(df['u'].dtype=='int64' and df2['u'].dtype=='int16', True)\n", "test_eq(df2['date'].dtype, 'object')\n", "\n", "test_eq(df2.memory_usage().sum() < df.memory_usage().sum(), True)\n", "\n", "# Test int => uint (when col.min() >= 0)\n", "df3 = df_shrink(df, int2uint=True)\n", "test_eq(df3['u'].dtype, 'uint8') # int64 -> uint8 instead of int16\n", "\n", "# Test excluding columns\n", "df4 = df_shrink(df, skip=['i','u'])\n", "test_eq(df['i'].dtype, df4['i'].dtype)\n", "test_eq(df4['u'].dtype, 'int64')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's an example using the `ADULT_SAMPLE` dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Memory usage: 3907448 --> 818665\n" ] } ], "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(path/'adult.csv')\n", "new_df = df_shrink(df, int2uint=True)\n", "print(f\"Memory usage: {df.memory_usage().sum()} --> {new_df.memory_usage().sum()}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tabular -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class _TabIloc:\n", " \"Get/set rows by iloc and cols by name\"\n", " def __init__(self,to): self.to = to\n", " def __getitem__(self, idxs):\n", " df = self.to.items\n", " if isinstance(idxs,tuple):\n", " rows,cols = idxs\n", " cols = df.columns.isin(cols) if is_listy(cols) else df.columns.get_loc(cols)\n", " else: rows,cols = idxs,slice(None)\n", " return self.to.new(df.iloc[rows, cols])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Tabular(CollBase, GetAttr, FilteredBase):\n", " \"A `DataFrame` wrapper that knows which cols are cont/cat/y, and returns rows in `__getitem__`\"\n", " _default,with_cont='procs',True\n", " def __init__(self, df, procs=None, cat_names=None, cont_names=None, y_names=None, y_block=None, splits=None,\n", " do_setup=True, device=None, inplace=False, reduce_memory=True):\n", " if inplace and splits is not None and pd.options.mode.chained_assignment is not None:\n", " warn(\"Using inplace with splits will trigger a pandas error. Set `pd.options.mode.chained_assignment=None` to avoid it.\")\n", " if not inplace: df = df.copy()\n", " if reduce_memory: df = df_shrink(df)\n", " if splits is not None: df = df.iloc[sum(splits, [])]\n", " self.dataloaders = delegates(self._dl_type.__init__)(self.dataloaders)\n", " super().__init__(df)\n", "\n", " self.y_names,self.device = L(y_names),device\n", " if y_block is None and self.y_names:\n", " # Make ys categorical if they're not numeric\n", " ys = df[self.y_names]\n", " if len(ys.select_dtypes(include='number').columns)!=len(ys.columns): y_block = CategoryBlock()\n", " else: y_block = RegressionBlock()\n", " if y_block is not None and do_setup:\n", " if callable(y_block): y_block = y_block()\n", " procs = L(procs) + y_block.type_tfms\n", " self.cat_names,self.cont_names,self.procs = L(cat_names),L(cont_names),Pipeline(procs)\n", " self.split = len(df) if splits is None else len(splits[0])\n", " if do_setup: self.setup()\n", "\n", " def new(self, df):\n", " return type(self)(df, do_setup=False, reduce_memory=False, y_block=TransformBlock(),\n", " **attrdict(self, 'procs','cat_names','cont_names','y_names', 'device'))\n", "\n", " def subset(self, i): return self.new(self.items[slice(0,self.split) if i==0 else slice(self.split,len(self))])\n", " def copy(self): self.items = self.items.copy(); return self\n", " def decode(self): return self.procs.decode(self)\n", " def decode_row(self, row): return self.new(pd.DataFrame(row).T).decode().items.iloc[0]\n", " def show(self, max_n=10, **kwargs): display_df(self.new(self.all_cols[:max_n]).decode().items)\n", " def setup(self): self.procs.setup(self)\n", " def process(self): self.procs(self)\n", " def loc(self): return self.items.loc\n", " def iloc(self): return _TabIloc(self)\n", " def targ(self): return self.items[self.y_names]\n", " def x_names (self): return self.cat_names + self.cont_names\n", " def n_subsets(self): return 2\n", " def y(self): return self[self.y_names[0]]\n", " def new_empty(self): return self.new(pd.DataFrame({}, columns=self.items.columns))\n", " def to_device(self, d=None):\n", " self.device = d\n", " return self\n", "\n", " def all_col_names (self):\n", " ys = [n for n in self.y_names if n in self.items.columns]\n", " return self.x_names + self.y_names if len(ys) == len(self.y_names) else self.x_names\n", "\n", "properties(Tabular,'loc','iloc','targ','all_col_names','n_subsets','x_names','y')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* `df`: A `DataFrame` of your data\n", "* `cat_names`: Your categorical `x` variables\n", "* `cont_names`: Your continuous `x` variables\n", "* `y_names`: Your dependendant `y` variables\n", " * Note: Mixed y's such as Regression and Classification is not currently supported, however multiple regression or classification outputs is\n", "* `y_block`: How to sub-categorize the type of `y_names` (`CategoryBlock` or `RegressionBlock`)\n", "* `splits`: How to split your data\n", "* `do_setup`: A parameter for if `Tabular` will run the data through the `procs` upon initialization\n", "* `device`: `cuda` or `cpu`\n", "* `inplace`: If `True`, `Tabular` will not keep a seperate copy of your original `DataFrame` in memory. You should ensure `pd.options.mode.chained_assignment` is `None` before setting this\n", "* `reduce_memory`: `fastai` will attempt to reduce the overall memory usage by the inputed `DataFrame` with `df_shrink`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TabularPandas(Tabular):\n", " \"A `Tabular` object with transforms\"\n", " def transform(self, cols, f, all_col=True):\n", " if not all_col: cols = [c for c in cols if c in self.items.columns]\n", " if len(cols) > 0: self[cols] = self[cols].transform(f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _add_prop(cls, nm):\n", " @property\n", " def f(o): return o[list(getattr(o,nm+'_names'))]\n", " @f.setter\n", " def fset(o, v): o[getattr(o,nm+'_names')] = v\n", " setattr(cls, nm+'s', f)\n", " setattr(cls, nm+'s', fset)\n", "\n", "_add_prop(Tabular, 'cat')\n", "_add_prop(Tabular, 'cont')\n", "_add_prop(Tabular, 'y')\n", "_add_prop(Tabular, 'x')\n", "_add_prop(Tabular, 'all_col')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame({'a':[0,1,2,0,2], 'b':[0,0,0,0,1]})\n", "to = TabularPandas(df, cat_names='a')\n", "t = pickle.loads(pickle.dumps(to))\n", "test_eq(t.items,to.items)\n", "test_eq(to.all_cols,to[['a']])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TabularProc(InplaceTransform):\n", " \"Base class to write a non-lazy tabular processor for dataframes\"\n", " store_attrs=''\n", " def setup(self, items=None, train_setup=False): #TODO: properly deal with train_setup\n", " super().setup(getattr(items,'train',items), train_setup=False)\n", " # Procs are called as soon as data is available\n", " return self(items.items if isinstance(items,Datasets) else items)\n", "\n", " @property\n", " def name(self):\n", " if self.store_attrs: attrs = self.store_attrs.split(',')\n", " else: attrs = ''\n", " return f\"{super().name} -- {attrdict(self, *attrs)}\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _apply_cats (voc, add, c):\n", " if not is_categorical_dtype(c):\n", " return pd.Categorical(c, categories=voc[c.name][add:]).codes+add\n", " return c.cat.codes+add #if is_categorical_dtype(c) else c.map(voc[c.name].o2i)\n", "def _decode_cats(voc, c): return c.map(dict(enumerate(voc[c.name].items)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Categorify(TabularProc):\n", " \"Transform the categorical variables to something similar to `pd.Categorical`\"\n", " order,store_attrs = 1,'classes'\n", " def setups(self, to):\n", " self.classes = {n:CategoryMap(to.iloc[:,n].items, add_na=(n in to.cat_names)) for n in to.cat_names}\n", "\n", " def encodes(self, to): to.transform(to.cat_names, partial(_apply_cats, self.classes, 1))\n", " def decodes(self, to): to.transform(to.cat_names, partial(_decode_cats, self.classes))\n", " def __getitem__(self,k): return self.classes[k]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@Categorize\n", "def setups(self, to:Tabular):\n", " if len(to.y_names) > 0:\n", " if self.vocab is None:\n", " self.vocab = CategoryMap(getattr(to, 'train', to).iloc[:,to.y_names[0]].items, strict=True)\n", " else:\n", " self.vocab = CategoryMap(self.vocab, sort=False, add_na=self.add_na)\n", " self.c = len(self.vocab)\n", " return self(to)\n", "\n", "@Categorize\n", "def encodes(self, to:Tabular):\n", " to.transform(to.y_names, partial(_apply_cats, {n: self.vocab for n in to.y_names}, 0), all_col=False)\n", " return to\n", "\n", "@Categorize\n", "def decodes(self, to:Tabular):\n", " to.transform(to.y_names, partial(_decode_cats, {n: self.vocab for n in to.y_names}), all_col=False)\n", " return to" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class Categorify[source]

\n", "\n", "> Categorify(**`enc`**=*`None`*, **`dec`**=*`None`*, **`split_idx`**=*`None`*, **`order`**=*`None`*) :: [`TabularProc`](/tabular.core#TabularProc)\n", "\n", "Transform the categorical variables to something similar to `pd.Categorical`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Categorify, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
a
00
11
22
30
42
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "df = pd.DataFrame({'a':[0,1,2,0,2]})\n", "to = TabularPandas(df, Categorify, 'a')\n", "cat = to.procs.categorify\n", "test_eq(cat['a'], ['#na#',0,1,2])\n", "test_eq(to['a'], [1,2,3,1,3])\n", "to.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df1 = pd.DataFrame({'a':[1,0,3,-1,2]})\n", "to1 = to.new(df1)\n", "to1.process()\n", "#Values that weren't in the training df are sent to 0 (na)\n", "test_eq(to1['a'], [2,1,0,0,3])\n", "to2 = cat.decode(to1)\n", "test_eq(to2['a'], [1,0,'#na#','#na#',2])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test with splits\n", "cat = Categorify()\n", "df = pd.DataFrame({'a':[0,1,2,3,2]})\n", "to = TabularPandas(df, cat, 'a', splits=[[0,1,2],[3,4]])\n", "test_eq(cat['a'], ['#na#',0,1,2])\n", "test_eq(to['a'], [1,2,3,0,3])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame({'a':pd.Categorical(['M','H','L','M'], categories=['H','M','L'], ordered=True)})\n", "to = TabularPandas(df, Categorify, 'a')\n", "cat = to.procs.categorify\n", "test_eq(cat['a'], ['#na#','H','M','L'])\n", "test_eq(to.items.a, [2,1,3,2])\n", "to2 = cat.decode(to)\n", "test_eq(to2['a'], ['M','H','L','M'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test with targets\n", "cat = Categorify()\n", "df = pd.DataFrame({'a':[0,1,2,3,2], 'b': ['a', 'b', 'a', 'b', 'b']})\n", "to = TabularPandas(df, cat, 'a', splits=[[0,1,2],[3,4]], y_names='b')\n", "test_eq(to.vocab, ['a', 'b'])\n", "test_eq(to['b'], [0,1,0,1,1])\n", "to2 = to.procs.decode(to)\n", "test_eq(to2['b'], ['a', 'b', 'a', 'b', 'b'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat = Categorify()\n", "df = pd.DataFrame({'a':[0,1,2,3,2], 'b': ['a', 'b', 'a', 'b', 'b']})\n", "to = TabularPandas(df, cat, 'a', splits=[[0,1,2],[3,4]], y_names='b')\n", "test_eq(to.vocab, ['a', 'b'])\n", "test_eq(to['b'], [0,1,0,1,1])\n", "to2 = to.procs.decode(to)\n", "test_eq(to2['b'], ['a', 'b', 'a', 'b', 'b'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test with targets and train\n", "cat = Categorify()\n", "df = pd.DataFrame({'a':[0,1,2,3,2], 'b': ['a', 'b', 'a', 'c', 'b']})\n", "to = TabularPandas(df, cat, 'a', splits=[[0,1,2],[3,4]], y_names='b')\n", "test_eq(to.vocab, ['a', 'b'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@Normalize\n", "def setups(self, to:Tabular):\n", " self.means,self.stds = dict(getattr(to, 'train', to).conts.mean()),dict(getattr(to, 'train', to).conts.std(ddof=0)+1e-7)\n", " self.store_attrs = 'means,stds'\n", " return self(to)\n", "\n", "@Normalize\n", "def encodes(self, to:Tabular):\n", " to.conts = (to.conts-self.means) / self.stds\n", " return to\n", "\n", "@Normalize\n", "def decodes(self, to:Tabular):\n", " to.conts = (to.conts*self.stds ) + self.means\n", " return to" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "norm = Normalize()\n", "df = pd.DataFrame({'a':[0,1,2,3,4]})\n", "to = TabularPandas(df, norm, cont_names='a')\n", "x = np.array([0,1,2,3,4])\n", "m,s = x.mean(),x.std()\n", "test_eq(norm.means['a'], m)\n", "test_close(norm.stds['a'], s)\n", "test_close(to['a'].values, (x-m)/s)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df1 = pd.DataFrame({'a':[5,6,7]})\n", "to1 = to.new(df1)\n", "to1.process()\n", "test_close(to1['a'].values, (np.array([5,6,7])-m)/s)\n", "to2 = norm.decode(to1)\n", "test_close(to2['a'].values, [5,6,7])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "norm = Normalize()\n", "df = pd.DataFrame({'a':[0,1,2,3,4]})\n", "to = TabularPandas(df, norm, cont_names='a', splits=[[0,1,2],[3,4]])\n", "x = np.array([0,1,2])\n", "m,s = x.mean(),x.std()\n", "test_eq(norm.means['a'], m)\n", "test_close(norm.stds['a'], s)\n", "test_close(to['a'].values, (np.array([0,1,2,3,4])-m)/s)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class FillStrategy:\n", " \"Namespace containing the various filling strategies.\"\n", " def median (c,fill): return c.median()\n", " def constant(c,fill): return fill\n", " def mode (c,fill): return c.dropna().value_counts().idxmax()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Currently, filling with the `median`, a `constant`, and the `mode` are supported." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class FillMissing(TabularProc):\n", " \"Fill the missing values in continuous columns.\"\n", " store_attrs = 'fill_strategy,add_col,fill_vals'\n", " def __init__(self, fill_strategy=FillStrategy.median, add_col=True, fill_vals=None):\n", " if fill_vals is None: fill_vals = defaultdict(int)\n", " store_attr(self,self.store_attrs)\n", " def setups(self, dsets):\n", " missing = pd.isnull(dsets.conts).any()\n", " self.na_dict = {n:self.fill_strategy(dsets[n], self.fill_vals[n])\n", " for n in missing[missing].keys()}\n", " self.store_attrs += ',na_dict'\n", " self.fill_strategy = self.fill_strategy.__name__\n", "\n", " def encodes(self, to):\n", " missing = pd.isnull(to.conts)\n", " for n in missing.any()[missing.any()].keys():\n", " assert n in self.na_dict, f\"nan values in `{n}` but not in setup training set\"\n", " for n in self.na_dict.keys():\n", " to[n].fillna(self.na_dict[n], inplace=True)\n", " if self.add_col:\n", " to.loc[:,n+'_na'] = missing[n]\n", " if n+'_na' not in to.cat_names: to.cat_names.append(n+'_na')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class FillMissing[source]

\n", "\n", "> FillMissing(**`fill_strategy`**=*`'median'`*, **`add_col`**=*`True`*, **`fill_vals`**=*`None`*) :: [`TabularProc`](/tabular.core#TabularProc)\n", "\n", "Fill the missing values in continuous columns." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(FillMissing, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fill1,fill2,fill3 = (FillMissing(fill_strategy=s) \n", " for s in [FillStrategy.median, FillStrategy.constant, FillStrategy.mode])\n", "df = pd.DataFrame({'a':[0,1,np.nan,1,2,3,4]})\n", "df1 = df.copy(); df2 = df.copy()\n", "tos = (TabularPandas(df, fill1, cont_names='a'),\n", " TabularPandas(df1, fill2, cont_names='a'),\n", " TabularPandas(df2, fill3, cont_names='a'))\n", "test_eq(fill1.na_dict, {'a': 1.5})\n", "test_eq(fill2.na_dict, {'a': 0})\n", "test_eq(fill3.na_dict, {'a': 1.0})\n", "\n", "for t in tos: test_eq(t.cat_names, ['a_na'])\n", "\n", "for to_,v in zip(tos, [1.5, 0., 1.]):\n", " test_eq(to_['a'].values, np.array([0, 1, v, 1, 2, 3, 4]))\n", " test_eq(to_['a_na'].values, np.array([0, 0, 1, 0, 0, 0, 0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fill = FillMissing() \n", "df = pd.DataFrame({'a':[0,1,np.nan,1,2,3,4], 'b': [0,1,2,3,4,5,6]})\n", "to = TabularPandas(df, fill, cont_names=['a', 'b'])\n", "test_eq(fill.na_dict, {'a': 1.5})\n", "test_eq(to.cat_names, ['a_na'])\n", "test_eq(to['a'].values, np.array([0, 1, 1.5, 1, 2, 3, 4]))\n", "test_eq(to['a_na'].values, np.array([0, 0, 1, 0, 0, 0, 0]))\n", "test_eq(to['b'].values, np.array([0,1,2,3,4,5,6]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TabularPandas Pipelines -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "procs = [Normalize, Categorify, FillMissing, noop]\n", "df = pd.DataFrame({'a':[0,1,2,1,1,2,0], 'b':[0,1,np.nan,1,2,3,4]})\n", "to = TabularPandas(df, procs, cat_names='a', cont_names='b')\n", "\n", "#Test setup and apply on df_main\n", "test_eq(to.cat_names, ['a', 'b_na'])\n", "test_eq(to['a'], [1,2,3,2,2,3,1])\n", "test_eq(to['b_na'], [1,1,2,1,1,1,1])\n", "x = np.array([0,1,1.5,1,2,3,4])\n", "m,s = x.mean(),x.std()\n", "test_close(to['b'].values, (x-m)/s)\n", "test_eq(to.classes, {'a': ['#na#',0,1,2], 'b_na': ['#na#',False,True]})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Test apply on y_names\n", "df = pd.DataFrame({'a':[0,1,2,1,1,2,0], 'b':[0,1,np.nan,1,2,3,4], 'c': ['b','a','b','a','a','b','a']})\n", "to = TabularPandas(df, procs, 'a', 'b', y_names='c')\n", "\n", "test_eq(to.cat_names, ['a', 'b_na'])\n", "test_eq(to['a'], [1,2,3,2,2,3,1])\n", "test_eq(to['b_na'], [1,1,2,1,1,1,1])\n", "test_eq(to['c'], [1,0,1,0,0,1,0])\n", "x = np.array([0,1,1.5,1,2,3,4])\n", "m,s = x.mean(),x.std()\n", "test_close(to['b'].values, (x-m)/s)\n", "test_eq(to.classes, {'a': ['#na#',0,1,2], 'b_na': ['#na#',False,True]})\n", "test_eq(to.vocab, ['a','b'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame({'a':[0,1,2,1,1,2,0], 'b':[0,1,np.nan,1,2,3,4], 'c': ['b','a','b','a','a','b','a']})\n", "to = TabularPandas(df, procs, 'a', 'b', y_names='c')\n", "\n", "test_eq(to.cat_names, ['a', 'b_na'])\n", "test_eq(to['a'], [1,2,3,2,2,3,1])\n", "test_eq(df.a.dtype,int)\n", "test_eq(to['b_na'], [1,1,2,1,1,1,1])\n", "test_eq(to['c'], [1,0,1,0,0,1,0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame({'a':[0,1,2,1,1,2,0], 'b':[0,np.nan,1,1,2,3,4], 'c': ['b','a','b','a','a','b','a']})\n", "to = TabularPandas(df, procs, cat_names='a', cont_names='b', y_names='c', splits=[[0,1,4,6], [2,3,5]])\n", "\n", "test_eq(to.cat_names, ['a', 'b_na'])\n", "test_eq(to['a'], [1,2,2,1,0,2,0])\n", "test_eq(df.a.dtype,int)\n", "test_eq(to['b_na'], [1,2,1,1,1,1,1])\n", "test_eq(to['c'], [1,0,0,0,1,0,1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _maybe_expand(o): return o[:,None] if o.ndim==1 else o" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ReadTabBatch(ItemTransform):\n", " def __init__(self, to): self.to = to\n", "\n", " def encodes(self, to):\n", " if not to.with_cont: res = (tensor(to.cats).long(),)\n", " else: res = (tensor(to.cats).long(),tensor(to.conts).float())\n", " ys = [n for n in to.y_names if n in to.items.columns]\n", " if len(ys) == len(to.y_names): res = res + (tensor(to.targ),)\n", " if to.device is not None: res = to_device(res, to.device)\n", " return res\n", "\n", " def decodes(self, o):\n", " o = [_maybe_expand(o_) for o_ in to_np(o) if o_.size != 0]\n", " vals = np.concatenate(o, axis=1)\n", " try: df = pd.DataFrame(vals, columns=self.to.all_col_names)\n", " except: df = pd.DataFrame(vals, columns=self.to.x_names)\n", " to = self.to.new(df)\n", " return to" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@typedispatch\n", "def show_batch(x: Tabular, y, its, max_n=10, ctxs=None):\n", " x.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind\n", "_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@delegates()\n", "class TabDataLoader(TfmdDL):\n", " \"A transformed `DataLoader` for Tabular data\"\n", " do_item = noops\n", " def __init__(self, dataset, bs=16, shuffle=False, after_batch=None, num_workers=0, **kwargs):\n", " if after_batch is None: after_batch = L(TransformBlock().batch_tfms)+ReadTabBatch(dataset)\n", " super().__init__(dataset, bs=bs, shuffle=shuffle, after_batch=after_batch, num_workers=num_workers, **kwargs)\n", "\n", " def create_batch(self, b): return self.dataset.iloc[b]\n", "\n", "TabularPandas._dl_type = TabDataLoader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Integration example\n", "\n", "For a more in-depth explaination, see the [tabular tutorial](http://dev.fast.ai/tutorial.tabular)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countrysalary
049Private101320Assoc-acdm12.0Married-civ-spouseNaNWifeWhiteFemale0190240United-States>=50k
144Private236746Masters14.0DivorcedExec-managerialNot-in-familyWhiteMale10520045United-States>=50k
238Private96185HS-gradNaNDivorcedNaNUnmarriedBlackFemale0032United-States<50k
338Self-emp-inc112847Prof-school15.0Married-civ-spouseProf-specialtyHusbandAsian-Pac-IslanderMale0040United-States>=50k
442Self-emp-not-inc822977th-8thNaNMarried-civ-spouseOther-serviceWifeBlackFemale0050United-States<50k
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education-num \\\n", "0 49 Private 101320 Assoc-acdm 12.0 \n", "1 44 Private 236746 Masters 14.0 \n", "2 38 Private 96185 HS-grad NaN \n", "3 38 Self-emp-inc 112847 Prof-school 15.0 \n", "4 42 Self-emp-not-inc 82297 7th-8th NaN \n", "\n", " marital-status occupation relationship race \\\n", "0 Married-civ-spouse NaN Wife White \n", "1 Divorced Exec-managerial Not-in-family White \n", "2 Divorced NaN Unmarried Black \n", "3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n", "4 Married-civ-spouse Other-service Wife Black \n", "\n", " sex capital-gain capital-loss hours-per-week native-country salary \n", "0 Female 0 1902 40 United-States >=50k \n", "1 Male 10520 0 45 United-States >=50k \n", "2 Female 0 0 32 United-States <50k \n", "3 Male 0 0 40 United-States >=50k \n", "4 Female 0 0 50 United-States <50k " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(path/'adult.csv')\n", "df_main,df_test = df.iloc[:10000].copy(),df.iloc[10000:].copy()\n", "df_test.drop('salary', axis=1, inplace=True)\n", "df_main.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", "cont_names = ['age', 'fnlwgt', 'education-num']\n", "procs = [Categorify, FillMissing, Normalize]\n", "splits = RandomSplitter()(range_of(df_main))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "to = TabularPandas(df_main, procs, cat_names, cont_names, y_names=\"salary\", splits=splits)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
workclasseducationmarital-statusoccupationrelationshipraceeducation-num_naagefnlwgteducation-numsalary
0PrivateSome-collegeMarried-spouse-absentOther-serviceNot-in-familyWhiteFalse22.99999954472.00540710.0<50k
1PrivateSome-collegeNever-marriedOther-serviceOther-relativeBlackFalse21.000001236683.99990510.0<50k
2PrivateSome-collegeNever-marriedSalesOwn-childWhiteFalse18.000001163786.99840610.0<50k
3Local-govMastersDivorced#na#UnmarriedWhiteFalse44.000000135055.99862214.0<50k
4Self-emp-incHS-gradMarried-civ-spouseAdm-clericalHusbandWhiteFalse40.000000207577.9998869.0>=50k
5State-govMastersMarried-civ-spouseExec-managerialHusbandWhiteFalse37.000000210451.99954814.0<50k
6?BachelorsNever-married?Not-in-familyWhiteFalse32.000000169885.99945313.0<50k
7PrivateHS-gradNever-marriedAdm-clericalNot-in-familyWhiteFalse20.000000236804.0004959.0<50k
8PrivateSome-collegeMarried-civ-spouseOther-serviceHusbandWhiteFalse31.000000137680.99866710.0<50k
9Self-emp-incSome-collegeMarried-civ-spouseSalesHusbandWhiteFalse46.000000284798.99746210.0<50k
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dls = to.dataloaders()\n", "dls.valid.show_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
workclasseducationmarital-statusoccupationrelationshipraceeducation-num_naagefnlwgteducation-numsalary
3380PrivateSome-collegeMarried-civ-spouseOther-serviceHusbandWhiteFalse33.0248584.010.0<50k
3158Local-govBachelorsMarried-civ-spouseExec-managerialHusbandWhiteFalse51.0110327.013.0>=50k
8904PrivateSome-collegeNever-marriedExec-managerialNot-in-familyWhiteFalse27.0133937.010.0<50k
5912Self-emp-not-incSome-collegeMarried-civ-spouseFarming-fishingHusbandWhiteFalse48.0164582.010.0>=50k
3583PrivateMastersNever-marriedExec-managerialNot-in-familyWhiteFalse39.049020.014.0<50k
2945PrivateBachelorsNever-marriedAdm-clericalOwn-childWhiteFalse26.0166051.013.0<50k
204?HS-gradMarried-civ-spouse#na#HusbandWhiteTrue60.0174073.010.0<50k
3196PrivateSome-collegeNever-marriedAdm-clericalOwn-childWhiteFalse21.0241367.010.0<50k
1183?Some-collegeMarried-civ-spouse?HusbandWhiteFalse65.052728.010.0<50k
2829PrivateMastersMarried-civ-spouseProf-specialtyHusbandWhiteFalse46.0261059.014.0>=50k
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "to.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can decode any set of transformed data by calling `to.decode_row` with our raw data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "age 33\n", "workclass Private\n", "fnlwgt 248584\n", "education Some-college\n", "education-num 10\n", "marital-status Married-civ-spouse\n", "occupation Other-service\n", "relationship Husband\n", "race White\n", "sex Male\n", "capital-gain 0\n", "capital-loss 0\n", "hours-per-week 50\n", "native-country United-States\n", "salary <50k\n", "education-num_na False\n", "Name: 3380, dtype: object" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "row = to.items.iloc[0]\n", "to.decode_row(row)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryeducation-num_na
100000.46691051.359596101.1705203212Male0040Philippines1
10001-0.93229251.27199012-0.42589331514Male0040United-States1
100021.05604750.1619112-1.2240991925Female0037United-States1
100030.5405525-0.27410012-0.4258937255Female0043United-States1
100040.76147961.46281990.3723133515Male0060United-States1
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education-num \\\n", "10000 0.466910 5 1.359596 10 1.170520 \n", "10001 -0.932292 5 1.271990 12 -0.425893 \n", "10002 1.056047 5 0.161911 2 -1.224099 \n", "10003 0.540552 5 -0.274100 12 -0.425893 \n", "10004 0.761479 6 1.462819 9 0.372313 \n", "\n", " marital-status occupation relationship race sex capital-gain \\\n", "10000 3 2 1 2 Male 0 \n", "10001 3 15 1 4 Male 0 \n", "10002 1 9 2 5 Female 0 \n", "10003 7 2 5 5 Female 0 \n", "10004 3 5 1 5 Male 0 \n", "\n", " capital-loss hours-per-week native-country education-num_na \n", "10000 0 40 Philippines 1 \n", "10001 0 40 United-States 1 \n", "10002 0 37 United-States 1 \n", "10003 0 43 United-States 1 \n", "10004 0 60 United-States 1 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "to_tst = to.new(df_test)\n", "to_tst.process()\n", "to_tst.items.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
workclasseducationmarital-statusoccupationrelationshipraceeducation-num_naagefnlwgteducation-num
0PrivateBachelorsMarried-civ-spouseAdm-clericalHusbandAsian-Pac-IslanderFalse45.000000338105.00017213.0
1PrivateHS-gradMarried-civ-spouseTransport-movingHusbandOtherFalse26.000000328662.9966259.0
2Private11thDivorcedOther-serviceNot-in-familyWhiteFalse52.999999209021.9994847.0
3PrivateHS-gradWidowedAdm-clericalUnmarriedWhiteFalse46.000000162030.0015549.0
4Self-emp-incAssoc-vocMarried-civ-spouseExec-managerialHusbandWhiteFalse49.000000349230.00556111.0
5Local-govSome-collegeMarried-civ-spouseExec-managerialHusbandWhiteFalse34.000000124827.00191610.0
6Self-emp-incSome-collegeMarried-civ-spouseSalesHusbandWhiteFalse52.999999290640.00045410.0
7PrivateSome-collegeNever-marriedSalesOwn-childWhiteFalse19.000000106273.00286610.0
8PrivateSome-collegeMarried-civ-spouseProtective-servHusbandBlackFalse71.99999953683.99725410.0
9PrivateSome-collegeNever-marriedSalesOwn-childWhiteFalse20.000000505980.00455510.0
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tst_dl = dls.valid.new(to_tst)\n", "tst_dl.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Other target types" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multi-label categories" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### one-hot encoded label" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _mock_multi_label(df):\n", " sal,sex,white = [],[],[]\n", " for row in df.itertuples():\n", " sal.append(row.salary == '>=50k')\n", " sex.append(row.sex == ' Male')\n", " white.append(row.race == ' White')\n", " df['salary'] = np.array(sal)\n", " df['male'] = np.array(sex)\n", " df['white'] = np.array(white)\n", " return df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(path/'adult.csv')\n", "df_main,df_test = df.iloc[:10000].copy(),df.iloc[10000:].copy()\n", "df_main = _mock_multi_label(df_main)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countrysalarymalewhite
049Private101320Assoc-acdm12.0Married-civ-spouseNaNWifeWhiteFemale0190240United-StatesTrueFalseTrue
144Private236746Masters14.0DivorcedExec-managerialNot-in-familyWhiteMale10520045United-StatesTrueTrueTrue
238Private96185HS-gradNaNDivorcedNaNUnmarriedBlackFemale0032United-StatesFalseFalseFalse
338Self-emp-inc112847Prof-school15.0Married-civ-spouseProf-specialtyHusbandAsian-Pac-IslanderMale0040United-StatesTrueTrueFalse
442Self-emp-not-inc822977th-8thNaNMarried-civ-spouseOther-serviceWifeBlackFemale0050United-StatesFalseFalseFalse
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education-num \\\n", "0 49 Private 101320 Assoc-acdm 12.0 \n", "1 44 Private 236746 Masters 14.0 \n", "2 38 Private 96185 HS-grad NaN \n", "3 38 Self-emp-inc 112847 Prof-school 15.0 \n", "4 42 Self-emp-not-inc 82297 7th-8th NaN \n", "\n", " marital-status occupation relationship race \\\n", "0 Married-civ-spouse NaN Wife White \n", "1 Divorced Exec-managerial Not-in-family White \n", "2 Divorced NaN Unmarried Black \n", "3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n", "4 Married-civ-spouse Other-service Wife Black \n", "\n", " sex capital-gain capital-loss hours-per-week native-country \\\n", "0 Female 0 1902 40 United-States \n", "1 Male 10520 0 45 United-States \n", "2 Female 0 0 32 United-States \n", "3 Male 0 0 40 United-States \n", "4 Female 0 0 50 United-States \n", "\n", " salary male white \n", "0 True False True \n", "1 True True True \n", "2 False False False \n", "3 True True False \n", "4 False False False " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_main.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@EncodedMultiCategorize\n", "def setups(self, to:Tabular):\n", " self.c = len(self.vocab)\n", " return self(to)\n", "\n", "@EncodedMultiCategorize\n", "def encodes(self, to:Tabular): return to\n", "\n", "@EncodedMultiCategorize\n", "def decodes(self, to:Tabular):\n", " to.transform(to.y_names, lambda c: c==1)\n", " return to" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", "cont_names = ['age', 'fnlwgt', 'education-num']\n", "procs = [Categorify, FillMissing, Normalize]\n", "splits = RandomSplitter()(range_of(df_main))\n", "y_names=[\"salary\", \"male\", \"white\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 77.2 ms, sys: 238 µs, total: 77.4 ms\n", "Wall time: 76.7 ms\n" ] } ], "source": [ "%time to = TabularPandas(df_main, procs, cat_names, cont_names, y_names=y_names, y_block=MultiCategoryBlock(encoded=True, vocab=y_names), splits=splits)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
workclasseducationmarital-statusoccupationrelationshipraceeducation-num_naagefnlwgteducation-numsalarymalewhite
0PrivateHS-gradMarried-civ-spouse#na#HusbandAmer-Indian-EskimoTrue30.000000216811.00073910.000000FalseTrueFalse
1PrivateBachelorsMarried-civ-spouse#na#HusbandWhiteFalse53.00000096061.99800913.000000FalseTrueTrue
2PrivateHS-gradMarried-civ-spouseAdm-clericalWifeWhiteFalse31.000000196787.9999019.000000FalseFalseTrue
3?BachelorsMarried-civ-spouse?HusbandWhiteTrue65.999999177351.00022610.000000TrueTrueTrue
4Private10thSeparatedSalesUnmarriedBlackFalse21.000000353628.0056625.999999FalseFalseFalse
5PrivateBachelorsNever-marriedProf-specialtyNot-in-familyWhiteFalse40.000000143045.99922913.000000FalseFalseTrue
6PrivateMastersMarried-civ-spouseProf-specialtyHusbandWhiteFalse37.000000117381.00256114.000000FalseTrueTrue
7PrivateHS-gradNever-marriedSalesNot-in-familyWhiteFalse29.000000183854.0002919.000000FalseFalseTrue
8PrivateHS-gradDivorcedPriv-house-servNot-in-familyWhiteFalse54.999999175942.0000539.000000FalseFalseTrue
9PrivateSome-collegeWidowedTech-supportUnmarriedWhiteFalse64.00000091342.99944810.000000FalseFalseTrue
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dls = to.dataloaders()\n", "dls.valid.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Not one-hot encoded" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _mock_multi_label(df):\n", " targ = []\n", " for row in df.itertuples():\n", " labels = []\n", " if row.salary == '>=50k': labels.append('>50k')\n", " if row.sex == ' Male': labels.append('male')\n", " if row.race == ' White': labels.append('white')\n", " targ.append(' '.join(labels))\n", " df['target'] = np.array(targ)\n", " return df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(path/'adult.csv')\n", "df_main,df_test = df.iloc[:10000].copy(),df.iloc[10000:].copy()\n", "df_main = _mock_multi_label(df_main)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countrysalarytarget
049Private101320Assoc-acdm12.0Married-civ-spouseNaNWifeWhiteFemale0190240United-States>=50k>50k white
144Private236746Masters14.0DivorcedExec-managerialNot-in-familyWhiteMale10520045United-States>=50k>50k male white
238Private96185HS-gradNaNDivorcedNaNUnmarriedBlackFemale0032United-States<50k
338Self-emp-inc112847Prof-school15.0Married-civ-spouseProf-specialtyHusbandAsian-Pac-IslanderMale0040United-States>=50k>50k male
442Self-emp-not-inc822977th-8thNaNMarried-civ-spouseOther-serviceWifeBlackFemale0050United-States<50k
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education-num \\\n", "0 49 Private 101320 Assoc-acdm 12.0 \n", "1 44 Private 236746 Masters 14.0 \n", "2 38 Private 96185 HS-grad NaN \n", "3 38 Self-emp-inc 112847 Prof-school 15.0 \n", "4 42 Self-emp-not-inc 82297 7th-8th NaN \n", "\n", " marital-status occupation relationship race \\\n", "0 Married-civ-spouse NaN Wife White \n", "1 Divorced Exec-managerial Not-in-family White \n", "2 Divorced NaN Unmarried Black \n", "3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n", "4 Married-civ-spouse Other-service Wife Black \n", "\n", " sex capital-gain capital-loss hours-per-week native-country salary \\\n", "0 Female 0 1902 40 United-States >=50k \n", "1 Male 10520 0 45 United-States >=50k \n", "2 Female 0 0 32 United-States <50k \n", "3 Male 0 0 40 United-States >=50k \n", "4 Female 0 0 50 United-States <50k \n", "\n", " target \n", "0 >50k white \n", "1 >50k male white \n", "2 \n", "3 >50k male \n", "4 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_main.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@MultiCategorize\n", "def encodes(self, to:Tabular): \n", " #to.transform(to.y_names, partial(_apply_cats, {n: self.vocab for n in to.y_names}, 0))\n", " return to\n", " \n", "@MultiCategorize\n", "def decodes(self, to:Tabular): \n", " #to.transform(to.y_names, partial(_decode_cats, {n: self.vocab for n in to.y_names}))\n", " return to" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", "cont_names = ['age', 'fnlwgt', 'education-num']\n", "procs = [Categorify, FillMissing, Normalize]\n", "splits = RandomSplitter()(range_of(df_main))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 81 ms, sys: 178 µs, total: 81.2 ms\n", "Wall time: 80.1 ms\n" ] } ], "source": [ "%time to = TabularPandas(df_main, procs, cat_names, cont_names, y_names=\"target\", y_block=MultiCategoryBlock(), splits=splits)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#24) ['-','_','a','c','d','e','f','g','h','i'...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "to.procs[2].vocab" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Regression" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@RegressionSetup\n", "def setups(self, to:Tabular):\n", " if self.c is not None: return\n", " self.c = len(to.y_names)\n", " return to\n", "\n", "@RegressionSetup\n", "def encodes(self, to:Tabular): return to\n", "\n", "@RegressionSetup\n", "def decodes(self, to:Tabular): return to" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(path/'adult.csv')\n", "df_main,df_test = df.iloc[:10000].copy(),df.iloc[10000:].copy()\n", "df_main = _mock_multi_label(df_main)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", "cont_names = ['fnlwgt', 'education-num']\n", "procs = [Categorify, FillMissing, Normalize]\n", "splits = RandomSplitter()(range_of(df_main))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 82.2 ms, sys: 508 µs, total: 82.7 ms\n", "Wall time: 81.8 ms\n" ] } ], "source": [ "%time to = TabularPandas(df_main, procs, cat_names, cont_names, y_names='age', splits=splits)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'fnlwgt': 193046.84475, 'education-num': 10.08025}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "to.procs[-1].means" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
workclasseducationmarital-statusoccupationrelationshipraceeducation-num_nafnlwgteducation-numage
0State-govMastersNever-married#na#Not-in-familyWhiteFalse47569.99474814.036.0
1Federal-gov11thNever-marriedSalesNot-in-familyBlackFalse166418.9992877.050.0
2Private9thDivorcedFarming-fishingNot-in-familyBlackFalse225603.0005375.058.0
3Local-gov12thWidowedAdm-clericalNot-in-familyWhiteFalse48055.0042828.055.0
4Federal-govProf-schoolDivorcedProf-specialtyNot-in-familyWhiteFalse66504.00398815.057.0
5PrivateSome-collegeNever-marriedAdm-clericalUnmarriedAsian-Pac-IslanderFalse91274.99892710.036.0
6State-govBachelorsMarried-civ-spouseExec-managerialHusbandWhiteFalse391584.99652813.049.0
7Self-emp-not-inc1st-4thDivorcedCraft-repairNot-in-familyWhiteFalse130435.9993902.071.0
8PrivateBachelorsNever-marriedProf-specialtyOwn-childWhiteFalse62507.00394013.022.0
9PrivateHS-gradMarried-civ-spouseHandlers-cleanersOwn-childWhiteFalse236696.0009039.024.0
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dls = to.dataloaders()\n", "dls.valid.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Not being used now - for multi-modal" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TensorTabular(fastuple):\n", " def get_ctxs(self, max_n=10, **kwargs):\n", " n_samples = min(self[0].shape[0], max_n)\n", " df = pd.DataFrame(index = range(n_samples))\n", " return [df.iloc[i] for i in range(n_samples)]\n", "\n", " def display(self, ctxs): display_df(pd.DataFrame(ctxs))\n", "\n", "class TabularLine(pd.Series):\n", " \"A line of a dataframe that knows how to show itself\"\n", " def show(self, ctx=None, **kwargs): return self if ctx is None else ctx.append(self)\n", "\n", "class ReadTabLine(ItemTransform):\n", " def __init__(self, proc): self.proc = proc\n", "\n", " def encodes(self, row):\n", " cats,conts = (o.map(row.__getitem__) for o in (self.proc.cat_names,self.proc.cont_names))\n", " return TensorTabular(tensor(cats).long(),tensor(conts).float())\n", "\n", " def decodes(self, o):\n", " to = TabularPandas(o, self.proc.cat_names, self.proc.cont_names, self.proc.y_names)\n", " to = self.proc.decode(to)\n", " return TabularLine(pd.Series({c: v for v,c in zip(to.items[0]+to.items[1], self.proc.cat_names+self.proc.cont_names)}))\n", "\n", "class ReadTabTarget(ItemTransform):\n", " def __init__(self, proc): self.proc = proc\n", " def encodes(self, row): return row[self.proc.y_names].astype(np.int64)\n", " def decodes(self, o): return Category(self.proc.classes[self.proc.y_names][o])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# tds = TfmdDS(to.items, tfms=[[ReadTabLine(proc)], ReadTabTarget(proc)])\n", "# enc = tds[1]\n", "# test_eq(enc[0][0], tensor([2,1]))\n", "# test_close(enc[0][1], tensor([-0.628828]))\n", "# test_eq(enc[1], 1)\n", "\n", "# dec = tds.decode(enc)\n", "# assert isinstance(dec[0], TabularLine)\n", "# test_close(dec[0], pd.Series({'a': 1, 'b_na': False, 'b': 1}))\n", "# test_eq(dec[1], 'a')\n", "\n", "# test_stdout(lambda: print(show_at(tds, 1)), \"\"\"a 1\n", "# b_na False\n", "# b 1\n", "# category a\n", "# dtype: object\"\"\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_torch_core.ipynb.\n", "Converted 01_layers.ipynb.\n", "Converted 02_data.load.ipynb.\n", "Converted 03_data.core.ipynb.\n", "Converted 04_data.external.ipynb.\n", "Converted 05_data.transforms.ipynb.\n", "Converted 06_data.block.ipynb.\n", "Converted 07_vision.core.ipynb.\n", "Converted 08_vision.data.ipynb.\n", "Converted 09_vision.augment.ipynb.\n", "Converted 09b_vision.utils.ipynb.\n", "Converted 09c_vision.widgets.ipynb.\n", "Converted 10_tutorial.pets.ipynb.\n", "Converted 11_vision.models.xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_callback.core.ipynb.\n", "Converted 13a_learner.ipynb.\n", "Converted 13b_metrics.ipynb.\n", "Converted 14_callback.schedule.ipynb.\n", "Converted 14a_callback.data.ipynb.\n", "Converted 15_callback.hook.ipynb.\n", "Converted 15a_vision.models.unet.ipynb.\n", "Converted 16_callback.progress.ipynb.\n", "Converted 17_callback.tracker.ipynb.\n", "Converted 18_callback.fp16.ipynb.\n", "Converted 18a_callback.training.ipynb.\n", "Converted 19_callback.mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision.learner.ipynb.\n", "Converted 22_tutorial.imagenette.ipynb.\n", "Converted 23_tutorial.vision.ipynb.\n", "Converted 24_tutorial.siamese.ipynb.\n", "Converted 24_vision.gan.ipynb.\n", "Converted 30_text.core.ipynb.\n", "Converted 31_text.data.ipynb.\n", "Converted 32_text.models.awdlstm.ipynb.\n", "Converted 33_text.models.core.ipynb.\n", "Converted 34_callback.rnn.ipynb.\n", "Converted 35_tutorial.wikitext.ipynb.\n", "Converted 36_text.models.qrnn.ipynb.\n", "Converted 37_text.learner.ipynb.\n", "Converted 38_tutorial.text.ipynb.\n", "Converted 39_tutorial.transformers.ipynb.\n", "Converted 40_tabular.core.ipynb.\n", "Converted 41_tabular.data.ipynb.\n", "Converted 42_tabular.model.ipynb.\n", "Converted 43_tabular.learner.ipynb.\n", "Converted 44_tutorial.tabular.ipynb.\n", "Converted 45_collab.ipynb.\n", "Converted 46_tutorial.collab.ipynb.\n", "Converted 50_tutorial.datablock.ipynb.\n", "Converted 60_medical.imaging.ipynb.\n", "Converted 61_tutorial.medical_imaging.ipynb.\n", "Converted 65_medical.text.ipynb.\n", "Converted 70_callback.wandb.ipynb.\n", "Converted 71_callback.tensorboard.ipynb.\n", "Converted 72_callback.neptune.ipynb.\n", "Converted 73_callback.captum.ipynb.\n", "Converted 74_callback.cutmix.ipynb.\n", "Converted 97_test_utils.ipynb.\n", "Converted 99_pytorch_doc.ipynb.\n", "Converted index.ipynb.\n", "Converted tutorial.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 1 }