{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-13-ae-ml.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/P508050%20%7C%20Autoencoder%20RecSys%20Models%20on%20ML-1m.ipynb","timestamp":1644611319207}],"collapsed_sections":[],"mount_file_id":"1XD0NS0Y4bQ3fES70_DavV6DXMqW6NaMS","authorship_tag":"ABX9TyMi9lp1Ic3GEuacq/CzDbFz"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"0bc4610585dc4b6e89d3bc4598293dc4":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_view_name":"HBoxView","_dom_classes":[],"_model_name":"HBoxModel","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.5.0","box_style":"","layout":"IPY_MODEL_0e1c65f3d0a04bbfa4d53d8287faf7b0","_model_module":"@jupyter-widgets/controls","children":["IPY_MODEL_707a58e866a9416eaf94f9aea50bb9f5","IPY_MODEL_35c412197fb6462b9d03d9fe8f65e3cb","IPY_MODEL_a89f0935d9c846ae92716177f9a3da31"]}},"0e1c65f3d0a04bbfa4d53d8287faf7b0":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"707a58e866a9416eaf94f9aea50bb9f5":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_ae8bf7ee34674a0e9fe0b9aef3b8063d","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":"","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_eea7588021c24de29cd5018918139907"}},"35c412197fb6462b9d03d9fe8f65e3cb":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_view_name":"ProgressView","style":"IPY_MODEL_b997a6a3013f4f649bbba23c10d1bbe3","_dom_classes":[],"description":"","_model_name":"FloatProgressModel","bar_style":"success","max":1,"_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":1,"_view_count":null,"_view_module_version":"1.5.0","orientation":"horizontal","min":0,"description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_c04be49d501e4f70a3ed182c6253a336"}},"a89f0935d9c846ae92716177f9a3da31":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_ee609560d3104683b742c9ed66e782d7","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":" 100/? [02:07<00:00, 1.26s/it]","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_4e3a404483954fa3916ac4bbf4bd4929"}},"ae8bf7ee34674a0e9fe0b9aef3b8063d":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"eea7588021c24de29cd5018918139907":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"b997a6a3013f4f649bbba23c10d1bbe3":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"ProgressStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","bar_color":null,"_model_module":"@jupyter-widgets/controls"}},"c04be49d501e4f70a3ed182c6253a336":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":"20px","min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"ee609560d3104683b742c9ed66e782d7":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"4e3a404483954fa3916ac4bbf4bd4929":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"87c9bc69718d40c0acd55be8b3d028c3":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_view_name":"HBoxView","_dom_classes":[],"_model_name":"HBoxModel","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.5.0","box_style":"","layout":"IPY_MODEL_592a360bb8f74fd690223f2e4bb14f0e","_model_module":"@jupyter-widgets/controls","children":["IPY_MODEL_291c4790efeb43b68b78eaef6a99ced7","IPY_MODEL_8d0ed0b8a3e94733aad8d23ae0265d3c","IPY_MODEL_4e554fb4f30a48faa6859d457d08ba1e"]}},"592a360bb8f74fd690223f2e4bb14f0e":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"291c4790efeb43b68b78eaef6a99ced7":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_32727136c19e46558e6016eea4fa6fec","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":"","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_fa3e40aea0f14e4db3be250b51c8ede0"}},"8d0ed0b8a3e94733aad8d23ae0265d3c":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_view_name":"ProgressView","style":"IPY_MODEL_7a486090a89343168b6c82943865733c","_dom_classes":[],"description":"","_model_name":"FloatProgressModel","bar_style":"success","max":1,"_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":1,"_view_count":null,"_view_module_version":"1.5.0","orientation":"horizontal","min":0,"description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_57b0ddec4be6490b8da427622de4ebac"}},"4e554fb4f30a48faa6859d457d08ba1e":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_2d1191aacad24eba93d354e26e6ee37b","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":" 100/? [03:04<00:00, 1.84s/it]","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_9940ceebfb384b6da8e6d5deffcce3a4"}},"32727136c19e46558e6016eea4fa6fec":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"fa3e40aea0f14e4db3be250b51c8ede0":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"7a486090a89343168b6c82943865733c":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"ProgressStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","bar_color":null,"_model_module":"@jupyter-widgets/controls"}},"57b0ddec4be6490b8da427622de4ebac":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":"20px","min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"2d1191aacad24eba93d354e26e6ee37b":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"9940ceebfb384b6da8e6d5deffcce3a4":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}}}}},"cells":[{"cell_type":"markdown","source":["# Autoencoder RecSys Models on ML-1m"],"metadata":{"id":"tT6lcgsxsYzT"}},{"cell_type":"markdown","source":["## Setup"],"metadata":{"id":"MugCaZzrsrM9"}},{"cell_type":"code","metadata":{"id":"FwYoeMVyJvzL"},"source":["import numpy as np\n","import pandas as pd\n","\n","import seaborn as sns\n","import matplotlib.pyplot as plt\n","\n","import os,sys,inspect\n","import gc\n","from tqdm.notebook import tqdm\n","import random\n","import heapq\n","\n","from sklearn.preprocessing import LabelEncoder\n","from scipy.sparse import csr_matrix\n","\n","from tensorflow import keras\n","import tensorflow as tf\n","from tensorflow.keras import optimizers, callbacks, layers, losses\n","from tensorflow.keras.layers import Dense, Concatenate, Activation, Add, BatchNormalization, Dropout, Input, Embedding, Flatten, Multiply\n","from tensorflow.keras.models import Model, Sequential, load_model"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"sB_OJdPrYZSK"},"source":["SEED = 42\n","np.random.seed(SEED)\n","tf.random.set_seed(SEED)\n","os.environ['PYTHONHASHSEED']=str(SEED)\n","random.seed(SEED)\n","gpus = tf.config.experimental.list_physical_devices('GPU')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"0fxfIpXLTe-B"},"source":["if gpus:\n"," try:\n"," tf.config.experimental.set_memory_growth(gpus[0], True)\n"," except RuntimeError as e:\n"," print(e)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"EpIyG5zgtaai","executionInfo":{"status":"ok","timestamp":1639716148938,"user_tz":-330,"elapsed":28,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"9d5c6806-8dde-45d4-a715-1639547315ab"},"source":["print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Num GPUs Available: 0\n"]}]},{"cell_type":"code","source":["!wget -q --show-progress https://files.grouplens.org/datasets/movielens/ml-1m.zip\n","!unzip ml-1m.zip"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"UlyLzNq1sooD","executionInfo":{"status":"ok","timestamp":1639716176022,"user_tz":-330,"elapsed":1132,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"bdabe7d6-81af-4772-c63e-7b466287e619"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["ml-1m.zip 100%[===================>] 5.64M 25.0MB/s in 0.2s \n","Archive: ml-1m.zip\n"," creating: ml-1m/\n"," inflating: ml-1m/movies.dat \n"," inflating: ml-1m/ratings.dat \n"," inflating: ml-1m/README \n"," inflating: ml-1m/users.dat \n"]}]},{"cell_type":"code","metadata":{"id":"4DbVJCtKPuBF"},"source":["def mish(x):\n"," return x*tf.math.tanh(tf.math.softplus(x))\n","\n","def leakyrelu(x, factor=0.2):\n"," return tf.maximum(x, factor*x)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"XOfeaOEgLsGG"},"source":["def load_data(filepath, threshold=0):\n"," df = pd.read_csv(filepath,\n"," sep=\"::\",\n"," header=None,\n"," engine='python',\n"," names=['userId', 'movieId', 'rating', 'time'])\n"," df = df.drop('time', axis=1)\n"," df['userId'] = df['userId'].astype(int)\n"," df['movieId'] = df['movieId'].astype(int)\n"," df['rating'] = df['rating'].astype(float)\n"," \n"," df = df[['userId', 'movieId', 'rating']]\n"," if threshold > 0:\n"," df['rating'] = np.where(df['rating']>threshold, 1, 0) \n"," else:\n"," df['rating'] = 1.\n"," m_codes = df['movieId'].astype('category').cat.codes\n"," u_codes = df['userId'].astype('category').cat.codes\n"," df['movieId'] = m_codes\n"," df['userId'] = u_codes\n","\n"," return df\n"," \n","\n","def add_negative(df, uiid, times=4):\n"," df_ = df.copy()\n"," user_id = df_['userId'].unique()\n"," item_id = df_['movieId'].unique()\n"," \n"," for i in tqdm(user_id):\n"," cnt = 0\n"," n = len(df_[df_['userId']==i])\n"," n_negative = min(n*times, len(item_id)-n-1)\n"," available_negative = list(set(uiid) - set(df[df['userId']==i]['movieId'].values))\n"," \n"," new = np.random.choice(available_negative, n_negative, replace=False)\n"," new = [[i, j, 0] for j in new]\n"," df_ = df_.append(pd.DataFrame(new, columns=df.columns), ignore_index=True)\n"," \n"," return df_\n","\n","def extract_from_df(df, n_positive, n_negative):\n"," df_ = df.copy()\n"," rtd = []\n"," \n"," user_id = df['userId'].unique()\n"," \n"," for i in tqdm(user_id):\n"," rtd += list(np.random.choice(df[df['userId']==i][df['rating']==1]['movieId'].index, n_positive, replace=False))\n"," rtd += list(np.random.choice(df[df['userId']==i][df['rating']==0]['movieId'].index, n_negative, replace=False))\n"," \n"," return rtd"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"GFQjwqAPMH8C"},"source":["def eval_NDCG(true, pred):\n"," top_k = pred\n","\n"," for i, item in enumerate(top_k, 1):\n"," if item == true:\n"," return 1 / np.log2(i+1)\n"," return 0"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"S78F5a7AYgQz"},"source":["## CDAE"]},{"cell_type":"markdown","metadata":{"id":"1sAQYk5eYv1j"},"source":["### Load data"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":215},"id":"rCg36MGFTue5","executionInfo":{"status":"ok","timestamp":1639716202713,"user_tz":-330,"elapsed":5930,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"d87757a9-77d6-4c1f-e057-5adab8f37e60"},"source":["df = load_data('./ml-1m/ratings.dat', threshold=3)\n","df.head()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n","
\n","
\n","
\n","\n","
\n"," \n"," \n"," | \n"," userId | \n"," movieId | \n"," rating | \n","
\n"," \n"," \n"," \n"," | 0 | \n"," 0 | \n"," 1104 | \n"," 1 | \n","
\n"," \n"," | 1 | \n"," 0 | \n"," 639 | \n"," 0 | \n","
\n"," \n"," | 2 | \n"," 0 | \n"," 853 | \n"," 0 | \n","
\n"," \n"," | 3 | \n"," 0 | \n"," 3177 | \n"," 1 | \n","
\n"," \n"," | 4 | \n"," 0 | \n"," 2162 | \n"," 1 | \n","
\n"," \n","
\n","
\n","
\n"," \n"," \n","\n"," \n","
\n","
\n"," "],"text/plain":[" userId movieId rating\n","0 0 1104 1\n","1 0 639 0\n","2 0 853 0\n","3 0 3177 1\n","4 0 2162 1"]},"metadata":{},"execution_count":9}]},{"cell_type":"markdown","metadata":{"id":"uo-H-H5yZjSI"},"source":["### Preprocessing"]},{"cell_type":"code","metadata":{"id":"4_0IEsjWM4nI"},"source":["df = df[df['rating']==1].reset_index(drop=True)\n","tdf = pd.pivot_table(df, index='userId', values='rating', columns='movieId').fillna(0)\n","\n","cnt = tdf.sum(1)\n","df = df[df['userId'].isin(np.where(cnt >= 10)[0])].reset_index(drop=True)\n","tdf = pd.pivot_table(df, index='userId', values='rating', columns='movieId').fillna(0)\n","tdf.iloc[:,:] = 0\n","\n","test_idx = []\n","for i in tdf.index:\n"," test_idx += list(np.random.choice(df[df['userId']==i].index, 1))\n"," \n","train = df.loc[list(set(df.index)-set(test_idx)),:]\n","test = df.loc[test_idx, :]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":424},"id":"QCEA3JWYNgHT","executionInfo":{"status":"ok","timestamp":1639716243239,"user_tz":-330,"elapsed":49,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"3077ce6b-b907-4e0d-b2b2-94c6e019dc42"},"source":["df"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n"," \n","
\n","
\n","\n","
\n"," \n"," \n"," | \n"," userId | \n"," movieId | \n"," rating | \n","
\n"," \n"," \n"," \n"," | 0 | \n"," 0 | \n"," 1104 | \n"," 1 | \n","
\n"," \n"," | 1 | \n"," 0 | \n"," 3177 | \n"," 1 | \n","
\n"," \n"," | 2 | \n"," 0 | \n"," 2162 | \n"," 1 | \n","
\n"," \n"," | 3 | \n"," 0 | \n"," 1195 | \n"," 1 | \n","
\n"," \n"," | 4 | \n"," 0 | \n"," 2599 | \n"," 1 | \n","
\n"," \n"," | ... | \n"," ... | \n"," ... | \n"," ... | \n","
\n"," \n"," | 570512 | \n"," 6037 | \n"," 346 | \n"," 1 | \n","
\n"," \n"," | 570513 | \n"," 6037 | \n"," 1120 | \n"," 1 | \n","
\n"," \n"," | 570514 | \n"," 6037 | \n"," 1133 | \n"," 1 | \n","
\n"," \n"," | 570515 | \n"," 6037 | \n"," 1204 | \n"," 1 | \n","
\n"," \n"," | 570516 | \n"," 6037 | \n"," 1007 | \n"," 1 | \n","
\n"," \n","
\n","
570517 rows × 3 columns
\n","
\n","
\n"," \n"," \n","\n"," \n","
\n","
\n"," "],"text/plain":[" userId movieId rating\n","0 0 1104 1\n","1 0 3177 1\n","2 0 2162 1\n","3 0 1195 1\n","4 0 2599 1\n","... ... ... ...\n","570512 6037 346 1\n","570513 6037 1120 1\n","570514 6037 1133 1\n","570515 6037 1204 1\n","570516 6037 1007 1\n","\n","[570517 rows x 3 columns]"]},"metadata":{},"execution_count":11}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Ov0ZmaZabiVe","executionInfo":{"status":"ok","timestamp":1639716249445,"user_tz":-330,"elapsed":425,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"f4c9d415-cd86-4274-8d0e-dca78697f41d"},"source":["df.shape, train.shape, test.shape"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["((570517, 3), (564569, 3), (5948, 3))"]},"metadata":{},"execution_count":12}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":470},"id":"AyFVjZiUNtBV","executionInfo":{"status":"ok","timestamp":1630835048249,"user_tz":-330,"elapsed":51728,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"e9cbcb22-0779-4089-a265-085b72948158"},"source":["for uid, iid in zip(train['userId'].values, train['movieId'].values):\n"," tdf.loc[uid, iid] = 1\n","train = tdf.copy()\n","train"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n","\n","
\n"," \n"," \n"," | movieId | \n"," 0 | \n"," 1 | \n"," 2 | \n"," 3 | \n"," 4 | \n"," 5 | \n"," 6 | \n"," 7 | \n"," 8 | \n"," 9 | \n"," 10 | \n"," 11 | \n"," 12 | \n"," 13 | \n"," 14 | \n"," 15 | \n"," 16 | \n"," 17 | \n"," 18 | \n"," 19 | \n"," 20 | \n"," 21 | \n"," 22 | \n"," 23 | \n"," 24 | \n"," 25 | \n"," 26 | \n"," 27 | \n"," 28 | \n"," 29 | \n"," 30 | \n"," 31 | \n"," 32 | \n"," 33 | \n"," 34 | \n"," 35 | \n"," 36 | \n"," 37 | \n"," 38 | \n"," 39 | \n"," ... | \n"," 3666 | \n"," 3667 | \n"," 3668 | \n"," 3669 | \n"," 3670 | \n"," 3671 | \n"," 3672 | \n"," 3673 | \n"," 3674 | \n"," 3675 | \n"," 3676 | \n"," 3677 | \n"," 3678 | \n"," 3679 | \n"," 3680 | \n"," 3681 | \n"," 3682 | \n"," 3683 | \n"," 3684 | \n"," 3685 | \n"," 3686 | \n"," 3687 | \n"," 3688 | \n"," 3689 | \n"," 3690 | \n"," 3691 | \n"," 3692 | \n"," 3693 | \n"," 3694 | \n"," 3695 | \n"," 3696 | \n"," 3697 | \n"," 3698 | \n"," 3699 | \n"," 3700 | \n"," 3701 | \n"," 3702 | \n"," 3703 | \n"," 3704 | \n"," 3705 | \n","
\n"," \n"," | userId | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n","
\n"," \n"," \n"," \n"," | 0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 1 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 2 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 3 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 4 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n","
\n"," \n"," | 6033 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 6034 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 6035 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 6036 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 6037 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n","
\n","
5948 rows × 3533 columns
\n","
"],"text/plain":["movieId 0 1 2 3 4 5 ... 3700 3701 3702 3703 3704 3705\n","userId ... \n","0 1.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","1 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","2 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","3 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","4 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","... ... ... ... ... ... ... ... ... ... ... ... ... ...\n","6033 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","6034 1.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","6035 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","6036 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","6037 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","\n","[5948 rows x 3533 columns]"]},"metadata":{},"execution_count":15}]},{"cell_type":"markdown","metadata":{"id":"_br-gyTNd40x"},"source":["### Model architecture"]},{"cell_type":"code","metadata":{"id":"UK-_v59OYjHz"},"source":["class CDAE(tf.keras.models.Model):\n"," def __init__(self, input_dim, latent_dim, n_user, lamda=1e-4):\n"," super().__init__()\n"," self.input_dim = input_dim\n"," self.latent_dim = latent_dim\n"," self.lamda = lamda\n"," self.n_user = n_user\n"," self.embedding = Embedding(n_user, latent_dim, ) \n","\n"," self.model = self.build()\n","\n"," def compile(self, optimizer, loss_fn=None):\n"," super().compile()\n"," self.optimizer = optimizer\n"," self.loss_fn = loss_fn\n","\n"," \n"," def build(self):\n"," self.encoder = self.build_encoder()\n"," self.decoder = self.build_decoder()\n"," \n"," rating = Input(shape=(self.input_dim, ), name='rating_input')\n"," user_id = Input(shape=(1, ), name='user_input')\n"," \n"," emb = self.embedding(user_id)\n"," emb = tf.squeeze(emb, 1)\n"," enc = self.encoder(rating) + emb\n"," enc = tf.nn.tanh(enc)\n"," outputs = self.decoder(enc)\n"," \n"," return Model([rating, user_id], outputs)\n"," \n"," def build_encoder(self):\n"," inputs = Input(shape = (self.input_dim, ))\n"," \n"," encoder = Sequential()\n"," encoder.add(Dropout(0.2))\n"," encoder.add(Dense(self.latent_dim, activation='tanh'))\n"," \n"," outputs = encoder(inputs)\n"," \n"," return Model(inputs, outputs)\n"," \n"," def build_decoder(self):\n"," inputs = Input(shape = (self.latent_dim, ))\n"," \n"," encoder = Sequential()\n"," encoder.add(Dense(self.input_dim, activation='sigmoid'))\n"," \n"," outputs = encoder(inputs)\n"," \n"," return Model(inputs, outputs)\n"," \n"," def train_step(self, data):\n"," x = data['rating']\n"," user_ids = data['id']\n"," with tf.GradientTape() as tape:\n"," pred = self.model([x, user_ids])\n"," \n"," rec_loss = tf.losses.binary_crossentropy(x, pred)\n"," loss = rec_loss\n","\n"," grads = tape.gradient(loss, self.model.trainable_weights)\n"," self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))\n"," \n"," return {'loss': loss}"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-R-qDHX_dD3T"},"source":["### Training"]},{"cell_type":"code","metadata":{"id":"Fjtaq3RkcIO6","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630835242291,"user_tz":-330,"elapsed":194051,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"a69c188d-f725-4553-ee6d-d6df3484f19d"},"source":["loader = tf.data.Dataset.from_tensor_slices({'rating': train.values, 'id': np.arange(len(train))})\n","loader = loader.batch(32, drop_remainder=True).shuffle(len(train))\n","model = CDAE(train.shape[1], 200, len(train))\n","model.compile(optimizer=tf.optimizers.Adam())\n","model.fit(loader, epochs=25)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch 1/25\n","185/185 [==============================] - 7s 31ms/step - loss: 0.1558\n","Epoch 2/25\n","185/185 [==============================] - 6s 31ms/step - loss: 0.1036\n","Epoch 3/25\n","185/185 [==============================] - 6s 31ms/step - loss: 0.1007\n","Epoch 4/25\n","185/185 [==============================] - 6s 29ms/step - loss: 0.0972\n","Epoch 5/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0922\n","Epoch 6/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0884\n","Epoch 7/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0857\n","Epoch 8/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0833\n","Epoch 9/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0810\n","Epoch 10/25\n","185/185 [==============================] - 5s 29ms/step - loss: 0.0793\n","Epoch 11/25\n","185/185 [==============================] - 6s 29ms/step - loss: 0.0776\n","Epoch 12/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0763\n","Epoch 13/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0747\n","Epoch 14/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0735\n","Epoch 15/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0722\n","Epoch 16/25\n","185/185 [==============================] - 6s 31ms/step - loss: 0.0709\n","Epoch 17/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0699\n","Epoch 18/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0687\n","Epoch 19/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0678\n","Epoch 20/25\n","185/185 [==============================] - 6s 31ms/step - loss: 0.0668\n","Epoch 21/25\n","185/185 [==============================] - 6s 31ms/step - loss: 0.0660\n","Epoch 22/25\n","185/185 [==============================] - 6s 31ms/step - loss: 0.0649\n","Epoch 23/25\n","185/185 [==============================] - 6s 31ms/step - loss: 0.0640\n","Epoch 24/25\n","185/185 [==============================] - 6s 30ms/step - loss: 0.0633\n","Epoch 25/25\n","185/185 [==============================] - 6s 31ms/step - loss: 0.0624\n"]},{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":17}]},{"cell_type":"markdown","metadata":{"id":"NfDBp4XndBDM"},"source":["### Evaluation"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":66,"referenced_widgets":["0bc4610585dc4b6e89d3bc4598293dc4","0e1c65f3d0a04bbfa4d53d8287faf7b0","707a58e866a9416eaf94f9aea50bb9f5","35c412197fb6462b9d03d9fe8f65e3cb","a89f0935d9c846ae92716177f9a3da31","ae8bf7ee34674a0e9fe0b9aef3b8063d","eea7588021c24de29cd5018918139907","b997a6a3013f4f649bbba23c10d1bbe3","c04be49d501e4f70a3ed182c6253a336","ee609560d3104683b742c9ed66e782d7","4e3a404483954fa3916ac4bbf4bd4929"]},"id":"4vR5WPgGOwbn","executionInfo":{"status":"ok","timestamp":1630835457454,"user_tz":-330,"elapsed":128721,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"6ad93a00-03d5-4158-bb59-0babb6fac1b5"},"source":["top_k = 10\n","np.random.seed(42)\n","\n","scores = []\n","for idx, i in tqdm(enumerate(np.random.choice(train.index, 100))):\n"," item_to_pred = {item: pred for item, pred in zip(train.columns, model.model.predict([train.values, np.arange(len(train))])[idx])}\n"," test_ = test[(test['userId']==i) & (test['rating']==1)]['movieId'].values\n"," items = list(np.random.choice(list(filter(lambda x: x not in np.argwhere(train.values[idx]).flatten(), item_to_pred.keys())), 100)) + list(test_)\n"," top_k_items = heapq.nlargest(top_k, items, key=item_to_pred.get)\n"," \n"," score = eval_NDCG(test_, top_k_items)\n"," scores.append(score)\n"," \n","np.mean(scores)"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"0bc4610585dc4b6e89d3bc4598293dc4","version_minor":0,"version_major":2},"text/plain":["0it [00:00, ?it/s]"]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":["0.2810483955912976"]},"metadata":{},"execution_count":21}]},{"cell_type":"markdown","metadata":{"id":"cqBW2VetZvne"},"source":["## EASE"]},{"cell_type":"markdown","metadata":{"id":"R-XpG8cMZvnf"},"source":["### Load data"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":204},"id":"D_oXGFDhZvnf","executionInfo":{"status":"ok","timestamp":1630834671821,"user_tz":-330,"elapsed":5997,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"66c7d829-c8f3-4e7b-f03c-1c832040c3f5"},"source":["df = load_data('./ml-1m/ratings.dat', threshold=3)\n","df.head()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n","\n","
\n"," \n"," \n"," | \n"," userId | \n"," movieId | \n"," rating | \n","
\n"," \n"," \n"," \n"," | 0 | \n"," 0 | \n"," 1104 | \n"," 1 | \n","
\n"," \n"," | 1 | \n"," 0 | \n"," 639 | \n"," 0 | \n","
\n"," \n"," | 2 | \n"," 0 | \n"," 853 | \n"," 0 | \n","
\n"," \n"," | 3 | \n"," 0 | \n"," 3177 | \n"," 1 | \n","
\n"," \n"," | 4 | \n"," 0 | \n"," 2162 | \n"," 1 | \n","
\n"," \n","
\n","
"],"text/plain":[" userId movieId rating\n","0 0 1104 1\n","1 0 639 0\n","2 0 853 0\n","3 0 3177 1\n","4 0 2162 1"]},"metadata":{},"execution_count":13}]},{"cell_type":"markdown","metadata":{"id":"OfHD5JjfZvnh"},"source":["### Preprocessing"]},{"cell_type":"code","metadata":{"id":"KRzrwrHpZvnh"},"source":["test_idx = []\n","user_id = df\n","for i in df['userId'].unique():\n"," test_idx += list(np.random.choice(df[df['userId']==i].index, 1))\n"," \n","train = df.iloc[list(set(df.index)-set(test_idx)),:]\n","test = df.iloc[test_idx, :]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"5hbeGQeyZvni","executionInfo":{"status":"ok","timestamp":1630834902542,"user_tz":-330,"elapsed":16,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"19ad5720-8a92-47bc-b018-b360fe911342"},"source":["df.shape, train.shape, test.shape"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["((570517, 3), (564569, 3), (5948, 3))"]},"metadata":{},"execution_count":25}]},{"cell_type":"markdown","metadata":{"id":"REsycBTbZvni"},"source":["### Model architecture"]},{"cell_type":"code","metadata":{"id":"N8qLutMpZvnj"},"source":["class EASE:\n"," def __init__(self):\n"," self.user_enc = LabelEncoder()\n"," self.item_enc = LabelEncoder()\n","\n"," def _get_users_and_items(self, df):\n"," users = self.user_enc.fit_transform(df.loc[:, 'userId'])\n"," items = self.item_enc.fit_transform(df.loc[:, 'movieId'])\n"," return users, items\n","\n"," def fit(self, df, lambda_: float = 0.5, implicit=True):\n"," \"\"\"\n"," df: pandas.DataFrame with columns user_id, item_id and (rating)\n"," lambda_: l2-regularization term\n"," implicit: if True, ratings are ignored and taken as 1, else normalized ratings are used\n"," \"\"\"\n"," users, items = self._get_users_and_items(df)\n"," values = np.ones(df.shape[0]) if implicit else df['rating'].to_numpy() / df['rating'].max()\n","\n"," X = csr_matrix((values, (users, items)))\n"," self.X = X\n","\n"," G = X.T.dot(X).toarray()\n"," diagIndices = np.diag_indices(G.shape[0])\n"," G[diagIndices] += lambda_\n"," P = np.linalg.inv(G)\n"," B = P / (-np.diag(P))\n"," B[diagIndices] = 0\n","\n"," self.B = B\n"," self.pred = X.dot(B)\n","\n"," def predict(self, train, users, items, k):\n"," df = pd.DataFrame()\n"," items = self.item_enc.transform(items)\n"," dd = train.loc[train['userId'].isin(users)]\n"," dd['ci'] = self.item_enc.transform(dd['movieId'])\n"," dd['cu'] = self.user_enc.transform(dd['userId'])\n"," g = dd.groupby('userId')\n"," for user, group in tqdm(g):\n"," watched = set(group['ci'])\n"," candidates = [item for item in items if item not in watched]\n"," u = group['cu'].iloc[0]\n"," pred = np.take(self.pred[u, :], candidates)\n"," res = np.argpartition(pred, -k)[-k:]\n"," r = pd.DataFrame({\n"," \"userId\": [user] * len(res),\n"," \"movieId\": np.take(candidates, res),\n"," \"score\": np.take(pred, res)\n"," }).sort_values('score', ascending=False)\n"," df = df.append(r, ignore_index=True)\n"," df['movieId'] = self.item_enc.inverse_transform(df['movieId'])\n"," return df"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"KwfcFM2yZvnj"},"source":["### Training"]},{"cell_type":"code","metadata":{"id":"zEzkqTbWVCzD"},"source":["ease = EASE()\n","ease.fit(train)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"5I5877tBU7yB","executionInfo":{"status":"ok","timestamp":1630834999318,"user_tz":-330,"elapsed":644,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"dcb8117a-e70e-42b8-d713-4da7dd8cc4e4"},"source":["uid = 0\n","ease.user_enc.inverse_transform([0])[0]"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["0"]},"metadata":{},"execution_count":31}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ZdpFQsXgV1Mp","executionInfo":{"status":"ok","timestamp":1630835007002,"user_tz":-330,"elapsed":551,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"80c3d15c-99a2-4603-b91f-f5c7bb2045f7"},"source":["ease.item_enc.inverse_transform(np.argsort(ease.pred[0]))"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([2845, 1204, 1502, ..., 957, 574, 581], dtype=int16)"]},"metadata":{},"execution_count":32}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"d9fCVWvOV2sK","executionInfo":{"status":"ok","timestamp":1630835012996,"user_tz":-330,"elapsed":785,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"8b52e238-bc6f-47d4-ce62-2e5602f4c4f7"},"source":["np.argsort(-ease.pred[0])"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([ 565, 559, 903, ..., 1420, 1138, 2716])"]},"metadata":{},"execution_count":33}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"j-arIMZHV4e5","executionInfo":{"status":"ok","timestamp":1630835020095,"user_tz":-330,"elapsed":508,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"257943c7-262d-4aee-99fa-a97e53ec23fe"},"source":["ease.pred[0][np.argsort(-ease.pred[0])]"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([ 0.88894744, 0.86783598, 0.76730558, ..., -0.26904345,\n"," -0.29024257, -0.29286189])"]},"metadata":{},"execution_count":34}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"GTdoYucRU30u","executionInfo":{"status":"ok","timestamp":1630835023395,"user_tz":-330,"elapsed":738,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"7a35e7f5-8f17-4200-b8bb-a18c5cf8b4d9"},"source":["np.unique(train[train['userId']==0]['movieId'])"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([ 0, 47, 144, 253, 513, 517, 574, 580, 581, 593, 740,\n"," 858, 877, 957, 963, 964, 970, 1025, 1104, 1117, 1154, 1178,\n"," 1195, 1421, 1439, 1574, 1658, 1727, 1781, 1782, 1838, 1848, 2102,\n"," 2162, 2205, 2488, 2557, 2586, 2592, 2599, 2710, 2889, 2969, 3177],\n"," dtype=int16)"]},"metadata":{},"execution_count":35}]},{"cell_type":"markdown","metadata":{"id":"pV3bdkthZvnn"},"source":["### Evaluation"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":451,"referenced_widgets":["2f741acf5a964291af3d1314627871e7","508e94b448894be08342221ff5ca5515","4de8e9ae4b8e4e4c954bec8acd09d992","c5882e7ec03a4e978341e8406e7ec604","23ef87acab4e46bf9eb6e11870fbb6fc","ee3e841f043d404fa1a8a8409cc3f421","386ab749cd954d36a9c7ecebccf5f2df","eacf302707954850878334ff6a5c1fae","ad83565fcfd34dfc824ab4c16ad6cd7a","f9ef7106318146e89ecbcde0ad71bcd7","6c3ee40cc767411d98efd6d4af1ca2eb"]},"id":"01b3WkGrV9Nk","executionInfo":{"status":"ok","timestamp":1630835080341,"user_tz":-330,"elapsed":40605,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"b6635be8-427b-4694-84bf-90d422c9d30e"},"source":["pred = ease.predict(train, train['userId'].unique(), train['movieId'].unique(), 100)\n","pred"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"2f741acf5a964291af3d1314627871e7","version_minor":0,"version_major":2},"text/plain":[" 0%| | 0/5947 [00:00, ?it/s]"]},"metadata":{}},{"output_type":"execute_result","data":{"text/html":["\n","\n","
\n"," \n"," \n"," | \n"," userId | \n"," movieId | \n"," score | \n","
\n"," \n"," \n"," \n"," | 0 | \n"," 0 | \n"," 354 | \n"," 0.659144 | \n","
\n"," \n"," | 1 | \n"," 0 | \n"," 2511 | \n"," 0.420217 | \n","
\n"," \n"," | 2 | \n"," 0 | \n"," 2058 | \n"," 0.397685 | \n","
\n"," \n"," | 3 | \n"," 0 | \n"," 853 | \n"," 0.382166 | \n","
\n"," \n"," | 4 | \n"," 0 | \n"," 892 | \n"," 0.325232 | \n","
\n"," \n"," | ... | \n"," ... | \n"," ... | \n"," ... | \n","
\n"," \n"," | 594695 | \n"," 6037 | \n"," 1729 | \n"," 0.105471 | \n","
\n"," \n"," | 594696 | \n"," 6037 | \n"," 1978 | \n"," 0.104400 | \n","
\n"," \n"," | 594697 | \n"," 6037 | \n"," 1172 | \n"," 0.104144 | \n","
\n"," \n"," | 594698 | \n"," 6037 | \n"," 27 | \n"," 0.103118 | \n","
\n"," \n"," | 594699 | \n"," 6037 | \n"," 2128 | \n"," 0.102213 | \n","
\n"," \n","
\n","
594700 rows × 3 columns
\n","
"],"text/plain":[" userId movieId score\n","0 0 354 0.659144\n","1 0 2511 0.420217\n","2 0 2058 0.397685\n","3 0 853 0.382166\n","4 0 892 0.325232\n","... ... ... ...\n","594695 6037 1729 0.105471\n","594696 6037 1978 0.104400\n","594697 6037 1172 0.104144\n","594698 6037 27 0.103118\n","594699 6037 2128 0.102213\n","\n","[594700 rows x 3 columns]"]},"metadata":{},"execution_count":36}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":49},"id":"ggRtLpiNV_c2","executionInfo":{"status":"ok","timestamp":1630835091546,"user_tz":-330,"elapsed":588,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"744a2a4c-d13f-4211-eafa-1dea1e4671ee"},"source":["uid = 1\n","df[(df['userId']==uid) & (df['movieId'].isin(pred[pred['userId']==uid]['movieId']))]"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n","\n","
\n"," \n"," \n"," | \n"," userId | \n"," movieId | \n"," rating | \n","
\n"," \n"," \n"," \n","
\n","
"],"text/plain":["Empty DataFrame\n","Columns: [userId, movieId, rating]\n","Index: []"]},"metadata":{},"execution_count":37}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":49},"id":"QPkVz-gnWCQW","executionInfo":{"status":"ok","timestamp":1630835091967,"user_tz":-330,"elapsed":13,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"def7da36-b6ea-4131-a7ac-bbf81ae15a3a"},"source":["train[(train['userId']==uid) & (train['movieId'].isin(pred[pred['userId']==uid]['movieId']))]"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n","\n","
\n"," \n"," \n"," | \n"," userId | \n"," movieId | \n"," rating | \n","
\n"," \n"," \n"," \n","
\n","
"],"text/plain":["Empty DataFrame\n","Columns: [userId, movieId, rating]\n","Index: []"]},"metadata":{},"execution_count":38}]},{"cell_type":"code","metadata":{"id":"UcilNIv2WD_n"},"source":["for uid in range(942):\n"," pdf = df[(df['userId']==uid) & (df['movieId'].isin(pred[pred['userId']==uid]['movieId']))]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"8WvaXEwZWFTu","executionInfo":{"status":"ok","timestamp":1630835114633,"user_tz":-330,"elapsed":726,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"3eb88b98-6150-41d2-fad5-c97043016ef2"},"source":["ease.pred.shape"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(5947, 3532)"]},"metadata":{},"execution_count":42}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"WudrTlyWZvnq","executionInfo":{"status":"ok","timestamp":1630835116746,"user_tz":-330,"elapsed":11,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"528edaa4-2162-45fe-e108-41d0c5de33bb"},"source":["train['userId'].unique().shape, train['movieId'].unique().shape, "],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["((5947,), (3532,))"]},"metadata":{},"execution_count":43}]},{"cell_type":"markdown","metadata":{"id":"dQUI23C-WWWO"},"source":["## MultiVAE"]},{"cell_type":"markdown","metadata":{"id":"hA2o80xvWlah"},"source":["### Load data"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":204},"id":"7NKMTvN_Wlai","executionInfo":{"status":"ok","timestamp":1630835218743,"user_tz":-330,"elapsed":6555,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"a5b11636-ea6a-458e-e482-49b72a12d954"},"source":["df = load_data('./ml-1m/ratings.dat', threshold=3)\n","df.head()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n","\n","
\n"," \n"," \n"," | \n"," userId | \n"," movieId | \n"," rating | \n","
\n"," \n"," \n"," \n"," | 0 | \n"," 0 | \n"," 1104 | \n"," 1 | \n","
\n"," \n"," | 1 | \n"," 0 | \n"," 639 | \n"," 0 | \n","
\n"," \n"," | 2 | \n"," 0 | \n"," 853 | \n"," 0 | \n","
\n"," \n"," | 3 | \n"," 0 | \n"," 3177 | \n"," 1 | \n","
\n"," \n"," | 4 | \n"," 0 | \n"," 2162 | \n"," 1 | \n","
\n"," \n","
\n","
"],"text/plain":[" userId movieId rating\n","0 0 1104 1\n","1 0 639 0\n","2 0 853 0\n","3 0 3177 1\n","4 0 2162 1"]},"metadata":{},"execution_count":44}]},{"cell_type":"markdown","metadata":{"id":"vm3oUHOcWlaj"},"source":["### Preprocessing"]},{"cell_type":"code","metadata":{"id":"nxvgVlESWlaj"},"source":["df = df[df['rating']==1].reset_index(drop=True)\n","tdf = pd.pivot_table(df, index='userId', values='rating', columns='movieId').fillna(0)\n","\n","cnt = tdf.sum(1)\n","df = df[df['userId'].isin(np.where(cnt >= 10)[0])].reset_index(drop=True)\n","tdf = pd.pivot_table(df, index='userId', values='rating', columns='movieId').fillna(0)\n","tdf.iloc[:,:] = 0\n","test_idx = []\n","\n","for i in tdf.index:\n"," test_idx += list(np.random.choice(df[df['userId']==i].index, 1))\n"," \n","train = df.iloc[list(set(df.index)-set(test_idx)),:]\n","test = df.iloc[test_idx, :]\n","\n","for uid, iid in zip(train['userId'].values, train['movieId'].values):\n"," tdf.loc[uid, iid] = 1\n","\n","train = tdf.copy()\n","\n","def sampling(args):\n"," z_mean, z_log_var = args\n"," batch = tf.shape(z_mean)[0]\n"," dim = tf.shape(z_mean)[1]\n"," epsilon = tf.random.normal(shape=(batch, dim), stddev=0.01)\n"," return z_mean + tf.exp(0.5 * z_log_var) * epsilon"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Lr6hS321Wlak","executionInfo":{"status":"ok","timestamp":1630835304957,"user_tz":-330,"elapsed":20,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"7e9aa5fa-d0ee-47c7-c909-daca50c8b412"},"source":["df.shape, train.shape, test.shape"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["((570517, 3), (5948, 3533), (5948, 3))"]},"metadata":{},"execution_count":46}]},{"cell_type":"markdown","metadata":{"id":"Q3XcxcpoWlak"},"source":["### Model architecture"]},{"cell_type":"code","metadata":{"id":"25ZXRGkfXLjd"},"source":["class MultVAE(tf.keras.models.Model):\n"," def __init__(self, input_dim, latent_dim, lamda=1e-4):\n"," super().__init__()\n"," self.input_dim = input_dim\n"," self.latent_dim = latent_dim\n"," self.anneal = 0.\n"," \n"," self.model = self.build()\n","\n"," def compile(self, optimizer, loss_fn=None):\n"," super().compile()\n"," self.optimizer = optimizer\n"," self.loss_fn = loss_fn\n","\n"," \n"," def build(self):\n"," self.encoder = self.build_encoder()\n"," self.decoder = self.build_decoder()\n"," \n"," inputs = self.encoder.input\n"," \n"," mu, log_var = self.encoder(inputs)\n"," h = sampling([mu, log_var])\n"," \n"," outputs = self.decoder(h)\n"," \n"," return Model(inputs, outputs)\n"," \n"," def build_encoder(self):\n"," inputs = Input(shape = (self.input_dim, ))\n"," h = Dropout(0.2)(inputs)\n"," \n"," mu = Dense(self.latent_dim)(h)\n"," log_var = Dense(self.latent_dim)(h)\n"," \n"," return Model(inputs, [mu, log_var])\n"," \n"," def build_decoder(self):\n"," inputs = Input(shape = (self.latent_dim, ))\n"," \n"," outputs = Dense(self.input_dim, activation='sigmoid')(inputs)\n","\n"," return Model(inputs, outputs)\n"," \n"," def train_step(self, data):\n"," x = data\n"," with tf.GradientTape() as tape:\n"," mu, log_var = self.encoder(x)\n"," pred = self.model(x)\n"," \n"," kl_loss = tf.reduce_mean(tf.reduce_sum(0.5*(log_var + tf.exp(log_var) + tf.pow(mu, 2)-1), 1, keepdims=True))\n"," ce_loss = -tf.reduce_mean(tf.reduce_sum(tf.nn.log_softmax(pred) * x, -1))\n"," \n"," loss = ce_loss + kl_loss*self.anneal\n"," \n"," grads = tape.gradient(loss, self.model.trainable_weights)\n"," self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))\n"," \n"," return {'loss': loss}\n"," \n"," def predict(self, data):\n"," mu, log_var = self.encoder(data)\n"," return self.decoder(mu)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"zh7u-CfiXNwo"},"source":["### Training"]},{"cell_type":"code","metadata":{"id":"OMgAqDDAXQ0s"},"source":["loader = tf.data.Dataset.from_tensor_slices(train.values.astype(np.float32))\n","loader = loader.batch(8, drop_remainder=True).shuffle(len(train))\n","\n","model = MultVAE(train.shape[1], 200)\n","model.compile(optimizer=tf.optimizers.Adam())"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"rRtvceiYXXmX"},"source":["class AnnealCallback(callbacks.Callback):\n"," def __init__(self):\n"," super().__init__()\n"," self.anneal_cap = 0.3\n"," \n"," def on_train_batch_end(self, batch, logs=None):\n"," self.model.anneal = min(self.anneal_cap, self.model.anneal+1e-4)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"vPj4AythXUo6","executionInfo":{"status":"ok","timestamp":1630835847202,"user_tz":-330,"elapsed":429687,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"23552135-8011-439d-a61e-22d007fb7f44"},"source":["model.fit(loader, epochs=25, callbacks=[AnnealCallback()])"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch 1/25\n","743/743 [==============================] - 13s 16ms/step - loss: 730.4190\n","Epoch 2/25\n","743/743 [==============================] - 12s 16ms/step - loss: 722.0552\n","Epoch 3/25\n","743/743 [==============================] - 12s 16ms/step - loss: 719.0447\n","Epoch 4/25\n","743/743 [==============================] - 12s 16ms/step - loss: 717.2124\n","Epoch 5/25\n","743/743 [==============================] - 12s 17ms/step - loss: 715.4193\n","Epoch 6/25\n","743/743 [==============================] - 12s 16ms/step - loss: 713.9064\n","Epoch 7/25\n","743/743 [==============================] - 12s 17ms/step - loss: 711.3033\n","Epoch 8/25\n","743/743 [==============================] - 12s 16ms/step - loss: 709.6291\n","Epoch 9/25\n","743/743 [==============================] - 12s 16ms/step - loss: 707.6846\n","Epoch 10/25\n","743/743 [==============================] - 12s 16ms/step - loss: 707.0621\n","Epoch 11/25\n","743/743 [==============================] - 12s 16ms/step - loss: 705.8719\n","Epoch 12/25\n","743/743 [==============================] - 12s 16ms/step - loss: 704.4416\n","Epoch 13/25\n","743/743 [==============================] - 12s 16ms/step - loss: 703.5001\n","Epoch 14/25\n","743/743 [==============================] - 12s 16ms/step - loss: 703.4260\n","Epoch 15/25\n","743/743 [==============================] - 12s 16ms/step - loss: 703.5530\n","Epoch 16/25\n","743/743 [==============================] - 12s 16ms/step - loss: 701.2676\n","Epoch 17/25\n","743/743 [==============================] - 12s 16ms/step - loss: 700.5692\n","Epoch 18/25\n","743/743 [==============================] - 12s 16ms/step - loss: 700.5253\n","Epoch 19/25\n","743/743 [==============================] - 12s 16ms/step - loss: 699.8253\n","Epoch 20/25\n","743/743 [==============================] - 12s 16ms/step - loss: 700.0319\n","Epoch 21/25\n","743/743 [==============================] - 12s 16ms/step - loss: 699.0198\n","Epoch 22/25\n","743/743 [==============================] - 12s 16ms/step - loss: 699.0835\n","Epoch 23/25\n","743/743 [==============================] - 12s 16ms/step - loss: 698.7805\n","Epoch 24/25\n","743/743 [==============================] - 12s 16ms/step - loss: 698.1454\n","Epoch 25/25\n","743/743 [==============================] - 12s 16ms/step - loss: 698.6210\n"]},{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":49}]},{"cell_type":"markdown","metadata":{"id":"OutR1NzaXZ4p"},"source":["### Evaluation"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":66,"referenced_widgets":["356fbd553cd74283bd40f67e43b24e2f","4e71c9f1b0a74ffbae2c04f977efd6a6","2b4e62ee18cf43a59467818ebb76289e","39515edfa1324e2e8b3ded22fd6f1384","a7a16d4486fb4a9ab1ff40b1b61a41a4","e2d3d44b598242ed84c780fab3cea768","0f589d8eceff4716a2029420f1da243c","c0dc1db826314529aadd1884683d4eda","59269807a50b4aecac1472fdece6a0be","ef2b338adc7140a8856263406b28c8d3","388c8dca4a624437a32b04844a53e84d"]},"id":"EGuXlAcpXtCA","executionInfo":{"status":"ok","timestamp":1630835990749,"user_tz":-330,"elapsed":143578,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"9dc168b8-fc8d-46f8-ff4f-3aef4f32cc5b"},"source":["top_k = 10\n","np.random.seed(42)\n","\n","scores = []\n","for idx, i in tqdm(enumerate(np.random.choice(train.index, 100))):\n"," item_to_pred = {item: pred for item, pred in zip(train.columns, model.model.predict(train.values)[idx])}\n"," test_ = test[(test['userId']==i) & (test['rating']==1)]['movieId'].values\n"," items = list(np.random.choice(list(filter(lambda x: x not in np.argwhere(train.values[idx]).flatten(), item_to_pred.keys())), 100)) + list(test_)\n"," top_k_items = heapq.nlargest(top_k, items, key=item_to_pred.get)\n"," \n"," score = eval_NDCG(test_, top_k_items)\n"," scores.append(score)\n"," \n","np.mean(scores)"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"356fbd553cd74283bd40f67e43b24e2f","version_minor":0,"version_major":2},"text/plain":["0it [00:00, ?it/s]"]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":["0.22758533052845625"]},"metadata":{},"execution_count":50}]},{"cell_type":"markdown","metadata":{"id":"gyrB_wB1aXGf"},"source":["## DAE"]},{"cell_type":"markdown","metadata":{"id":"p27z1a9RaXGg"},"source":["### Load data"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":204},"id":"FUCzohpraXGh","executionInfo":{"status":"ok","timestamp":1630833675434,"user_tz":-330,"elapsed":5144,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"28528ff8-0aec-41dd-8355-823359e5f830"},"source":["df = load_data('./ml-1m/ratings.dat', threshold=3)\n","df.head()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n","\n","
\n"," \n"," \n"," | \n"," userId | \n"," movieId | \n"," rating | \n","
\n"," \n"," \n"," \n"," | 0 | \n"," 0 | \n"," 1104 | \n"," 1 | \n","
\n"," \n"," | 1 | \n"," 0 | \n"," 639 | \n"," 0 | \n","
\n"," \n"," | 2 | \n"," 0 | \n"," 853 | \n"," 0 | \n","
\n"," \n"," | 3 | \n"," 0 | \n"," 3177 | \n"," 1 | \n","
\n"," \n"," | 4 | \n"," 0 | \n"," 2162 | \n"," 1 | \n","
\n"," \n","
\n","
"],"text/plain":[" userId movieId rating\n","0 0 1104 1\n","1 0 639 0\n","2 0 853 0\n","3 0 3177 1\n","4 0 2162 1"]},"metadata":{},"execution_count":13}]},{"cell_type":"markdown","metadata":{"id":"VYGBPKuKaXGj"},"source":["### Preprocessing"]},{"cell_type":"code","metadata":{"id":"95emN_MmaXGk"},"source":["df = df[df['rating']==1].reset_index(drop=True)\n","tdf = pd.pivot_table(df, index='userId', values='rating', columns='movieId').fillna(0)\n","\n","cnt = tdf.sum(1)\n","df = df[df['userId'].isin(np.where(cnt >= 10)[0])].reset_index(drop=True)\n","tdf = pd.pivot_table(df, index='userId', values='rating', columns='movieId').fillna(0)\n","tdf.iloc[:,:] = 0\n","\n","test_idx = []\n","for i in tdf.index:\n"," test_idx += list(np.random.choice(df[df['userId']==i].index, 1))\n"," \n","train = df.loc[list(set(df.index)-set(test_idx)),:]\n","test = df.loc[test_idx, :]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":419},"id":"3Qjyrw5SaXGl","executionInfo":{"status":"ok","timestamp":1630833683218,"user_tz":-330,"elapsed":31,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"6286da55-c46b-42d7-de7e-bcfda08b3227"},"source":["df"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n","\n","
\n"," \n"," \n"," | \n"," userId | \n"," movieId | \n"," rating | \n","
\n"," \n"," \n"," \n"," | 0 | \n"," 0 | \n"," 1104 | \n"," 1 | \n","
\n"," \n"," | 1 | \n"," 0 | \n"," 3177 | \n"," 1 | \n","
\n"," \n"," | 2 | \n"," 0 | \n"," 2162 | \n"," 1 | \n","
\n"," \n"," | 3 | \n"," 0 | \n"," 1195 | \n"," 1 | \n","
\n"," \n"," | 4 | \n"," 0 | \n"," 2599 | \n"," 1 | \n","
\n"," \n"," | ... | \n"," ... | \n"," ... | \n"," ... | \n","
\n"," \n"," | 570512 | \n"," 6037 | \n"," 346 | \n"," 1 | \n","
\n"," \n"," | 570513 | \n"," 6037 | \n"," 1120 | \n"," 1 | \n","
\n"," \n"," | 570514 | \n"," 6037 | \n"," 1133 | \n"," 1 | \n","
\n"," \n"," | 570515 | \n"," 6037 | \n"," 1204 | \n"," 1 | \n","
\n"," \n"," | 570516 | \n"," 6037 | \n"," 1007 | \n"," 1 | \n","
\n"," \n","
\n","
570517 rows × 3 columns
\n","
"],"text/plain":[" userId movieId rating\n","0 0 1104 1\n","1 0 3177 1\n","2 0 2162 1\n","3 0 1195 1\n","4 0 2599 1\n","... ... ... ...\n","570512 6037 346 1\n","570513 6037 1120 1\n","570514 6037 1133 1\n","570515 6037 1204 1\n","570516 6037 1007 1\n","\n","[570517 rows x 3 columns]"]},"metadata":{},"execution_count":15}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"tFJKSzzfaXGm","executionInfo":{"status":"ok","timestamp":1630833683220,"user_tz":-330,"elapsed":27,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"bc4ae887-d201-46be-be65-6fab4a39d805"},"source":["df.shape, train.shape, test.shape"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["((570517, 3), (564569, 3), (5948, 3))"]},"metadata":{},"execution_count":16}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":470},"id":"24m_wWovaXGo","executionInfo":{"status":"ok","timestamp":1630833732782,"user_tz":-330,"elapsed":49579,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"0bc3342b-926b-41be-c4f6-19764a138450"},"source":["for uid, iid in zip(train['userId'].values, train['movieId'].values):\n"," tdf.loc[uid, iid] = 1\n","train = tdf.copy()\n","train"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n","\n","
\n"," \n"," \n"," | movieId | \n"," 0 | \n"," 1 | \n"," 2 | \n"," 3 | \n"," 4 | \n"," 5 | \n"," 6 | \n"," 7 | \n"," 8 | \n"," 9 | \n"," 10 | \n"," 11 | \n"," 12 | \n"," 13 | \n"," 14 | \n"," 15 | \n"," 16 | \n"," 17 | \n"," 18 | \n"," 19 | \n"," 20 | \n"," 21 | \n"," 22 | \n"," 23 | \n"," 24 | \n"," 25 | \n"," 26 | \n"," 27 | \n"," 28 | \n"," 29 | \n"," 30 | \n"," 31 | \n"," 32 | \n"," 33 | \n"," 34 | \n"," 35 | \n"," 36 | \n"," 37 | \n"," 38 | \n"," 39 | \n"," ... | \n"," 3666 | \n"," 3667 | \n"," 3668 | \n"," 3669 | \n"," 3670 | \n"," 3671 | \n"," 3672 | \n"," 3673 | \n"," 3674 | \n"," 3675 | \n"," 3676 | \n"," 3677 | \n"," 3678 | \n"," 3679 | \n"," 3680 | \n"," 3681 | \n"," 3682 | \n"," 3683 | \n"," 3684 | \n"," 3685 | \n"," 3686 | \n"," 3687 | \n"," 3688 | \n"," 3689 | \n"," 3690 | \n"," 3691 | \n"," 3692 | \n"," 3693 | \n"," 3694 | \n"," 3695 | \n"," 3696 | \n"," 3697 | \n"," 3698 | \n"," 3699 | \n"," 3700 | \n"," 3701 | \n"," 3702 | \n"," 3703 | \n"," 3704 | \n"," 3705 | \n","
\n"," \n"," | userId | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n"," | \n","
\n"," \n"," \n"," \n"," | 0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 1 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 2 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 3 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 4 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n"," ... | \n","
\n"," \n"," | 6033 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 6034 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 6035 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 6036 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 1.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n"," | 6037 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," ... | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n"," 0.0 | \n","
\n"," \n","
\n","
5948 rows × 3533 columns
\n","
"],"text/plain":["movieId 0 1 2 3 4 5 ... 3700 3701 3702 3703 3704 3705\n","userId ... \n","0 1.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","1 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","2 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","3 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","4 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","... ... ... ... ... ... ... ... ... ... ... ... ... ...\n","6033 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","6034 1.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","6035 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","6036 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","6037 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0\n","\n","[5948 rows x 3533 columns]"]},"metadata":{},"execution_count":17}]},{"cell_type":"markdown","metadata":{"id":"jJQnYRyGaXGp"},"source":["### Model architecture"]},{"cell_type":"code","metadata":{"id":"k0osEM5vaXGq"},"source":["class DAE(tf.keras.models.Model):\n"," def __init__(self, input_dim, latent_dim, lamda=1e-4):\n"," super().__init__()\n"," self.input_dim = input_dim\n"," self.latent_dim = latent_dim\n"," self.lamda = lamda\n"," self.model = self.build()\n"," \n"," def compile(self, optimizer, loss_fn=None):\n"," super().compile()\n"," self.optimizer = optimizer\n"," self.loss_fn = loss_fn\n","\n"," \n"," def build(self):\n"," self.encoder = self.build_encoder()\n"," self.decoder = self.build_decoder()\n"," inputs = self.encoder.input\n"," outputs = self.decoder(self.encoder(inputs))\n"," \n"," return Model(inputs, outputs)\n"," \n"," def build_encoder(self):\n"," inputs = Input(shape = (self.input_dim, ))\n"," \n"," encoder = Sequential()\n"," encoder.add(Dropout(0.2))\n"," encoder.add(Dense(self.latent_dim, activation='tanh'))\n"," \n"," outputs = encoder(inputs)\n"," \n"," return Model(inputs, outputs)\n"," \n"," def build_decoder(self):\n"," inputs = Input(shape = (self.latent_dim, ))\n"," \n"," encoder = Sequential()\n"," encoder.add(Dense(self.input_dim, activation='sigmoid'))\n"," \n"," outputs = encoder(inputs)\n"," \n"," return Model(inputs, outputs)\n"," \n"," def train_step(self, x):\n"," with tf.GradientTape() as tape:\n"," pred = self.model(x)\n"," \n"," rec_loss = tf.losses.binary_crossentropy(x, pred)\n"," loss = rec_loss\n","\n"," grads = tape.gradient(loss, self.model.trainable_weights)\n"," self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))\n"," \n"," return {'loss': loss}"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Nx3cmNKaXGr"},"source":["### Training"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_O0JYFMeaXGr","executionInfo":{"status":"ok","timestamp":1630833850254,"user_tz":-330,"elapsed":113676,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"00b48b7c-859a-478d-f0a9-4349db9ca679"},"source":["loader = tf.data.Dataset.from_tensor_slices(train.values)\n","loader = loader.batch(32, drop_remainder=True).shuffle(len(df))\n","model = DAE(train.shape[1], 200)\n","model.compile(optimizer=tf.optimizers.Adam())\n","model.fit(loader, epochs = 25)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch 1/25\n","185/185 [==============================] - 4s 16ms/step - loss: 0.1585\n","Epoch 2/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.1032\n","Epoch 3/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0940\n","Epoch 4/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0876\n","Epoch 5/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0823\n","Epoch 6/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0785\n","Epoch 7/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0758\n","Epoch 8/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0737\n","Epoch 9/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0719\n","Epoch 10/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0705\n","Epoch 11/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0693\n","Epoch 12/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0684\n","Epoch 13/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0675\n","Epoch 14/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0668\n","Epoch 15/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0661\n","Epoch 16/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0654\n","Epoch 17/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0649\n","Epoch 18/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0640\n","Epoch 19/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0635\n","Epoch 20/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0628\n","Epoch 21/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0621\n","Epoch 22/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0613\n","Epoch 23/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0606\n","Epoch 24/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0601\n","Epoch 25/25\n","185/185 [==============================] - 3s 16ms/step - loss: 0.0595\n"]},{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":19}]},{"cell_type":"markdown","metadata":{"id":"UG7OWzoCaXGt"},"source":["### Evaluation"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":66,"referenced_widgets":["54ec6362c43f4c9fb0da5fae155dc21e","b5bae646ec9141fa858b0f24e55c3919","26275847d59a42ef9a5199401a59e6d5","f4e0639d4d194b15aeeb3e8fcf17b0ee","79e2ba356105419aba51015935b1206c","a569e072c0dc409eb91f27bab7ad4c4f","4276f2e7a5aa4042ac0fea7a4433341a","e0c831a414ab443ab1b2d5053d34ca43","55d8f25f6840463689954ae7061546cf","bb840c41a45e44b78d9dffb837ef2872","ac9ba7893cbd4c44aacd670a5ab2ac1b"]},"id":"zaBPGF_paXGt","executionInfo":{"status":"ok","timestamp":1630836102534,"user_tz":-330,"elapsed":131749,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"a05cf1da-f629-41b6-8a38-755acaf0ad39"},"source":["top_k = 10\n","np.random.seed(42)\n","\n","scores = []\n","for idx, i in tqdm(enumerate(np.random.choice(train.index, 100))):\n"," item_to_pred = {item: pred for item, pred in zip(train.columns, model.model.predict(train.values)[idx])}\n"," test_ = test[(test['userId']==i) & (test['rating']==1)]['movieId'].values\n"," items = list(np.random.choice(list(filter(lambda x: x not in np.argwhere(train.values[idx]).flatten(), item_to_pred.keys())), 100)) + list(test_)\n"," top_k_items = heapq.nlargest(top_k, items, key=item_to_pred.get)\n"," \n"," score = eval_NDCG(test_, top_k_items)\n"," scores.append(score)\n"," \n","np.mean(scores)"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"54ec6362c43f4c9fb0da5fae155dc21e","version_minor":0,"version_major":2},"text/plain":["0it [00:00, ?it/s]"]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":["0.2853871661973964"]},"metadata":{},"execution_count":21}]},{"cell_type":"markdown","metadata":{"id":"3WzMlytjXENu"},"source":["## RecVAE"]},{"cell_type":"markdown","metadata":{"id":"3JGgb4cpX-8S"},"source":["### Load data"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":204},"id":"iJY1toVZX-8S","executionInfo":{"status":"ok","timestamp":1630835582915,"user_tz":-330,"elapsed":6008,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"5d82f1ef-db9e-4ce1-a201-ec5f3ed4b04e"},"source":["df = load_data('./ml-1m/ratings.dat', threshold=3)\n","df.head()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n","\n","
\n"," \n"," \n"," | \n"," userId | \n"," movieId | \n"," rating | \n","
\n"," \n"," \n"," \n"," | 0 | \n"," 0 | \n"," 1104 | \n"," 1 | \n","
\n"," \n"," | 1 | \n"," 0 | \n"," 639 | \n"," 0 | \n","
\n"," \n"," | 2 | \n"," 0 | \n"," 853 | \n"," 0 | \n","
\n"," \n"," | 3 | \n"," 0 | \n"," 3177 | \n"," 1 | \n","
\n"," \n"," | 4 | \n"," 0 | \n"," 2162 | \n"," 1 | \n","
\n"," \n","
\n","
"],"text/plain":[" userId movieId rating\n","0 0 1104 1\n","1 0 639 0\n","2 0 853 0\n","3 0 3177 1\n","4 0 2162 1"]},"metadata":{},"execution_count":22}]},{"cell_type":"markdown","metadata":{"id":"_er5KwLBX-8T"},"source":["### Preprocessing"]},{"cell_type":"code","metadata":{"id":"STkKDPOiX-8U"},"source":["df = df[df['rating']==1].reset_index(drop=True)\n","tdf = pd.pivot_table(df, index='userId', values='rating', columns='movieId').fillna(0)\n","\n","cnt = tdf.sum(1)\n","df = df[df['userId'].isin(np.where(cnt >= 10)[0])].reset_index(drop=True)\n","tdf = pd.pivot_table(df, index='userId', values='rating', columns='movieId').fillna(0)\n","tdf.iloc[:,:] = 0\n","test_idx = []\n","for i in tdf.index:\n"," test_idx += list(np.random.choice(df[df['userId']==i].index, 1))\n"," \n","train = df.iloc[list(set(df.index)-set(test_idx)),:]\n","test = df.iloc[test_idx, :]\n","\n","for uid, iid in zip(train['userId'].values, train['movieId'].values):\n"," tdf.loc[uid, iid] = 1\n","train = tdf.copy().astype(np.float32)\n","\n","loader = tf.data.Dataset.from_tensor_slices(train.values.astype(np.float32))\n","loader = loader.batch(8, drop_remainder=True).shuffle(len(train))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"6ZDpx3i9X-8V"},"source":["### Model architecture"]},{"cell_type":"code","metadata":{"id":"axGJNldkYLEe"},"source":["def log_norm_pdf(x, mu, logvar):\n"," return -0.5*(logvar + tf.math.log(2 * np.pi) + tf.pow((x - mu), 2) / tf.exp(logvar))\n","\n","def sampling(args):\n"," z_mean, z_log_var = args\n"," batch = tf.shape(z_mean)[0]\n"," dim = tf.shape(z_mean)[1]\n"," epsilon = tf.random.normal(shape=(batch, dim), stddev=0.01)\n"," return z_mean + tf.exp(0.5 * z_log_var) * epsilon"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"48AlS5OtX-8W"},"source":["class CompositePrior(tf.keras.models.Model):\n"," def __init__(self, x_dim, latent_dim, mixture_weights = [3/20, 15/20, 2/20]):\n"," super().__init__()\n"," self.encoder_old = Encoder(x_dim, latent_dim, dropout_rate=0)\n"," self.latent_dim = latent_dim\n"," self.mixture_weights = mixture_weights\n"," \n"," self.mu_prior = self.add_weight(shape=(self.latent_dim, ), initializer = tf.zeros_initializer(), trainable=False)\n"," self.logvar_prior = self.add_weight(shape=(self.latent_dim, ), initializer = tf.zeros_initializer(), trainable=False)\n"," self.logvar_unif_prior = self.add_weight(shape=(self.latent_dim, ), initializer = tf.constant_initializer(10), trainable=False)\n"," \n"," def call(self, x, z):\n"," post_mu, post_logvar = self.encoder_old(x)\n"," \n"," stnd_prior = log_norm_pdf(z, self.mu_prior, self.logvar_prior)\n"," post_prior = log_norm_pdf(z, post_mu, post_logvar)\n"," unif_prior = log_norm_pdf(z, self.mu_prior, self.logvar_unif_prior)\n"," \n"," gaussians = [stnd_prior, post_prior, unif_prior]\n"," gaussians = [g+tf.math.log(w) for g, w in zip(gaussians, self.mixture_weights)]\n"," \n"," density = tf.stack(gaussians, -1)\n"," return tf.math.log(tf.reduce_sum(tf.exp(density), -1)) # logsumexp"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"CQG9eIBxYPWl"},"source":["class Encoder(tf.keras.models.Model):\n"," def __init__(self, x_dim, latent_dim, dropout_rate = 0.1):\n"," super().__init__()\n"," self.latent_dim = latent_dim\n"," self.x_dim = x_dim\n"," self.dropout_rate = dropout_rate\n"," self.model = self.build_model()\n"," \n"," def build_model(self): # now just shallow net\n"," x_in = Input(shape=(self.x_dim, ))\n"," \n"," h = Dense(1024, activation='relu')(x_in)\n"," mu = Dense(self.latent_dim)(h)\n"," logvar = Dense(self.latent_dim)(h)\n"," \n"," return Model(x_in, [mu, logvar])\n"," \n"," def call(self, x):\n"," norm = tf.sqrt(tf.reduce_sum(tf.pow(x, 2), -1, keepdims=True))\n"," x = x/norm\n"," if self.dropout_rate>0:\n"," x = Dropout(self.dropout_rate)(x)\n"," \n"," return self.model(x)\n","\n","class RecVAE(tf.keras.models.Model):\n"," def __init__(self, x_dim, latent_dim):\n"," super().__init__()\n"," \n"," self.encoder = Encoder(x_dim, latent_dim)\n"," self.decoder = Dense(x_dim)\n"," self.prior = CompositePrior(x_dim, latent_dim)\n"," \n"," def call(self, data):\n"," mu, logvar = self.encoder(data)\n"," z = sampling([mu, logvar])\n"," recon = self.decoder(z)\n"," \n"," return mu, logvar, z, recon\n"," \n"," def predict(self, data):\n"," mu, logvar = self.encoder(data)\n"," z = sampling([mu, logvar])\n"," recon = self.decoder(z)\n"," \n"," return recon\n"," \n"," def update_prior(self):\n"," self.prior.encoder_old.set_weights(self.encoder.get_weights())"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HdkKDWyXX-8W"},"source":["### Training"]},{"cell_type":"code","metadata":{"id":"peb46Hd9X-8W"},"source":["def tf_train(model, loader, optimizer, target, gamma=1.):\n"," total_loss = 0.\n"," for x in loader:\n"," norm = tf.reduce_sum(x, -1, keepdims=True)\n"," kl_weight = gamma*norm\n"," \n"," with tf.GradientTape() as tape:\n"," mu, logvar, z, pred = model(x)\n"," \n","# kl_loss = tf.reduce_mean(tf.reduce_sum(0.5*(logvar + tf.exp(logvar) + tf.pow(mu, 2)-1), 1, keepdims=True))\n"," kl_loss = tf.reduce_mean(log_norm_pdf(z, mu, logvar) - tf.multiply(model.prior(x, z), kl_weight))\n"," ce_loss = -tf.reduce_mean(tf.reduce_sum(tf.nn.log_softmax(pred) * x, -1))\n"," \n"," loss = ce_loss + kl_loss*kl_weight\n"," \n"," if target == 'encoder':\n"," grads = tape.gradient(loss, model.encoder.trainable_weights)\n"," optimizer.apply_gradients(zip(grads, model.encoder.trainable_weights))\n"," else:\n"," grads = tape.gradient(loss, model.decoder.trainable_weights)\n"," optimizer.apply_gradients(zip(grads, model.decoder.trainable_weights))\n"," \n"," total_loss += tf.reduce_sum(loss)\n"," return total_loss"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"FzTZPD9CYfpj"},"source":["epochs = 25\n","\n","model = RecVAE(train.shape[1], 200)\n","enc_opt = optimizers.Adam()\n","dec_opt = optimizers.Adam()\n","\n","for e in range(epochs):\n"," # alternating \n"," ## train step\n"," tf_train(model, loader, enc_opt, 'encoder')\n"," model.update_prior()\n"," tf_train(model, loader, dec_opt, 'decoder')\n"," ## eval step"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bhZuJedDX-8X"},"source":["### Evaluation"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":66,"referenced_widgets":["87c9bc69718d40c0acd55be8b3d028c3","592a360bb8f74fd690223f2e4bb14f0e","291c4790efeb43b68b78eaef6a99ced7","8d0ed0b8a3e94733aad8d23ae0265d3c","4e554fb4f30a48faa6859d457d08ba1e","32727136c19e46558e6016eea4fa6fec","fa3e40aea0f14e4db3be250b51c8ede0","7a486090a89343168b6c82943865733c","57b0ddec4be6490b8da427622de4ebac","2d1191aacad24eba93d354e26e6ee37b","9940ceebfb384b6da8e6d5deffcce3a4"]},"id":"KgYQGSPBYUCb","executionInfo":{"status":"ok","timestamp":1630838390198,"user_tz":-330,"elapsed":184666,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"}},"outputId":"b760dcd0-9831-49d3-c3ee-3ab95dc24eda"},"source":["top_k = 10\n","np.random.seed(42)\n","\n","scores = []\n","for idx, i in tqdm(enumerate(np.random.choice(train.index, 100))):\n"," item_to_pred = {item: pred.numpy() for item, pred in zip(train.columns, model.predict(train.values)[idx])}\n"," test_ = test[(test['userId']==i) & (test['rating']==1)]['movieId'].values\n"," items = list(np.random.choice(list(filter(lambda x: x not in np.argwhere(train.values[idx]).flatten(), item_to_pred.keys())), 100)) + list(test_)\n"," top_k_items = heapq.nlargest(top_k, items, key=item_to_pred.get)\n"," \n"," score = eval_NDCG(test_, top_k_items)\n"," scores.append(score)\n","# break\n","np.mean(scores)"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"87c9bc69718d40c0acd55be8b3d028c3","version_minor":0,"version_major":2},"text/plain":["0it [00:00, ?it/s]"]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":["0.0031546487678572877"]},"metadata":{},"execution_count":29}]},{"cell_type":"markdown","source":["---"],"metadata":{"id":"RAaqLy1UtcIC"}},{"cell_type":"code","source":["!pip install -q watermark\n","%reload_ext watermark\n","%watermark -a \"Sparsh A.\" -m -iv -u -t -d"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Jit1oP3jtd7k","executionInfo":{"status":"ok","timestamp":1639716362410,"user_tz":-330,"elapsed":4112,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"8470ae76-a435-4a4f-f909-351ed5e37fed"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Author: Sparsh A.\n","\n","Last updated: 2021-12-17 04:46:00\n","\n","Compiler : GCC 7.5.0\n","OS : Linux\n","Release : 5.4.104+\n","Machine : x86_64\n","Processor : x86_64\n","CPU cores : 2\n","Architecture: 64bit\n","\n","IPython : 5.5.0\n","matplotlib: 3.2.2\n","sys : 3.7.12 (default, Sep 10 2021, 00:21:48) \n","[GCC 7.5.0]\n","tensorflow: 2.7.0\n","pandas : 1.1.5\n","numpy : 1.19.5\n","seaborn : 0.11.2\n","keras : 2.7.0\n","google : 2.0.3\n","\n"]}]},{"cell_type":"markdown","source":["---"],"metadata":{"id":"qrYL9Jx-tcIF"}},{"cell_type":"markdown","source":["**END**"],"metadata":{"id":"pZR6MBOZtcIG"}}]}