{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Collaborative filtering" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This package contains all the necessary functions to quickly train a model for a collaborative filtering task. Let's start by importing all we'll need." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.collab import * " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Collaborative filtering is when you're tasked to predict how much a user is going to like a certain item. The fastai library contains a [`CollabFilteringDataset`](/collab.html#CollabFilteringDataset) class that will help you create datasets suitable for training, and a function `get_colab_learner` to build a simple model directly from a ratings table. Let's first see how we can get started before delving into the documentation.\n", "\n", "For this example, we'll use a small subset of the [MovieLens](https://grouplens.org/datasets/movielens/) dataset to predict the rating a user would give a particular movie (from 0 to 5). The dataset comes in the form of a csv file where each line is a rating of a movie by a given person." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | userId | \n", "movieId | \n", "rating | \n", "timestamp | \n", "
---|---|---|---|---|
0 | \n", "73 | \n", "1097 | \n", "4.0 | \n", "1255504951 | \n", "
1 | \n", "561 | \n", "924 | \n", "3.5 | \n", "1172695223 | \n", "
2 | \n", "157 | \n", "260 | \n", "3.5 | \n", "1291598691 | \n", "
3 | \n", "358 | \n", "1210 | \n", "5.0 | \n", "957481884 | \n", "
4 | \n", "130 | \n", "316 | \n", "2.0 | \n", "1138999234 | \n", "
epoch | \n", "train_loss | \n", "valid_loss | \n", "
---|---|---|
1 | \n", "2.427430 | \n", "1.999472 | \n", "
2 | \n", "1.116335 | \n", "0.663345 | \n", "
3 | \n", "0.736155 | \n", "0.636640 | \n", "
4 | \n", "0.612827 | \n", "0.626773 | \n", "
5 | \n", "0.565003 | \n", "0.626336 | \n", "
class
CollabDataBunch
[source][test]CollabDataBunch
(**`train_dl`**:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), **`valid_dl`**:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), **`fix_dl`**:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)=***`None`***, **`test_dl`**:`Optional`\\[[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)\\]=***`None`***, **`device`**:[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)=***`None`***, **`dl_tfms`**:`Optional`\\[`Collection`\\[`Callable`\\]\\]=***`None`***, **`path`**:`PathOrStr`=***`'.'`***, **`collate_fn`**:`Callable`=***`'data_collate'`***, **`no_check`**:`bool`=***`False`***) :: [`DataBunch`](/basic_data.html#DataBunch)\n",
"\n",
"No tests found for CollabDataBunch
. To contribute a test please refer to this guide and this discussion.
from_df
[source][test]from_df
(**`ratings`**:`DataFrame`, **`valid_pct`**:`float`=***`0.2`***, **`user_name`**:`Optional`\\[`str`\\]=***`None`***, **`item_name`**:`Optional`\\[`str`\\]=***`None`***, **`rating_name`**:`Optional`\\[`str`\\]=***`None`***, **`test`**:`DataFrame`=***`None`***, **`seed`**:`int`=***`None`***, **`path`**:`PathOrStr`=***`'.'`***, **`bs`**:`int`=***`64`***, **`val_bs`**:`int`=***`None`***, **`num_workers`**:`int`=***`16`***, **`dl_tfms`**:`Optional`\\[`Collection`\\[`Callable`\\]\\]=***`None`***, **`device`**:[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)=***`None`***, **`collate_fn`**:`Callable`=***`'data_collate'`***, **`no_check`**:`bool`=***`False`***) → `CollabDataBunch`\n",
"\n",
"No tests found for from_df
. To contribute a test please refer to this guide and this discussion.
class
CollabLearner
[source][test]CollabLearner
(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`opt_func`**:`Callable`=***`'Adam'`***, **`loss_func`**:`Callable`=***`None`***, **`metrics`**:`Collection`\\[`Callable`\\]=***`None`***, **`true_wd`**:`bool`=***`True`***, **`bn_wd`**:`bool`=***`True`***, **`wd`**:`Floats`=***`0.01`***, **`train_bn`**:`bool`=***`True`***, **`path`**:`str`=***`None`***, **`model_dir`**:`PathOrStr`=***`'models'`***, **`callback_fns`**:`Collection`\\[`Callable`\\]=***`None`***, **`callbacks`**:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=***`No tests found for CollabLearner
. To contribute a test please refer to this guide and this discussion.
bias
[source][test]bias
(**`arr`**:`Collection`\\[`T_co`\\], **`is_item`**:`bool`=***`True`***)\n",
"\n",
"No tests found for bias
. To contribute a test please refer to this guide and this discussion.
get_idx
[source][test]get_idx
(**`arr`**:`Collection`\\[`T_co`\\], **`is_item`**:`bool`=***`True`***)\n",
"\n",
"No tests found for get_idx
. To contribute a test please refer to this guide and this discussion.
weight
[source][test]weight
(**`arr`**:`Collection`\\[`T_co`\\], **`is_item`**:`bool`=***`True`***)\n",
"\n",
"No tests found for weight
. To contribute a test please refer to this guide and this discussion.
class
EmbeddingDotBias
[source][test]EmbeddingDotBias
(**`n_factors`**:`int`, **`n_users`**:`int`, **`n_items`**:`int`, **`y_range`**:`Point`=***`None`***) :: [`PrePostInitMeta`](/core.html#PrePostInitMeta) :: [`Module`](/torch_core.html#Module)\n",
"\n",
"No tests found for EmbeddingDotBias
. To contribute a test please refer to this guide and this discussion.
class
EmbeddingNN
[source][test]EmbeddingNN
(**`emb_szs`**:`ListSizes`, **`layers`**:`Collection`\\[`int`\\]=***`None`***, **`ps`**:`Collection`\\[`float`\\]=***`None`***, **`emb_drop`**:`float`=***`0.0`***, **`y_range`**:`OptRange`=***`None`***, **`use_bn`**:`bool`=***`True`***, **`bn_final`**:`bool`=***`False`***) :: [`PrePostInitMeta`](/core.html#PrePostInitMeta) :: [`TabularModel`](/tabular.models.html#TabularModel)\n",
"\n",
"No tests found for EmbeddingNN
. To contribute a test please refer to this guide and this discussion.
collab_learner
[source][test]collab_learner
(**`data`**, **`n_factors`**:`int`=***`None`***, **`use_nn`**:`bool`=***`False`***, **`emb_szs`**:`Dict`\\[`str`, `int`\\]=***`None`***, **`layers`**:`Collection`\\[`int`\\]=***`None`***, **`ps`**:`Collection`\\[`float`\\]=***`None`***, **`emb_drop`**:`float`=***`0.0`***, **`y_range`**:`OptRange`=***`None`***, **`use_bn`**:`bool`=***`True`***, **`bn_final`**:`bool`=***`False`***, **\\*\\*`learn_kwargs`**) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for collab_learner
. To contribute a test please refer to this guide and this discussion.
class
CollabLine
[source][test]CollabLine
(**`cats`**, **`conts`**, **`classes`**, **`names`**) :: [`TabularLine`](/tabular.data.html#TabularLine)\n",
"\n",
"No tests found for CollabLine
. To contribute a test please refer to this guide and this discussion.
class
CollabList
[source][test]CollabList
(**`items`**:`Iterator`\\[`T_co`\\], **`cat_names`**:`OptStrList`=***`None`***, **`cont_names`**:`OptStrList`=***`None`***, **`procs`**=***`None`***, **\\*\\*`kwargs`**) → `TabularList` :: [`TabularList`](/tabular.data.html#TabularList)\n",
"\n",
"No tests found for CollabList
. To contribute a test please refer to this guide and this discussion.
forward
[source][test]forward
(**`users`**:`LongTensor`, **`items`**:`LongTensor`) → `Tensor`\n",
"\n",
"No tests found for forward
. To contribute a test please refer to this guide and this discussion.
reconstruct
[source][test]reconstruct
(**`t`**:`Tensor`)\n",
"\n",
"No tests found for reconstruct
. To contribute a test please refer to this guide and this discussion.
forward
[source][test]forward
(**`users`**:`LongTensor`, **`items`**:`LongTensor`) → `Tensor`\n",
"\n",
"No tests found for forward
. To contribute a test please refer to this guide and this discussion.