{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-25-bpr-recbole.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T331379%20%7C%20BPR%20in%20PyTorch%20Referencing%20RecBole%20Library.ipynb","timestamp":1644671130956}],"collapsed_sections":[],"toc_visible":true,"authorship_tag":"ABX9TyNMFqprV1o4i6VbRMx1On+c"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"ql1c695Ag6pO"},"source":["# BPR in PyTorch Referencing RecBole Library"]},{"cell_type":"code","metadata":{"id":"ZlQXlVMjhXkP"},"source":["import torch\n","import torch.nn as nn\n","from torch.nn.init import xavier_normal_, constant_\n","\n","from enum import Enum"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"1cM8Zu3fihUu"},"source":["def set_color(log, color, highlight=True):\n"," color_set = ['black', 'red', 'green', 'yellow', 'blue', 'pink', 'cyan', 'white']\n"," try:\n"," index = color_set.index(color)\n"," except:\n"," index = len(color_set) - 1\n"," prev_log = '\\033['\n"," if highlight:\n"," prev_log += '1;3'\n"," else:\n"," prev_log += '0;3'\n"," prev_log += str(index) + 'm'\n"," return prev_log + log + '\\033[0m'"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"RXT1qGUZiNmK"},"source":["class ModelType(Enum):\n"," \"\"\"Type of models.\n"," - ``GENERAL``: General Recommendation\n"," - ``SEQUENTIAL``: Sequential Recommendation\n"," - ``CONTEXT``: Context-aware Recommendation\n"," - ``KNOWLEDGE``: Knowledge-based Recommendation\n"," \"\"\"\n","\n"," GENERAL = 1\n"," SEQUENTIAL = 2\n"," CONTEXT = 3\n"," KNOWLEDGE = 4\n"," TRADITIONAL = 5\n"," DECISIONTREE = 6\n","\n","\n","class KGDataLoaderState(Enum):\n"," \"\"\"States for Knowledge-based DataLoader.\n"," - ``RSKG``: Return both knowledge graph information and user-item interaction information.\n"," - ``RS``: Only return the user-item interaction.\n"," - ``KG``: Only return the triplets with negative examples in a knowledge graph.\n"," \"\"\"\n","\n"," RSKG = 1\n"," RS = 2\n"," KG = 3\n","\n","\n","class EvaluatorType(Enum):\n"," \"\"\"Type for evaluation metrics.\n"," - ``RANKING``: Ranking-based metrics like NDCG, Recall, etc.\n"," - ``VALUE``: Value-based metrics like AUC, etc.\n"," \"\"\"\n","\n"," RANKING = 1\n"," VALUE = 2\n","\n","\n","class InputType(Enum):\n"," \"\"\"Type of Models' input.\n"," - ``POINTWISE``: Point-wise input, like ``uid, iid, label``.\n"," - ``PAIRWISE``: Pair-wise input, like ``uid, pos_iid, neg_iid``.\n"," \"\"\"\n","\n"," POINTWISE = 1\n"," PAIRWISE = 2\n"," LISTWISE = 3\n","\n","\n","class FeatureType(Enum):\n"," \"\"\"Type of features.\n"," - ``TOKEN``: Token features like user_id and item_id.\n"," - ``FLOAT``: Float features like rating and timestamp.\n"," - ``TOKEN_SEQ``: Token sequence features like review.\n"," - ``FLOAT_SEQ``: Float sequence features like pretrained vector.\n"," \"\"\"\n","\n"," TOKEN = 'token'\n"," FLOAT = 'float'\n"," TOKEN_SEQ = 'token_seq'\n"," FLOAT_SEQ = 'float_seq'\n","\n","\n","class FeatureSource(Enum):\n"," \"\"\"Source of features.\n"," - ``INTERACTION``: Features from ``.inter`` (other than ``user_id`` and ``item_id``).\n"," - ``USER``: Features from ``.user`` (other than ``user_id``).\n"," - ``ITEM``: Features from ``.item`` (other than ``item_id``).\n"," - ``USER_ID``: ``user_id`` feature in ``inter_feat`` and ``user_feat``.\n"," - ``ITEM_ID``: ``item_id`` feature in ``inter_feat`` and ``item_feat``.\n"," - ``KG``: Features from ``.kg``.\n"," - ``NET``: Features from ``.net``.\n"," \"\"\"\n","\n"," INTERACTION = 'inter'\n"," USER = 'user'\n"," ITEM = 'item'\n"," USER_ID = 'user_id'\n"," ITEM_ID = 'item_id'\n"," KG = 'kg'\n"," NET = 'net'"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xmI3-anMhyEC"},"source":["class AbstractRecommender(nn.Module):\n"," r\"\"\"Base class for all models\n"," \"\"\"\n","\n"," def __init__(self):\n"," self.logger = getLogger()\n"," super(AbstractRecommender, self).__init__()\n","\n"," def calculate_loss(self, interaction):\n"," r\"\"\"Calculate the training loss for a batch data.\n"," Args:\n"," interaction (Interaction): Interaction class of the batch.\n"," Returns:\n"," torch.Tensor: Training loss, shape: []\n"," \"\"\"\n"," raise NotImplementedError\n","\n"," def predict(self, interaction):\n"," r\"\"\"Predict the scores between users and items.\n"," Args:\n"," interaction (Interaction): Interaction class of the batch.\n"," Returns:\n"," torch.Tensor: Predicted scores for given users and items, shape: [batch_size]\n"," \"\"\"\n"," raise NotImplementedError\n","\n"," def full_sort_predict(self, interaction):\n"," r\"\"\"full sort prediction function.\n"," Given users, calculate the scores between users and all candidate items.\n"," Args:\n"," interaction (Interaction): Interaction class of the batch.\n"," Returns:\n"," torch.Tensor: Predicted scores for given users and all candidate items,\n"," shape: [n_batch_users * n_candidate_items]\n"," \"\"\"\n"," raise NotImplementedError\n","\n"," def other_parameter(self):\n"," if hasattr(self, 'other_parameter_name'):\n"," return {key: getattr(self, key) for key in self.other_parameter_name}\n"," return dict()\n","\n"," def load_other_parameter(self, para):\n"," if para is None:\n"," return\n"," for key, value in para.items():\n"," setattr(self, key, value)\n","\n"," def __str__(self):\n"," \"\"\"\n"," Model prints with number of trainable parameters\n"," \"\"\"\n"," model_parameters = filter(lambda p: p.requires_grad, self.parameters())\n"," params = sum([np.prod(p.size()) for p in model_parameters])\n"," return super().__str__() + set_color('\\nTrainable parameters', 'blue') + f': {params}'"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"EX8z1eC9h_o8"},"source":["class GeneralRecommender(AbstractRecommender):\n"," \"\"\"This is a abstract general recommender. All the general model should implement this class.\n"," The base general recommender class provide the basic dataset and parameters information.\n"," \"\"\"\n"," type = ModelType.GENERAL\n","\n"," def __init__(self, config, dataset):\n"," super(GeneralRecommender, self).__init__()\n","\n"," # load dataset info\n"," self.USER_ID = config['USER_ID_FIELD']\n"," self.ITEM_ID = config['ITEM_ID_FIELD']\n"," self.NEG_ITEM_ID = config['NEG_PREFIX'] + self.ITEM_ID\n"," self.n_users = dataset.num(self.USER_ID)\n"," self.n_items = dataset.num(self.ITEM_ID)\n","\n"," # load parameters info\n"," self.device = config['device']"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ew6eY5uciqU-"},"source":["def xavier_normal_initialization(module):\n"," r\"\"\" using `xavier_normal_`_ in PyTorch to initialize the parameters in\n"," nn.Embedding and nn.Linear layers. For bias in nn.Linear layers,\n"," using constant 0 to initialize.\n"," .. _`xavier_normal_`:\n"," https://pytorch.org/docs/stable/nn.init.html?highlight=xavier_normal_#torch.nn.init.xavier_normal_\n"," Examples:\n"," >>> self.apply(xavier_normal_initialization)\n"," \"\"\"\n"," if isinstance(module, nn.Embedding):\n"," xavier_normal_(module.weight.data)\n"," elif isinstance(module, nn.Linear):\n"," xavier_normal_(module.weight.data)\n"," if module.bias is not None:\n"," constant_(module.bias.data, 0)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"S3pg3ICSi4Dr"},"source":["class BPRLoss(nn.Module):\n"," \"\"\" BPRLoss, based on Bayesian Personalized Ranking\n"," Args:\n"," - gamma(float): Small value to avoid division by zero\n"," Shape:\n"," - Pos_score: (N)\n"," - Neg_score: (N), same shape as the Pos_score\n"," - Output: scalar.\n"," Examples::\n"," >>> loss = BPRLoss()\n"," >>> pos_score = torch.randn(3, requires_grad=True)\n"," >>> neg_score = torch.randn(3, requires_grad=True)\n"," >>> output = loss(pos_score, neg_score)\n"," >>> output.backward()\n"," \"\"\"\n","\n"," def __init__(self, gamma=1e-10):\n"," super(BPRLoss, self).__init__()\n"," self.gamma = gamma\n","\n"," def forward(self, pos_score, neg_score):\n"," loss = -torch.log(self.gamma + torch.sigmoid(pos_score - neg_score)).mean()\n"," return loss"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"qp7Fd64bi8HL"},"source":["class BPR(GeneralRecommender):\n"," r\"\"\"BPR is a basic matrix factorization model that be trained in the pairwise way.\n","\n"," \"\"\"\n"," input_type = InputType.PAIRWISE\n","\n"," def __init__(self, config, dataset):\n"," super(BPR, self).__init__(config, dataset)\n","\n"," # load parameters info\n"," self.embedding_size = config['embedding_size']\n","\n"," # define layers and loss\n"," self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)\n"," self.item_embedding = nn.Embedding(self.n_items, self.embedding_size)\n"," self.loss = BPRLoss()\n","\n"," # parameters initialization\n"," self.apply(xavier_normal_initialization)\n","\n"," def get_user_embedding(self, user):\n"," r\"\"\" Get a batch of user embedding tensor according to input user's id.\n","\n"," Args:\n"," user (torch.LongTensor): The input tensor that contains user's id, shape: [batch_size, ]\n","\n"," Returns:\n"," torch.FloatTensor: The embedding tensor of a batch of user, shape: [batch_size, embedding_size]\n"," \"\"\"\n"," return self.user_embedding(user)\n","\n"," def get_item_embedding(self, item):\n"," r\"\"\" Get a batch of item embedding tensor according to input item's id.\n","\n"," Args:\n"," item (torch.LongTensor): The input tensor that contains item's id, shape: [batch_size, ]\n","\n"," Returns:\n"," torch.FloatTensor: The embedding tensor of a batch of item, shape: [batch_size, embedding_size]\n"," \"\"\"\n"," return self.item_embedding(item)\n","\n"," def forward(self, user, item):\n"," user_e = self.get_user_embedding(user)\n"," item_e = self.get_item_embedding(item)\n"," return user_e, item_e\n","\n"," def calculate_loss(self, interaction):\n"," user = interaction[self.USER_ID]\n"," pos_item = interaction[self.ITEM_ID]\n"," neg_item = interaction[self.NEG_ITEM_ID]\n","\n"," user_e, pos_e = self.forward(user, pos_item)\n"," neg_e = self.get_item_embedding(neg_item)\n"," pos_item_score, neg_item_score = torch.mul(user_e, pos_e).sum(dim=1), torch.mul(user_e, neg_e).sum(dim=1)\n"," loss = self.loss(pos_item_score, neg_item_score)\n"," return loss\n","\n"," def predict(self, interaction):\n"," user = interaction[self.USER_ID]\n"," item = interaction[self.ITEM_ID]\n"," user_e, item_e = self.forward(user, item)\n"," return torch.mul(user_e, item_e).sum(dim=1)\n","\n"," def full_sort_predict(self, interaction):\n"," user = interaction[self.USER_ID]\n"," user_e = self.get_user_embedding(user)\n"," all_item_e = self.item_embedding.weight\n"," score = torch.matmul(user_e, all_item_e.transpose(0, 1))\n"," return score.view(-1)"],"execution_count":null,"outputs":[]}]}