{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-23-eval-replay.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T218100%20%7C%20RecSys%20Evaluation%20Metrics%20using%20RePlay%20Library.ipynb","timestamp":1644663605822},{"file_id":"1Owr37Bzv24jGzcbqVh7VbnGJtSni9-0C","timestamp":1636102214814}],"collapsed_sections":[],"authorship_tag":"ABX9TyP4enTbjWDzaSxfn1ILO+a3"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"QBW8NyVnSuK6"},"source":["# RecSys Evaluation Metrics using RePlay Library"]},{"cell_type":"markdown","metadata":{"id":"YGqJzbSIP9vL"},"source":["**Links**\n","\n","- https://sberbank-ai-lab.github.io/RePlay/pages/modules/metrics.html"]},{"cell_type":"code","metadata":{"id":"MhhXDw00Ol_H"},"source":["!apt-get install openjdk-8-jdk-headless -qq > /dev/null\n","!wget -q https://archive.apache.org/dist/spark/spark-3.0.0/spark-3.0.0-bin-hadoop3.2.tgz\n","!tar xf spark-3.0.0-bin-hadoop3.2.tgz\n","!pip install -q findspark\n","!pip install -q pyspark"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"-YJwm-W8PB0i"},"source":["!pip install replay-rec #v0.6.1"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"4Ctev1swR4tR"},"source":["!pip install ipytest"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"A6zNhUM9SASe"},"source":["import ipytest\n","ipytest.autoconfig()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Ro_aHv8kOhuB"},"source":["import os\n","os.environ[\"JAVA_HOME\"] = \"/usr/lib/jvm/java-8-openjdk-amd64\"\n","os.environ[\"SPARK_HOME\"] = \"/content/spark-3.0.0-bin-hadoop3.2\"\n","\n","import findspark\n","findspark.init()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"gYKEylAPQrfA"},"source":["import os\n","import re\n","\n","from datetime import datetime\n","from typing import Dict, List, Optional\n","\n","import numpy as np\n","import pandas as pd\n","import pytest\n","from numpy.testing import assert_allclose\n","from pyspark.ml.linalg import DenseVector\n","from pyspark.sql import DataFrame\n","\n","from replay.metrics import *\n","from replay.distributions import item_distribution\n","from replay.metrics.base_metric import sorter\n","from replay.constants import REC_SCHEMA, LOG_SCHEMA\n","from replay.session_handler import get_spark_session"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"AARhI1KISamZ"},"source":["import warnings\n","warnings.filterwarnings('ignore')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"SZ_0j36kQpMQ"},"source":["def assertDictAlmostEqual(d1: Dict, d2: Dict) -> None:\n"," assert set(d1.keys()) == set(d2.keys())\n"," for key in d1:\n"," assert_allclose(d1[key], d2[key])\n","\n","\n","@pytest.fixture\n","def spark():\n"," return get_spark_session(1, 1)\n","\n","\n","@pytest.fixture\n","def log2(spark):\n"," return spark.createDataFrame(\n"," data=[\n"," [\"user1\", \"item1\", datetime(2019, 9, 12), 3.0],\n"," [\"user1\", \"item5\", datetime(2019, 9, 13), 2.0],\n"," [\"user1\", \"item2\", datetime(2019, 9, 17), 1.0],\n"," [\"user2\", \"item6\", datetime(2019, 9, 14), 4.0],\n"," [\"user2\", \"item1\", datetime(2019, 9, 15), 3.0],\n"," [\"user3\", \"item2\", datetime(2019, 9, 15), 3.0],\n"," ],\n"," schema=LOG_SCHEMA,\n"," )\n","\n","\n","@pytest.fixture\n","def log(spark):\n"," return spark.createDataFrame(\n"," data=[\n"," [\"user1\", \"item1\", datetime(2019, 8, 22), 4.0],\n"," [\"user1\", \"item3\", datetime(2019, 8, 23), 3.0],\n"," [\"user1\", \"item2\", datetime(2019, 8, 27), 2.0],\n"," [\"user2\", \"item4\", datetime(2019, 8, 24), 3.0],\n"," [\"user2\", \"item1\", datetime(2019, 8, 25), 4.0],\n"," [\"user3\", \"item2\", datetime(2019, 8, 26), 5.0],\n"," [\"user3\", \"item1\", datetime(2019, 8, 26), 5.0],\n"," [\"user3\", \"item3\", datetime(2019, 8, 26), 3.0],\n"," [\"user4\", \"item2\", datetime(2019, 8, 26), 5.0],\n"," [\"user4\", \"item1\", datetime(2019, 8, 26), 5.0],\n"," [\"user4\", \"item1\", datetime(2019, 8, 26), 1.0],\n"," ],\n"," schema=LOG_SCHEMA,\n"," )\n","\n","\n","@pytest.fixture\n","def long_log_with_features(spark):\n"," date = datetime(2019, 1, 1)\n"," return spark.createDataFrame(\n"," data=[\n"," [\"u1\", \"i1\", date, 1.0],\n"," [\"u1\", \"i4\", datetime(2019, 1, 5), 3.0],\n"," [\"u1\", \"i2\", date, 2.0],\n"," [\"u1\", \"i5\", date, 4.0],\n"," [\"u2\", \"i1\", date, 1.0],\n"," [\"u2\", \"i3\", datetime(2018, 1, 1), 2.0],\n"," [\"u2\", \"i7\", datetime(2019, 1, 1), 4.0],\n"," [\"u2\", \"i8\", datetime(2020, 1, 1), 4.0],\n"," [\"u3\", \"i9\", date, 3.0],\n"," [\"u3\", \"i2\", date, 2.0],\n"," [\"u3\", \"i6\", datetime(2020, 3, 1), 1.0],\n"," [\"u3\", \"i7\", date, 5.0],\n"," ],\n"," schema=[\"user_id\", \"item_id\", \"timestamp\", \"relevance\"],\n"," )\n","\n","\n","@pytest.fixture\n","def short_log_with_features(spark):\n"," date = datetime(2021, 1, 1)\n"," return spark.createDataFrame(\n"," data=[\n"," [\"u1\", \"i3\", date, 1.0],\n"," [\"u1\", \"i7\", datetime(2019, 1, 5), 3.0],\n"," [\"u2\", \"i2\", date, 1.0],\n"," [\"u2\", \"i10\", datetime(2018, 1, 1), 2.0],\n"," [\"u3\", \"i8\", date, 3.0],\n"," [\"u3\", \"i1\", date, 2.0],\n"," [\"u4\", \"i7\", date, 5.0],\n"," ],\n"," schema=[\"user_id\", \"item_id\", \"timestamp\", \"relevance\"],\n"," )\n","\n","\n","@pytest.fixture\n","def user_features(spark):\n"," return spark.createDataFrame(\n"," [(\"u1\", 20.0, -3.0, \"M\"), (\"u2\", 30.0, 4.0, \"F\")]\n"," ).toDF(\"user_id\", \"age\", \"mood\", \"gender\")\n","\n","\n","@pytest.fixture\n","def item_features(spark):\n"," return spark.createDataFrame(\n"," [\n"," (\"i1\", 4.0, \"cat\", \"black\"),\n"," (\"i2\", 10.0, \"dog\", \"green\"),\n"," (\"i3\", 7.0, \"mouse\", \"yellow\"),\n"," (\"i4\", -1.0, \"cat\", \"yellow\"),\n"," (\"i5\", 11.0, \"dog\", \"white\"),\n"," (\"i6\", 0.0, \"mouse\", \"yellow\"),\n"," ]\n"," ).toDF(\"item_id\", \"iq\", \"class\", \"color\")\n","\n","\n","def unify_dataframe(data_frame: DataFrame):\n"," pandas_df = data_frame.toPandas()\n"," columns_to_sort_by: List[str] = []\n","\n"," if len(pandas_df) == 0:\n"," columns_to_sort_by = pandas_df.columns\n"," else:\n"," for column in pandas_df.columns:\n"," if not type(pandas_df[column][0]) in {\n"," DenseVector,\n"," list,\n"," np.ndarray,\n"," }:\n"," columns_to_sort_by.append(column)\n","\n"," return (\n"," pandas_df[sorted(data_frame.columns)]\n"," .sort_values(by=sorted(columns_to_sort_by))\n"," .reset_index(drop=True)\n"," )\n","\n","\n","def sparkDataFrameEqual(df1: DataFrame, df2: DataFrame):\n"," return pd.testing.assert_frame_equal(\n"," unify_dataframe(df1), unify_dataframe(df2), check_like=True\n"," )\n","\n","\n","def sparkDataFrameNotEqual(df1: DataFrame, df2: DataFrame):\n"," try:\n"," sparkDataFrameEqual(df1, df2)\n"," except AssertionError:\n"," pass\n"," else:\n"," raise AssertionError(\"spark dataframes are equal\")\n","\n","\n","def del_files_by_pattern(directory: str, pattern: str) -> None:\n"," \"\"\"\n"," Deletes files by pattern\n"," \"\"\"\n"," for filename in os.listdir(directory):\n"," if re.match(pattern, filename):\n"," os.remove(os.path.join(directory, filename))\n","\n","\n","def find_file_by_pattern(directory: str, pattern: str) -> Optional[str]:\n"," \"\"\"\n"," Returns path to first found file, if exists\n"," \"\"\"\n"," for filename in os.listdir(directory):\n"," if re.match(pattern, filename):\n"," return os.path.join(directory, filename)\n"," return None"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"hjbJkbenN85D","executionInfo":{"status":"ok","timestamp":1636102199250,"user_tz":-330,"elapsed":25125,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"eebf5dea-e1c6-4c17-914a-a3a456cfd19d"},"source":["%%ipytest\n","\n","@pytest.fixture\n","def one_user():\n"," df = pd.DataFrame({\"user_id\": [1], \"item_id\": [1], \"relevance\": [1]})\n"," return df\n","\n","\n","@pytest.fixture\n","def two_users():\n"," df = pd.DataFrame(\n"," {\"user_id\": [1, 2], \"item_id\": [1, 2], \"relevance\": [1, 1]}\n"," )\n"," return df\n","\n","\n","@pytest.fixture\n","def recs(spark):\n"," return spark.createDataFrame(\n"," data=[\n"," [\"user1\", \"item1\", 3.0],\n"," [\"user1\", \"item2\", 2.0],\n"," [\"user1\", \"item3\", 1.0],\n"," [\"user2\", \"item1\", 3.0],\n"," [\"user2\", \"item2\", 4.0],\n"," [\"user2\", \"item5\", 1.0],\n"," [\"user3\", \"item1\", 5.0],\n"," [\"user3\", \"item3\", 1.0],\n"," [\"user3\", \"item4\", 2.0],\n"," ],\n"," schema=REC_SCHEMA,\n"," )\n","\n","\n","@pytest.fixture\n","def recs2(spark):\n"," return spark.createDataFrame(\n"," data=[[\"user1\", \"item4\", 4.0], [\"user1\", \"item5\", 5.0]],\n"," schema=REC_SCHEMA,\n"," )\n","\n","\n","@pytest.fixture\n","def empty_recs(spark):\n"," return spark.createDataFrame(\n"," data=[],\n"," schema=REC_SCHEMA,\n"," )\n","\n","\n","@pytest.fixture\n","def true(spark):\n"," return spark.createDataFrame(\n"," data=[\n"," [\"user1\", \"item1\", datetime(2019, 9, 12), 3.0],\n"," [\"user1\", \"item5\", datetime(2019, 9, 13), 2.0],\n"," [\"user1\", \"item2\", datetime(2019, 9, 17), 1.0],\n"," [\"user2\", \"item6\", datetime(2019, 9, 14), 4.0],\n"," [\"user2\", \"item1\", datetime(2019, 9, 15), 3.0],\n"," [\"user3\", \"item2\", datetime(2019, 9, 15), 3.0],\n"," ],\n"," schema=LOG_SCHEMA,\n"," )\n","\n","\n","@pytest.fixture\n","def quality_metrics():\n"," return [NDCG(), HitRate(), Precision(), Recall(), MAP(), MRR(), RocAuc()]\n","\n","\n","@pytest.fixture\n","def duplicate_recs(spark):\n"," return spark.createDataFrame(\n"," data=[\n"," [\"user1\", \"item1\", 3.0],\n"," [\"user1\", \"item2\", 2.0],\n"," [\"user1\", \"item3\", 1.0],\n"," [\"user1\", \"item1\", 3.0],\n"," [\"user2\", \"item1\", 3.0],\n"," [\"user2\", \"item2\", 4.0],\n"," [\"user2\", \"item5\", 1.0],\n"," [\"user2\", \"item2\", 2.0],\n"," [\"user3\", \"item1\", 5.0],\n"," [\"user3\", \"item3\", 1.0],\n"," [\"user3\", \"item4\", 2.0],\n"," ],\n"," schema=REC_SCHEMA,\n"," )\n","\n","\n","def test_test_is_bigger(quality_metrics, one_user, two_users):\n"," for metric in quality_metrics:\n"," assert metric(one_user, two_users, 1) == 0.5, str(metric)\n","\n","\n","def test_pred_is_bigger(quality_metrics, one_user, two_users):\n"," for metric in quality_metrics:\n"," assert metric(two_users, one_user, 1) == 1.0, str(metric)\n","\n","\n","def test_hit_rate_at_k(recs, true):\n"," assertDictAlmostEqual(\n"," HitRate()(recs, true, [3, 1]),\n"," {3: 2 / 3, 1: 1 / 3},\n"," )\n","\n","\n","def test_user_dist(log, recs, true):\n"," vals = HitRate().user_distribution(log, recs, true, 1)[\"value\"].to_list()\n"," assert_allclose(vals, [0.0, 0.5])\n","\n","\n","def test_item_dist(log, recs):\n"," assert_allclose(\n"," item_distribution(log, recs, 1)[\"rec_count\"].to_list(),\n"," [0, 0, 1, 2],\n"," )\n","\n","\n","def test_ndcg_at_k(recs, true):\n"," pred = [300, 200, 100]\n"," k_set = [1, 2, 3]\n"," user_id = 1\n"," ground_truth = [200, 400]\n"," ndcg_value = 1 / np.log2(3) / (1 / np.log2(2) + 1 / np.log2(3))\n"," assert (\n"," NDCG()._get_metric_value_by_user_all_k(\n"," k_set, user_id, pred, ground_truth\n"," )\n"," == [(1, 0, 1), (1, ndcg_value, 2), (1, ndcg_value, 3)],\n"," )\n"," assertDictAlmostEqual(\n"," NDCG()(recs, true, [1, 3]),\n"," {\n"," 1: 1 / 3,\n"," 3: 1\n"," / 3\n"," * (\n"," 1\n"," / (1 / np.log2(2) + 1 / np.log2(3) + 1 / np.log2(4))\n"," * (1 / np.log2(2) + 1 / np.log2(3))\n"," + 1 / (1 / np.log2(2) + 1 / np.log2(3)) * (1 / np.log2(3))\n"," ),\n"," },\n"," )\n","\n","\n","def test_precision_at_k(recs, true):\n"," assertDictAlmostEqual(\n"," Precision()(recs, true, [1, 2, 3]),\n"," {3: 1 / 3, 1: 1 / 3, 2: 1 / 2},\n"," )\n","\n","\n","def test_map_at_k(recs, true):\n"," assertDictAlmostEqual(\n"," MAP()(recs, true, [1, 3]),\n"," {3: 11 / 36, 1: 1 / 3},\n"," )\n","\n","\n","def test_recall_at_k(recs, true):\n"," assertDictAlmostEqual(\n"," Recall()(recs, true, [1, 3]),\n"," {3: (1 / 2 + 2 / 3) / 3, 1: 1 / 9},\n"," )\n","\n","\n","def test_surprisal_at_k(true, recs, recs2):\n"," assertDictAlmostEqual(Surprisal(true)(recs2, [1, 2]), {1: 1.0, 2: 1.0})\n","\n"," assert_allclose(\n"," Surprisal(true)(recs, 3),\n"," 5 * (1 - 1 / np.log2(3)) / 9 + 4 / 9,\n"," )\n","\n","\n","def test_unexpectedness_at_k(true, recs, recs2):\n"," assert Unexpectedness._get_metric_value_by_user(2, (), (2, 3)) == 0\n"," assert Unexpectedness._get_metric_value_by_user(2, (1, 2), (1,)) == 0.5\n","\n","\n","def test_coverage(true, recs, empty_recs):\n"," coverage = Coverage(recs.union(true.drop(\"timestamp\")))\n"," assertDictAlmostEqual(\n"," coverage(recs, [1, 3, 5]),\n"," {1: 0.3333333333333333, 3: 0.8333333333333334, 5: 0.8333333333333334},\n"," )\n"," assertDictAlmostEqual(\n"," coverage(empty_recs, [1, 3, 5]),\n"," {1: 0.0, 3: 0.0, 5: 0.0},\n"," )\n","\n","\n","def test_bad_coverage(true, recs):\n"," assert_allclose(Coverage(true)(recs, 3), 1.25)\n","\n","\n","def test_empty_recs(quality_metrics):\n"," for metric in quality_metrics:\n"," assert_allclose(\n"," metric._get_metric_value_by_user(\n"," k=4, pred=[], ground_truth=[2, 4]\n"," ),\n"," 0,\n"," err_msg=str(metric),\n"," )\n","\n","\n","def test_bad_recs(quality_metrics):\n"," for metric in quality_metrics:\n"," assert_allclose(\n"," metric._get_metric_value_by_user(\n"," k=4, pred=[1, 3], ground_truth=[2, 4]\n"," ),\n"," 0,\n"," err_msg=str(metric),\n"," )\n","\n","\n","def test_not_full_recs(quality_metrics):\n"," for metric in quality_metrics:\n"," assert_allclose(\n"," metric._get_metric_value_by_user(\n"," k=4, pred=[4, 1, 2], ground_truth=[2, 4]\n"," ),\n"," metric._get_metric_value_by_user(\n"," k=3, pred=[4, 1, 2], ground_truth=[2, 4]\n"," ),\n"," err_msg=str(metric),\n"," )\n","\n","\n","def test_duplicate_recs(quality_metrics, duplicate_recs, recs, true):\n"," for metric in quality_metrics:\n"," assert_allclose(\n"," metric(k=4, recommendations=duplicate_recs, ground_truth=true),\n"," metric(k=4, recommendations=recs, ground_truth=true),\n"," err_msg=str(metric),\n"," )\n","\n","\n","def test_sorter():\n"," result = sorter(((1, 2), (2, 3), (3, 2)))\n"," assert result == [2, 3]\n","\n","\n","def test_sorter_index():\n"," result = sorter([(1, 2, 3), (2, 3, 4), (3, 3, 5)], index=2)\n"," assert result == [5, 3]"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n","\u001b[32m\u001b[32m\u001b[1m19 passed\u001b[0m\u001b[32m in 24.13s\u001b[0m\u001b[0m\n"]}]}]}