{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-18-tagnnpp.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T026634%20%7C%20TAGNN%2B%2B%20Session-based%20Recommendations%20on%20Diginetica%20dataset.ipynb","timestamp":1644650309274},{"file_id":"1YAv6DxFRlOMWCLWU-8gShdcpi4JMkW0V","timestamp":1638344309523},{"file_id":"1buphX5uNmwXsebanVqUfOxSP1MkUBXsx","timestamp":1637777022918}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU","widgets":{"application/vnd.jupyter.widget-state+json":{"2532292a67b54901ae69cc3a270ec97d":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_view_name":"HBoxView","_dom_classes":[],"_model_name":"HBoxModel","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.5.0","box_style":"","layout":"IPY_MODEL_e9a0bd285197471e955bcd3309f0f9f4","_model_module":"@jupyter-widgets/controls","children":["IPY_MODEL_8e131d3b74d14eefbca397b93d49ccb1","IPY_MODEL_ef31a250c90b4786b61b98125decd8d7","IPY_MODEL_35b3c0271ceb47cf9edd6f70ffb34af3"]}},"e9a0bd285197471e955bcd3309f0f9f4":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"8e131d3b74d14eefbca397b93d49ccb1":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_51a1648191584f8192b5f8065e7bbd75","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":" 60%","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_3c0b22f7b0b84ae8a2d53059b36bfc16"}},"ef31a250c90b4786b61b98125decd8d7":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_view_name":"ProgressView","style":"IPY_MODEL_98d6ae7b07e34c2c89ba5a5fd6a1ed49","_dom_classes":[],"description":"","_model_name":"FloatProgressModel","bar_style":"","max":12951,"_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":7820,"_view_count":null,"_view_module_version":"1.5.0","orientation":"horizontal","min":0,"description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_1403f29751ec4e9b83fbbd8c505f1b4a"}},"35b3c0271ceb47cf9edd6f70ffb34af3":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_93f9689f0c9d4e41b3f121896e983507","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":" 7820/12951 [59:28<38:53, 2.20it/s]","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_c6f37224738843ffa487cbda4c1bef90"}},"51a1648191584f8192b5f8065e7bbd75":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"3c0b22f7b0b84ae8a2d53059b36bfc16":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"98d6ae7b07e34c2c89ba5a5fd6a1ed49":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"ProgressStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","bar_color":null,"_model_module":"@jupyter-widgets/controls"}},"1403f29751ec4e9b83fbbd8c505f1b4a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"93f9689f0c9d4e41b3f121896e983507":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"c6f37224738843ffa487cbda4c1bef90":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}}}}},"cells":[{"cell_type":"markdown","metadata":{"id":"OwSaXMsA7tEN"},"source":["# TAGNN++ Session-based Recommendations on Diginetica dataset"]},{"cell_type":"markdown","metadata":{"id":"9P7DngMS7cIA"},"source":["## Setup"]},{"cell_type":"code","metadata":{"id":"oTkBob7DTs2L"},"source":["import torch\n","from torch import nn, optim\n","from torch.nn import Module, Parameter\n","import torch.nn.functional as F\n","from torch.utils.tensorboard import SummaryWriter\n","\n","\n","from collections import Iterable\n","from tqdm.notebook import tqdm\n","import datetime\n","import math\n","import numpy as np\n","import pickle\n","import time\n","import sys\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Uk_7AdvECZxr"},"source":["import warnings\n","warnings.filterwarnings('ignore')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"LE5-b7OQ7aep"},"source":["## Dataset"]},{"cell_type":"code","metadata":{"id":"U4fNVWWSTgWa","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1637777066024,"user_tz":-330,"elapsed":1349,"user":{"displayName":"sparsh agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"00322518567794762549"}},"outputId":"9e311caf-6b57-41ba-f815-cba259020321"},"source":["!wget -q --show-progress https://github.com/sparsh-ai/stanza/raw/S969796/datasets/cikm16/raw/train.txt\n","!wget -q --show-progress https://github.com/sparsh-ai/stanza/raw/S969796/datasets/cikm16/raw/test.txt"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["train.txt 100%[===================>] 15.55M --.-KB/s in 0.1s \n","test.txt 100%[===================>] 1.47M --.-KB/s in 0.03s \n"]}]},{"cell_type":"code","metadata":{"id":"BH4N35xr90o9"},"source":["def split_validation(train_set, valid_portion):\n"," train_set_x, train_set_y = train_set\n"," n_samples = len(train_set_x)\n"," sidx = np.arange(n_samples, dtype='int32')\n"," np.random.shuffle(sidx)\n"," n_train = int(np.round(n_samples * (1. - valid_portion)))\n"," valid_set_x = [train_set_x[s] for s in sidx[n_train:]]\n"," valid_set_y = [train_set_y[s] for s in sidx[n_train:]]\n"," train_set_x = [train_set_x[s] for s in sidx[:n_train]]\n"," train_set_y = [train_set_y[s] for s in sidx[:n_train]]\n","\n"," return (train_set_x, train_set_y), (valid_set_x, valid_set_y)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"KNFB6Rl293Zc"},"source":["def data_masks(all_usr_pois, item_tail):\n"," us_lens = [len(upois) for upois in all_usr_pois]\n"," len_max = max(us_lens)\n"," us_pois = [upois + item_tail * (len_max - le)\n"," for upois, le in zip(all_usr_pois, us_lens)]\n"," us_msks = [[1] * le + [0] * (len_max - le) for le in us_lens]\n"," return us_pois, us_msks, len_max"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"bLRy6o3196VW"},"source":["class Dataset():\n"," def __init__(self, data, shuffle=False, graph=None):\n"," inputs = data[0]\n"," inputs, mask, len_max = data_masks(inputs, [0])\n"," self.inputs = np.asarray(inputs)\n"," self.mask = np.asarray(mask)\n"," self.len_max = len_max\n"," self.targets = np.asarray(data[1])\n"," self.length = len(inputs)\n"," self.shuffle = shuffle\n"," self.graph = graph\n","\n"," def generate_batch(self, batch_size):\n"," if self.shuffle:\n"," shuffled_arg = np.arange(self.length)\n"," np.random.shuffle(shuffled_arg)\n"," self.inputs = self.inputs[shuffled_arg]\n"," self.mask = self.mask[shuffled_arg]\n"," self.targets = self.targets[shuffled_arg]\n"," n_batch = int(self.length / batch_size)\n"," if self.length % batch_size != 0:\n"," n_batch += 1\n"," slices = np.split(np.arange(n_batch * batch_size), n_batch)\n"," slices[-1] = slices[-1][:(self.length - batch_size * (n_batch - 1))]\n"," return slices\n","\n"," def get_slice(self, i):\n"," inputs, mask, targets = self.inputs[i], self.mask[i], self.targets[i]\n"," items, n_node, A, alias_inputs = [], [], [], []\n"," for u_input in inputs:\n"," n_node.append(len(np.unique(u_input)))\n"," max_n_node = np.max(n_node)\n"," for u_input in inputs:\n"," node = np.unique(u_input)\n"," items.append(node.tolist() + (max_n_node - len(node)) * [0])\n"," u_A = np.zeros((max_n_node, max_n_node))\n"," for i in np.arange(len(u_input) - 1):\n"," if u_input[i + 1] == 0:\n"," break\n"," u = np.where(node == u_input[i])[0][0]\n"," v = np.where(node == u_input[i + 1])[0][0]\n"," u_A[u][v] = 1\n"," u_sum_in = np.sum(u_A, 0)\n"," u_sum_in[np.where(u_sum_in == 0)] = 1\n"," u_A_in = np.divide(u_A, u_sum_in)\n"," u_sum_out = np.sum(u_A, 1)\n"," u_sum_out[np.where(u_sum_out == 0)] = 1\n"," u_A_out = np.divide(u_A.transpose(), u_sum_out)\n"," u_A = np.concatenate([u_A_in, u_A_out]).transpose()\n"," A.append(u_A)\n"," alias_inputs.append([np.where(node == i)[0][0] for i in u_input])\n"," return alias_inputs, A, items, mask, targets"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rFug-Z4-7X4C"},"source":["## Model"]},{"cell_type":"code","metadata":{"id":"LeJK_7u-9Z3v"},"source":["class AGC(optim.Optimizer):\n"," \n"," \"\"\"Generic implementation of the Adaptive Gradient Clipping\n"," Args:\n"," params (iterable): iterable of parameters to optimize or dicts defining\n"," parameter groups\n"," optim (torch.optim.Optimizer): Optimizer with base class optim.Optimizer\n"," clipping (float, optional): clipping value (default: 1e-3)\n"," eps (float, optional): eps (default: 1e-3)\n"," model (torch.nn.Module, optional): The original model\n"," ignore_agc (str, Iterable, optional): Layers for AGC to ignore\n"," \"\"\"\n","\n"," def __init__(self, params, optim: optim.Optimizer, clipping: float = 1e-2, eps: float = 1e-3, model=None, ignore_agc= ['']):\n"," if clipping < 0.0:\n"," raise ValueError(\"Invalid clipping value: {}\".format(clipping))\n"," if eps < 0.0:\n"," raise ValueError(\"Invalid eps value: {}\".format(eps))\n","\n"," self.optim = optim\n","\n"," defaults = dict(clipping=clipping, eps=eps)\n"," defaults = {**defaults, **optim.defaults}\n","\n"," if not isinstance(ignore_agc, Iterable):\n"," ignore_agc = [ignore_agc]\n","\n"," if model is not None:\n"," assert ignore_agc not in [\n"," None, []], \"Specify args ignore_agc to ignore fc-like(or other) layers\"\n"," names = [name for name, module in model.named_modules()]\n","\n"," for module_name in ignore_agc:\n"," if module_name not in names:\n"," raise ModuleNotFoundError(\n"," \"Module name {} not found in the model\".format(module_name))\n"," parameters = [{\"params\": module.parameters()} for name,\n"," module in model.named_modules() if name not in ignore_agc]\n","\n"," super(AGC, self).__init__(params, defaults)\n","\n"," @torch.no_grad()\n"," def step(self, closure=None):\n"," \"\"\"Performs a single optimization step.\n"," Arguments:\n"," closure (callable, optional): A closure that reevaluates the model\n"," and returns the loss.\n"," \"\"\"\n"," loss = None\n"," if closure is not None:\n"," with torch.enable_grad():\n"," loss = closure()\n","\n"," for group in self.param_groups:\n"," for p in group['params']:\n"," if p.grad is None:\n"," continue\n"," param_norm = torch.max(unitwise_norm(\n"," p.detach()), torch.tensor(group['eps']).to(p.device))\n"," grad_norm = unitwise_norm(p.grad.detach())\n"," max_norm = param_norm * group['clipping']\n","\n"," trigger = grad_norm < max_norm\n","\n"," clipped_grad = p.grad * \\\n"," (max_norm / torch.max(grad_norm,\n"," torch.tensor(1e-6).to(grad_norm.device)))\n"," p.grad.data.copy_(torch.where(trigger, clipped_grad, p.grad))\n","\n"," return self.optim.step(closure)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"nUAi6Q6n9Z6F"},"source":["def unitwise_norm(x: torch.Tensor):\n"," if x.ndim <= 1:\n"," dim = 0\n"," keepdim = False\n"," elif x.ndim in [2, 3]:\n"," dim = 0\n"," keepdim = True\n"," elif x.ndim == 4:\n"," dim = [1, 2, 3]\n"," keepdim = True\n"," else:\n"," raise ValueError('Wrong dimensions of x')\n","\n"," return torch.sum(x**2, dim=dim, keepdim=keepdim) ** 0.5"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xMhrwkxC9Z1q"},"source":["class Attention_GNN(Module):\n"," def __init__(self, hidden_size, step=1):\n"," super(Attention_GNN, self).__init__()\n"," self.step = step\n"," self.hidden_size = hidden_size\n"," self.input_size = hidden_size * 2\n"," self.gate_size = 3 * hidden_size\n"," self.w_ih = Parameter(torch.Tensor(self.gate_size, self.input_size))\n"," self.w_hh = Parameter(torch.Tensor(self.gate_size, self.hidden_size))\n"," self.b_ih = Parameter(torch.Tensor(self.gate_size))\n"," self.b_hh = Parameter(torch.Tensor(self.gate_size))\n"," self.b_iah = Parameter(torch.Tensor(self.hidden_size))\n"," self.b_oah = Parameter(torch.Tensor(self.hidden_size))\n","\n"," self.linear_edge_in = nn.Linear(\n"," self.hidden_size, self.hidden_size, bias=True)\n"," self.linear_edge_out = nn.Linear(\n"," self.hidden_size, self.hidden_size, bias=True)\n"," self.linear_edge_f = nn.Linear(\n"," self.hidden_size, self.hidden_size, bias=True)\n","\n"," def GNNCell(self, A, hidden):\n"," input_in = torch.matmul(A[:, :, :A.shape[1]],\n"," self.linear_edge_in(hidden)) + self.b_iah\n","\n"," input_out = torch.matmul(\n"," A[:, :, A.shape[1]: 2 * A.shape[1]], self.linear_edge_out(hidden)) + self.b_oah\n","\n"," inputs = torch.cat([input_in, input_out], 2)\n"," gi = F.linear(inputs, self.w_ih, self.b_ih)\n"," gh = F.linear(hidden, self.w_hh, self.b_hh)\n"," i_r, i_i, i_n = gi.chunk(3, 2)\n"," h_r, h_i, h_n = gh.chunk(3, 2)\n"," resetgate = torch.sigmoid(i_r + h_r)\n"," inputgate = torch.sigmoid(i_i + h_i)\n"," newgate = torch.tanh(i_n + resetgate * h_n)\n"," hy = newgate + inputgate * (hidden - newgate)\n"," return hy\n","\n"," def forward(self, A, hidden):\n"," for i in range(self.step):\n"," hidden = self.GNNCell(A, hidden)\n"," return hidden"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"OBrHhi2M9Zzu"},"source":["class Attention_SessionGraph(Module):\n"," def __init__(self, opt, n_node):\n"," super(Attention_SessionGraph, self).__init__()\n"," self.hidden_size = args.hiddenSize\n"," self.n_node = n_node\n"," self.batch_size = args.batchSize\n"," self.nonhybrid = args.nonhybrid\n"," self.embedding = nn.Embedding(self.n_node, self.hidden_size)\n"," self.tagnn = Attention_GNN(self.hidden_size, step=args.step)\n","\n"," self.layer_norm1 = nn.LayerNorm(self.hidden_size)\n"," self.attn = nn.MultiheadAttention(\n"," embed_dim=self.hidden_size, num_heads=2, dropout=0.1)\n","\n"," self.linear_one = nn.Linear(\n"," self.hidden_size, self.hidden_size, bias=True)\n","\n"," self.linear_two = nn.Linear(\n"," self.hidden_size, self.hidden_size, bias=True)\n","\n"," self.linear_three = nn.Linear(self.hidden_size, 1, bias=False)\n"," self.linear_transform = nn.Linear(\n"," self.hidden_size * 2, self.hidden_size, bias=True)\n"," self.linear_t = nn.Linear(\n"," self.hidden_size, self.hidden_size, bias=False) # target attention\n"," self.loss_function = nn.CrossEntropyLoss()\n"," self.optimizer = torch.optim.Adam(\n"," self.parameters(), lr=args.lr, weight_decay=args.l2)\n"," self.agc_optimizer = AGC(self.parameters(), self.optimizer, model=self)\n"," self.scheduler = torch.optim.lr_scheduler.StepLR(\n"," self.optimizer, step_size=args.lr_dc_step, gamma=args.lr_dc)\n"," self.reset_parameters()\n","\n"," def reset_parameters(self):\n"," stdv = 1.0 / math.sqrt(self.hidden_size)\n"," for weight in self.parameters():\n"," weight.data.uniform_(-stdv, stdv)\n","\n"," def compute_scores(self, hidden, mask):\n"," ht = hidden[torch.arange(mask.shape[0]).long(), torch.sum(\n"," mask, 1) - 1] # batch_size x latent_size\n"," # batch_size x 1 x latent_size\n"," q1 = self.linear_one(ht).view(ht.shape[0], 1, ht.shape[1])\n"," q2 = self.linear_two(hidden) # batch_size x seq_length x latent_size\n"," # batch_size x seq_length x 1\n"," alpha = self.linear_three(torch.sigmoid(q1 + q2))\n"," alpha = F.softmax(alpha, 1) # batch_size x seq_length x 1\n"," # batch_size x latent_size\n"," a = torch.sum(alpha * hidden *\n"," mask.view(mask.shape[0], -1, 1).float(), 1)\n","\n"," if not self.nonhybrid:\n"," a = self.linear_transform(torch.cat([a, ht], 1))\n"," b = self.embedding.weight[1:] # n_nodes x latent_size\n","\n"," # batch_size x seq_length x latent_size\n"," hidden = hidden * mask.view(mask.shape[0], -1, 1).float()\n"," qt = self.linear_t(hidden) # batch_size x seq_length x latent_size\n"," # batch_size x n_nodes x seq_length\n"," beta = F.softmax(b @ qt.transpose(1, 2), -1)\n"," target = beta @ hidden # batch_size x n_nodes x latent_size\n"," a = a.view(ht.shape[0], 1, ht.shape[1]) # batch_size x 1 x latent_size\n"," a = a + target # batch_size x n_nodes x latent_size\n"," scores = torch.sum(a * b, -1) # batch_size x n_nodes\n"," return scores\n","\n"," def forward(self, inputs, A):\n"," hidden = self.embedding(inputs)\n"," hidden = self.tagnn(A, hidden)\n"," hidden = hidden.permute(1, 0, 2)\n","\n"," skip = self.layer_norm1(hidden)\n"," hidden, attn_w = self.attn(\n"," hidden, hidden, hidden, attn_mask=get_mask(hidden.shape[0]))\n"," hidden = hidden+skip\n"," hidden = hidden.permute(1, 0, 2)\n","\n"," return hidden"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Chu4GEAt7fn5"},"source":["## Training"]},{"cell_type":"code","metadata":{"id":"LISE5EPV9Zxg"},"source":["def get_mask(seq_len):\n"," return torch.from_numpy(np.triu(np.ones((seq_len, seq_len)), k=1).astype('bool')).to('cuda')\n","\n","\n","def to_cuda(input_variable):\n"," if torch.cuda.is_available():\n"," return input_variable.cuda()\n"," else:\n"," return input_variable\n","\n","\n","def to_cpu(input_variable):\n"," if torch.cuda.is_available():\n"," return input_variable.cpu()\n"," else:\n"," return input_variable\n","\n","\n","def forward(model, i, data):\n"," alias_inputs, A, items, mask, targets = data.get_slice(i)\n"," alias_inputs = to_cuda(torch.Tensor(alias_inputs).long())\n"," items = to_cuda(torch.Tensor(items).long())\n"," A = to_cuda(torch.Tensor(A).float())\n"," mask = to_cuda(torch.Tensor(mask).long())\n"," hidden = model(items, A)\n","\n"," def get(i): return hidden[i][alias_inputs[i]]\n"," seq_hidden = torch.stack([get(i)\n"," for i in torch.arange(len(alias_inputs)).long()])\n","\n"," return targets, model.compute_scores(seq_hidden, mask)\n","\n","\n","def train_test(model, train_data, test_data):\n"," model.scheduler.step()\n"," print('Start training: ', datetime.datetime.now())\n"," model.train()\n"," total_loss = 0.0\n"," slices = train_data.generate_batch(model.batch_size)\n","\n"," for i, j in tqdm(zip(slices, np.arange(len(slices))), total=len(slices)):\n"," model.optimizer.zero_grad()\n"," targets, scores = forward(model, i, train_data)\n"," targets = to_cuda(torch.Tensor(targets).long())\n"," loss = model.loss_function(scores, targets - 1)\n"," loss.backward()\n"," model.optimizer.step()\n"," total_loss += loss.item()\n","\n"," if j % int(len(slices) / 5 + 1) == 0:\n"," print('[%d/%d] Loss: %.4f' % (j, len(slices), loss.item()))\n","\n"," print('\\tLoss Value:\\t%.3f' % total_loss)\n"," print('Start Prediction: ', datetime.datetime.now())\n","\n"," model.eval()\n"," hit, mrr = [], []\n"," slices = test_data.generate_batch(model.batch_size)\n","\n"," for i in slices:\n"," targets, scores = forward(model, i, test_data)\n"," sub_scores = scores.topk(20)[1]\n"," sub_scores = to_cpu(sub_scores).detach().numpy()\n","\n"," for score, target, mask in zip(sub_scores, targets, test_data.mask):\n"," hit.append(np.isin(target - 1, score))\n"," if len(np.where(score == target - 1)[0]) == 0:\n"," mrr.append(0)\n"," else:\n"," mrr.append(1 / (np.where(score == target - 1)[0][0] + 1))\n","\n"," hit = np.mean(hit) * 100\n"," mrr = np.mean(mrr) * 100\n"," return hit, mrr\n","\n","\n","def get_pos(seq_len):\n"," return torch.arange(seq_len).unsqueeze(0)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"6TAZMGZV9Zvx"},"source":["def str2bool(v):\n"," return v.lower() in ('true')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"MO5yEePq9ZtY"},"source":["class Args():\n"," dataset = 'diginetica'\n"," batchSize = 50\n"," hiddenSize = 100 # Hidden state dimensions\n"," epoch = 1\n"," lr = 0.001\n"," lr_dc = 0.1 # Set the decay rate for Learning rate\n"," lr_dc_step = 3 # Steps for learning rate decay\n"," l2 = 1e-5 # Assign L2 Penalty\n"," step = 1\n"," patience = 10 # Used for early stopping criterion\n"," nonhybrid = True\n"," validation = True\n"," valid_portion = 0.1 # Portion of train set to split into val set\n"," defaults = True # Use default configuration\n","\n","args = Args()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":171,"referenced_widgets":["2532292a67b54901ae69cc3a270ec97d","e9a0bd285197471e955bcd3309f0f9f4","8e131d3b74d14eefbca397b93d49ccb1","ef31a250c90b4786b61b98125decd8d7","35b3c0271ceb47cf9edd6f70ffb34af3","51a1648191584f8192b5f8065e7bbd75","3c0b22f7b0b84ae8a2d53059b36bfc16","98d6ae7b07e34c2c89ba5a5fd6a1ed49","1403f29751ec4e9b83fbbd8c505f1b4a","93f9689f0c9d4e41b3f121896e983507","c6f37224738843ffa487cbda4c1bef90"]},"id":"uCfErpBd9ZrY","outputId":"3ce85cf0-20a8-4e4b-ae3b-4a9c69fa3e35"},"source":["model_save_dir = 'saved/'\n","writer = SummaryWriter(log_dir='with_pos/logs')\n","\n","train_data = pickle.load(open('train.txt', 'rb'))\n","test_data = pickle.load(open('test.txt', 'rb'))\n","\n","if args.validation:\n"," train_data, valid_data = split_validation(\n"," train_data, args.valid_portion)\n"," test_data = valid_data\n","else:\n"," print('Testing dataset used validation set')\n","\n","train_data = Dataset(train_data, shuffle=True)\n","test_data = Dataset(test_data, shuffle=False)\n","\n","n_node = 43098\n","\n","model = to_cuda(Attention_SessionGraph(args, n_node))\n","start = time.time()\n","best_result = [0, 0]\n","best_epoch = [0, 0]\n","bad_counter = 0\n","\n","for epoch in range(args.epoch):\n"," print('-' * 50)\n"," print('Epoch: ', epoch)\n"," hit, mrr = train_test(model, train_data, test_data)\n"," flag = 0\n","\n"," # Logging\n"," writer.add_scalar('epoch/recall', hit, epoch)\n"," writer.add_scalar('epoch/mrr', mrr, epoch)\n","\n"," flag = 0\n","\n"," if hit >= best_result[0]:\n"," best_result[0] = hit\n"," best_epoch[0] = epoch\n"," flag = 1\n"," torch.save(model, model_save_dir + 'epoch_' +\n"," str(epoch) + '_recall_' + str(hit) + '_.pt')\n"," if mrr >= best_result[1]:\n"," best_result[1] = mrr\n"," best_epoch[1] = epoch\n"," flag = 1\n"," torch.save(model, model_save_dir + 'epoch_' +\n"," str(epoch) + '_mrr_' + str(mrr) + '_.pt')\n","\n"," print('Best Result:')\n"," print('\\tRecall@20:\\t%.4f\\tMRR@20:\\t%.4f\\tEpoch:\\t%d,\\t%d' %\n"," (best_result[0], best_result[1], best_epoch[0], best_epoch[1]))\n","\n"," bad_counter += 1 - flag\n","\n"," if bad_counter >= args.patience:\n"," break\n","\n","print('-' * 50)\n","end = time.time()\n","print(\"Running time: %f seconds\" % (end - start))"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["--------------------------------------------------\n","Epoch: 0\n","Start training: 2021-11-24 18:07:07.107601\n"]},{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"2532292a67b54901ae69cc3a270ec97d","version_minor":0,"version_major":2},"text/plain":[" 0%| | 0/12951 [00:00