{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# RePlay Tutorial\n", "This notebook is designed to familiarize with the use of RePlay library, including \n", "- data preprocessing\n", "- data splitting\n", "- model training and inference\n", "- model optimization\n", "- model saving and loading\n", "- models comparison" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2020-02-10T16:01:45.639135Z", "start_time": "2020-02-10T16:01:45.612577Z" }, "jupyter": { "outputs_hidden": false }, "pycharm": { "is_executing": false } }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%config Completer.use_jedi = False" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "from optuna.exceptions import ExperimentalWarning\n", "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", "warnings.filterwarnings(\"ignore\", category=ExperimentalWarning)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from pyspark.sql.functions import rand\n", "\n", "from replay.data_preparator import DataPreparator\n", "from replay.experiment import Experiment\n", "from replay.metrics import Coverage, HitRate, NDCG, MAP\n", "from replay.model_handler import save, load\n", "from replay.models import ALSWrap, KNN, SLIM\n", "from replay.session_handler import State\n", "from replay.splitters import UserSplitter\n", "from replay.utils import convert2spark" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "K = 5\n", "SEED=1234" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 0. Data preprocessing \n", "We will use MovieLens 1m as an example." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2020-02-10T15:59:42.041251Z", "start_time": "2020-02-10T15:59:09.230636Z" }, "jupyter": { "outputs_hidden": false }, "scrolled": true }, "outputs": [], "source": [ "df = pd.read_csv(\"data/ml1m_ratings.dat\", sep=\"\\t\", names=[\"user_id\", \"item_id\", \"relevance\", \"timestamp\"])\n", "users = pd.read_csv(\"data/ml1m_users.dat\", sep=\"\\t\", names=[\"user_id\", \"gender\", \"age\", \"occupation\", \"zip_code\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 0.1. DataPreparator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An inner data format in RePlay is a spark dataframe.\n", "You can pass spark or pandas dataframe as an input. Columns ``item_id`` and ``user_id`` are required for interaction matrix.\n", "Optional columns for interaction matrix are ``relevance`` and interaction ``timestamp``. \n", "\n", "We implemented DataPreparator class to convert dataframes to spark format and preprocess the data, including renaming/creation of required and optional interaction matrix columns, null check and dates parsing.\n", "\n", "To convert pandas dataframe to spark as is use function ``convert_to_spark`` from ``replay.utils``." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: An illegal reflective access operation has occurred\n", "WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/home/u19893556/miniconda3/envs/replay/lib/python3.7/site-packages/pyspark/jars/spark-unsafe_2.12-3.1.2.jar) to constructor java.nio.DirectByteBuffer(long,int)\n", "WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform\n", "WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n", "WARNING: All illegal access operations will be denied in a future release\n", "22/02/27 23:04:22 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n", "Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "22/02/27 23:04:23 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).\n", "22/02/27 23:04:23 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n", "22/02/27 23:04:23 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.\n", "22/02/27 23:04:23 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.\n", " \r" ] } ], "source": [ "preparator = DataPreparator()\n", "log, _, _ = preparator(df)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------+---------+--------+--------+\n", "|relevance|timestamp|user_idx|item_idx|\n", "+---------+---------+--------+--------+\n", "| 5|978300760| 4131| 43|\n", "| 3|978302109| 4131| 585|\n", "| 3|978301968| 4131| 461|\n", "+---------+---------+--------+--------+\n", "only showing top 3 rows\n", "\n" ] } ], "source": [ "log.show(3)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-------+------+---+----------+--------+\n", "|user_id|gender|age|occupation|zip_code|\n", "+-------+------+---+----------+--------+\n", "| 1| F| 1| 10| 48067|\n", "| 2| M| 56| 16| 70072|\n", "| 3| M| 25| 15| 55117|\n", "+-------+------+---+----------+--------+\n", "only showing top 3 rows\n", "\n" ] } ], "source": [ "users = convert2spark(users)\n", "users.show(3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 0.2. Split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "RePlay provides you with data splitters to reproduce a validation schemas widely-used in recommender systems.\n", "\n", "`UserSplitter` takes ``item_test_size`` items for ``user_test_size`` user to the test dataset." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2020-02-10T15:59:50.986401Z", "start_time": "2020-02-10T15:59:42.042998Z" }, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "22/02/27 23:04:37 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n", "22/02/27 23:04:38 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n", "[Stage 27:============================================> (121 + 23) / 144]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "997709 2500\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", "[Stage 30:===================================> (95 + 48) / 144]\r", "\r", " \r" ] } ], "source": [ "splitter = UserSplitter(\n", " drop_cold_items=True,\n", " drop_cold_users=True,\n", " item_test_size=K,\n", " user_test_size=500,\n", " seed=SEED,\n", " shuffle=True\n", ")\n", "train, test = splitter.split(log)\n", "print(train.count(), test.count())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Models training\n", "\n", "#### SLIM" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "slim = SLIM(seed=SEED)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.53 s, sys: 129 ms, total: 1.66 s\n", "Wall time: 5.9 s\n" ] } ], "source": [ "%%time\n", "\n", "slim.fit(log=train)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "27-Feb-22 23:04:55, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:04:55, replay, WARNING: This model can't predict cold items, they will be ignored\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 23.1 ms, sys: 16.4 ms, total: 39.4 ms\n", "Wall time: 1.77 s\n" ] } ], "source": [ "%%time\n", "\n", "recs = slim.predict(\n", " k=K,\n", " users=test.select('user_idx').distinct(),\n", " log=train,\n", " filter_seen_items=True\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 130:==================================> (94 + 48) / 144]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+--------+--------+------------------+\n", "|user_idx|item_idx| relevance|\n", "+--------+--------+------------------+\n", "| 38| 73| 1.235672623556484|\n", "| 38| 361|1.1715979128347436|\n", "+--------+--------+------------------+\n", "only showing top 2 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "recs.show(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Models evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "RePlay implements some popular recommenders' quality metrics. Use pure metrics or calculate a set of chosen metrics and compare models with the ``Experiment`` class." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2020-02-10T16:07:28.942205Z", "start_time": "2020-02-10T16:07:26.281475Z" }, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "metrics = Experiment(test, {NDCG(): K,\n", " MAP() : K,\n", " HitRate(): [1, K],\n", " Coverage(train): K\n", " })" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 360 ms, sys: 75.5 ms, total: 436 ms\n", "Wall time: 47.5 s\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Coverage@5HitRate@1HitRate@5MAP@5NDCG@5
SLIM0.160550.2420.5580.093720.165643
\n", "
" ], "text/plain": [ " Coverage@5 HitRate@1 HitRate@5 MAP@5 NDCG@5\n", "SLIM 0.16055 0.242 0.558 0.09372 0.165643" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "metrics.add_result(\"SLIM\", recs)\n", "metrics.results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Hyperparameters optimization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.1 Search" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "22/02/27 23:06:17 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n", "22/02/27 23:06:17 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n" ] } ], "source": [ "# data split for hyperparameters optimization\n", "train_opt, val_opt = splitter.split(train)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m[I 2022-02-27 23:06:17,681]\u001b[0m A new study created in memory with name: no-name-b0d54335-8d37-401f-a916-3ba55ed9c932\u001b[0m\n", "27-Feb-22 23:06:22, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:06:22, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:06:51,535]\u001b[0m Trial 0 finished with value: 0.18130037719542139 and parameters: {'beta': 0.01, 'lambda_': 0.01}. Best is trial 0 with value: 0.18130037719542139.\u001b[0m\n", "22/02/27 23:06:51 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:06:51 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:06:54, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:06:54, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:07:31,090]\u001b[0m Trial 1 finished with value: 0.18197356840108678 and parameters: {'beta': 0.003401392505408624, 'lambda_': 0.002240239840999655}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n", "22/02/27 23:07:31 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:07:31 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:07:33, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:07:33, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:07:56,590]\u001b[0m Trial 2 finished with value: 0.10199049759426765 and parameters: {'beta': 1.9301997111553214e-05, 'lambda_': 1.1554917603144903}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n", "22/02/27 23:07:56 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:07:56 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:07:58, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:07:59, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:08:36,960]\u001b[0m Trial 3 finished with value: 0.18040798348695616 and parameters: {'beta': 1.153628706350771e-05, 'lambda_': 2.9530757569977826e-05}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n", "22/02/27 23:08:36 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:08:36 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:08:39, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:08:39, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:08:54,180]\u001b[0m Trial 4 finished with value: 0.1184952197160257 and parameters: {'beta': 0.0007214008774259759, 'lambda_': 0.7632771957535475}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n", "22/02/27 23:08:54 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:08:54 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:08:56, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:08:56, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:09:11,341]\u001b[0m Trial 5 finished with value: 0.10801852484092478 and parameters: {'beta': 0.003501448697693071, 'lambda_': 0.9936237326658697}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n", "22/02/27 23:09:11 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:09:11 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:09:13, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:09:13, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:09:37,837]\u001b[0m Trial 6 finished with value: 0.17982896295330483 and parameters: {'beta': 0.00099662876434958, 'lambda_': 0.03255064745931469}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n", "22/02/27 23:09:37 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:09:37 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:09:40, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:09:40, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:10:08,671]\u001b[0m Trial 7 finished with value: 0.1794415531942852 and parameters: {'beta': 0.0002421671516396994, 'lambda_': 6.591737385850111e-05}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n", "22/02/27 23:10:08 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:10:08 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:10:11, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:10:11, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:10:32,795]\u001b[0m Trial 8 finished with value: 0.17158421706899432 and parameters: {'beta': 2.982305555199248, 'lambda_': 0.047378561915999574}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n", "22/02/27 23:10:32 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:10:32 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:10:35, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:10:35, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:10:50,730]\u001b[0m Trial 9 finished with value: 0.11605432663219907 and parameters: {'beta': 0.18611456836257362, 'lambda_': 0.8088532397607969}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n", "22/02/27 23:10:50 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:10:50 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:10:53, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:10:53, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:11:18,939]\u001b[0m Trial 10 finished with value: 0.184208319077201 and parameters: {'beta': 0.1399219194028095, 'lambda_': 1.0774584742955482e-06}. Best is trial 10 with value: 0.184208319077201.\u001b[0m\n", "22/02/27 23:11:18 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:11:18 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:11:21, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:11:21, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:11:46,639]\u001b[0m Trial 11 finished with value: 0.1847990350329721 and parameters: {'beta': 0.11351011099824757, 'lambda_': 2.678667716748947e-06}. Best is trial 11 with value: 0.1847990350329721.\u001b[0m\n", "22/02/27 23:11:46 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:11:46 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:11:49, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:11:49, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:12:14,589]\u001b[0m Trial 12 finished with value: 0.18428506912745 and parameters: {'beta': 0.14719744446335933, 'lambda_': 1.5136391124700838e-06}. Best is trial 11 with value: 0.1847990350329721.\u001b[0m\n", "22/02/27 23:12:14 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:12:14 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:12:16, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:12:16, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:12:44,060]\u001b[0m Trial 13 finished with value: 0.1847990350329721 and parameters: {'beta': 0.11317072609268032, 'lambda_': 1.963819303556553e-06}. Best is trial 11 with value: 0.1847990350329721.\u001b[0m\n", "22/02/27 23:12:44 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:12:44 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:12:46, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:12:46, replay, WARNING: This model can't predict cold items, they will be ignored\n", "\u001b[32m[I 2022-02-27 23:13:10,407]\u001b[0m Trial 14 finished with value: 0.17283814674301312 and parameters: {'beta': 4.697217749389299, 'lambda_': 2.3961336617398848e-05}. Best is trial 11 with value: 0.1847990350329721.\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 24.9 s, sys: 2.73 s, total: 27.6 s\n", "Wall time: 6min 52s\n" ] } ], "source": [ "%%time\n", "best_params = slim.optimize(train_opt, val_opt, criterion=NDCG(), k=K, budget=15)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'beta': 0.11351011099824757, 'lambda_': 2.678667716748947e-06}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.2 Compare with previous" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def fit_predict_evaluate(model, experiment, name):\n", " model.fit(log=train)\n", "\n", " recs = model.predict(\n", " k=K,\n", " users=test.select('user_idx').distinct(),\n", " log=train,\n", " filter_seen_items=True\n", " )\n", "\n", " experiment.add_result(name, recs)\n", " return recs" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "22/02/27 23:13:15 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:13:15 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:13:18, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:13:18, replay, WARNING: This model can't predict cold items, they will be ignored\n", " 4]]]]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.9 s, sys: 284 ms, total: 2.18 s\n", "Wall time: 52.9 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", "[Stage 2059:==========================================> (120 + 24) / 144]\r", "\r", " \r" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Coverage@5HitRate@1HitRate@5MAP@5NDCG@5
SLIM_optimized0.1475980.2400.5700.0955470.168684
SLIM0.1605500.2420.5580.0937200.165643
\n", "
" ], "text/plain": [ " Coverage@5 HitRate@1 HitRate@5 MAP@5 NDCG@5\n", "SLIM_optimized 0.147598 0.240 0.570 0.095547 0.168684\n", "SLIM 0.160550 0.242 0.558 0.093720 0.165643" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "recs = fit_predict_evaluate(SLIM(**best_params, seed=SEED), metrics, 'SLIM_optimized')\n", "metrics.results.sort_values('NDCG@5', ascending=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Convert to pandas" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 4]\r" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idxitem_idxrelevance
038731.230351
1383611.212302
\n", "
" ], "text/plain": [ " user_idx item_idx relevance\n", "0 38 73 1.230351\n", "1 38 361 1.212302" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "recs_pd = recs.toPandas()\n", "recs_pd.head(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Save and load" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "RePlay allows to save and load fitted models with `save` and `load` functions of `model_handler` module. Model is saved as a folder with all necessary parameters and data." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "save(slim, path='./slim_best_params')\n", "slim_loaded = load('./slim_best_params')" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "27-Feb-22 23:14:31, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:14:31, replay, WARNING: This model can't predict cold items, they will be ignored\n", "[Stage 2161:================================> (90 + 48) / 144]4]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+--------+--------+------------------+\n", "|user_idx|item_idx| relevance|\n", "+--------+--------+------------------+\n", "| 38| 14|1.1936188460415138|\n", "| 38| 73|1.1193345759515603|\n", "+--------+--------+------------------+\n", "only showing top 2 rows\n", "\n", "CPU times: user 67 ms, sys: 3.66 ms, total: 70.7 ms\n", "Wall time: 13.1 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", "[Stage 2161:==========================================> (120 + 24) / 144]\r", "\r", " \r" ] } ], "source": [ "%%time\n", "pred_from_loaded = slim_loaded.predict(k=K,\n", " users=test.select('user_idx').distinct(),\n", " log=train,\n", " filter_seen_items=True)\n", "pred_from_loaded.show(2)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.11351011099824757, 2.678667716748947e-06)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "slim_loaded.beta, slim_loaded.lambda_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Other RePlay models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### ALS\n", "Commonly-used matrix factorization algorithm." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "22/02/27 23:14:46 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:14:46 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:14:50 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS\n", "22/02/27 23:14:50 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS\n", "22/02/27 23:14:50 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeSystemLAPACK\n", "22/02/27 23:14:50 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeRefLAPACK\n", "22/02/27 23:15:01 WARN DAGScheduler: Broadcasting large task binary with size 1004.5 KiB\n", "22/02/27 23:15:02 WARN DAGScheduler: Broadcasting large task binary with size 1047.0 KiB\n", "22/02/27 23:15:03 WARN DAGScheduler: Broadcasting large task binary with size 1005.4 KiB\n", "22/02/27 23:15:03 WARN DAGScheduler: Broadcasting large task binary with size 1089.4 KiB\n", "22/02/27 23:15:03 WARN DAGScheduler: Broadcasting large task binary with size 1047.8 KiB\n", "22/02/27 23:15:04 WARN DAGScheduler: Broadcasting large task binary with size 1131.8 KiB\n", "22/02/27 23:15:04 WARN DAGScheduler: Broadcasting large task binary with size 1090.3 KiB\n", "22/02/27 23:15:04 WARN DAGScheduler: Broadcasting large task binary with size 1174.3 KiB\n", "22/02/27 23:15:05 WARN DAGScheduler: Broadcasting large task binary with size 1132.7 KiB\n", "22/02/27 23:15:05 WARN DAGScheduler: Broadcasting large task binary with size 1216.7 KiB\n", "22/02/27 23:15:06 WARN DAGScheduler: Broadcasting large task binary with size 1175.2 KiB\n", "22/02/27 23:15:06 WARN DAGScheduler: Broadcasting large task binary with size 1259.2 KiB\n", "22/02/27 23:15:07 WARN DAGScheduler: Broadcasting large task binary with size 1217.6 KiB\n", "22/02/27 23:15:07 WARN DAGScheduler: Broadcasting large task binary with size 1260.6 KiB\n", "22/02/27 23:15:07 WARN DAGScheduler: Broadcasting large task binary with size 1218.2 KiB\n", "27-Feb-22 23:15:08, replay, WARNING: This model can't predict cold users, they will be ignored\n", "27-Feb-22 23:15:08, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:15:08, replay, WARNING: This model can't predict cold users, they will be ignored\n", "27-Feb-22 23:15:08, replay, WARNING: This model can't predict cold items, they will be ignored\n", "22/02/27 23:15:11 WARN DAGScheduler: Broadcasting large task binary with size 1270.4 KiB\n", "22/02/27 23:15:11 WARN DAGScheduler: Broadcasting large task binary with size 1227.9 KiB\n", "22/02/27 23:15:26 WARN DAGScheduler: Broadcasting large task binary with size 1398.7 KiB\n", "22/02/27 23:15:29 WARN DAGScheduler: Broadcasting large task binary with size 1423.3 KiB\n", "22/02/27 23:15:32 WARN DAGScheduler: Broadcasting large task binary with size 1476.2 KiB\n", "22/02/27 23:15:35 WARN DAGScheduler: Broadcasting large task binary with size 1457.6 KiB\n", "22/02/27 23:15:36 WARN DAGScheduler: Broadcasting large task binary with size 1464.0 KiB\n", "22/02/27 23:15:36 WARN DAGScheduler: Broadcasting large task binary with size 1482.1 KiB\n", "22/02/27 23:15:36 WARN DAGScheduler: Broadcasting large task binary with size 1479.0 KiB\n", "22/02/27 23:15:37 WARN DAGScheduler: Broadcasting large task binary with size 1227.7 KiB\n", "22/02/27 23:15:37 WARN DAGScheduler: Broadcasting large task binary with size 1270.1 KiB\n", "22/02/27 23:15:55 WARN DAGScheduler: Broadcasting large task binary with size 1461.2 KiB\n", "22/02/27 23:15:57 WARN DAGScheduler: Broadcasting large task binary with size 1442.7 KiB\n", "22/02/27 23:15:58 WARN DAGScheduler: Broadcasting large task binary with size 1457.8 KiB\n", "22/02/27 23:15:58 WARN DAGScheduler: Broadcasting large task binary with size 1460.8 KiB\n", "22/02/27 23:16:00 WARN DAGScheduler: Broadcasting large task binary with size 1270.4 KiB\n", "22/02/27 23:16:00 WARN DAGScheduler: Broadcasting large task binary with size 1227.9 KiB\n", "22/02/27 23:16:16 WARN DAGScheduler: Broadcasting large task binary with size 1398.7 KiB\n", "22/02/27 23:16:19 WARN DAGScheduler: Broadcasting large task binary with size 1423.3 KiB\n", "22/02/27 23:16:23 WARN DAGScheduler: Broadcasting large task binary with size 1476.2 KiB\n", "22/02/27 23:16:25 WARN DAGScheduler: Broadcasting large task binary with size 1457.6 KiB\n", "22/02/27 23:16:26 WARN DAGScheduler: Broadcasting large task binary with size 1469.2 KiB\n", "22/02/27 23:16:29 WARN DAGScheduler: Broadcasting large task binary with size 1498.8 KiB\n", "22/02/27 23:16:30 WARN DAGScheduler: Broadcasting large task binary with size 1498.8 KiB\n", "22/02/27 23:16:30 WARN DAGScheduler: Broadcasting large task binary with size 1498.7 KiB\n", "22/02/27 23:16:31 WARN DAGScheduler: Broadcasting large task binary with size 1498.7 KiB\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 437 ms, sys: 130 ms, total: 566 ms\n", "Wall time: 1min 44s\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Coverage@5HitRate@1HitRate@5MAP@5NDCG@5
SLIM_optimized0.1475980.2400.5700.0955470.168684
SLIM0.1605500.2420.5580.0937200.165643
ALS0.1953590.2160.5400.0916000.160843
\n", "
" ], "text/plain": [ " Coverage@5 HitRate@1 HitRate@5 MAP@5 NDCG@5\n", "SLIM_optimized 0.147598 0.240 0.570 0.095547 0.168684\n", "SLIM 0.160550 0.242 0.558 0.093720 0.165643\n", "ALS 0.195359 0.216 0.540 0.091600 0.160843" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "recs = fit_predict_evaluate(ALSWrap(rank=100, seed=SEED), metrics, 'ALS')\n", "metrics.results.sort_values('NDCG@5', ascending=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### KNN\n", "Commonly-used item-based recommender" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "22/02/27 23:16:31 WARN CacheManager: Asked to cache already cached data.\n", "22/02/27 23:16:31 WARN CacheManager: Asked to cache already cached data.\n", "27-Feb-22 23:16:33, replay, WARNING: This model can't predict cold items, they will be ignored\n", "27-Feb-22 23:16:33, replay, WARNING: This model can't predict cold items, they will be ignored\n", " 144]]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 283 ms, sys: 90.3 ms, total: 374 ms\n", "Wall time: 1min 7s\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Coverage@5HitRate@1HitRate@5MAP@5NDCG@5
SLIM_optimized0.1475980.2400.5700.0955470.168684
SLIM0.1605500.2420.5580.0937200.165643
ALS0.1953590.2160.5400.0916000.160843
KNN0.0523480.1440.3840.0544470.101923
\n", "
" ], "text/plain": [ " Coverage@5 HitRate@1 HitRate@5 MAP@5 NDCG@5\n", "SLIM_optimized 0.147598 0.240 0.570 0.095547 0.168684\n", "SLIM 0.160550 0.242 0.558 0.093720 0.165643\n", "ALS 0.195359 0.216 0.540 0.091600 0.160843\n", "KNN 0.052348 0.144 0.384 0.054447 0.101923" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "recs = fit_predict_evaluate(KNN(num_neighbours=100), metrics, 'KNN')\n", "metrics.results.sort_values('NDCG@5', ascending=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6 Compare RePlay models with others\n", "To easily evaluate recommendations obtained from other sources, read and pass these recommendations to ``Experiment``" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "jupyter": { "outputs_hidden": false } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Coverage@5HitRate@1HitRate@5MAP@5NDCG@5
SLIM_optimized0.1475980.2400.5700.0955470.168684
SLIM0.1605500.2420.5580.0937200.165643
ALS0.1953590.2160.5400.0916000.160843
KNN0.0523480.1440.3840.0544470.101923
my_model0.0523480.1440.3840.0544470.101923
\n", "
" ], "text/plain": [ " Coverage@5 HitRate@1 HitRate@5 MAP@5 NDCG@5\n", "SLIM_optimized 0.147598 0.240 0.570 0.095547 0.168684\n", "SLIM 0.160550 0.242 0.558 0.093720 0.165643\n", "ALS 0.195359 0.216 0.540 0.091600 0.160843\n", "KNN 0.052348 0.144 0.384 0.054447 0.101923\n", "my_model 0.052348 0.144 0.384 0.054447 0.101923" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics.add_result(\"my_model\", recs)\n", "metrics.results.sort_values(\"NDCG@5\", ascending=False)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" }, "name": "movielens_nmf.ipynb", "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [ "null" ] } } }, "nbformat": 4, "nbformat_minor": 4 }