{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# RePlay Tutorial\n",
"This notebook is designed to familiarize with the use of RePlay library, including \n",
"- data preprocessing\n",
"- data splitting\n",
"- model training and inference\n",
"- model optimization\n",
"- model saving and loading\n",
"- models comparison"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-10T16:01:45.639135Z",
"start_time": "2020-02-10T16:01:45.612577Z"
},
"jupyter": {
"outputs_hidden": false
},
"pycharm": {
"is_executing": false
}
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%config Completer.use_jedi = False"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"from optuna.exceptions import ExperimentalWarning\n",
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
"warnings.filterwarnings(\"ignore\", category=ExperimentalWarning)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from pyspark.sql.functions import rand\n",
"\n",
"from replay.data_preparator import DataPreparator\n",
"from replay.experiment import Experiment\n",
"from replay.metrics import Coverage, HitRate, NDCG, MAP\n",
"from replay.model_handler import save, load\n",
"from replay.models import ALSWrap, KNN, SLIM\n",
"from replay.session_handler import State\n",
"from replay.splitters import UserSplitter\n",
"from replay.utils import convert2spark"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"K = 5\n",
"SEED=1234"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 0. Data preprocessing \n",
"We will use MovieLens 1m as an example."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-10T15:59:42.041251Z",
"start_time": "2020-02-10T15:59:09.230636Z"
},
"jupyter": {
"outputs_hidden": false
},
"scrolled": true
},
"outputs": [],
"source": [
"df = pd.read_csv(\"data/ml1m_ratings.dat\", sep=\"\\t\", names=[\"user_id\", \"item_id\", \"relevance\", \"timestamp\"])\n",
"users = pd.read_csv(\"data/ml1m_users.dat\", sep=\"\\t\", names=[\"user_id\", \"gender\", \"age\", \"occupation\", \"zip_code\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 0.1. DataPreparator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An inner data format in RePlay is a spark dataframe.\n",
"You can pass spark or pandas dataframe as an input. Columns ``item_id`` and ``user_id`` are required for interaction matrix.\n",
"Optional columns for interaction matrix are ``relevance`` and interaction ``timestamp``. \n",
"\n",
"We implemented DataPreparator class to convert dataframes to spark format and preprocess the data, including renaming/creation of required and optional interaction matrix columns, null check and dates parsing.\n",
"\n",
"To convert pandas dataframe to spark as is use function ``convert_to_spark`` from ``replay.utils``."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: An illegal reflective access operation has occurred\n",
"WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/home/u19893556/miniconda3/envs/replay/lib/python3.7/site-packages/pyspark/jars/spark-unsafe_2.12-3.1.2.jar) to constructor java.nio.DirectByteBuffer(long,int)\n",
"WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform\n",
"WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n",
"WARNING: All illegal access operations will be denied in a future release\n",
"22/02/27 23:04:22 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
"Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties\n",
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
"22/02/27 23:04:23 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).\n",
"22/02/27 23:04:23 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n",
"22/02/27 23:04:23 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.\n",
"22/02/27 23:04:23 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.\n",
" \r"
]
}
],
"source": [
"preparator = DataPreparator()\n",
"log, _, _ = preparator(df)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------+---------+--------+--------+\n",
"|relevance|timestamp|user_idx|item_idx|\n",
"+---------+---------+--------+--------+\n",
"| 5|978300760| 4131| 43|\n",
"| 3|978302109| 4131| 585|\n",
"| 3|978301968| 4131| 461|\n",
"+---------+---------+--------+--------+\n",
"only showing top 3 rows\n",
"\n"
]
}
],
"source": [
"log.show(3)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-------+------+---+----------+--------+\n",
"|user_id|gender|age|occupation|zip_code|\n",
"+-------+------+---+----------+--------+\n",
"| 1| F| 1| 10| 48067|\n",
"| 2| M| 56| 16| 70072|\n",
"| 3| M| 25| 15| 55117|\n",
"+-------+------+---+----------+--------+\n",
"only showing top 3 rows\n",
"\n"
]
}
],
"source": [
"users = convert2spark(users)\n",
"users.show(3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 0.2. Split"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"RePlay provides you with data splitters to reproduce a validation schemas widely-used in recommender systems.\n",
"\n",
"`UserSplitter` takes ``item_test_size`` items for ``user_test_size`` user to the test dataset."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-10T15:59:50.986401Z",
"start_time": "2020-02-10T15:59:42.042998Z"
},
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"22/02/27 23:04:37 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"22/02/27 23:04:38 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"[Stage 27:============================================> (121 + 23) / 144]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"997709 2500\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"[Stage 30:===================================> (95 + 48) / 144]\r",
"\r",
" \r"
]
}
],
"source": [
"splitter = UserSplitter(\n",
" drop_cold_items=True,\n",
" drop_cold_users=True,\n",
" item_test_size=K,\n",
" user_test_size=500,\n",
" seed=SEED,\n",
" shuffle=True\n",
")\n",
"train, test = splitter.split(log)\n",
"print(train.count(), test.count())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Models training\n",
"\n",
"#### SLIM"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"slim = SLIM(seed=SEED)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.53 s, sys: 129 ms, total: 1.66 s\n",
"Wall time: 5.9 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"slim.fit(log=train)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"27-Feb-22 23:04:55, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:04:55, replay, WARNING: This model can't predict cold items, they will be ignored\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 23.1 ms, sys: 16.4 ms, total: 39.4 ms\n",
"Wall time: 1.77 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"recs = slim.predict(\n",
" k=K,\n",
" users=test.select('user_idx').distinct(),\n",
" log=train,\n",
" filter_seen_items=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 130:==================================> (94 + 48) / 144]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+--------+------------------+\n",
"|user_idx|item_idx| relevance|\n",
"+--------+--------+------------------+\n",
"| 38| 73| 1.235672623556484|\n",
"| 38| 361|1.1715979128347436|\n",
"+--------+--------+------------------+\n",
"only showing top 2 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" \r"
]
}
],
"source": [
"recs.show(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Models evaluation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"RePlay implements some popular recommenders' quality metrics. Use pure metrics or calculate a set of chosen metrics and compare models with the ``Experiment`` class."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-10T16:07:28.942205Z",
"start_time": "2020-02-10T16:07:26.281475Z"
},
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"metrics = Experiment(test, {NDCG(): K,\n",
" MAP() : K,\n",
" HitRate(): [1, K],\n",
" Coverage(train): K\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 360 ms, sys: 75.5 ms, total: 436 ms\n",
"Wall time: 47.5 s\n"
]
},
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Coverage@5 | \n",
" HitRate@1 | \n",
" HitRate@5 | \n",
" MAP@5 | \n",
" NDCG@5 | \n",
"
\n",
" \n",
" \n",
" \n",
" SLIM | \n",
" 0.16055 | \n",
" 0.242 | \n",
" 0.558 | \n",
" 0.09372 | \n",
" 0.165643 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Coverage@5 HitRate@1 HitRate@5 MAP@5 NDCG@5\n",
"SLIM 0.16055 0.242 0.558 0.09372 0.165643"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"metrics.add_result(\"SLIM\", recs)\n",
"metrics.results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Hyperparameters optimization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.1 Search"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"22/02/27 23:06:17 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"22/02/27 23:06:17 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n"
]
}
],
"source": [
"# data split for hyperparameters optimization\n",
"train_opt, val_opt = splitter.split(train)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m[I 2022-02-27 23:06:17,681]\u001b[0m A new study created in memory with name: no-name-b0d54335-8d37-401f-a916-3ba55ed9c932\u001b[0m\n",
"27-Feb-22 23:06:22, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:06:22, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:06:51,535]\u001b[0m Trial 0 finished with value: 0.18130037719542139 and parameters: {'beta': 0.01, 'lambda_': 0.01}. Best is trial 0 with value: 0.18130037719542139.\u001b[0m\n",
"22/02/27 23:06:51 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:06:51 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:06:54, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:06:54, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:07:31,090]\u001b[0m Trial 1 finished with value: 0.18197356840108678 and parameters: {'beta': 0.003401392505408624, 'lambda_': 0.002240239840999655}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n",
"22/02/27 23:07:31 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:07:31 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:07:33, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:07:33, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:07:56,590]\u001b[0m Trial 2 finished with value: 0.10199049759426765 and parameters: {'beta': 1.9301997111553214e-05, 'lambda_': 1.1554917603144903}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n",
"22/02/27 23:07:56 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:07:56 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:07:58, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:07:59, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:08:36,960]\u001b[0m Trial 3 finished with value: 0.18040798348695616 and parameters: {'beta': 1.153628706350771e-05, 'lambda_': 2.9530757569977826e-05}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n",
"22/02/27 23:08:36 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:08:36 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:08:39, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:08:39, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:08:54,180]\u001b[0m Trial 4 finished with value: 0.1184952197160257 and parameters: {'beta': 0.0007214008774259759, 'lambda_': 0.7632771957535475}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n",
"22/02/27 23:08:54 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:08:54 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:08:56, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:08:56, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:09:11,341]\u001b[0m Trial 5 finished with value: 0.10801852484092478 and parameters: {'beta': 0.003501448697693071, 'lambda_': 0.9936237326658697}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n",
"22/02/27 23:09:11 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:09:11 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:09:13, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:09:13, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:09:37,837]\u001b[0m Trial 6 finished with value: 0.17982896295330483 and parameters: {'beta': 0.00099662876434958, 'lambda_': 0.03255064745931469}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n",
"22/02/27 23:09:37 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:09:37 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:09:40, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:09:40, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:10:08,671]\u001b[0m Trial 7 finished with value: 0.1794415531942852 and parameters: {'beta': 0.0002421671516396994, 'lambda_': 6.591737385850111e-05}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n",
"22/02/27 23:10:08 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:10:08 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:10:11, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:10:11, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:10:32,795]\u001b[0m Trial 8 finished with value: 0.17158421706899432 and parameters: {'beta': 2.982305555199248, 'lambda_': 0.047378561915999574}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n",
"22/02/27 23:10:32 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:10:32 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:10:35, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:10:35, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:10:50,730]\u001b[0m Trial 9 finished with value: 0.11605432663219907 and parameters: {'beta': 0.18611456836257362, 'lambda_': 0.8088532397607969}. Best is trial 1 with value: 0.18197356840108678.\u001b[0m\n",
"22/02/27 23:10:50 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:10:50 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:10:53, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:10:53, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:11:18,939]\u001b[0m Trial 10 finished with value: 0.184208319077201 and parameters: {'beta': 0.1399219194028095, 'lambda_': 1.0774584742955482e-06}. Best is trial 10 with value: 0.184208319077201.\u001b[0m\n",
"22/02/27 23:11:18 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:11:18 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:11:21, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:11:21, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:11:46,639]\u001b[0m Trial 11 finished with value: 0.1847990350329721 and parameters: {'beta': 0.11351011099824757, 'lambda_': 2.678667716748947e-06}. Best is trial 11 with value: 0.1847990350329721.\u001b[0m\n",
"22/02/27 23:11:46 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:11:46 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:11:49, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:11:49, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:12:14,589]\u001b[0m Trial 12 finished with value: 0.18428506912745 and parameters: {'beta': 0.14719744446335933, 'lambda_': 1.5136391124700838e-06}. Best is trial 11 with value: 0.1847990350329721.\u001b[0m\n",
"22/02/27 23:12:14 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:12:14 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:12:16, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:12:16, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:12:44,060]\u001b[0m Trial 13 finished with value: 0.1847990350329721 and parameters: {'beta': 0.11317072609268032, 'lambda_': 1.963819303556553e-06}. Best is trial 11 with value: 0.1847990350329721.\u001b[0m\n",
"22/02/27 23:12:44 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:12:44 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:12:46, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:12:46, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"\u001b[32m[I 2022-02-27 23:13:10,407]\u001b[0m Trial 14 finished with value: 0.17283814674301312 and parameters: {'beta': 4.697217749389299, 'lambda_': 2.3961336617398848e-05}. Best is trial 11 with value: 0.1847990350329721.\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 24.9 s, sys: 2.73 s, total: 27.6 s\n",
"Wall time: 6min 52s\n"
]
}
],
"source": [
"%%time\n",
"best_params = slim.optimize(train_opt, val_opt, criterion=NDCG(), k=K, budget=15)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'beta': 0.11351011099824757, 'lambda_': 2.678667716748947e-06}"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best_params"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.2 Compare with previous"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def fit_predict_evaluate(model, experiment, name):\n",
" model.fit(log=train)\n",
"\n",
" recs = model.predict(\n",
" k=K,\n",
" users=test.select('user_idx').distinct(),\n",
" log=train,\n",
" filter_seen_items=True\n",
" )\n",
"\n",
" experiment.add_result(name, recs)\n",
" return recs"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"22/02/27 23:13:15 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:13:15 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:13:18, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:13:18, replay, WARNING: This model can't predict cold items, they will be ignored\n",
" 4]]]]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.9 s, sys: 284 ms, total: 2.18 s\n",
"Wall time: 52.9 s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"[Stage 2059:==========================================> (120 + 24) / 144]\r",
"\r",
" \r"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Coverage@5 | \n",
" HitRate@1 | \n",
" HitRate@5 | \n",
" MAP@5 | \n",
" NDCG@5 | \n",
"
\n",
" \n",
" \n",
" \n",
" SLIM_optimized | \n",
" 0.147598 | \n",
" 0.240 | \n",
" 0.570 | \n",
" 0.095547 | \n",
" 0.168684 | \n",
"
\n",
" \n",
" SLIM | \n",
" 0.160550 | \n",
" 0.242 | \n",
" 0.558 | \n",
" 0.093720 | \n",
" 0.165643 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Coverage@5 HitRate@1 HitRate@5 MAP@5 NDCG@5\n",
"SLIM_optimized 0.147598 0.240 0.570 0.095547 0.168684\n",
"SLIM 0.160550 0.242 0.558 0.093720 0.165643"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"recs = fit_predict_evaluate(SLIM(**best_params, seed=SEED), metrics, 'SLIM_optimized')\n",
"metrics.results.sort_values('NDCG@5', ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Convert to pandas"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 4]\r"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" user_idx | \n",
" item_idx | \n",
" relevance | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 38 | \n",
" 73 | \n",
" 1.230351 | \n",
"
\n",
" \n",
" 1 | \n",
" 38 | \n",
" 361 | \n",
" 1.212302 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" user_idx item_idx relevance\n",
"0 38 73 1.230351\n",
"1 38 361 1.212302"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"recs_pd = recs.toPandas()\n",
"recs_pd.head(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Save and load"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"RePlay allows to save and load fitted models with `save` and `load` functions of `model_handler` module. Model is saved as a folder with all necessary parameters and data."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"save(slim, path='./slim_best_params')\n",
"slim_loaded = load('./slim_best_params')"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"27-Feb-22 23:14:31, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:14:31, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"[Stage 2161:================================> (90 + 48) / 144]4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+--------+------------------+\n",
"|user_idx|item_idx| relevance|\n",
"+--------+--------+------------------+\n",
"| 38| 14|1.1936188460415138|\n",
"| 38| 73|1.1193345759515603|\n",
"+--------+--------+------------------+\n",
"only showing top 2 rows\n",
"\n",
"CPU times: user 67 ms, sys: 3.66 ms, total: 70.7 ms\n",
"Wall time: 13.1 s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"[Stage 2161:==========================================> (120 + 24) / 144]\r",
"\r",
" \r"
]
}
],
"source": [
"%%time\n",
"pred_from_loaded = slim_loaded.predict(k=K,\n",
" users=test.select('user_idx').distinct(),\n",
" log=train,\n",
" filter_seen_items=True)\n",
"pred_from_loaded.show(2)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.11351011099824757, 2.678667716748947e-06)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"slim_loaded.beta, slim_loaded.lambda_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Other RePlay models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### ALS\n",
"Commonly-used matrix factorization algorithm."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"22/02/27 23:14:46 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:14:46 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:14:50 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS\n",
"22/02/27 23:14:50 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS\n",
"22/02/27 23:14:50 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeSystemLAPACK\n",
"22/02/27 23:14:50 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeRefLAPACK\n",
"22/02/27 23:15:01 WARN DAGScheduler: Broadcasting large task binary with size 1004.5 KiB\n",
"22/02/27 23:15:02 WARN DAGScheduler: Broadcasting large task binary with size 1047.0 KiB\n",
"22/02/27 23:15:03 WARN DAGScheduler: Broadcasting large task binary with size 1005.4 KiB\n",
"22/02/27 23:15:03 WARN DAGScheduler: Broadcasting large task binary with size 1089.4 KiB\n",
"22/02/27 23:15:03 WARN DAGScheduler: Broadcasting large task binary with size 1047.8 KiB\n",
"22/02/27 23:15:04 WARN DAGScheduler: Broadcasting large task binary with size 1131.8 KiB\n",
"22/02/27 23:15:04 WARN DAGScheduler: Broadcasting large task binary with size 1090.3 KiB\n",
"22/02/27 23:15:04 WARN DAGScheduler: Broadcasting large task binary with size 1174.3 KiB\n",
"22/02/27 23:15:05 WARN DAGScheduler: Broadcasting large task binary with size 1132.7 KiB\n",
"22/02/27 23:15:05 WARN DAGScheduler: Broadcasting large task binary with size 1216.7 KiB\n",
"22/02/27 23:15:06 WARN DAGScheduler: Broadcasting large task binary with size 1175.2 KiB\n",
"22/02/27 23:15:06 WARN DAGScheduler: Broadcasting large task binary with size 1259.2 KiB\n",
"22/02/27 23:15:07 WARN DAGScheduler: Broadcasting large task binary with size 1217.6 KiB\n",
"22/02/27 23:15:07 WARN DAGScheduler: Broadcasting large task binary with size 1260.6 KiB\n",
"22/02/27 23:15:07 WARN DAGScheduler: Broadcasting large task binary with size 1218.2 KiB\n",
"27-Feb-22 23:15:08, replay, WARNING: This model can't predict cold users, they will be ignored\n",
"27-Feb-22 23:15:08, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:15:08, replay, WARNING: This model can't predict cold users, they will be ignored\n",
"27-Feb-22 23:15:08, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"22/02/27 23:15:11 WARN DAGScheduler: Broadcasting large task binary with size 1270.4 KiB\n",
"22/02/27 23:15:11 WARN DAGScheduler: Broadcasting large task binary with size 1227.9 KiB\n",
"22/02/27 23:15:26 WARN DAGScheduler: Broadcasting large task binary with size 1398.7 KiB\n",
"22/02/27 23:15:29 WARN DAGScheduler: Broadcasting large task binary with size 1423.3 KiB\n",
"22/02/27 23:15:32 WARN DAGScheduler: Broadcasting large task binary with size 1476.2 KiB\n",
"22/02/27 23:15:35 WARN DAGScheduler: Broadcasting large task binary with size 1457.6 KiB\n",
"22/02/27 23:15:36 WARN DAGScheduler: Broadcasting large task binary with size 1464.0 KiB\n",
"22/02/27 23:15:36 WARN DAGScheduler: Broadcasting large task binary with size 1482.1 KiB\n",
"22/02/27 23:15:36 WARN DAGScheduler: Broadcasting large task binary with size 1479.0 KiB\n",
"22/02/27 23:15:37 WARN DAGScheduler: Broadcasting large task binary with size 1227.7 KiB\n",
"22/02/27 23:15:37 WARN DAGScheduler: Broadcasting large task binary with size 1270.1 KiB\n",
"22/02/27 23:15:55 WARN DAGScheduler: Broadcasting large task binary with size 1461.2 KiB\n",
"22/02/27 23:15:57 WARN DAGScheduler: Broadcasting large task binary with size 1442.7 KiB\n",
"22/02/27 23:15:58 WARN DAGScheduler: Broadcasting large task binary with size 1457.8 KiB\n",
"22/02/27 23:15:58 WARN DAGScheduler: Broadcasting large task binary with size 1460.8 KiB\n",
"22/02/27 23:16:00 WARN DAGScheduler: Broadcasting large task binary with size 1270.4 KiB\n",
"22/02/27 23:16:00 WARN DAGScheduler: Broadcasting large task binary with size 1227.9 KiB\n",
"22/02/27 23:16:16 WARN DAGScheduler: Broadcasting large task binary with size 1398.7 KiB\n",
"22/02/27 23:16:19 WARN DAGScheduler: Broadcasting large task binary with size 1423.3 KiB\n",
"22/02/27 23:16:23 WARN DAGScheduler: Broadcasting large task binary with size 1476.2 KiB\n",
"22/02/27 23:16:25 WARN DAGScheduler: Broadcasting large task binary with size 1457.6 KiB\n",
"22/02/27 23:16:26 WARN DAGScheduler: Broadcasting large task binary with size 1469.2 KiB\n",
"22/02/27 23:16:29 WARN DAGScheduler: Broadcasting large task binary with size 1498.8 KiB\n",
"22/02/27 23:16:30 WARN DAGScheduler: Broadcasting large task binary with size 1498.8 KiB\n",
"22/02/27 23:16:30 WARN DAGScheduler: Broadcasting large task binary with size 1498.7 KiB\n",
"22/02/27 23:16:31 WARN DAGScheduler: Broadcasting large task binary with size 1498.7 KiB\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 437 ms, sys: 130 ms, total: 566 ms\n",
"Wall time: 1min 44s\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Coverage@5 | \n",
" HitRate@1 | \n",
" HitRate@5 | \n",
" MAP@5 | \n",
" NDCG@5 | \n",
"
\n",
" \n",
" \n",
" \n",
" SLIM_optimized | \n",
" 0.147598 | \n",
" 0.240 | \n",
" 0.570 | \n",
" 0.095547 | \n",
" 0.168684 | \n",
"
\n",
" \n",
" SLIM | \n",
" 0.160550 | \n",
" 0.242 | \n",
" 0.558 | \n",
" 0.093720 | \n",
" 0.165643 | \n",
"
\n",
" \n",
" ALS | \n",
" 0.195359 | \n",
" 0.216 | \n",
" 0.540 | \n",
" 0.091600 | \n",
" 0.160843 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Coverage@5 HitRate@1 HitRate@5 MAP@5 NDCG@5\n",
"SLIM_optimized 0.147598 0.240 0.570 0.095547 0.168684\n",
"SLIM 0.160550 0.242 0.558 0.093720 0.165643\n",
"ALS 0.195359 0.216 0.540 0.091600 0.160843"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"recs = fit_predict_evaluate(ALSWrap(rank=100, seed=SEED), metrics, 'ALS')\n",
"metrics.results.sort_values('NDCG@5', ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### KNN\n",
"Commonly-used item-based recommender"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"22/02/27 23:16:31 WARN CacheManager: Asked to cache already cached data.\n",
"22/02/27 23:16:31 WARN CacheManager: Asked to cache already cached data.\n",
"27-Feb-22 23:16:33, replay, WARNING: This model can't predict cold items, they will be ignored\n",
"27-Feb-22 23:16:33, replay, WARNING: This model can't predict cold items, they will be ignored\n",
" 144]]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 283 ms, sys: 90.3 ms, total: 374 ms\n",
"Wall time: 1min 7s\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Coverage@5 | \n",
" HitRate@1 | \n",
" HitRate@5 | \n",
" MAP@5 | \n",
" NDCG@5 | \n",
"
\n",
" \n",
" \n",
" \n",
" SLIM_optimized | \n",
" 0.147598 | \n",
" 0.240 | \n",
" 0.570 | \n",
" 0.095547 | \n",
" 0.168684 | \n",
"
\n",
" \n",
" SLIM | \n",
" 0.160550 | \n",
" 0.242 | \n",
" 0.558 | \n",
" 0.093720 | \n",
" 0.165643 | \n",
"
\n",
" \n",
" ALS | \n",
" 0.195359 | \n",
" 0.216 | \n",
" 0.540 | \n",
" 0.091600 | \n",
" 0.160843 | \n",
"
\n",
" \n",
" KNN | \n",
" 0.052348 | \n",
" 0.144 | \n",
" 0.384 | \n",
" 0.054447 | \n",
" 0.101923 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Coverage@5 HitRate@1 HitRate@5 MAP@5 NDCG@5\n",
"SLIM_optimized 0.147598 0.240 0.570 0.095547 0.168684\n",
"SLIM 0.160550 0.242 0.558 0.093720 0.165643\n",
"ALS 0.195359 0.216 0.540 0.091600 0.160843\n",
"KNN 0.052348 0.144 0.384 0.054447 0.101923"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"recs = fit_predict_evaluate(KNN(num_neighbours=100), metrics, 'KNN')\n",
"metrics.results.sort_values('NDCG@5', ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6 Compare RePlay models with others\n",
"To easily evaluate recommendations obtained from other sources, read and pass these recommendations to ``Experiment``"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Coverage@5 | \n",
" HitRate@1 | \n",
" HitRate@5 | \n",
" MAP@5 | \n",
" NDCG@5 | \n",
"
\n",
" \n",
" \n",
" \n",
" SLIM_optimized | \n",
" 0.147598 | \n",
" 0.240 | \n",
" 0.570 | \n",
" 0.095547 | \n",
" 0.168684 | \n",
"
\n",
" \n",
" SLIM | \n",
" 0.160550 | \n",
" 0.242 | \n",
" 0.558 | \n",
" 0.093720 | \n",
" 0.165643 | \n",
"
\n",
" \n",
" ALS | \n",
" 0.195359 | \n",
" 0.216 | \n",
" 0.540 | \n",
" 0.091600 | \n",
" 0.160843 | \n",
"
\n",
" \n",
" KNN | \n",
" 0.052348 | \n",
" 0.144 | \n",
" 0.384 | \n",
" 0.054447 | \n",
" 0.101923 | \n",
"
\n",
" \n",
" my_model | \n",
" 0.052348 | \n",
" 0.144 | \n",
" 0.384 | \n",
" 0.054447 | \n",
" 0.101923 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Coverage@5 HitRate@1 HitRate@5 MAP@5 NDCG@5\n",
"SLIM_optimized 0.147598 0.240 0.570 0.095547 0.168684\n",
"SLIM 0.160550 0.242 0.558 0.093720 0.165643\n",
"ALS 0.195359 0.216 0.540 0.091600 0.160843\n",
"KNN 0.052348 0.144 0.384 0.054447 0.101923\n",
"my_model 0.052348 0.144 0.384 0.054447 0.101923"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"metrics.add_result(\"my_model\", recs)\n",
"metrics.results.sort_values(\"NDCG@5\", ascending=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
},
"name": "movielens_nmf.ipynb",
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": [
"null"
]
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}