{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import * # Quick access to most common functionality\n", "from fastai.collab import * # Quick access to collab filtering functionality" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Collaborative filtering example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`collab` models use data in a `DataFrame` of user, items, and ratings." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('/data1/jhoward/git/fastai/fastai/../data/movie_lens_sample')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.ML_SAMPLE)\n", "path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
userIdmovieIdratingtimestamp
07310974.01255504951
15619243.51172695223
21572603.51291598691
335812105.0957481884
41303162.01138999234
\n", "
" ], "text/plain": [ " userId movieId rating timestamp\n", "0 73 1097 4.0 1255504951\n", "1 561 924 3.5 1172695223\n", "2 157 260 3.5 1291598691\n", "3 358 1210 5.0 957481884\n", "4 130 316 2.0 1138999234" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ratings = pd.read_csv(path/'ratings.csv')\n", "series2cat(ratings, 'userId', 'movieId')\n", "ratings.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's all we need to create and train a model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=4), HTML(value='0.00% [0/4 00:00<00:00]'))), HTML(value…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:04\n", "epoch train loss valid loss\n", "0 2.214395 1.604201 (00:01)\n", "1 1.006937 0.719938 (00:01)\n", "2 0.704926 0.713904 (00:01)\n", "3 0.600082 0.709458 (00:01)\n", "\n" ] } ], "source": [ "learn = get_collab_learner(ratings, n_factors=50, min_score=0., max_score=5.)\n", "learn.fit_one_cycle(4, 5e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }