{"cells":[{"cell_type":"markdown","metadata":{},"source":["# Neural Graph Collaborative Filtering"]},{"cell_type":"markdown","metadata":{"id":"hxeshW_Gpoys"},"source":["### Model architecture"]},{"cell_type":"code","execution_count":1,"metadata":{"executionInfo":{"elapsed":4237,"status":"ok","timestamp":1630672865987,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"},"user_tz":-330},"id":"nIGXj9Yo6WnR"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","\n","class NGCF(nn.Module):\n"," def __init__(self, n_user, n_item, norm_adj, args):\n"," super(NGCF, self).__init__()\n"," self.n_user = n_user\n"," self.n_item = n_item\n"," self.device = args.device\n"," self.emb_size = args.embed_size\n"," self.batch_size = args.batch_size\n"," self.node_dropout = args.node_dropout[0]\n"," self.mess_dropout = args.mess_dropout\n"," self.batch_size = args.batch_size\n","\n"," self.norm_adj = norm_adj\n","\n"," self.layers = eval(args.layer_size)\n"," self.decay = eval(args.regs)[0]\n","\n"," \"\"\"\n"," *********************************************************\n"," Init the weight of user-item.\n"," \"\"\"\n"," self.embedding_dict, self.weight_dict = self.init_weight()\n","\n"," \"\"\"\n"," *********************************************************\n"," Get sparse adj.\n"," \"\"\"\n"," self.sparse_norm_adj = self._convert_sp_mat_to_sp_tensor(self.norm_adj).to(self.device)\n","\n"," def init_weight(self):\n"," # xavier init\n"," initializer = nn.init.xavier_uniform_\n","\n"," embedding_dict = nn.ParameterDict({\n"," 'user_emb': nn.Parameter(initializer(torch.empty(self.n_user,\n"," self.emb_size))),\n"," 'item_emb': nn.Parameter(initializer(torch.empty(self.n_item,\n"," self.emb_size)))\n"," })\n","\n"," weight_dict = nn.ParameterDict()\n"," layers = [self.emb_size] + self.layers\n"," for k in range(len(self.layers)):\n"," weight_dict.update({'W_gc_%d'%k: nn.Parameter(initializer(torch.empty(layers[k],\n"," layers[k+1])))})\n"," weight_dict.update({'b_gc_%d'%k: nn.Parameter(initializer(torch.empty(1, layers[k+1])))})\n","\n"," weight_dict.update({'W_bi_%d'%k: nn.Parameter(initializer(torch.empty(layers[k],\n"," layers[k+1])))})\n"," weight_dict.update({'b_bi_%d'%k: nn.Parameter(initializer(torch.empty(1, layers[k+1])))})\n","\n"," return embedding_dict, weight_dict\n","\n"," def _convert_sp_mat_to_sp_tensor(self, X):\n"," coo = X.tocoo()\n"," i = torch.LongTensor([coo.row, coo.col])\n"," v = torch.from_numpy(coo.data).float()\n"," return torch.sparse.FloatTensor(i, v, coo.shape)\n","\n"," def sparse_dropout(self, x, rate, noise_shape):\n"," random_tensor = 1 - rate\n"," random_tensor += torch.rand(noise_shape).to(x.device)\n"," dropout_mask = torch.floor(random_tensor).type(torch.bool)\n"," i = x._indices()\n"," v = x._values()\n","\n"," i = i[:, dropout_mask]\n"," v = v[dropout_mask]\n","\n"," out = torch.sparse.FloatTensor(i, v, x.shape).to(x.device)\n"," return out * (1. / (1 - rate))\n","\n"," def create_bpr_loss(self, users, pos_items, neg_items):\n"," pos_scores = torch.sum(torch.mul(users, pos_items), axis=1)\n"," neg_scores = torch.sum(torch.mul(users, neg_items), axis=1)\n","\n"," maxi = nn.LogSigmoid()(pos_scores - neg_scores)\n","\n"," mf_loss = -1 * torch.mean(maxi)\n","\n"," # cul regularizer\n"," regularizer = (torch.norm(users) ** 2\n"," + torch.norm(pos_items) ** 2\n"," + torch.norm(neg_items) ** 2) / 2\n"," emb_loss = self.decay * regularizer / self.batch_size\n","\n"," return mf_loss + emb_loss, mf_loss, emb_loss\n","\n"," def rating(self, u_g_embeddings, pos_i_g_embeddings):\n"," return torch.matmul(u_g_embeddings, pos_i_g_embeddings.t())\n","\n"," def forward(self, users, pos_items, neg_items, drop_flag=True):\n","\n"," A_hat = self.sparse_dropout(self.sparse_norm_adj,\n"," self.node_dropout,\n"," self.sparse_norm_adj._nnz()) if drop_flag else self.sparse_norm_adj\n","\n"," ego_embeddings = torch.cat([self.embedding_dict['user_emb'],\n"," self.embedding_dict['item_emb']], 0)\n","\n"," all_embeddings = [ego_embeddings]\n","\n"," for k in range(len(self.layers)):\n"," side_embeddings = torch.sparse.mm(A_hat, ego_embeddings)\n","\n"," # transformed sum messages of neighbors.\n"," sum_embeddings = torch.matmul(side_embeddings, self.weight_dict['W_gc_%d' % k]) \\\n"," + self.weight_dict['b_gc_%d' % k]\n","\n"," # bi messages of neighbors.\n"," # element-wise product\n"," bi_embeddings = torch.mul(ego_embeddings, side_embeddings)\n"," # transformed bi messages of neighbors.\n"," bi_embeddings = torch.matmul(bi_embeddings, self.weight_dict['W_bi_%d' % k]) \\\n"," + self.weight_dict['b_bi_%d' % k]\n","\n"," # non-linear activation.\n"," ego_embeddings = nn.LeakyReLU(negative_slope=0.2)(sum_embeddings + bi_embeddings)\n","\n"," # message dropout.\n"," ego_embeddings = nn.Dropout(self.mess_dropout[k])(ego_embeddings)\n","\n"," # normalize the distribution of embeddings.\n"," norm_embeddings = F.normalize(ego_embeddings, p=2, dim=1)\n","\n"," all_embeddings += [norm_embeddings]\n","\n"," all_embeddings = torch.cat(all_embeddings, 1)\n"," u_g_embeddings = all_embeddings[:self.n_user, :]\n"," i_g_embeddings = all_embeddings[self.n_user:, :]\n","\n"," \"\"\"\n"," *********************************************************\n"," look up.\n"," \"\"\"\n"," u_g_embeddings = u_g_embeddings[users, :]\n"," pos_i_g_embeddings = i_g_embeddings[pos_items, :]\n"," neg_i_g_embeddings = i_g_embeddings[neg_items, :]\n","\n"," return u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings"]},{"cell_type":"markdown","metadata":{"id":"tRkhB2bBqm8H"},"source":["### Downloading data"]},{"cell_type":"code","execution_count":2,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":4196,"status":"ok","timestamp":1630672870177,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"},"user_tz":-330},"id":"GLACW6U1qrB6","outputId":"158b550d-f3ac-4086-cbc3-d9886b36782b"},"outputs":[{"name":"stdout","output_type":"stream","text":["/content/gowalla\n","train.txt 100%[===================>] 4.42M --.-KB/s in 0.07s \n","test.txt 100%[===================>] 1.31M --.-KB/s in 0.05s \n","user_list.txt 100%[===================>] 342.97K --.-KB/s in 0.03s \n","item_list.txt 100%[===================>] 495.04K --.-KB/s in 0.04s \n","/content\n"]}],"source":["!mkdir gowalla\n","%cd gowalla\n","!wget -q --show-progress https://github.com/huangtinglin/NGCF-PyTorch/raw/master/Data/gowalla/train.txt\n","!wget -q --show-progress https://github.com/huangtinglin/NGCF-PyTorch/raw/master/Data/gowalla/test.txt\n","!wget -q --show-progress https://github.com/huangtinglin/NGCF-PyTorch/raw/master/Data/gowalla/user_list.txt\n","!wget -q --show-progress https://github.com/huangtinglin/NGCF-PyTorch/raw/master/Data/gowalla/item_list.txt\n","%cd .."]},{"cell_type":"markdown","metadata":{"id":"AVbNP-c8on4L"},"source":["### Argument parser"]},{"cell_type":"code","execution_count":3,"metadata":{"executionInfo":{"elapsed":15,"status":"ok","timestamp":1630672870601,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"},"user_tz":-330},"id":"-DxGqaaannPI"},"outputs":[],"source":["import argparse\n","\n","def parse_args():\n"," parser = argparse.ArgumentParser(description=\"Run NGCF.\")\n"," parser.add_argument('--weights_path', nargs='?', default='./',\n"," help='Store model path.')\n"," parser.add_argument('--data_path', nargs='?', default='./',\n"," help='Input data path.')\n"," parser.add_argument('--proj_path', nargs='?', default='./',\n"," help='Project path.')\n","\n"," parser.add_argument('--dataset', nargs='?', default='gowalla',\n"," help='Choose a dataset from {gowalla, yelp2018, amazon-book}')\n"," parser.add_argument('--pretrain', type=int, default=0,\n"," help='0: No pretrain, -1: Pretrain with the learned embeddings, 1:Pretrain with stored models.')\n"," parser.add_argument('--verbose', type=int, default=1,\n"," help='Interval of evaluation.')\n"," parser.add_argument('--epoch', type=int, default=400,\n"," help='Number of epoch.')\n","\n"," parser.add_argument('--embed_size', type=int, default=64,\n"," help='Embedding size.')\n"," parser.add_argument('--layer_size', nargs='?', default='[64,64,64]',\n"," help='Output sizes of every layer')\n"," parser.add_argument('--batch_size', type=int, default=1024,\n"," help='Batch size.')\n","\n"," parser.add_argument('--regs', nargs='?', default='[1e-5]',\n"," help='Regularizations.')\n"," parser.add_argument('--lr', type=float, default=0.0001,\n"," help='Learning rate.')\n","\n"," parser.add_argument('--model_type', nargs='?', default='ngcf',\n"," help='Specify the name of model (ngcf).')\n"," parser.add_argument('--adj_type', nargs='?', default='norm',\n"," help='Specify the type of the adjacency (laplacian) matrix from {plain, norm, mean}.')\n","\n"," parser.add_argument('--gpu_id', type=int, default=0)\n","\n"," parser.add_argument('--node_dropout_flag', type=int, default=1,\n"," help='0: Disable node dropout, 1: Activate node dropout')\n"," parser.add_argument('--node_dropout', nargs='?', default='[0.1]',\n"," help='Keep probability w.r.t. node dropout (i.e., 1-dropout_ratio) for each deep layer. 1: no dropout.')\n"," parser.add_argument('--mess_dropout', nargs='?', default='[0.1,0.1,0.1]',\n"," help='Keep probability w.r.t. message dropout (i.e., 1-dropout_ratio) for each deep layer. 1: no dropout.')\n","\n"," parser.add_argument('--Ks', nargs='?', default='[20, 40, 60, 80, 100]',\n"," help='Output sizes of every layer')\n","\n"," parser.add_argument('--save_flag', type=int, default=0,\n"," help='0: Disable model saver, 1: Activate model saver')\n","\n"," parser.add_argument('--test_flag', nargs='?', default='part',\n"," help='Specify the test type from {part, full}, indicating whether the reference is done in mini-batch')\n","\n"," parser.add_argument('--report', type=int, default=0,\n"," help='0: Disable performance report w.r.t. sparsity levels, 1: Show performance report w.r.t. sparsity levels')\n"," return parser.parse_args(args={})"]},{"cell_type":"markdown","metadata":{"id":"GzCm5gf8oqJH"},"source":["### Metrics"]},{"cell_type":"code","execution_count":4,"metadata":{"executionInfo":{"elapsed":1099,"status":"ok","timestamp":1630672871688,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"},"user_tz":-330},"id":"IN1q9gZnowJy"},"outputs":[],"source":["import numpy as np\n","from sklearn.metrics import roc_auc_score\n","\n","\n","def recall(rank, ground_truth, N):\n"," return len(set(rank[:N]) & set(ground_truth)) / float(len(set(ground_truth)))\n","\n","\n","def precision_at_k(r, k):\n"," \"\"\"Score is precision @ k\n"," Relevance is binary (nonzero is relevant).\n"," Returns:\n"," Precision @ k\n"," Raises:\n"," ValueError: len(r) must be >= k\n"," \"\"\"\n"," assert k >= 1\n"," r = np.asarray(r)[:k]\n"," return np.mean(r)\n","\n","\n","def average_precision(r,cut):\n"," \"\"\"Score is average precision (area under PR curve)\n"," Relevance is binary (nonzero is relevant).\n"," Returns:\n"," Average precision\n"," \"\"\"\n"," r = np.asarray(r)\n"," out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]]\n"," if not out:\n"," return 0.\n"," return np.sum(out)/float(min(cut, np.sum(r)))\n","\n","\n","def mean_average_precision(rs):\n"," \"\"\"Score is mean average precision\n"," Relevance is binary (nonzero is relevant).\n"," Returns:\n"," Mean average precision\n"," \"\"\"\n"," return np.mean([average_precision(r) for r in rs])\n","\n","\n","def dcg_at_k(r, k, method=1):\n"," \"\"\"Score is discounted cumulative gain (dcg)\n"," Relevance is positive real values. Can use binary\n"," as the previous methods.\n"," Returns:\n"," Discounted cumulative gain\n"," \"\"\"\n"," r = np.asfarray(r)[:k]\n"," if r.size:\n"," if method == 0:\n"," return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))\n"," elif method == 1:\n"," return np.sum(r / np.log2(np.arange(2, r.size + 2)))\n"," else:\n"," raise ValueError('method must be 0 or 1.')\n"," return 0.\n","\n","\n","def ndcg_at_k(r, k, ground_truth, method=1):\n"," \"\"\"Score is normalized discounted cumulative gain (ndcg)\n"," Relevance is positive real values. Can use binary\n"," as the previous methods.\n"," Returns:\n"," Normalized discounted cumulative gain\n"," Low but correct defination\n"," \"\"\"\n"," GT = set(ground_truth)\n"," if len(GT) > k :\n"," sent_list = [1.0] * k\n"," else:\n"," sent_list = [1.0]*len(GT) + [0.0]*(k-len(GT))\n"," dcg_max = dcg_at_k(sent_list, k, method)\n"," if not dcg_max:\n"," return 0.\n"," return dcg_at_k(r, k, method) / dcg_max\n","\n","\n","def recall_at_k(r, k, all_pos_num):\n"," # if all_pos_num == 0:\n"," # return 0\n"," r = np.asfarray(r)[:k]\n"," return np.sum(r) / all_pos_num\n","\n","\n","def hit_at_k(r, k):\n"," r = np.array(r)[:k]\n"," if np.sum(r) > 0:\n"," return 1.\n"," else:\n"," return 0.\n","\n","def F1(pre, rec):\n"," if pre + rec > 0:\n"," return (2.0 * pre * rec) / (pre + rec)\n"," else:\n"," return 0.\n","\n","def AUC(ground_truth, prediction):\n"," try:\n"," res = roc_auc_score(y_true=ground_truth, y_score=prediction)\n"," except Exception:\n"," res = 0.\n"," return res"]},{"cell_type":"markdown","metadata":{"id":"BjElFa0do89S"},"source":["### Data loader"]},{"cell_type":"code","execution_count":5,"metadata":{"executionInfo":{"elapsed":747,"status":"ok","timestamp":1630672872428,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"},"user_tz":-330},"id":"9vSns4aao-jG"},"outputs":[],"source":["import numpy as np\n","import random as rd\n","import scipy.sparse as sp\n","from time import time\n","\n","class Data(object):\n"," def __init__(self, path, batch_size):\n"," self.path = path\n"," self.batch_size = batch_size\n","\n"," train_file = path + '/train.txt'\n"," test_file = path + '/test.txt'\n","\n"," #get number of users and items\n"," self.n_users, self.n_items = 0, 0\n"," self.n_train, self.n_test = 0, 0\n"," self.neg_pools = {}\n","\n"," self.exist_users = []\n","\n"," with open(train_file) as f:\n"," for l in f.readlines():\n"," if len(l) > 0:\n"," l = l.strip('\\n').split(' ')\n"," items = [int(i) for i in l[1:]]\n"," uid = int(l[0])\n"," self.exist_users.append(uid)\n"," self.n_items = max(self.n_items, max(items))\n"," self.n_users = max(self.n_users, uid)\n"," self.n_train += len(items)\n","\n"," with open(test_file) as f:\n"," for l in f.readlines():\n"," if len(l) > 0:\n"," l = l.strip('\\n')\n"," try:\n"," items = [int(i) for i in l.split(' ')[1:]]\n"," except Exception:\n"," continue\n"," self.n_items = max(self.n_items, max(items))\n"," self.n_test += len(items)\n"," self.n_items += 1\n"," self.n_users += 1\n","\n"," self.print_statistics()\n","\n"," self.R = sp.dok_matrix((self.n_users, self.n_items), dtype=np.float32)\n","\n"," self.train_items, self.test_set = {}, {}\n"," with open(train_file) as f_train:\n"," with open(test_file) as f_test:\n"," for l in f_train.readlines():\n"," if len(l) == 0:\n"," break\n"," l = l.strip('\\n')\n"," items = [int(i) for i in l.split(' ')]\n"," uid, train_items = items[0], items[1:]\n","\n"," for i in train_items:\n"," self.R[uid, i] = 1.\n"," # self.R[uid][i] = 1\n","\n"," self.train_items[uid] = train_items\n","\n"," for l in f_test.readlines():\n"," if len(l) == 0: break\n"," l = l.strip('\\n')\n"," try:\n"," items = [int(i) for i in l.split(' ')]\n"," except Exception:\n"," continue\n","\n"," uid, test_items = items[0], items[1:]\n"," self.test_set[uid] = test_items\n","\n"," def get_adj_mat(self):\n"," try:\n"," t1 = time()\n"," adj_mat = sp.load_npz(self.path + '/s_adj_mat.npz')\n"," norm_adj_mat = sp.load_npz(self.path + '/s_norm_adj_mat.npz')\n"," mean_adj_mat = sp.load_npz(self.path + '/s_mean_adj_mat.npz')\n"," print('already load adj matrix', adj_mat.shape, time() - t1)\n","\n"," except Exception:\n"," adj_mat, norm_adj_mat, mean_adj_mat = self.create_adj_mat()\n"," sp.save_npz(self.path + '/s_adj_mat.npz', adj_mat)\n"," sp.save_npz(self.path + '/s_norm_adj_mat.npz', norm_adj_mat)\n"," sp.save_npz(self.path + '/s_mean_adj_mat.npz', mean_adj_mat)\n"," return adj_mat, norm_adj_mat, mean_adj_mat\n","\n"," def create_adj_mat(self):\n"," t1 = time()\n"," adj_mat = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32)\n"," adj_mat = adj_mat.tolil()\n"," R = self.R.tolil()\n","\n"," adj_mat[:self.n_users, self.n_users:] = R\n"," adj_mat[self.n_users:, :self.n_users] = R.T\n"," adj_mat = adj_mat.todok()\n"," print('already create adjacency matrix', adj_mat.shape, time() - t1)\n","\n"," t2 = time()\n","\n"," def mean_adj_single(adj):\n"," # D^-1 * A\n"," rowsum = np.array(adj.sum(1))\n","\n"," d_inv = np.power(rowsum, -1).flatten()\n"," d_inv[np.isinf(d_inv)] = 0.\n"," d_mat_inv = sp.diags(d_inv)\n","\n"," norm_adj = d_mat_inv.dot(adj)\n"," # norm_adj = adj.dot(d_mat_inv)\n"," print('generate single-normalized adjacency matrix.')\n"," return norm_adj.tocoo()\n","\n"," def normalized_adj_single(adj):\n"," # D^-1/2 * A * D^-1/2\n"," rowsum = np.array(adj.sum(1))\n","\n"," d_inv_sqrt = np.power(rowsum, -0.5).flatten()\n"," d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.\n"," d_mat_inv_sqrt = sp.diags(d_inv_sqrt)\n","\n"," # bi_lap = adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt)\n"," bi_lap = d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt)\n"," return bi_lap.tocoo()\n","\n"," def check_adj_if_equal(adj):\n"," dense_A = np.array(adj.todense())\n"," degree = np.sum(dense_A, axis=1, keepdims=False)\n","\n"," temp = np.dot(np.diag(np.power(degree, -1)), dense_A)\n"," print('check normalized adjacency matrix whether equal to this laplacian matrix.')\n"," return temp\n","\n"," norm_adj_mat = mean_adj_single(adj_mat + sp.eye(adj_mat.shape[0]))\n"," # norm_adj_mat = normalized_adj_single(adj_mat + sp.eye(adj_mat.shape[0]))\n"," mean_adj_mat = mean_adj_single(adj_mat)\n","\n"," print('already normalize adjacency matrix', time() - t2)\n"," return adj_mat.tocsr(), norm_adj_mat.tocsr(), mean_adj_mat.tocsr()\n","\n"," def negative_pool(self):\n"," t1 = time()\n"," for u in self.train_items.keys():\n"," neg_items = list(set(range(self.n_items)) - set(self.train_items[u]))\n"," pools = [rd.choice(neg_items) for _ in range(100)]\n"," self.neg_pools[u] = pools\n"," print('refresh negative pools', time() - t1)\n","\n"," def sample(self):\n"," if self.batch_size <= self.n_users:\n"," users = rd.sample(self.exist_users, self.batch_size)\n"," else:\n"," users = [rd.choice(self.exist_users) for _ in range(self.batch_size)]\n","\n"," def sample_pos_items_for_u(u, num):\n"," # sample num pos items for u-th user\n"," pos_items = self.train_items[u]\n"," n_pos_items = len(pos_items)\n"," pos_batch = []\n"," while True:\n"," if len(pos_batch) == num:\n"," break\n"," pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0]\n"," pos_i_id = pos_items[pos_id]\n","\n"," if pos_i_id not in pos_batch:\n"," pos_batch.append(pos_i_id)\n"," return pos_batch\n","\n"," def sample_neg_items_for_u(u, num):\n"," # sample num neg items for u-th user\n"," neg_items = []\n"," while True:\n"," if len(neg_items) == num:\n"," break\n"," neg_id = np.random.randint(low=0, high=self.n_items,size=1)[0]\n"," if neg_id not in self.train_items[u] and neg_id not in neg_items:\n"," neg_items.append(neg_id)\n"," return neg_items\n","\n"," def sample_neg_items_for_u_from_pools(u, num):\n"," neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u]))\n"," return rd.sample(neg_items, num)\n","\n"," pos_items, neg_items = [], []\n"," for u in users:\n"," pos_items += sample_pos_items_for_u(u, 1)\n"," neg_items += sample_neg_items_for_u(u, 1)\n","\n"," return users, pos_items, neg_items\n","\n"," def get_num_users_items(self):\n"," return self.n_users, self.n_items\n","\n"," def print_statistics(self):\n"," print('n_users=%d, n_items=%d' % (self.n_users, self.n_items))\n"," print('n_interactions=%d' % (self.n_train + self.n_test))\n"," print('n_train=%d, n_test=%d, sparsity=%.5f' % (self.n_train, self.n_test, (self.n_train + self.n_test)/(self.n_users * self.n_items)))\n","\n"," def get_sparsity_split(self):\n"," try:\n"," split_uids, split_state = [], []\n"," lines = open(self.path + '/sparsity.split', 'r').readlines()\n","\n"," for idx, line in enumerate(lines):\n"," if idx % 2 == 0:\n"," split_state.append(line.strip())\n"," print(line.strip())\n"," else:\n"," split_uids.append([int(uid) for uid in line.strip().split(' ')])\n"," print('get sparsity split.')\n","\n"," except Exception:\n"," split_uids, split_state = self.create_sparsity_split()\n"," f = open(self.path + '/sparsity.split', 'w')\n"," for idx in range(len(split_state)):\n"," f.write(split_state[idx] + '\\n')\n"," f.write(' '.join([str(uid) for uid in split_uids[idx]]) + '\\n')\n"," print('create sparsity split.')\n","\n"," return split_uids, split_state\n","\n"," def create_sparsity_split(self):\n"," all_users_to_test = list(self.test_set.keys())\n"," user_n_iid = dict()\n","\n"," # generate a dictionary to store (key=n_iids, value=a list of uid).\n"," for uid in all_users_to_test:\n"," train_iids = self.train_items[uid]\n"," test_iids = self.test_set[uid]\n","\n"," n_iids = len(train_iids) + len(test_iids)\n","\n"," if n_iids not in user_n_iid.keys():\n"," user_n_iid[n_iids] = [uid]\n"," else:\n"," user_n_iid[n_iids].append(uid)\n"," split_uids = list()\n","\n"," # split the whole user set into four subset.\n"," temp = []\n"," count = 1\n"," fold = 4\n"," n_count = (self.n_train + self.n_test)\n"," n_rates = 0\n","\n"," split_state = []\n"," for idx, n_iids in enumerate(sorted(user_n_iid)):\n"," temp += user_n_iid[n_iids]\n"," n_rates += n_iids * len(user_n_iid[n_iids])\n"," n_count -= n_iids * len(user_n_iid[n_iids])\n","\n"," if n_rates >= count * 0.25 * (self.n_train + self.n_test):\n"," split_uids.append(temp)\n","\n"," state = '#inter per user<=[%d], #users=[%d], #all rates=[%d]' %(n_iids, len(temp), n_rates)\n"," split_state.append(state)\n"," print(state)\n","\n"," temp = []\n"," n_rates = 0\n"," fold -= 1\n","\n"," if idx == len(user_n_iid.keys()) - 1 or n_count == 0:\n"," split_uids.append(temp)\n","\n"," state = '#inter per user<=[%d], #users=[%d], #all rates=[%d]' % (n_iids, len(temp), n_rates)\n"," split_state.append(state)\n"," print(state)\n","\n"," return split_uids, split_state"]},{"cell_type":"markdown","metadata":{"id":"6G9Ud3wOpGaV"},"source":["### Utils"]},{"cell_type":"code","execution_count":6,"metadata":{"executionInfo":{"elapsed":10,"status":"ok","timestamp":1630672872430,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"},"user_tz":-330},"id":"LiwE0Y_GpH4T"},"outputs":[],"source":["import os\n","import re\n","\n","def txt2list(file_src):\n"," orig_file = open(file_src, \"r\")\n"," lines = orig_file.readlines()\n"," return lines\n","\n","def ensureDir(dir_path):\n"," d = os.path.dirname(dir_path)\n"," if not os.path.exists(d):\n"," os.makedirs(d)\n","\n","def uni2str(unicode_str):\n"," return str(unicode_str.encode('ascii', 'ignore')).replace('\\n', '').strip()\n","\n","def hasNumbers(inputString):\n"," return bool(re.search(r'\\d', inputString))\n","\n","def delMultiChar(inputString, chars):\n"," for ch in chars:\n"," inputString = inputString.replace(ch, '')\n"," return inputString\n","\n","def merge_two_dicts(x, y):\n"," z = x.copy() # start with x's keys and values\n"," z.update(y) # modifies z with y's keys and values & returns None\n"," return z\n","\n","def early_stopping(log_value, best_value, stopping_step, expected_order='acc', flag_step=100):\n"," # early stopping strategy:\n"," assert expected_order in ['acc', 'dec']\n","\n"," if (expected_order == 'acc' and log_value >= best_value) or (expected_order == 'dec' and log_value <= best_value):\n"," stopping_step = 0\n"," best_value = log_value\n"," else:\n"," stopping_step += 1\n","\n"," if stopping_step >= flag_step:\n"," print(\"Early stopping is trigger at step: {} log:{}\".format(flag_step, log_value))\n"," should_stop = True\n"," else:\n"," should_stop = False\n"," return best_value, stopping_step, should_stop"]},{"cell_type":"markdown","metadata":{"id":"3C1euZJSpR1u"},"source":["### Batch testing"]},{"cell_type":"code","execution_count":7,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":13310,"status":"ok","timestamp":1630672885733,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"},"user_tz":-330},"id":"spu0_gZppTec","outputId":"3cd599d3-8b1a-4dec-c9be-803af0f8463f"},"outputs":[{"name":"stdout","output_type":"stream","text":["n_users=29858, n_items=40981\n","n_interactions=1027370\n","n_train=810128, n_test=217242, sparsity=0.00084\n"]}],"source":["import multiprocessing\n","import heapq\n","\n","cores = multiprocessing.cpu_count() // 2\n","\n","args = parse_args()\n","Ks = eval(args.Ks)\n","\n","data_generator = Data(path=args.data_path + args.dataset, batch_size=args.batch_size)\n","USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items\n","N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test\n","BATCH_SIZE = args.batch_size\n","\n","def ranklist_by_heapq(user_pos_test, test_items, rating, Ks):\n"," item_score = {}\n"," for i in test_items:\n"," item_score[i] = rating[i]\n","\n"," K_max = max(Ks)\n"," K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)\n","\n"," r = []\n"," for i in K_max_item_score:\n"," if i in user_pos_test:\n"," r.append(1)\n"," else:\n"," r.append(0)\n"," auc = 0.\n"," return r, auc\n","\n","def get_auc(item_score, user_pos_test):\n"," item_score = sorted(item_score.items(), key=lambda kv: kv[1])\n"," item_score.reverse()\n"," item_sort = [x[0] for x in item_score]\n"," posterior = [x[1] for x in item_score]\n","\n"," r = []\n"," for i in item_sort:\n"," if i in user_pos_test:\n"," r.append(1)\n"," else:\n"," r.append(0)\n"," auc = auc(ground_truth=r, prediction=posterior)\n"," return auc\n","\n","def ranklist_by_sorted(user_pos_test, test_items, rating, Ks):\n"," item_score = {}\n"," for i in test_items:\n"," item_score[i] = rating[i]\n","\n"," K_max = max(Ks)\n"," K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)\n","\n"," r = []\n"," for i in K_max_item_score:\n"," if i in user_pos_test:\n"," r.append(1)\n"," else:\n"," r.append(0)\n"," auc = get_auc(item_score, user_pos_test)\n"," return r, auc\n","\n","def get_performance(user_pos_test, r, auc, Ks):\n"," precision, recall, ndcg, hit_ratio = [], [], [], []\n","\n"," for K in Ks:\n"," precision.append(precision_at_k(r, K))\n"," recall.append(recall_at_k(r, K, len(user_pos_test)))\n"," ndcg.append(ndcg_at_k(r, K, user_pos_test))\n"," hit_ratio.append(hit_at_k(r, K))\n","\n"," return {'recall': np.array(recall), 'precision': np.array(precision),\n"," 'ndcg': np.array(ndcg), 'hit_ratio': np.array(hit_ratio), 'auc': auc}\n","\n","\n","def test_one_user(x):\n"," # user u's ratings for user u\n"," rating = x[0]\n"," #uid\n"," u = x[1]\n"," #user u's items in the training set\n"," try:\n"," training_items = data_generator.train_items[u]\n"," except Exception:\n"," training_items = []\n"," #user u's items in the test set\n"," user_pos_test = data_generator.test_set[u]\n","\n"," all_items = set(range(ITEM_NUM))\n","\n"," test_items = list(all_items - set(training_items))\n","\n"," if args.test_flag == 'part':\n"," r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks)\n"," else:\n"," r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, Ks)\n","\n"," return get_performance(user_pos_test, r, auc, Ks)\n","\n","\n","def test(model, users_to_test, drop_flag=False, batch_test_flag=False):\n"," result = {'precision': np.zeros(len(Ks)), 'recall': np.zeros(len(Ks)), 'ndcg': np.zeros(len(Ks)),\n"," 'hit_ratio': np.zeros(len(Ks)), 'auc': 0.}\n","\n"," pool = multiprocessing.Pool(cores)\n","\n"," u_batch_size = BATCH_SIZE * 2\n"," i_batch_size = BATCH_SIZE\n","\n"," test_users = users_to_test\n"," n_test_users = len(test_users)\n"," n_user_batchs = n_test_users // u_batch_size + 1\n","\n"," count = 0\n","\n"," for u_batch_id in range(n_user_batchs):\n"," start = u_batch_id * u_batch_size\n"," end = (u_batch_id + 1) * u_batch_size\n","\n"," user_batch = test_users[start: end]\n","\n"," if batch_test_flag:\n"," # batch-item test\n"," n_item_batchs = ITEM_NUM // i_batch_size + 1\n"," rate_batch = np.zeros(shape=(len(user_batch), ITEM_NUM))\n","\n"," i_count = 0\n"," for i_batch_id in range(n_item_batchs):\n"," i_start = i_batch_id * i_batch_size\n"," i_end = min((i_batch_id + 1) * i_batch_size, ITEM_NUM)\n","\n"," item_batch = range(i_start, i_end)\n","\n"," if drop_flag == False:\n"," u_g_embeddings, pos_i_g_embeddings, _ = model(user_batch,\n"," item_batch,\n"," [],\n"," drop_flag=False)\n"," i_rate_batch = model.rating(u_g_embeddings, pos_i_g_embeddings).detach().cpu()\n"," else:\n"," u_g_embeddings, pos_i_g_embeddings, _ = model(user_batch,\n"," item_batch,\n"," [],\n"," drop_flag=True)\n"," i_rate_batch = model.rating(u_g_embeddings, pos_i_g_embeddings).detach().cpu()\n","\n"," rate_batch[:, i_start: i_end] = i_rate_batch\n"," i_count += i_rate_batch.shape[1]\n","\n"," assert i_count == ITEM_NUM\n","\n"," else:\n"," # all-item test\n"," item_batch = range(ITEM_NUM)\n","\n"," if drop_flag == False:\n"," u_g_embeddings, pos_i_g_embeddings, _ = model(user_batch,\n"," item_batch,\n"," [],\n"," drop_flag=False)\n"," rate_batch = model.rating(u_g_embeddings, pos_i_g_embeddings).detach().cpu()\n"," else:\n"," u_g_embeddings, pos_i_g_embeddings, _ = model(user_batch,\n"," item_batch,\n"," [],\n"," drop_flag=True)\n"," rate_batch = model.rating(u_g_embeddings, pos_i_g_embeddings).detach().cpu()\n","\n"," user_batch_rating_uid = zip(rate_batch.numpy(), user_batch)\n"," batch_result = pool.map(test_one_user, user_batch_rating_uid)\n"," count += len(batch_result)\n","\n"," for re in batch_result:\n"," result['precision'] += re['precision']/n_test_users\n"," result['recall'] += re['recall']/n_test_users\n"," result['ndcg'] += re['ndcg']/n_test_users\n"," result['hit_ratio'] += re['hit_ratio']/n_test_users\n"," result['auc'] += re['auc']/n_test_users\n","\n","\n"," assert count == n_test_users\n"," pool.close()\n"," return result"]},{"cell_type":"markdown","metadata":{"id":"AR0QR33TrWKy"},"source":["### Training"]},{"cell_type":"code","execution_count":8,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"executionInfo":{"elapsed":1844273,"status":"error","timestamp":1630674729996,"user":{"displayName":"Sparsh Agarwal","photoUrl":"","userId":"13037694610922482904"},"user_tz":-330},"id":"sIioaFLcmUp8","outputId":"2a398903-b223-466a-9933-2f15e7423d66"},"outputs":[{"name":"stdout","output_type":"stream","text":["already create adjacency matrix (70839, 70839) 79.06983399391174\n","generate single-normalized adjacency matrix.\n","generate single-normalized adjacency matrix.\n","already normalize adjacency matrix 2.180420160293579\n","Epoch 0 [173.5s]: train==[432.81360=432.77740 + 0.03577]\n","Epoch 1 [173.0s]: train==[238.26131=238.22516 + 0.03618]\n","Epoch 2 [173.6s]: train==[198.87595=198.83954 + 0.03643]\n","Epoch 3 [172.9s]: train==[183.36285=183.32619 + 0.03663]\n","Epoch 4 [174.2s]: train==[169.82016=169.78339 + 0.03679]\n","Epoch 5 [173.9s]: train==[154.85112=154.81422 + 0.03695]\n","Epoch 6 [175.6s]: train==[143.06267=143.02547 + 0.03708]\n","Epoch 7 [174.5s]: train==[134.72755=134.69037 + 0.03721]\n","Epoch 8 [174.8s]: train==[129.35472=129.31735 + 0.03733]\n"]},{"ename":"NameError","evalue":"ignored","output_type":"error","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mRemoteTraceback\u001b[0m Traceback (most recent call last)","\u001b[0;31mRemoteTraceback\u001b[0m: \n\"\"\"\nTraceback (most recent call last):\n File \"/usr/lib/python3.7/multiprocessing/pool.py\", line 121, in worker\n result = (True, func(*args, **kwds))\n File \"/usr/lib/python3.7/multiprocessing/pool.py\", line 44, in mapstar\n return list(map(*args))\n File \"\", line 98, in test_one_user\n return get_performance(user_pos_test, r, auc, Ks)\n File \"\", line 67, in get_performance\n precision.append(metrics.precision_at_k(r, K))\nNameError: name 'metrics' is not defined\n\"\"\"","\nThe above exception was the direct cause of the following exception:\n","\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0mt2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0musers_to_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_generator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest_set\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 64\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0musers_to_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdrop_flag\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 65\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0mt3\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m\u001b[0m in \u001b[0;36mtest\u001b[0;34m(model, users_to_test, drop_flag, batch_test_flag)\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[0muser_batch_rating_uid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrate_batch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0muser_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 170\u001b[0;31m \u001b[0mbatch_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpool\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_one_user\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0muser_batch_rating_uid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 171\u001b[0m \u001b[0mcount\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_result\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/lib/python3.7/multiprocessing/pool.py\u001b[0m in \u001b[0;36mmap\u001b[0;34m(self, func, iterable, chunksize)\u001b[0m\n\u001b[1;32m 266\u001b[0m \u001b[0;32min\u001b[0m \u001b[0ma\u001b[0m \u001b[0mlist\u001b[0m \u001b[0mthat\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mreturned\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 267\u001b[0m '''\n\u001b[0;32m--> 268\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_map_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmapstar\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchunksize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 269\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 270\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstarmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchunksize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/lib/python3.7/multiprocessing/pool.py\u001b[0m in \u001b[0;36mget\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 655\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 656\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 657\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 658\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 659\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_set\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mNameError\u001b[0m: name 'metrics' is not defined"]}],"source":["import torch\n","import torch.optim as optim\n","\n","from time import time\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","\n","if __name__ == '__main__':\n","\n"," args.device = torch.device('cuda:' + str(args.gpu_id))\n","\n"," plain_adj, norm_adj, mean_adj = data_generator.get_adj_mat()\n","\n"," args.node_dropout = eval(args.node_dropout)\n"," args.mess_dropout = eval(args.mess_dropout)\n","\n"," model = NGCF(data_generator.n_users,\n"," data_generator.n_items,\n"," norm_adj,\n"," args).to(args.device)\n","\n"," t0 = time()\n"," \"\"\"\n"," *********************************************************\n"," Train.\n"," \"\"\"\n"," cur_best_pre_0, stopping_step = 0, 0\n"," optimizer = optim.Adam(model.parameters(), lr=args.lr)\n","\n"," loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], []\n"," for epoch in range(args.epoch):\n"," t1 = time()\n"," loss, mf_loss, emb_loss = 0., 0., 0.\n"," n_batch = data_generator.n_train // args.batch_size + 1\n","\n"," for idx in range(n_batch):\n"," users, pos_items, neg_items = data_generator.sample()\n"," u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings = model(users,\n"," pos_items,\n"," neg_items,\n"," drop_flag=args.node_dropout_flag)\n","\n"," batch_loss, batch_mf_loss, batch_emb_loss = model.create_bpr_loss(u_g_embeddings,\n"," pos_i_g_embeddings,\n"," neg_i_g_embeddings)\n"," optimizer.zero_grad()\n"," batch_loss.backward()\n"," optimizer.step()\n","\n"," loss += batch_loss\n"," mf_loss += batch_mf_loss\n"," emb_loss += batch_emb_loss\n","\n"," if (epoch + 1) % 10 != 0:\n"," if args.verbose > 0 and epoch % args.verbose == 0:\n"," perf_str = 'Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f]' % (\n"," epoch, time() - t1, loss, mf_loss, emb_loss)\n"," print(perf_str)\n"," continue\n","\n"," t2 = time()\n"," users_to_test = list(data_generator.test_set.keys())\n"," ret = test(model, users_to_test, drop_flag=False)\n","\n"," t3 = time()\n","\n"," loss_loger.append(loss)\n"," rec_loger.append(ret['recall'])\n"," pre_loger.append(ret['precision'])\n"," ndcg_loger.append(ret['ndcg'])\n"," hit_loger.append(ret['hit_ratio'])\n","\n"," if args.verbose > 0:\n"," perf_str = 'Epoch %d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f], recall=[%.5f, %.5f], ' \\\n"," 'precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]' % \\\n"," (epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, ret['recall'][0], ret['recall'][-1],\n"," ret['precision'][0], ret['precision'][-1], ret['hit_ratio'][0], ret['hit_ratio'][-1],\n"," ret['ndcg'][0], ret['ndcg'][-1])\n"," print(perf_str)\n","\n"," cur_best_pre_0, stopping_step, should_stop = early_stopping(ret['recall'][0], cur_best_pre_0,\n"," stopping_step, expected_order='acc', flag_step=5)\n","\n"," # *********************************************************\n"," # early stopping when cur_best_pre_0 is decreasing for ten successive steps.\n"," if should_stop == True:\n"," break\n","\n"," # *********************************************************\n"," # save the user & item embeddings for pretraining.\n"," if ret['recall'][0] == cur_best_pre_0 and args.save_flag == 1:\n"," torch.save(model.state_dict(), args.weights_path + str(epoch) + '.pkl')\n"," print('save the weights in path: ', args.weights_path + str(epoch) + '.pkl')\n","\n"," recs = np.array(rec_loger)\n"," pres = np.array(pre_loger)\n"," ndcgs = np.array(ndcg_loger)\n"," hit = np.array(hit_loger)\n","\n"," best_rec_0 = max(recs[:, 0])\n"," idx = list(recs[:, 0]).index(best_rec_0)\n","\n"," final_perf = \"Best Iter=[%d]@[%.1f]\\trecall=[%s], precision=[%s], hit=[%s], ndcg=[%s]\" % \\\n"," (idx, time() - t0, '\\t'.join(['%.5f' % r for r in recs[idx]]),\n"," '\\t'.join(['%.5f' % r for r in pres[idx]]),\n"," '\\t'.join(['%.5f' % r for r in hit[idx]]),\n"," '\\t'.join(['%.5f' % r for r in ndcgs[idx]]))\n"," print(final_perf)"]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyNCm3/IWW2dyDD24Lq4OzgE","collapsed_sections":[],"name":"ngcf-pytorch.ipynb","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}