{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Collaborative filtering" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This package contains all the necessary functions to quickly train a model for a collaborative filtering task. Let's start by importing all we'll need." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import *\n", "from fastai.collab import * " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Collaborative filtering is when you're tasked to predict how much a user is going to like a certain item. The fastai library contains a [`CollabFilteringDataset`](/collab.html#CollabFilteringDataset) class that will help you create datasets suitable for training, and a function `get_colab_learner` to build a simple model directly from a ratings table. Let's first see how we can get started before devling in the documentation.\n", "\n", "For our example, we'll use a small subset of the [MovieLens](https://grouplens.org/datasets/movielens/) dataset. In there, we have to predict the rating a user gave a given movie (from 0 to 5). It comes in the form of a csv file where each line is the rating of a movie by a given person." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | userId | \n", "movieId | \n", "rating | \n", "timestamp | \n", "
---|---|---|---|---|
0 | \n", "73 | \n", "1097 | \n", "4.0 | \n", "1255504951 | \n", "
1 | \n", "561 | \n", "924 | \n", "3.5 | \n", "1172695223 | \n", "
2 | \n", "157 | \n", "260 | \n", "3.5 | \n", "1291598691 | \n", "
3 | \n", "358 | \n", "1210 | \n", "5.0 | \n", "957481884 | \n", "
4 | \n", "130 | \n", "316 | \n", "2.0 | \n", "1138999234 | \n", "
class
CollabDataBunch
[source]CollabDataBunch
(`train_dl`:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), `valid_dl`:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), `test_dl`:`Optional`\\[[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)\\]=`None`, `device`:[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)=`None`, `tfms`:`Optional`\\[`Collection`\\[`Callable`\\]\\]=`None`, `path`:`PathOrStr`=`'.'`, `collate_fn`:`Callable`=`'data_collate'`) :: [`DataBunch`](/basic_data.html#DataBunch)"
],
"text/plain": [
"from_df
[source]from_df
(`ratings`:`DataFrame`, `pct_val`:`float`=`0.2`, `user_name`:`Optional`\\[`str`\\]=`None`, `item_name`:`Optional`\\[`str`\\]=`None`, `rating_name`:`Optional`\\[`str`\\]=`None`, `test`:`DataFrame`=`None`, `seed`=`None`, `kwargs`)"
],
"text/plain": [
"class
EmbeddingDotBias
[source]EmbeddingDotBias
(`n_factors`:`int`, `n_users`:`int`, `n_items`:`int`, `y_range`:`Point`=`None`) :: [`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)"
],
"text/plain": [
"collab_learner
[source]collab_learner
(`data`, `n_factors`:`int`=`None`, `use_nn`:`bool`=`False`, `metrics`=`None`, `emb_szs`:`Dict`\\[`str`, `int`\\]=`None`, `wd`:`float`=`0.01`, `kwargs`) → [`Learner`](/basic_train.html#Learner)"
],
"text/plain": [
"class
CollabLine
[source]CollabLine
(`cats`, `conts`, `classes`, `names`) :: [`TabularLine`](/tabular.data.html#TabularLine)"
],
"text/plain": [
"class
CollabList
[source]CollabList
(`items`:`Iterator`, `cat_names`:`OptStrList`=`None`, `cont_names`:`OptStrList`=`None`, `procs`=`None`, `kwargs`) → `TabularList` :: [`TabularList`](/tabular.data.html#TabularList)"
],
"text/plain": [
"forward
[source]forward
(`users`:`LongTensor`, `items`:`LongTensor`) → `Tensor`\n",
"\n",
"Defines the computation performed at every call. Should be overridden by all subclasses.\n",
"\n",
".. note::\n",
" Although the recipe for forward pass needs to be defined within\n",
" this function, one should call the :class:`Module` instance afterwards\n",
" instead of this since the former takes care of running the\n",
" registered hooks while the latter silently ignores them. "
],
"text/plain": [
"