{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp collab\n", "#default_class_lvl 3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.tabular.all import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Collaborative filtering\n", "\n", "> Tools to quickly get the data and train models suitable for collaborative filtering" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This module contains all the high-level functions you need in a collaborative filtering application to assemble your data, get a model and train it with a `Learner`. We will go other those in order but you can also check the [collaborative filtering tutorial](http://dev.fast.ai/tutorial.collab)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gather the data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TabularCollab(TabularPandas):\n", " \"Instance of `TabularPandas` suitable for collaborative filtering (with no continuous variable)\"\n", " with_cont=False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is just to use the internal of the tabular application, don't worry about it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class CollabDataLoaders(DataLoaders):\n", " \"Base `DataLoaders` for collaborative filtering.\"\n", " @delegates(DataLoaders.from_dblock)\n", " @classmethod\n", " def from_df(cls, ratings, valid_pct=0.2, user_name=None, item_name=None, rating_name=None, seed=None, path='.', **kwargs):\n", " \"Create a `DataLoaders` suitable for collaborative filtering from `ratings`.\"\n", " user_name = ifnone(user_name, ratings.columns[0])\n", " item_name = ifnone(item_name, ratings.columns[1])\n", " rating_name = ifnone(rating_name, ratings.columns[2])\n", " cat_names = [user_name,item_name]\n", " splits = RandomSplitter(valid_pct=valid_pct, seed=seed)(range_of(ratings))\n", " to = TabularCollab(ratings, [Categorify], cat_names, y_names=[rating_name], y_block=TransformBlock(), splits=splits)\n", " return to.dataloaders(path=path, **kwargs)\n", "\n", " @classmethod\n", " def from_csv(cls, csv, **kwargs):\n", " \"Create a `DataLoaders` suitable for collaborative filtering from `csv`.\"\n", " return cls.from_df(pd.read_csv(csv), **kwargs)\n", "\n", "CollabDataLoaders.from_csv = delegates(to=CollabDataLoaders.from_df)(CollabDataLoaders.from_csv)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This class should not be used directly, one of the factory methods should be prefered instead. All those factory methods accept as arguments:\n", "\n", "- `valid_pct`: the random percentage of the dataset to set aside for validation (with an optional `seed`)\n", "- `user_name`: the name of the column containing the user (defaults to the first column)\n", "- `item_name`: the name of the column containing the item (defaults to the second column)\n", "- `rating_name`: the name of the column containing the rating (defaults to the third column)\n", "- `path`: the folder where to work\n", "- `bs`: the batch size\n", "- `val_bs`: the batch size for the validation `DataLoader` (defaults to `bs`)\n", "- `shuffle_train`: if we shuffle the training `DataLoader` or not\n", "- `device`: the PyTorch device to use (defaults to `default_device()`)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

CollabDataLoaders.from_df[source]

\n", "\n", "> CollabDataLoaders.from_df(**`ratings`**, **`valid_pct`**=*`0.2`*, **`user_name`**=*`None`*, **`item_name`**=*`None`*, **`rating_name`**=*`None`*, **`seed`**=*`None`*, **`path`**=*`'.'`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n", "\n", "Create a [`DataLoaders`](/data.core#DataLoaders) suitable for collaborative filtering from `ratings`." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(CollabDataLoaders.from_df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see how this works on an example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
userIdmovieIdratingtimestamp
07310974.01255504951
15619243.51172695223
21572603.51291598691
335812105.0957481884
41303162.01138999234
\n", "
" ], "text/plain": [ " userId movieId rating timestamp\n", "0 73 1097 4.0 1255504951\n", "1 561 924 3.5 1172695223\n", "2 157 260 3.5 1291598691\n", "3 358 1210 5.0 957481884\n", "4 130 316 2.0 1138999234" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.ML_SAMPLE)\n", "ratings = pd.read_csv(path/'ratings.csv')\n", "ratings.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
userIdmovieIdrating
015712653.0
113028582.0
248111965.0
31055973.5
41285975.0
558759524.0
611125715.0
71053563.0
87753494.0
911912404.0
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dls = CollabDataLoaders.from_df(ratings, bs=64)\n", "dls.show_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

CollabDataLoaders.from_csv[source]

\n", "\n", "> CollabDataLoaders.from_csv(**`csv`**, **`valid_pct`**=*`0.2`*, **`user_name`**=*`None`*, **`item_name`**=*`None`*, **`rating_name`**=*`None`*, **`seed`**=*`None`*, **`path`**=*`'.'`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n", "\n", "Create a [`DataLoaders`](/data.core#DataLoaders) suitable for collaborative filtering from `csv`." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(CollabDataLoaders.from_csv)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = CollabDataLoaders.from_csv(path/'ratings.csv', bs=64)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "fastai provides two kinds of models for collaborative filtering: a dot-product model and a neural net. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class EmbeddingDotBias(Module):\n", " \"Base dot model for collaborative filtering.\"\n", " def __init__(self, n_factors, n_users, n_items, y_range=None):\n", " self.y_range = y_range\n", " (self.u_weight, self.i_weight, self.u_bias, self.i_bias) = [Embedding(*o) for o in [\n", " (n_users, n_factors), (n_items, n_factors), (n_users,1), (n_items,1)\n", " ]]\n", "\n", " def forward(self, x):\n", " users,items = x[:,0],x[:,1]\n", " dot = self.u_weight(users)* self.i_weight(items)\n", " res = dot.sum(1) + self.u_bias(users).squeeze() + self.i_bias(items).squeeze()\n", " if self.y_range is None: return res\n", " return torch.sigmoid(res) * (self.y_range[1]-self.y_range[0]) + self.y_range[0]\n", "\n", " @classmethod\n", " def from_classes(cls, n_factors, classes, user=None, item=None, y_range=None):\n", " \"Build a model with `n_factors` by inferring `n_users` and `n_items` from `classes`\"\n", " if user is None: user = list(classes.keys())[0]\n", " if item is None: item = list(classes.keys())[1]\n", " res = cls(n_factors, len(classes[user]), len(classes[item]), y_range=y_range)\n", " res.classes,res.user,res.item = classes,user,item\n", " return res\n", "\n", " def _get_idx(self, arr, is_item=True):\n", " \"Fetch item or user (based on `is_item`) for all in `arr`\"\n", " assert hasattr(self, 'classes'), \"Build your model with `EmbeddingDotBias.from_classes` to use this functionality.\"\n", " classes = self.classes[self.item] if is_item else self.classes[self.user]\n", " c2i = {v:k for k,v in enumerate(classes)}\n", " try: return tensor([c2i[o] for o in arr])\n", " except Exception as e:\n", " print(f\"\"\"You're trying to access {'an item' if is_item else 'a user'} that isn't in the training data.\n", " If it was in your original data, it may have been split such that it's only in the validation set now.\"\"\")\n", "\n", " def bias(self, arr, is_item=True):\n", " \"Bias for item or user (based on `is_item`) for all in `arr`\"\n", " idx = self._get_idx(arr, is_item)\n", " layer = (self.i_bias if is_item else self.u_bias).eval().cpu()\n", " return to_detach(layer(idx).squeeze())\n", "\n", " def weight(self, arr, is_item=True):\n", " \"Weight for item or user (based on `is_item`) for all in `arr`\"\n", " idx = self._get_idx(arr, is_item)\n", " layer = (self.i_weight if is_item else self.u_weight).eval().cpu()\n", " return to_detach(layer(idx))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model is built with `n_factors` (the length of the internal vectors), `n_users` and `n_items`. For a given user and item, it grabs the corresponding weights and bias and returns\n", "``` python\n", "torch.dot(user_w, item_w) + user_b + item_b\n", "```\n", "Optionally, if `y_range` is passed, it applies a `SigmoidRange` to that result." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = dls.one_batch()\n", "model = EmbeddingDotBias(50, len(dls.classes['userId']), len(dls.classes['movieId']), y_range=(0,5)\n", " ).to(x.device)\n", "out = model(x)\n", "assert (0 <= out).all() and (out <= 5).all()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

EmbeddingDotBias.from_classes[source]

\n", "\n", "> EmbeddingDotBias.from_classes(**`n_factors`**, **`classes`**, **`user`**=*`None`*, **`item`**=*`None`*, **`y_range`**=*`None`*)\n", "\n", "Build a model with `n_factors` by inferring `n_users` and `n_items` from `classes`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(EmbeddingDotBias.from_classes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`y_range` is passed to the main init. `user` and `item` are the names of the keys for users and items in `classes` (default to the first and second key respectively). `classes` is expected to be a dictionary key to list of categories like the result of `dls.classes` in a `CollabDataLoaders`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'userId': (#101) ['#na#',15,17,19,23,30,48,56,73,77...],\n", " 'movieId': (#101) ['#na#',1,10,32,34,39,47,50,110,150...]}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dls.classes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see how it can be used in practice:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = EmbeddingDotBias.from_classes(50, dls.classes, y_range=(0,5)\n", " ).to(x.device)\n", "out = model(x)\n", "assert (0 <= out).all() and (out <= 5).all()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Two convenience methods are added to easily access the weights and bias when a model is created with `EmbeddingDotBias.from_classes`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

EmbeddingDotBias.weight[source]

\n", "\n", "> EmbeddingDotBias.weight(**`arr`**, **`is_item`**=*`True`*)\n", "\n", "Weight for item or user (based on `is_item`) for all in `arr`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(EmbeddingDotBias.weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The elements of `arr` are expected to be class names (which is why the model needs to be created with `EmbeddingDotBias.from_classes`)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mov = dls.classes['movieId'][42] \n", "w = model.weight([mov])\n", "test_eq(w, model.i_weight(tensor([42])))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

EmbeddingDotBias.bias[source]

\n", "\n", "> EmbeddingDotBias.bias(**`arr`**, **`is_item`**=*`True`*)\n", "\n", "Bias for item or user (based on `is_item`) for all in `arr`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(EmbeddingDotBias.bias)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The elements of `arr` are expected to be class names (which is why the model needs to be created with `EmbeddingDotBias.from_classes`)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mov = dls.classes['movieId'][42] \n", "b = model.bias([mov])\n", "test_eq(b, model.i_bias(tensor([42])))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "class EmbeddingNN(TabularModel):\n", " \"Subclass `TabularModel` to create a NN suitable for collaborative filtering.\"\n", " @delegates(TabularModel.__init__)\n", " def __init__(self, emb_szs, layers, **kwargs):\n", " super().__init__(emb_szs=emb_szs, n_cont=0, out_sz=1, layers=layers, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class EmbeddingNN[source]

\n", "\n", "> EmbeddingNN(**`emb_szs`**, **`layers`**, **`ps`**=*`None`*, **`embed_p`**=*`0.0`*, **`y_range`**=*`None`*, **`use_bn`**=*`True`*, **`bn_final`**=*`False`*, **`bn_cont`**=*`True`*) :: [`TabularModel`](/tabular.model#TabularModel)\n", "\n", "Subclass [`TabularModel`](/tabular.model#TabularModel) to create a NN suitable for collaborative filtering." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(EmbeddingNN)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`emb_szs` should be a list of two tuples, one for the users, one for the items, each tuple containing the number of users/items and the corresponding embedding size (the function `get_emb_sz` can give a good default). All the other arguments are passed to `TabularModel`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "emb_szs = get_emb_sz(dls.train_ds, {})\n", "model = EmbeddingNN(emb_szs, [50], y_range=(0,5)\n", " ).to(x.device)\n", "out = model(x)\n", "assert (0 <= out).all() and (out <= 5).all()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a `Learner`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following function lets us quickly create a `Learner` for collaborative filtering from the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "@log_args(to_return=True, but_as=Learner.__init__)\n", "@delegates(Learner.__init__)\n", "def collab_learner(dls, n_factors=50, use_nn=False, emb_szs=None, layers=None, config=None, y_range=None, loss_func=None, **kwargs):\n", " \"Create a Learner for collaborative filtering on `dls`.\"\n", " emb_szs = get_emb_sz(dls, ifnone(emb_szs, {}))\n", " if loss_func is None: loss_func = MSELossFlat()\n", " if config is None: config = tabular_config()\n", " if y_range is not None: config['y_range'] = y_range\n", " if layers is None: layers = [n_factors]\n", " if use_nn: model = EmbeddingNN(emb_szs=emb_szs, layers=layers, **config)\n", " else: model = EmbeddingDotBias.from_classes(n_factors, dls.classes, y_range=y_range)\n", " return Learner(dls, model, loss_func=loss_func, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If `use_nn=False`, the model used is an `EmbeddingDotBias` with `n_factors` and `y_range`. Otherwise, it's a `EmbeddingNN` for which you can pass `emb_szs` (will be infered from the `dls` with `get_emb_sz` if you don't provide any), `layers` (defaults to `[n_factors]`) `y_range`, and a `config` that you can create with `tabular_config` to customize your model. \n", "\n", "`loss_func` will default to `MSELossFlat` and all the other arguments are passed to `Learner`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = collab_learner(dls, y_range=(0,5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
02.5219792.54162700:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_torch_core.ipynb.\n", "Converted 01_layers.ipynb.\n", "Converted 02_data.load.ipynb.\n", "Converted 03_data.core.ipynb.\n", "Converted 04_data.external.ipynb.\n", "Converted 05_data.transforms.ipynb.\n", "Converted 06_data.block.ipynb.\n", "Converted 07_vision.core.ipynb.\n", "Converted 08_vision.data.ipynb.\n", "Converted 09_vision.augment.ipynb.\n", "Converted 09b_vision.utils.ipynb.\n", "Converted 09c_vision.widgets.ipynb.\n", "Converted 10_tutorial.pets.ipynb.\n", "Converted 11_vision.models.xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_callback.core.ipynb.\n", "Converted 13a_learner.ipynb.\n", "Converted 13b_metrics.ipynb.\n", "Converted 14_callback.schedule.ipynb.\n", "Converted 14a_callback.data.ipynb.\n", "Converted 15_callback.hook.ipynb.\n", "Converted 15a_vision.models.unet.ipynb.\n", "Converted 16_callback.progress.ipynb.\n", "Converted 17_callback.tracker.ipynb.\n", "Converted 18_callback.fp16.ipynb.\n", "Converted 18a_callback.training.ipynb.\n", "Converted 19_callback.mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision.learner.ipynb.\n", "Converted 22_tutorial.imagenette.ipynb.\n", "Converted 23_tutorial.vision.ipynb.\n", "Converted 24_tutorial.siamese.ipynb.\n", "Converted 24_vision.gan.ipynb.\n", "Converted 30_text.core.ipynb.\n", "Converted 31_text.data.ipynb.\n", "Converted 32_text.models.awdlstm.ipynb.\n", "Converted 33_text.models.core.ipynb.\n", "Converted 34_callback.rnn.ipynb.\n", "Converted 35_tutorial.wikitext.ipynb.\n", "Converted 36_text.models.qrnn.ipynb.\n", "Converted 37_text.learner.ipynb.\n", "Converted 38_tutorial.text.ipynb.\n", "Converted 39_tutorial.transformers.ipynb.\n", "Converted 40_tabular.core.ipynb.\n", "Converted 41_tabular.data.ipynb.\n", "Converted 42_tabular.model.ipynb.\n", "Converted 43_tabular.learner.ipynb.\n", "Converted 44_tutorial.tabular.ipynb.\n", "Converted 45_collab.ipynb.\n", "Converted 46_tutorial.collab.ipynb.\n", "Converted 50_tutorial.datablock.ipynb.\n", "Converted 60_medical.imaging.ipynb.\n", "Converted 61_tutorial.medical_imaging.ipynb.\n", "Converted 65_medical.text.ipynb.\n", "Converted 70_callback.wandb.ipynb.\n", "Converted 71_callback.tensorboard.ipynb.\n", "Converted 72_callback.neptune.ipynb.\n", "Converted 73_callback.captum.ipynb.\n", "Converted 74_callback.cutmix.ipynb.\n", "Converted 97_test_utils.ipynb.\n", "Converted 99_pytorch_doc.ipynb.\n", "Converted index.ipynb.\n", "Converted tutorial.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import *\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }