{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "2021-07-14-ncf-movielens-tensorflow.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyPoY1blNGyXXSkuc6Se8AHH"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "7LrG_dfoZXHz"
},
"source": [
"# NCF from scratch in Tensorflow\n",
"> We will build a Neural Collaborative Filtering model from scratch in Tensorflow and train it on movielens data. Then we will compare it side by side with lightfm model.\n",
"\n",
"- toc: true\n",
"- badges: true\n",
"- comments: true\n",
"- categories: [Movie, NCF, Tensorflow, LightFM]\n",
"- author: \"dataroots\"\n",
"- image:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1mlny70rYqhK"
},
"source": [
"> youtube: https://youtu.be/SD3irxdKfxk"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hbh3Iop2Y3eK"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"metadata": {
"id": "JqmlfPOIOp_g"
},
"source": [
"!pip install -q lightfm"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Us3S7n9WQHMh"
},
"source": [
"from scipy import sparse\n",
"from typing import List\n",
"import datetime\n",
"import os\n",
"\n",
"import lightfm\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"from lightfm import LightFM\n",
"from lightfm.datasets import fetch_movielens\n",
"\n",
"import tensorflow.keras as keras\n",
"from tensorflow.keras.layers import (\n",
" Concatenate,\n",
" Dense,\n",
" Embedding,\n",
" Flatten,\n",
" Input,\n",
" Multiply,\n",
")\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.regularizers import l2\n",
"from tensorflow.keras.optimizers import Adam\n",
"\n",
"\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"%reload_ext google.colab.data_table\n",
"%reload_ext tensorboard"
],
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TyE5oqxHQoTV",
"outputId": "5cd6e39e-7358-402d-942f-c9d07d209a33"
},
"source": [
"!pip install -q watermark\n",
"%reload_ext watermark\n",
"%watermark -m -iv"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"Compiler : GCC 7.5.0\n",
"OS : Linux\n",
"Release : 5.4.104+\n",
"Machine : x86_64\n",
"Processor : x86_64\n",
"CPU cores : 2\n",
"Architecture: 64bit\n",
"\n",
"lightfm : 1.16\n",
"pandas : 1.1.5\n",
"scipy : 1.4.1\n",
"numpy : 1.19.5\n",
"tensorflow: 2.5.0\n",
"IPython : 5.5.0\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "SK0cNTzQQ2LF"
},
"source": [
"TOP_K = 5\n",
"N_EPOCHS = 10"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "5Fik0pDpY5rI"
},
"source": [
"## Load Data"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SgA6bA85Q_Ov",
"outputId": "af120947-1c79-49c8-c94f-51dce8ef2222"
},
"source": [
"data = fetch_movielens(min_rating=3.0)\n",
"\n",
"print(\"Interaction matrix:\")\n",
"print(data[\"train\"].toarray()[:10, :10])"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"Interaction matrix:\n",
"[[5 3 4 3 3 5 4 0 5 3]\n",
" [4 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0]\n",
" [4 0 0 0 0 0 0 4 4 0]\n",
" [0 0 0 5 0 0 5 5 5 4]\n",
" [0 0 0 0 0 0 3 0 0 0]\n",
" [0 0 0 0 0 0 4 0 0 0]\n",
" [4 0 0 4 0 0 0 0 4 0]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JPCc11s8RMwe",
"outputId": "afc19592-aaa6-40db-e087-50bdee8d7588"
},
"source": [
"for dataset in [\"test\", \"train\"]:\n",
" data[dataset] = (data[dataset].toarray() > 0).astype(\"int8\")\n",
"\n",
"# Make the ratings binary\n",
"print(\"Interaction matrix:\")\n",
"print(data[\"train\"][:10, :10])\n",
"\n",
"print(\"\\nRatings:\")\n",
"unique_ratings = np.unique(data[\"train\"])\n",
"print(unique_ratings)"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"Interaction matrix:\n",
"[[1 1 1 1 1 1 1 0 1 1]\n",
" [1 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0]\n",
" [1 0 0 0 0 0 0 1 1 0]\n",
" [0 0 0 1 0 0 1 1 1 1]\n",
" [0 0 0 0 0 0 1 0 0 0]\n",
" [0 0 0 0 0 0 1 0 0 0]\n",
" [1 0 0 1 0 0 0 0 1 0]]\n",
"\n",
"Ratings:\n",
"[0 1]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QqGB6gUpY7y8"
},
"source": [
"## Preprocess"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ht_ua_B-RgT9"
},
"source": [
"def wide_to_long(wide: np.array, possible_ratings: List[int]) -> np.array:\n",
" \"\"\"Go from wide table to long.\n",
" :param wide: wide array with user-item interactions\n",
" :param possible_ratings: list of possible ratings that we may have.\"\"\"\n",
"\n",
" def _get_ratings(arr: np.array, rating: int) -> np.array:\n",
" \"\"\"Generate long array for the rating provided\n",
" :param arr: wide array with user-item interactions\n",
" :param rating: the rating that we are interested\"\"\"\n",
" idx = np.where(arr == rating)\n",
" return np.vstack(\n",
" (idx[0], idx[1], np.ones(idx[0].size, dtype=\"int8\") * rating)\n",
" ).T\n",
"\n",
" long_arrays = []\n",
" for r in possible_ratings:\n",
" long_arrays.append(_get_ratings(wide, r))\n",
"\n",
" return np.vstack(long_arrays)"
],
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 194
},
"id": "m8WEVPFaRstc",
"outputId": "c443e069-77a2-44ed-a623-25a532bfed07"
},
"source": [
"long_train = wide_to_long(data[\"train\"], unique_ratings)\n",
"df_train = pd.DataFrame(long_train, columns=[\"user_id\", \"item_id\", \"interaction\"])\n",
"df_train.head()"
],
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/a6224c040fa35dcf/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n }],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n }],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 19,\n 'f': \"19\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n }],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 20,\n 'f': \"20\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n }],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 26,\n 'f': \"26\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n }]],\n columns: [[\"number\", \"index\"], [\"number\", \"user_id\"], [\"number\", \"item_id\"], [\"number\", \"interaction\"]],\n columnOptions: [{\"width\": \"1px\", \"className\": \"index_column\"}],\n rowsPerPage: 25,\n helpUrl: \"https://colab.research.google.com/notebooks/data_table.ipynb\",\n suppressOutputScrolling: true,\n minimumWidth: undefined,\n });\n ",
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" user_id | \n",
" item_id | \n",
" interaction | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0 | \n",
" 7 | \n",
" 0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0 | \n",
" 10 | \n",
" 0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0 | \n",
" 19 | \n",
" 0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0 | \n",
" 20 | \n",
" 0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0 | \n",
" 26 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" user_id item_id interaction\n",
"0 0 7 0\n",
"1 0 10 0\n",
"2 0 19 0\n",
"3 0 20 0\n",
"4 0 26 0"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 211
},
"id": "ECg5mEuCR2zn",
"outputId": "1059eb6c-d133-4347-a3f2-69afce2c8dae"
},
"source": [
"print(\"Only positive interactions:\")\n",
"df_train[df_train[\"interaction\"] > 0].head()"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": [
"Only positive interactions:\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/a6224c040fa35dcf/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 1511499,\n 'f': \"1511499\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n }],\n [{\n 'v': 1511500,\n 'f': \"1511500\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n }],\n [{\n 'v': 1511501,\n 'f': \"1511501\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n }],\n [{\n 'v': 1511502,\n 'f': \"1511502\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n }],\n [{\n 'v': 1511503,\n 'f': \"1511503\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n }]],\n columns: [[\"number\", \"index\"], [\"number\", \"user_id\"], [\"number\", \"item_id\"], [\"number\", \"interaction\"]],\n columnOptions: [{\"width\": \"1px\", \"className\": \"index_column\"}],\n rowsPerPage: 25,\n helpUrl: \"https://colab.research.google.com/notebooks/data_table.ipynb\",\n suppressOutputScrolling: true,\n minimumWidth: undefined,\n });\n ",
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" user_id | \n",
" item_id | \n",
" interaction | \n",
"
\n",
" \n",
" \n",
" \n",
" | 1511499 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | 1511500 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" | 1511501 | \n",
" 0 | \n",
" 2 | \n",
" 1 | \n",
"
\n",
" \n",
" | 1511502 | \n",
" 0 | \n",
" 3 | \n",
" 1 | \n",
"
\n",
" \n",
" | 1511503 | \n",
" 0 | \n",
" 4 | \n",
" 1 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" user_id item_id interaction\n",
"1511499 0 0 1\n",
"1511500 0 1 1\n",
"1511501 0 2 1\n",
"1511502 0 3 1\n",
"1511503 0 4 1"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KXh2VlaDSYzu"
},
"source": [
"## NCF Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B2S1b8hxSXZT"
},
"source": [
""
]
},
{
"cell_type": "code",
"metadata": {
"id": "ooic4sNsR-36"
},
"source": [
"def create_ncf(\n",
" number_of_users: int,\n",
" number_of_items: int,\n",
" latent_dim_mf: int = 4,\n",
" latent_dim_mlp: int = 32,\n",
" reg_mf: int = 0,\n",
" reg_mlp: int = 0.01,\n",
" dense_layers: List[int] = [8, 4],\n",
" reg_layers: List[int] = [0.01, 0.01],\n",
" activation_dense: str = \"relu\",\n",
") -> keras.Model:\n",
"\n",
" # input layer\n",
" user = Input(shape=(), dtype=\"int32\", name=\"user_id\")\n",
" item = Input(shape=(), dtype=\"int32\", name=\"item_id\")\n",
"\n",
" # embedding layers\n",
" mf_user_embedding = Embedding(\n",
" input_dim=number_of_users,\n",
" output_dim=latent_dim_mf,\n",
" name=\"mf_user_embedding\",\n",
" embeddings_initializer=\"RandomNormal\",\n",
" embeddings_regularizer=l2(reg_mf),\n",
" input_length=1,\n",
" )\n",
" mf_item_embedding = Embedding(\n",
" input_dim=number_of_items,\n",
" output_dim=latent_dim_mf,\n",
" name=\"mf_item_embedding\",\n",
" embeddings_initializer=\"RandomNormal\",\n",
" embeddings_regularizer=l2(reg_mf),\n",
" input_length=1,\n",
" )\n",
"\n",
" mlp_user_embedding = Embedding(\n",
" input_dim=number_of_users,\n",
" output_dim=latent_dim_mlp,\n",
" name=\"mlp_user_embedding\",\n",
" embeddings_initializer=\"RandomNormal\",\n",
" embeddings_regularizer=l2(reg_mlp),\n",
" input_length=1,\n",
" )\n",
" mlp_item_embedding = Embedding(\n",
" input_dim=number_of_items,\n",
" output_dim=latent_dim_mlp,\n",
" name=\"mlp_item_embedding\",\n",
" embeddings_initializer=\"RandomNormal\",\n",
" embeddings_regularizer=l2(reg_mlp),\n",
" input_length=1,\n",
" )\n",
"\n",
" # MF vector\n",
" mf_user_latent = Flatten()(mf_user_embedding(user))\n",
" mf_item_latent = Flatten()(mf_item_embedding(item))\n",
" mf_cat_latent = Multiply()([mf_user_latent, mf_item_latent])\n",
"\n",
" # MLP vector\n",
" mlp_user_latent = Flatten()(mlp_user_embedding(user))\n",
" mlp_item_latent = Flatten()(mlp_item_embedding(item))\n",
" mlp_cat_latent = Concatenate()([mlp_user_latent, mlp_item_latent])\n",
"\n",
" mlp_vector = mlp_cat_latent\n",
"\n",
" # build dense layers for model\n",
" for i in range(len(dense_layers)):\n",
" layer = Dense(\n",
" dense_layers[i],\n",
" activity_regularizer=l2(reg_layers[i]),\n",
" activation=activation_dense,\n",
" name=\"layer%d\" % i,\n",
" )\n",
" mlp_vector = layer(mlp_vector)\n",
"\n",
" predict_layer = Concatenate()([mf_cat_latent, mlp_vector])\n",
"\n",
" result = Dense(\n",
" 1, activation=\"sigmoid\", kernel_initializer=\"lecun_uniform\", name=\"interaction\"\n",
" )\n",
"\n",
" output = result(predict_layer)\n",
"\n",
" model = Model(\n",
" inputs=[user, item],\n",
" outputs=[output],\n",
" )\n",
"\n",
" return model"
],
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GQWaSOCrSwSl",
"outputId": "b2d27924-0954-48ed-812e-83b52b132ccd"
},
"source": [
"n_users, n_items = data[\"train\"].shape\n",
"ncf_model = create_ncf(n_users, n_items)\n",
"\n",
"ncf_model.compile(\n",
" optimizer=Adam(),\n",
" loss=\"binary_crossentropy\",\n",
" metrics=[\n",
" tf.keras.metrics.TruePositives(name=\"tp\"),\n",
" tf.keras.metrics.FalsePositives(name=\"fp\"),\n",
" tf.keras.metrics.TrueNegatives(name=\"tn\"),\n",
" tf.keras.metrics.FalseNegatives(name=\"fn\"),\n",
" tf.keras.metrics.BinaryAccuracy(name=\"accuracy\"),\n",
" tf.keras.metrics.Precision(name=\"precision\"),\n",
" tf.keras.metrics.Recall(name=\"recall\"),\n",
" tf.keras.metrics.AUC(name=\"auc\"),\n",
" ],\n",
")\n",
"ncf_model._name = \"neural_collaborative_filtering\"\n",
"ncf_model.summary()"
],
"execution_count": 29,
"outputs": [
{
"output_type": "stream",
"text": [
"Model: \"neural_collaborative_filtering\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"user_id (InputLayer) [(None,)] 0 \n",
"__________________________________________________________________________________________________\n",
"item_id (InputLayer) [(None,)] 0 \n",
"__________________________________________________________________________________________________\n",
"mlp_user_embedding (Embedding) (None, 32) 30176 user_id[0][0] \n",
"__________________________________________________________________________________________________\n",
"mlp_item_embedding (Embedding) (None, 32) 53824 item_id[0][0] \n",
"__________________________________________________________________________________________________\n",
"flatten_10 (Flatten) (None, 32) 0 mlp_user_embedding[0][0] \n",
"__________________________________________________________________________________________________\n",
"flatten_11 (Flatten) (None, 32) 0 mlp_item_embedding[0][0] \n",
"__________________________________________________________________________________________________\n",
"mf_user_embedding (Embedding) (None, 4) 3772 user_id[0][0] \n",
"__________________________________________________________________________________________________\n",
"mf_item_embedding (Embedding) (None, 4) 6728 item_id[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_4 (Concatenate) (None, 64) 0 flatten_10[0][0] \n",
" flatten_11[0][0] \n",
"__________________________________________________________________________________________________\n",
"flatten_8 (Flatten) (None, 4) 0 mf_user_embedding[0][0] \n",
"__________________________________________________________________________________________________\n",
"flatten_9 (Flatten) (None, 4) 0 mf_item_embedding[0][0] \n",
"__________________________________________________________________________________________________\n",
"layer0 (Dense) (None, 8) 520 concatenate_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"multiply_2 (Multiply) (None, 4) 0 flatten_8[0][0] \n",
" flatten_9[0][0] \n",
"__________________________________________________________________________________________________\n",
"layer1 (Dense) (None, 4) 36 layer0[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_5 (Concatenate) (None, 8) 0 multiply_2[0][0] \n",
" layer1[0][0] \n",
"__________________________________________________________________________________________________\n",
"interaction (Dense) (None, 1) 9 concatenate_5[0][0] \n",
"==================================================================================================\n",
"Total params: 95,065\n",
"Trainable params: 95,065\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KeATHv3FZA7Q"
},
"source": [
"## TF Dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "aOuwbTKfS6f1"
},
"source": [
"def make_tf_dataset(\n",
" df: pd.DataFrame,\n",
" targets: List[str],\n",
" val_split: float = 0.1,\n",
" batch_size: int = 512,\n",
" seed=42,\n",
"):\n",
" \"\"\"Make TensorFlow dataset from Pandas DataFrame.\n",
" :param df: input DataFrame - only contains features and target(s)\n",
" :param targets: list of columns names corresponding to targets\n",
" :param val_split: fraction of the data that should be used for validation\n",
" :param batch_size: batch size for training\n",
" :param seed: random seed for shuffling data - `None` won't shuffle the data\"\"\"\n",
"\n",
" n_val = round(df.shape[0] * val_split)\n",
" if seed:\n",
" # shuffle all the rows\n",
" x = df.sample(frac=1, random_state=seed).to_dict(\"series\")\n",
" else:\n",
" x = df.to_dict(\"series\")\n",
" y = dict()\n",
" for t in targets:\n",
" y[t] = x.pop(t)\n",
" ds = tf.data.Dataset.from_tensor_slices((x, y))\n",
"\n",
" ds_val = ds.take(n_val).batch(batch_size)\n",
" ds_train = ds.skip(n_val).batch(batch_size)\n",
" return ds_train, ds_val"
],
"execution_count": 30,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "F6QXgUQgTP8d"
},
"source": [
"# create train and validation datasets\n",
"ds_train, ds_val = make_tf_dataset(df_train, [\"interaction\"])"
],
"execution_count": 31,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "S0COxKxkZDh7"
},
"source": [
"## Model Training"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "o46oVj9eTTaq",
"outputId": "5f2f5f96-cc07-45c7-82dc-30ddfb79f7bb"
},
"source": [
"%%time\n",
"# define logs and callbacks\n",
"logdir = os.path.join(\"logs\", datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n",
"tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)\n",
"early_stopping_callback = tf.keras.callbacks.EarlyStopping(\n",
" monitor=\"val_loss\", patience=2\n",
")\n",
"\n",
"train_hist = ncf_model.fit(\n",
" ds_train,\n",
" validation_data=ds_val,\n",
" epochs=N_EPOCHS,\n",
" callbacks=[tensorboard_callback, early_stopping_callback],\n",
" verbose=1,\n",
")"
],
"execution_count": 32,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"2789/2789 [==============================] - 21s 7ms/step - loss: 0.2304 - tp: 1031.0000 - fp: 828.0000 - tn: 1359586.0000 - fn: 66068.0000 - accuracy: 0.9531 - precision: 0.5546 - recall: 0.0154 - auc: 0.7983 - val_loss: 0.1338 - val_tp: 882.0000 - val_fp: 356.0000 - val_tn: 150729.0000 - val_fn: 6646.0000 - val_accuracy: 0.9559 - val_precision: 0.7124 - val_recall: 0.1172 - val_auc: 0.9138\n",
"Epoch 2/10\n",
"2789/2789 [==============================] - 17s 6ms/step - loss: 0.1204 - tp: 13114.0000 - fp: 6911.0000 - tn: 1353503.0000 - fn: 53985.0000 - accuracy: 0.9573 - precision: 0.6549 - recall: 0.1954 - auc: 0.9243 - val_loss: 0.1164 - val_tp: 1709.0000 - val_fp: 878.0000 - val_tn: 150207.0000 - val_fn: 5819.0000 - val_accuracy: 0.9578 - val_precision: 0.6606 - val_recall: 0.2270 - val_auc: 0.9263\n",
"Epoch 3/10\n",
"2789/2789 [==============================] - 17s 6ms/step - loss: 0.1122 - tp: 16039.0000 - fp: 8355.0000 - tn: 1352059.0000 - fn: 51060.0000 - accuracy: 0.9584 - precision: 0.6575 - recall: 0.2390 - auc: 0.9327 - val_loss: 0.1132 - val_tp: 1905.0000 - val_fp: 975.0000 - val_tn: 150110.0000 - val_fn: 5623.0000 - val_accuracy: 0.9584 - val_precision: 0.6615 - val_recall: 0.2531 - val_auc: 0.9307\n",
"Epoch 4/10\n",
"2789/2789 [==============================] - 17s 6ms/step - loss: 0.1094 - tp: 17243.0000 - fp: 8896.0000 - tn: 1351518.0000 - fn: 49856.0000 - accuracy: 0.9588 - precision: 0.6597 - recall: 0.2570 - auc: 0.9368 - val_loss: 0.1120 - val_tp: 1928.0000 - val_fp: 1010.0000 - val_tn: 150075.0000 - val_fn: 5600.0000 - val_accuracy: 0.9583 - val_precision: 0.6562 - val_recall: 0.2561 - val_auc: 0.9324\n",
"Epoch 5/10\n",
"2789/2789 [==============================] - 17s 6ms/step - loss: 0.1081 - tp: 17658.0000 - fp: 9176.0000 - tn: 1351238.0000 - fn: 49441.0000 - accuracy: 0.9589 - precision: 0.6580 - recall: 0.2632 - auc: 0.9390 - val_loss: 0.1113 - val_tp: 1959.0000 - val_fp: 1035.0000 - val_tn: 150050.0000 - val_fn: 5569.0000 - val_accuracy: 0.9584 - val_precision: 0.6543 - val_recall: 0.2602 - val_auc: 0.9333\n",
"Epoch 6/10\n",
"2789/2789 [==============================] - 17s 6ms/step - loss: 0.1071 - tp: 18001.0000 - fp: 9424.0000 - tn: 1350990.0000 - fn: 49098.0000 - accuracy: 0.9590 - precision: 0.6564 - recall: 0.2683 - auc: 0.9406 - val_loss: 0.1108 - val_tp: 1984.0000 - val_fp: 1056.0000 - val_tn: 150029.0000 - val_fn: 5544.0000 - val_accuracy: 0.9584 - val_precision: 0.6526 - val_recall: 0.2635 - val_auc: 0.9343\n",
"Epoch 7/10\n",
"2789/2789 [==============================] - 17s 6ms/step - loss: 0.1063 - tp: 18290.0000 - fp: 9632.0000 - tn: 1350782.0000 - fn: 48809.0000 - accuracy: 0.9591 - precision: 0.6550 - recall: 0.2726 - auc: 0.9418 - val_loss: 0.1104 - val_tp: 2013.0000 - val_fp: 1071.0000 - val_tn: 150014.0000 - val_fn: 5515.0000 - val_accuracy: 0.9585 - val_precision: 0.6527 - val_recall: 0.2674 - val_auc: 0.9348\n",
"Epoch 8/10\n",
"2789/2789 [==============================] - 17s 6ms/step - loss: 0.1056 - tp: 18570.0000 - fp: 9744.0000 - tn: 1350670.0000 - fn: 48529.0000 - accuracy: 0.9592 - precision: 0.6559 - recall: 0.2768 - auc: 0.9428 - val_loss: 0.1101 - val_tp: 2057.0000 - val_fp: 1091.0000 - val_tn: 149994.0000 - val_fn: 5471.0000 - val_accuracy: 0.9586 - val_precision: 0.6534 - val_recall: 0.2732 - val_auc: 0.9353\n",
"Epoch 9/10\n",
"2789/2789 [==============================] - 18s 6ms/step - loss: 0.1051 - tp: 18870.0000 - fp: 9835.0000 - tn: 1350579.0000 - fn: 48229.0000 - accuracy: 0.9593 - precision: 0.6574 - recall: 0.2812 - auc: 0.9437 - val_loss: 0.1098 - val_tp: 2076.0000 - val_fp: 1111.0000 - val_tn: 149974.0000 - val_fn: 5452.0000 - val_accuracy: 0.9586 - val_precision: 0.6514 - val_recall: 0.2758 - val_auc: 0.9357\n",
"Epoch 10/10\n",
"2789/2789 [==============================] - 18s 7ms/step - loss: 0.1046 - tp: 19109.0000 - fp: 9916.0000 - tn: 1350498.0000 - fn: 47990.0000 - accuracy: 0.9594 - precision: 0.6584 - recall: 0.2848 - auc: 0.9443 - val_loss: 0.1095 - val_tp: 2114.0000 - val_fp: 1126.0000 - val_tn: 149959.0000 - val_fn: 5414.0000 - val_accuracy: 0.9588 - val_precision: 0.6525 - val_recall: 0.2808 - val_auc: 0.9363\n",
"CPU times: user 3min 46s, sys: 13.3 s, total: 3min 59s\n",
"Wall time: 3min 6s\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "SHcbr5aYUWSA"
},
"source": [
"%tensorboard --logdir logs"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "8cFvLAbLZHL9"
},
"source": [
"## Inference"
]
},
{
"cell_type": "code",
"metadata": {
"id": "2z5WMWIOVFKo"
},
"source": [
"long_test = wide_to_long(data[\"train\"], unique_ratings)\n",
"df_test = pd.DataFrame(long_test, columns=[\"user_id\", \"item_id\", \"interaction\"])\n",
"ds_test, _ = make_tf_dataset(df_test, [\"interaction\"], val_split=0, seed=None)"
],
"execution_count": 33,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Yo7BxW_rWHO3"
},
"source": [
"ncf_predictions = ncf_model.predict(ds_test)\n",
"df_test[\"ncf_predictions\"] = ncf_predictions"
],
"execution_count": 34,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 194
},
"id": "MVoy544bWJjH",
"outputId": "3174fd54-6625-488b-aa3b-ad83730fe5a0"
},
"source": [
"df_test.head()"
],
"execution_count": 35,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/a6224c040fa35dcf/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 0.5920354127883911,\n 'f': \"0.5920354127883911\",\n }],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 0.558495044708252,\n 'f': \"0.558495044708252\",\n }],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 19,\n 'f': \"19\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 0.12524035573005676,\n 'f': \"0.12524035573005676\",\n }],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 20,\n 'f': \"20\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 0.13628098368644714,\n 'f': \"0.13628098368644714\",\n }],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 26,\n 'f': \"26\",\n },\n{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 0.12594255805015564,\n 'f': \"0.12594255805015564\",\n }]],\n columns: [[\"number\", \"index\"], [\"number\", \"user_id\"], [\"number\", \"item_id\"], [\"number\", \"interaction\"], [\"number\", \"ncf_predictions\"]],\n columnOptions: [{\"width\": \"1px\", \"className\": \"index_column\"}],\n rowsPerPage: 25,\n helpUrl: \"https://colab.research.google.com/notebooks/data_table.ipynb\",\n suppressOutputScrolling: true,\n minimumWidth: undefined,\n });\n ",
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" user_id | \n",
" item_id | \n",
" interaction | \n",
" ncf_predictions | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0 | \n",
" 7 | \n",
" 0 | \n",
" 0.592035 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0 | \n",
" 10 | \n",
" 0 | \n",
" 0.558495 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0 | \n",
" 19 | \n",
" 0 | \n",
" 0.125240 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0 | \n",
" 20 | \n",
" 0 | \n",
" 0.136281 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0 | \n",
" 26 | \n",
" 0 | \n",
" 0.125943 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" user_id item_id interaction ncf_predictions\n",
"0 0 7 0 0.592035\n",
"1 0 10 0 0.558495\n",
"2 0 19 0 0.125240\n",
"3 0 20 0 0.136281\n",
"4 0 26 0 0.125943"
]
},
"metadata": {
"tags": []
},
"execution_count": 35
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "03nWlokgX1IX"
},
"source": [
"> Tip: sanity checks. stop execution if low standard deviation (all recommendations are the same)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Gq4B3-uJXxAp"
},
"source": [
"std = df_test.describe().loc[\"std\", \"ncf_predictions\"]\n",
"if std < 0.01:\n",
" raise ValueError(\"Model predictions have standard deviation of less than 1e-2.\")"
],
"execution_count": 45,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OlcxHaCiYCPN",
"outputId": "4e3b66cc-fcc9-486c-9ce7-0d8cdd3a71b7"
},
"source": [
"data[\"ncf_predictions\"] = df_test.pivot(\n",
" index=\"user_id\", columns=\"item_id\", values=\"ncf_predictions\"\n",
").values\n",
"print(\"Neural collaborative filtering predictions\")\n",
"print(data[\"ncf_predictions\"][:10, :4])"
],
"execution_count": 46,
"outputs": [
{
"output_type": "stream",
"text": [
"Neural collaborative filtering predictions\n",
"[[7.5572348e-01 3.5788852e-01 2.1718740e-01 7.2733581e-01]\n",
" [1.5290251e-01 1.6686320e-03 3.1808287e-02 2.9422939e-03]\n",
" [2.7413875e-02 2.9927492e-04 1.8543899e-03 1.5550852e-04]\n",
" [7.2024286e-02 9.6502900e-04 1.9236505e-03 8.1184506e-04]\n",
" [6.6356444e-01 2.3927736e-01 1.2042263e-01 3.8960028e-01]\n",
" [4.0644848e-01 3.1249046e-02 1.5729398e-02 4.7397730e-01]\n",
" [7.1385705e-01 6.5075237e-01 2.0070928e-01 8.8203180e-01]\n",
" [4.3854299e-01 5.6260496e-02 3.0884445e-03 6.0800165e-02]\n",
" [1.6601676e-01 6.6775084e-04 3.9201975e-04 3.1027198e-03]\n",
" [4.4480303e-01 8.1428647e-02 3.0222714e-02 4.7680125e-01]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MEiYaieSYFeh",
"outputId": "fe7cab95-8633-4bd0-fbba-8a6bc399fa07"
},
"source": [
"precision_ncf = tf.keras.metrics.Precision(top_k=TOP_K)\n",
"recall_ncf = tf.keras.metrics.Recall(top_k=TOP_K)\n",
"\n",
"precision_ncf.update_state(data[\"test\"], data[\"ncf_predictions\"])\n",
"recall_ncf.update_state(data[\"test\"], data[\"ncf_predictions\"])\n",
"print(\n",
" f\"At K = {TOP_K}, we have a precision of {precision_ncf.result().numpy():.5f}, and a recall of {recall_ncf.result().numpy():.5f}\",\n",
")"
],
"execution_count": 50,
"outputs": [
{
"output_type": "stream",
"text": [
"At K = 5, we have a precision of 0.10838, and a recall of 0.06474\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0TRwpTFyZK1O"
},
"source": [
"## Comparison with LightFM (WARP loss) model"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YxFxJvZmYIBi",
"outputId": "f32bd42c-3da1-49d0-e3e5-aba730a305dd"
},
"source": [
"# LightFM model\n",
"def norm(x: float) -> float:\n",
" \"\"\"Normalize vector\"\"\"\n",
" return (x - np.min(x)) / np.ptp(x)\n",
"\n",
"\n",
"lightfm_model = LightFM(loss=\"warp\")\n",
"lightfm_model.fit(sparse.coo_matrix(data[\"train\"]), epochs=N_EPOCHS)\n",
"\n",
"lightfm_predictions = lightfm_model.predict(\n",
" df_test[\"user_id\"].values, df_test[\"item_id\"].values\n",
")\n",
"df_test[\"lightfm_predictions\"] = lightfm_predictions\n",
"wide_predictions = df_test.pivot(\n",
" index=\"user_id\", columns=\"item_id\", values=\"lightfm_predictions\"\n",
").values\n",
"data[\"lightfm_predictions\"] = norm(wide_predictions)\n",
"\n",
"# compute the metrics\n",
"precision_lightfm = tf.keras.metrics.Precision(top_k=TOP_K)\n",
"recall_lightfm = tf.keras.metrics.Recall(top_k=TOP_K)\n",
"precision_lightfm.update_state(data[\"test\"], data[\"lightfm_predictions\"])\n",
"recall_lightfm.update_state(data[\"test\"], data[\"lightfm_predictions\"])\n",
"print(\n",
" f\"At K = {TOP_K}, we have a precision of {precision_lightfm.result().numpy():.5f}, and a recall of {recall_lightfm.result().numpy():.5f}\",\n",
")"
],
"execution_count": 49,
"outputs": [
{
"output_type": "stream",
"text": [
"At K = 5, we have a precision of 0.10944, and a recall of 0.06537\n"
],
"name": "stdout"
}
]
}
]
}