{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-26-bert2bert-seq-attack.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T355514%20%7C%20Black-box%20Attack%20on%20Sequential%20Recs%20-%20Bert2Bert%20Autoregressive%20Data%20Poisoning%20Attack%20Model.ipynb","timestamp":1644673010540}],"collapsed_sections":["xX1F6KrSHm4S","R-Ffm9maHuhy","vWM5QQBNHeDM","iNd5hIrHHcch","VDHlTWUwKHI-","au4jaGK2JWg7"],"authorship_tag":"ABX9TyNhrSZi7vOBShd7xIFkq9XZ"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"6aEqg_CUkSIC"},"source":["# Black-box Attack on Sequential Recs"]},{"cell_type":"markdown","metadata":{"id":"HVS5pGO1u4sm"},"source":["## Context"]},{"cell_type":"markdown","metadata":{"id":"Yl4V3kE1vTT-"},"source":["Abstract\n","\n","- We will investigate whether **model extraction** can be used to **\"steal\" the weights** of sequential recommender systems, and the potential threats posed to victims of such attacks. We will use **API-based model extraction method** via limited-budget synthetic data generation and knowledge distillation.\n","- Unlike many existing recommender attackers, which assume the dataset used to train the victim model is exposed to attackers, we will consider a **data-free setting**, where training data are not accessible.\n","- We will perform attacks in two stages. (1) **Model extraction**: given different types of synthetic data and their labels retrieved from a black-box recommender, we extract the black-box model to a white-box model via distillation. (2) **Downstream attacks**: we attack the black-box model with adversarial samples generated by the white-box recommender.\n","- Experiments show the effectiveness of this data-free model extraction and downstream attacks on sequential recommenders in both **profile pollution and data poisoning settings**."]},{"cell_type":"markdown","metadata":{"id":"prsP6rpJvZLI"},"source":["Scope\n","\n","We formalize the problem with the following settings to define the research scope:\n","\n","- Unknown Weights: Weights or metrics of the victim recommender are not provided.\n","- Data-Free: Original training data is not available, and item statistics (e.g. popularity) are not accessible.\n","- Limited API Queries: Given some input data, the victim model API provides a ranked list of items (e.g. top 100 recommended items). To avoid large numbers of API requests, we define budgets for the total number of victim model queries. Here, we treat each input sequence as one budget unit.\n","- (Partially) Known Architecture: Although weights are confidential, model architectures are known (e.g. we know the victim recommender is a transformer-based model). We also relax this assumption to cases where the white-box recommender uses a different sequential model architecture from the victim recommender."]},{"cell_type":"markdown","metadata":{"id":"ixtg30qFu7sP"},"source":["Tasks\n","\n","- [ ] Prepare data for black-box model training\n","- [ ] Build the black-box model\n","- [ ] Train the black-box model\n","- [ ] Black-box Model validation\n","- [ ] Convert it into API\n","- [ ] Start the API in the background\n","- [ ] Generate the data by querying the API\n","- [ ] Build the white-box model architecture\n","- [ ] Train the white-box model by knowledge distillation\n","- [ ] White-box model validation\n","- [ ] Build the profile pollution attack model\n","- [ ] Train the profile pollution attack model\n","- [ ] Validate the profile pollution attack model\n","- [ ] Build the data poisoning attack model\n","- [ ] Train the data poisoning attack model\n","- [ ] Validate the data poisoning attack model"]},{"cell_type":"markdown","metadata":{"id":"OjJ_Kn5yv5H1"},"source":["## Setup"]},{"cell_type":"markdown","metadata":{"id":"LnAyk6ROv74A"},"source":["### Installation"]},{"cell_type":"code","metadata":{"id":"zFEU2MP98KFc"},"source":["!apt-get install libarchive-dev\n","!pip install faiss-cpu --no-cache\n","!apt-get install libomp-dev"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"jy1OEmV2L54t"},"source":["!pip install wget\n","!pip install libarchive"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"VS04Hk30v6yR"},"source":["### Imports"]},{"cell_type":"code","metadata":{"id":"ydRh8jY6La_f"},"source":["import argparse\n","import torch\n","import pickle\n","import random\n","import shutil\n","import tempfile\n","import os\n","from pathlib import Path\n","import gzip\n","import numpy as np\n","import pandas as pd\n","from tqdm import tqdm\n","tqdm.pandas()\n","from abc import *\n","import wget\n","import numpy as np\n","import pandas as pd\n","from tqdm import tqdm\n","\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","from torch.autograd import Variable\n","# from torch.autograd.gradcheck import zero_gradients\n","from torch.optim.lr_scheduler import LambdaLR\n","from torch.utils.tensorboard import SummaryWriter\n","from tqdm import tqdm\n","\n","import json\n","import faiss\n","import numpy as np\n","from abc import *\n","from pathlib import Path\n","\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import torch.nn.functional as F\n","from torch.optim.lr_scheduler import LambdaLR\n","from torch.utils.tensorboard import SummaryWriter\n","from tqdm import tqdm\n","\n","import json\n","import faiss\n","import numpy as np\n","from abc import *\n","from pathlib import Path\n","\n","import json\n","import os\n","import pprint as pp\n","import random\n","from datetime import date\n","from pathlib import Path\n","\n","import numpy as np\n","import torch\n","import torch.nn.functional as F\n","import torch.backends.cudnn as cudnn\n","from torch import optim as optim\n","\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import torch.nn.functional as F\n","from torch.optim.lr_scheduler import LambdaLR\n","from torch.utils.tensorboard import SummaryWriter\n","from tqdm import tqdm\n","\n","import json\n","import faiss\n","import numpy as np\n","from abc import *\n","from pathlib import Path\n","\n","from pathlib import Path\n","import zipfile\n","import libarchive\n","import sys\n","from datetime import date\n","from pathlib import Path\n","import pickle\n","import shutil\n","import tempfile\n","import os\n","\n","from tqdm import trange\n","from collections import Counter\n","import numpy as np\n","\n","import numpy as np\n","import pandas as pd\n","from tqdm import tqdm\n","tqdm.pandas()\n","\n","import math\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","from torch import nn as nn\n","import math\n","import torch\n","import random\n","import torch.utils.data as data_utils\n"," \n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import math\n","from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3VqyMbwb8U7T"},"source":["### Config"]},{"cell_type":"code","metadata":{"id":"-4-DuZc58eDc"},"source":["RAW_DATASET_ROOT_FOLDER = 'data'\n","GEN_DATASET_ROOT_FOLDER = 'gen_data'\n","\n","STATE_DICT_KEY = 'model_state_dict'\n","OPTIMIZER_STATE_DICT_KEY = 'optimizer_state_dict'"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"s5R3xlOrLLlT"},"source":["def fix_random_seed_as(random_seed):\n"," random.seed(random_seed)\n"," np.random.seed(random_seed)\n"," torch.manual_seed(random_seed)\n"," torch.backends.cudnn.deterministic = True\n"," torch.backends.cudnn.benchmark = False\n","\n","\n","def set_template(args):\n"," args.min_uc = 5\n"," args.min_sc = 5\n"," args.split = 'leave_one_out'\n"," # dataset_code = {'1': 'ml-1m', '20': 'ml-20m', 'b': 'beauty', 'bd': 'beauty_dense' , 'g': 'games', 's': 'steam', 'y': 'yoochoose'}\n"," # args.dataset_code = dataset_code[input('Input 1 / 20 for movielens, b for beauty, bd for dense beauty, g for games, s for steam and y for yoochoose: ')]\n"," args.dataset_code = 'ml-1m'\n"," if args.dataset_code == 'ml-1m':\n"," args.sliding_window_size = 0.5\n"," args.bert_hidden_units = 64\n"," args.bert_dropout = 0.1\n"," args.bert_attn_dropout = 0.1\n"," args.bert_max_len = 200\n"," args.bert_mask_prob = 0.2\n"," args.bert_max_predictions = 40\n"," elif args.dataset_code == 'ml-20m':\n"," args.sliding_window_size = 0.5\n"," args.bert_hidden_units = 64\n"," args.bert_dropout = 0.1\n"," args.bert_attn_dropout = 0.1\n"," args.bert_max_len = 200\n"," args.bert_mask_prob = 0.2\n"," args.bert_max_predictions = 20\n"," elif args.dataset_code in ['beauty', 'beauty_dense']:\n"," args.sliding_window_size = 0.5\n"," args.bert_hidden_units = 64\n"," args.bert_dropout = 0.5\n"," args.bert_attn_dropout = 0.2\n"," args.bert_max_len = 50\n"," args.bert_mask_prob = 0.6\n"," args.bert_max_predictions = 30\n"," elif args.dataset_code == 'games':\n"," args.sliding_window_size = 0.5\n"," args.bert_hidden_units = 64\n"," args.bert_dropout = 0.5\n"," args.bert_attn_dropout = 0.5\n"," args.bert_max_len = 50\n"," args.bert_mask_prob = 0.5\n"," args.bert_max_predictions = 25\n"," elif args.dataset_code == 'steam':\n"," args.sliding_window_size = 0.5\n"," args.bert_hidden_units = 64\n"," args.bert_dropout = 0.2\n"," args.bert_attn_dropout = 0.2\n"," args.bert_max_len = 50\n"," args.bert_mask_prob = 0.4\n"," args.bert_max_predictions = 20\n"," elif args.dataset_code == 'yoochoose':\n"," args.sliding_window_size = 0.5\n"," args.bert_hidden_units = 256\n"," args.bert_dropout = 0.2\n"," args.bert_attn_dropout = 0.2\n"," args.bert_max_len = 50\n"," args.bert_mask_prob = 0.4\n"," args.bert_max_predictions = 20\n","\n"," batch = 128\n"," args.train_batch_size = batch\n"," args.val_batch_size = batch\n"," args.test_batch_size = batch\n"," args.train_negative_sampler_code = 'random'\n"," args.train_negative_sample_size = 0\n"," args.train_negative_sampling_seed = 0\n"," args.test_negative_sampler_code = 'random'\n"," args.test_negative_sample_size = 100\n"," args.test_negative_sampling_seed = 98765\n","\n"," # model_codes = {'b': 'bert', 's':'sas', 'n':'narm'}\n"," # args.model_code = model_codes[input('Input model code, b for BERT, s for SASRec and n for NARM: ')]\n"," args.model_code = 'bert'\n","\n"," if torch.cuda.is_available():\n"," # args.device = 'cuda:' + input('Input GPU ID: ')\n"," args.device = 'cuda:0'\n"," else:\n"," args.device = 'cpu'\n"," args.optimizer = 'AdamW'\n"," args.lr = 0.001\n"," args.weight_decay = 0.01\n"," args.enable_lr_schedule = True\n"," args.decay_step = 10000\n"," args.gamma = 1.\n"," args.enable_lr_warmup = False\n"," args.warmup_steps = 100\n"," args.num_epochs = 1000\n","\n"," args.metric_ks = [1, 5, 10]\n"," args.best_metric = 'NDCG@10'\n"," args.model_init_seed = 98765\n"," args.bert_num_blocks = 2\n"," args.bert_num_heads = 2\n"," args.bert_head_size = None\n","\n","\n","parser = argparse.ArgumentParser()\n","\n","################\n","# Dataset\n","################\n","parser.add_argument('--dataset_code', type=str, default='ml-1m', choices=DATASETS.keys())\n","parser.add_argument('--min_rating', type=int, default=0)\n","parser.add_argument('--min_uc', type=int, default=5)\n","parser.add_argument('--min_sc', type=int, default=5)\n","parser.add_argument('--split', type=str, default='leave_one_out')\n","parser.add_argument('--dataset_split_seed', type=int, default=0)\n","\n","################\n","# Dataloader\n","################\n","parser.add_argument('--dataloader_random_seed', type=float, default=0)\n","parser.add_argument('--train_batch_size', type=int, default=64)\n","parser.add_argument('--val_batch_size', type=int, default=64)\n","parser.add_argument('--test_batch_size', type=int, default=64)\n","parser.add_argument('--sliding_window_size', type=float, default=0.5)\n","\n","################\n","# NegativeSampler\n","################\n","parser.add_argument('--train_negative_sampler_code', type=str, default='random', choices=['popular', 'random'])\n","parser.add_argument('--train_negative_sample_size', type=int, default=0)\n","parser.add_argument('--train_negative_sampling_seed', type=int, default=0)\n","parser.add_argument('--test_negative_sampler_code', type=str, default='random', choices=['popular', 'random'])\n","parser.add_argument('--test_negative_sample_size', type=int, default=100)\n","parser.add_argument('--test_negative_sampling_seed', type=int, default=0)\n","\n","################\n","# Trainer\n","################\n","# device #\n","parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'cuda'])\n","parser.add_argument('--num_gpu', type=int, default=1)\n","# optimizer & lr#\n","parser.add_argument('--optimizer', type=str, default='AdamW', choices=['AdamW', 'Adam', 'SGD'])\n","parser.add_argument('--weight_decay', type=float, default=0)\n","parser.add_argument('--adam_epsilon', type=float, default=1e-9)\n","parser.add_argument('--momentum', type=float, default=None)\n","parser.add_argument('--lr', type=float, default=0.001)\n","parser.add_argument('--enable_lr_schedule', type=bool, default=True)\n","parser.add_argument('--decay_step', type=int, default=100)\n","parser.add_argument('--gamma', type=float, default=1)\n","parser.add_argument('--enable_lr_warmup', type=bool, default=True)\n","parser.add_argument('--warmup_steps', type=int, default=100)\n","# epochs #\n","parser.add_argument('--num_epochs', type=int, default=100)\n","# logger #\n","parser.add_argument('--log_period_as_iter', type=int, default=12800)\n","# evaluation #\n","parser.add_argument('--metric_ks', nargs='+', type=int, default=[1, 5, 10, 20])\n","parser.add_argument('--best_metric', type=str, default='NDCG@10')\n","\n","################\n","# Model\n","################\n","parser.add_argument('--model_code', type=str, default='bert', choices=['bert', 'sas', 'narm'])\n","# BERT specs, used for SASRec and NARM as well #\n","parser.add_argument('--bert_max_len', type=int, default=None)\n","parser.add_argument('--bert_hidden_units', type=int, default=64)\n","parser.add_argument('--bert_num_blocks', type=int, default=2)\n","parser.add_argument('--bert_num_heads', type=int, default=2)\n","parser.add_argument('--bert_head_size', type=int, default=32)\n","parser.add_argument('--bert_dropout', type=float, default=0.1)\n","parser.add_argument('--bert_attn_dropout', type=float, default=0.1)\n","parser.add_argument('--bert_mask_prob', type=float, default=0.2)\n","\n","################\n","# Distillation & Retraining\n","################\n","parser.add_argument('--num_generated_seqs', type=int, default=3000)\n","parser.add_argument('--num_original_seqs', type=int, default=0)\n","parser.add_argument('--num_poisoned_seqs', type=int, default=100)\n","parser.add_argument('--num_alter_items', type=int, default=10)\n","\n","################\n","\n","\n","args = parser.parse_args(args={})"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"YL91dkqd8WKJ"},"source":["### Utils"]},{"cell_type":"code","metadata":{"id":"pLqsGqZb8GeN"},"source":["def download(url, savepath):\n"," wget.download(url, str(savepath))\n"," print()\n","\n","\n","def unzip(zippath, savepath):\n"," print(\"Extracting data...\")\n"," zip = zipfile.ZipFile(zippath)\n"," zip.extractall(savepath)\n"," zip.close()\n","\n","\n","def unzip7z(filename):\n"," print(\"Extracting data...\")\n"," libarchive.extract_file(filename)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"dak-B2wHJuE6"},"source":["def ndcg(scores, labels, k):\n"," scores = scores.cpu()\n"," labels = labels.cpu()\n"," rank = (-scores).argsort(dim=1)\n"," cut = rank[:, :k]\n"," hits = labels.gather(1, cut)\n"," position = torch.arange(2, 2+k)\n"," weights = 1 / torch.log2(position.float())\n"," dcg = (hits.float() * weights).sum(1)\n"," idcg = torch.Tensor([weights[:min(int(n), k)].sum()\n"," for n in labels.sum(1)])\n"," ndcg = dcg / idcg\n"," return ndcg.mean()\n","\n","\n","def recalls_and_ndcgs_for_ks(scores, labels, ks):\n"," metrics = {}\n","\n"," scores = scores\n"," labels = labels\n"," answer_count = labels.sum(1)\n","\n"," labels_float = labels.float()\n"," rank = (-scores).argsort(dim=1)\n","\n"," cut = rank\n"," for k in sorted(ks, reverse=True):\n"," cut = cut[:, :k]\n"," hits = labels_float.gather(1, cut)\n"," metrics['Recall@%d' % k] = \\\n"," (hits.sum(1) / torch.min(torch.Tensor([k]).to(\n"," labels.device), labels.sum(1).float())).mean().cpu().item()\n","\n"," position = torch.arange(2, 2+k)\n"," weights = 1 / torch.log2(position.float())\n"," dcg = (hits * weights.to(hits.device)).sum(1)\n"," idcg = torch.Tensor([weights[:min(int(n), k)].sum()\n"," for n in answer_count]).to(dcg.device)\n"," ndcg = (dcg / idcg).mean()\n"," metrics['NDCG@%d' % k] = ndcg.cpu().item()\n"," return metrics\n","\n","\n","def em_and_agreement(scores_rank, labels_rank):\n"," em = (scores_rank == labels_rank).float().mean()\n"," temp = np.hstack((scores_rank.numpy(), labels_rank.numpy()))\n"," temp = np.sort(temp, axis=1)\n"," agreement = np.mean(np.sum(temp[:, 1:] == temp[:, :-1], axis=1))\n"," return em, agreement\n","\n","\n","def kl_agreements_and_intersctions_for_ks(scores, soft_labels, ks, k_kl=100):\n"," metrics = {}\n"," scores = scores.cpu()\n"," soft_labels = soft_labels.cpu()\n"," scores_rank = (-scores).argsort(dim=1)\n"," labels_rank = (-soft_labels).argsort(dim=1)\n","\n"," top_kl_scores = F.log_softmax(scores.gather(1, labels_rank[:, :k_kl]), dim=-1)\n"," top_kl_labels = F.softmax(soft_labels.gather(1, labels_rank[:, :k_kl]), dim=-1)\n"," kl = F.kl_div(top_kl_scores, top_kl_labels, reduction='batchmean')\n"," metrics['KL-Div'] = kl.item()\n"," for k in sorted(ks, reverse=True):\n"," em, agreement = em_and_agreement(scores_rank[:, :k], labels_rank[:, :k])\n"," metrics['EM@%d' % k] = em.item()\n"," metrics['Agr@%d' % k] = (agreement / k).item()\n"," return metrics\n","\n","\n","class AverageMeterSet(object):\n"," def __init__(self, meters=None):\n"," self.meters = meters if meters else {}\n","\n"," def __getitem__(self, key):\n"," if key not in self.meters:\n"," meter = AverageMeter()\n"," meter.update(0)\n"," return meter\n"," return self.meters[key]\n","\n"," def update(self, name, value, n=1):\n"," if name not in self.meters:\n"," self.meters[name] = AverageMeter()\n"," self.meters[name].update(value, n)\n","\n"," def reset(self):\n"," for meter in self.meters.values():\n"," meter.reset()\n","\n"," def values(self, format_string='{}'):\n"," return {format_string.format(name): meter.val for name, meter in self.meters.items()}\n","\n"," def averages(self, format_string='{}'):\n"," return {format_string.format(name): meter.avg for name, meter in self.meters.items()}\n","\n"," def sums(self, format_string='{}'):\n"," return {format_string.format(name): meter.sum for name, meter in self.meters.items()}\n","\n"," def counts(self, format_string='{}'):\n"," return {format_string.format(name): meter.count for name, meter in self.meters.items()}\n","\n","\n","class AverageMeter(object):\n"," \"\"\"Computes and stores the average and current value\"\"\"\n","\n"," def __init__(self):\n"," self.val = 0\n"," self.avg = 0\n"," self.sum = 0\n"," self.count = 0\n","\n"," def reset(self):\n"," self.val = 0\n"," self.avg = 0\n"," self.sum = 0\n"," self.count = 0\n","\n"," def update(self, val, n=1):\n"," self.val = val\n"," self.sum += val\n"," self.count += n\n"," self.avg = self.sum / self.count\n","\n"," def __format__(self, format):\n"," return \"{self.val:{format}} ({self.avg:{format}})\".format(self=self, format=format)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"INd1BUQLJ0gn"},"source":["def save_state_dict(state_dict, path, filename):\n"," torch.save(state_dict, os.path.join(path, filename))\n","\n","\n","class LoggerService(object):\n"," def __init__(self, train_loggers=None, val_loggers=None):\n"," self.train_loggers = train_loggers if train_loggers else []\n"," self.val_loggers = val_loggers if val_loggers else []\n","\n"," def complete(self, log_data):\n"," for logger in self.train_loggers:\n"," logger.complete(**log_data)\n"," for logger in self.val_loggers:\n"," logger.complete(**log_data)\n","\n"," def log_train(self, log_data):\n"," for logger in self.train_loggers:\n"," logger.log(**log_data)\n","\n"," def log_val(self, log_data):\n"," for logger in self.val_loggers:\n"," logger.log(**log_data)\n","\n","\n","class AbstractBaseLogger(metaclass=ABCMeta):\n"," @abstractmethod\n"," def log(self, *args, **kwargs):\n"," raise NotImplementedError\n","\n"," def complete(self, *args, **kwargs):\n"," pass\n","\n","\n","class RecentModelLogger(AbstractBaseLogger):\n"," def __init__(self, checkpoint_path, filename='checkpoint-recent.pth'):\n"," self.checkpoint_path = checkpoint_path\n"," if not os.path.exists(self.checkpoint_path):\n"," os.mkdir(self.checkpoint_path)\n"," self.recent_epoch = None\n"," self.filename = filename\n","\n"," def log(self, *args, **kwargs):\n"," epoch = kwargs['epoch']\n","\n"," if self.recent_epoch != epoch:\n"," self.recent_epoch = epoch\n"," state_dict = kwargs['state_dict']\n"," state_dict['epoch'] = kwargs['epoch']\n"," save_state_dict(state_dict, self.checkpoint_path, self.filename)\n","\n"," def complete(self, *args, **kwargs):\n"," save_state_dict(kwargs['state_dict'],\n"," self.checkpoint_path, self.filename + '.final')\n","\n","\n","class BestModelLogger(AbstractBaseLogger):\n"," def __init__(self, checkpoint_path, metric_key='mean_iou', filename='best_acc_model.pth'):\n"," self.checkpoint_path = checkpoint_path\n"," if not os.path.exists(self.checkpoint_path):\n"," os.mkdir(self.checkpoint_path)\n","\n"," self.best_metric = 0.\n"," self.metric_key = metric_key\n"," self.filename = filename\n","\n"," def log(self, *args, **kwargs):\n"," current_metric = kwargs[self.metric_key]\n"," if self.best_metric < current_metric:\n"," print(\"Update Best {} Model at {}\".format(\n"," self.metric_key, kwargs['epoch']))\n"," self.best_metric = current_metric\n"," save_state_dict(kwargs['state_dict'],\n"," self.checkpoint_path, self.filename)\n","\n","\n","class MetricGraphPrinter(AbstractBaseLogger):\n"," def __init__(self, writer, key='train_loss', graph_name='Train Loss', group_name='metric'):\n"," self.key = key\n"," self.graph_label = graph_name\n"," self.group_name = group_name\n"," self.writer = writer\n","\n"," def log(self, *args, **kwargs):\n"," if self.key in kwargs:\n"," self.writer.add_scalar(\n"," self.group_name + '/' + self.graph_label, kwargs[self.key], kwargs['accum_iter'])\n"," else:\n"," self.writer.add_scalar(\n"," self.group_name + '/' + self.graph_label, 0, kwargs['accum_iter'])\n","\n"," def complete(self, *args, **kwargs):\n"," self.writer.close()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jp0R4oPgvBCO"},"source":["## Data"]},{"cell_type":"markdown","metadata":{"id":"UO32C7Lg7rQy"},"source":["### Datasets"]},{"cell_type":"markdown","metadata":{"id":"gWP-zGf4I0Cj"},"source":["#### Base"]},{"cell_type":"code","metadata":{"id":"q2fJSw4s74S0"},"source":["class AbstractDataset(metaclass=ABCMeta):\n"," def __init__(self, args):\n"," self.args = args\n"," self.min_rating = args.min_rating\n"," self.min_uc = args.min_uc\n"," self.min_sc = args.min_sc\n"," self.split = args.split\n","\n"," assert self.min_uc >= 2, 'Need at least 2 ratings per user for validation and test'\n","\n"," @classmethod\n"," @abstractmethod\n"," def code(cls):\n"," pass\n","\n"," @classmethod\n"," def raw_code(cls):\n"," return cls.code()\n","\n"," @classmethod\n"," def zip_file_content_is_folder(cls):\n"," return True\n","\n"," @classmethod\n"," def all_raw_file_names(cls):\n"," return []\n","\n"," @classmethod\n"," @abstractmethod\n"," def url(cls):\n"," pass\n","\n"," @classmethod\n"," @abstractmethod\n"," def is_zipfile(cls):\n"," pass\n","\n"," @classmethod\n"," @abstractmethod\n"," def is_7zfile(cls):\n"," pass\n","\n"," @abstractmethod\n"," def preprocess(self):\n"," pass\n","\n"," @abstractmethod\n"," def load_ratings_df(self):\n"," pass\n","\n"," @abstractmethod\n"," def maybe_download_raw_dataset(self):\n"," pass\n","\n"," def load_dataset(self):\n"," self.preprocess()\n"," dataset_path = self._get_preprocessed_dataset_path()\n"," dataset = pickle.load(dataset_path.open('rb'))\n"," return dataset\n","\n"," def filter_triplets(self, df):\n"," print('Filtering triplets')\n"," if self.min_sc > 0:\n"," item_sizes = df.groupby('sid').size()\n"," good_items = item_sizes.index[item_sizes >= self.min_sc]\n"," df = df[df['sid'].isin(good_items)]\n","\n"," if self.min_uc > 0:\n"," user_sizes = df.groupby('uid').size()\n"," good_users = user_sizes.index[user_sizes >= self.min_uc]\n"," df = df[df['uid'].isin(good_users)]\n"," return df\n","\n"," def densify_index(self, df):\n"," print('Densifying index')\n"," umap = {u: i for i, u in enumerate(set(df['uid']), start=1)}\n"," smap = {s: i for i, s in enumerate(set(df['sid']), start=1)}\n"," df['uid'] = df['uid'].map(umap)\n"," df['sid'] = df['sid'].map(smap)\n"," return df, umap, smap\n","\n"," def split_df(self, df, user_count):\n"," if self.args.split == 'leave_one_out':\n"," print('Splitting')\n"," user_group = df.groupby('uid')\n"," user2items = user_group.progress_apply(\n"," lambda d: list(d.sort_values(by=['timestamp', 'sid'])['sid']))\n"," train, val, test = {}, {}, {}\n"," for i in range(user_count):\n"," user = i + 1\n"," items = user2items[user]\n"," train[user], val[user], test[user] = items[:-2], items[-2:-1], items[-1:]\n"," return train, val, test\n"," else:\n"," raise NotImplementedError\n","\n"," def _get_rawdata_root_path(self):\n"," return Path(RAW_DATASET_ROOT_FOLDER)\n","\n"," def _get_rawdata_folder_path(self):\n"," root = self._get_rawdata_root_path()\n"," return root.joinpath(self.raw_code())\n","\n"," def _get_preprocessed_root_path(self):\n"," root = self._get_rawdata_root_path()\n"," return root.joinpath('preprocessed')\n","\n"," def _get_preprocessed_folder_path(self):\n"," preprocessed_root = self._get_preprocessed_root_path()\n"," folder_name = '{}_min_rating{}-min_uc{}-min_sc{}-split{}' \\\n"," .format(self.code(), self.min_rating, self.min_uc, self.min_sc, self.split)\n"," return preprocessed_root.joinpath(folder_name)\n","\n"," def _get_preprocessed_dataset_path(self):\n"," folder = self._get_preprocessed_folder_path()\n"," return folder.joinpath('dataset.pkl')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3d6oNy-7IyF9"},"source":["#### MovieLens"]},{"cell_type":"code","metadata":{"id":"CTp88roj8pEO"},"source":["class ML1MDataset(AbstractDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'ml-1m'\n","\n"," @classmethod\n"," def url(cls):\n"," return 'http://files.grouplens.org/datasets/movielens/ml-1m.zip'\n","\n"," @classmethod\n"," def zip_file_content_is_folder(cls):\n"," return True\n","\n"," @classmethod\n"," def all_raw_file_names(cls):\n"," return ['README',\n"," 'movies.dat',\n"," 'ratings.dat',\n"," 'users.dat']\n","\n"," @classmethod\n"," def is_zipfile(cls):\n"," return True\n","\n"," @classmethod\n"," def is_7zfile(cls):\n"," return False\n","\n"," def maybe_download_raw_dataset(self):\n"," folder_path = self._get_rawdata_folder_path()\n"," if folder_path.is_dir() and\\\n"," all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):\n"," print('Raw data already exists. Skip downloading')\n"," return\n"," \n"," print(\"Raw file doesn't exist. Downloading...\")\n"," tmproot = Path(tempfile.mkdtemp())\n"," tmpzip = tmproot.joinpath('file.zip')\n"," tmpfolder = tmproot.joinpath('folder')\n"," download(self.url(), tmpzip)\n"," unzip(tmpzip, tmpfolder)\n"," if self.zip_file_content_is_folder():\n"," tmpfolder = tmpfolder.joinpath(os.listdir(tmpfolder)[0])\n"," shutil.move(tmpfolder, folder_path)\n"," shutil.rmtree(tmproot)\n"," print()\n","\n"," def preprocess(self):\n"," dataset_path = self._get_preprocessed_dataset_path()\n"," if dataset_path.is_file():\n"," print('Already preprocessed. Skip preprocessing')\n"," return\n"," if not dataset_path.parent.is_dir():\n"," dataset_path.parent.mkdir(parents=True)\n"," self.maybe_download_raw_dataset()\n"," df = self.load_ratings_df()\n"," df = self.filter_triplets(df)\n"," df, umap, smap = self.densify_index(df)\n"," train, val, test = self.split_df(df, len(umap))\n"," dataset = {'train': train,\n"," 'val': val,\n"," 'test': test,\n"," 'umap': umap,\n"," 'smap': smap}\n"," with dataset_path.open('wb') as f:\n"," pickle.dump(dataset, f)\n","\n"," def load_ratings_df(self):\n"," folder_path = self._get_rawdata_folder_path()\n"," file_path = folder_path.joinpath('ratings.dat')\n"," df = pd.read_csv(file_path, sep='::', header=None)\n"," df.columns = ['uid', 'sid', 'rating', 'timestamp']\n"," return df"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lB1Uk7XeIwaJ"},"source":["#### Factory"]},{"cell_type":"code","metadata":{"id":"5YTdFIux7s0m"},"source":["DATASETS = {\n"," ML1MDataset.code(): ML1MDataset\n","}\n","\n","\n","def dataset_factory(args):\n"," dataset = DATASETS[args.dataset_code]\n"," return dataset(args)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hIG_qb34KmoE"},"source":["#### Distillation"]},{"cell_type":"code","metadata":{"id":"YNteeo9cKoYR"},"source":["class AbstractDistillationDataset(metaclass=ABCMeta):\n"," def __init__(self, args, bb_model_code, mode='random'):\n"," self.args = args\n"," self.bb_model_code = bb_model_code\n"," self.mode = mode\n"," assert self.mode in ['random', 'autoregressive', 'adversarial']\n","\n"," @classmethod\n"," @abstractmethod\n"," def code(cls):\n"," pass\n","\n"," @classmethod\n"," def raw_code(cls):\n"," return cls.code()\n","\n"," def check_data_present(self):\n"," dataset_path = self._get_distillation_dataset_path()\n"," return dataset_path.is_file()\n","\n"," def load_dataset(self):\n"," dataset_path = self._get_distillation_dataset_path()\n"," if not dataset_path.is_file():\n"," print('Dataset not found, please generate distillation dataset first')\n"," return\n"," dataset = pickle.load(dataset_path.open('rb'))\n"," return dataset\n","\n"," def save_dataset(self, tokens, logits, candidates):\n"," dataset_path = self._get_distillation_dataset_path()\n"," if not dataset_path.parent.is_dir():\n"," dataset_path.parent.mkdir(parents=True)\n","\n"," dataset = {'seqs': tokens,\n"," 'logits': logits,\n"," 'candidates': candidates}\n"," \n"," with dataset_path.open('wb') as f:\n"," pickle.dump(dataset, f)\n","\n"," def _get_rawdata_root_path(self):\n"," return Path(GEN_DATASET_ROOT_FOLDER)\n","\n"," def _get_folder_path(self):\n"," root = self._get_rawdata_root_path()\n"," return root.joinpath(self.raw_code())\n","\n"," def _get_subfolder_path(self):\n"," root = self._get_folder_path()\n"," return root.joinpath(self.bb_model_code + '_' + str(self.args.num_generated_seqs))\n","\n"," def _get_distillation_dataset_path(self):\n"," folder = self._get_subfolder_path()\n"," return folder.joinpath(self.mode + '_dataset.pkl')\n","\n","\n","class ML1MDistillationDataset(AbstractDistillationDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'ml-1m'\n","\n","\n","class ML20MDistillationDataset(AbstractDistillationDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'ml-20m'\n","\n","\n","class BeautyDistillationDataset(AbstractDistillationDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'beauty'\n","\n","class BeautyDenseDistillationDataset(AbstractDistillationDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'beauty_dense'\n","\n","\n","class GamesDistillationDataset(AbstractDistillationDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'games'\n","\n","\n","class SteamDistillationDataset(AbstractDistillationDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'steam'\n","\n","\n","class YooChooseDistillationDataset(AbstractDistillationDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'yoochoose'"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nqi-bu2RH6Oa"},"source":["### Negative Samplers"]},{"cell_type":"markdown","metadata":{"id":"3x3qXHsWI7I8"},"source":["#### Base"]},{"cell_type":"code","metadata":{"id":"hrSZ7Yu1I676"},"source":["class AbstractNegativeSampler(metaclass=ABCMeta):\n"," def __init__(self, train, val, test, user_count, item_count, sample_size, seed, flag, save_folder):\n"," self.train = train\n"," self.val = val\n"," self.test = test\n"," self.user_count = user_count\n"," self.item_count = item_count\n"," self.sample_size = sample_size\n"," self.seed = seed\n"," self.flag = flag\n"," self.save_folder = save_folder\n","\n"," @classmethod\n"," @abstractmethod\n"," def code(cls):\n"," pass\n","\n"," @abstractmethod\n"," def generate_negative_samples(self):\n"," pass\n","\n"," def get_negative_samples(self):\n"," savefile_path = self._get_save_path()\n"," if savefile_path.is_file():\n"," print('Negatives samples exist. Loading.')\n"," seen_samples, negative_samples = pickle.load(savefile_path.open('rb'))\n"," return seen_samples, negative_samples\n"," print(\"Negative samples don't exist. Generating.\")\n"," seen_samples, negative_samples = self.generate_negative_samples()\n"," with savefile_path.open('wb') as f:\n"," pickle.dump([seen_samples, negative_samples], f)\n"," return seen_samples, negative_samples\n","\n"," def _get_save_path(self):\n"," folder = Path(self.save_folder)\n"," filename = '{}-sample_size{}-seed{}-{}.pkl'.format(\n"," self.code(), self.sample_size, self.seed, self.flag)\n"," return folder.joinpath(filename)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PBvS1_vQI6uJ"},"source":["#### Random"]},{"cell_type":"code","metadata":{"id":"E0Sd7vHbI6i_"},"source":["class RandomNegativeSampler(AbstractNegativeSampler):\n"," @classmethod\n"," def code(cls):\n"," return 'random'\n","\n"," def generate_negative_samples(self):\n"," assert self.seed is not None, 'Specify seed for random sampling'\n"," np.random.seed(self.seed)\n"," num_samples = 2 * self.user_count * self.sample_size\n"," all_samples = np.random.choice(self.item_count, num_samples) + 1\n","\n"," seen_samples = {}\n"," negative_samples = {}\n"," print('Sampling negative items randomly...')\n"," j = 0\n"," for i in trange(self.user_count):\n"," user = i + 1\n"," seen = set(self.train[user])\n"," seen.update(self.val[user])\n"," seen.update(self.test[user])\n"," seen_samples[user] = seen\n","\n"," samples = []\n"," while len(samples) < self.sample_size:\n"," item = all_samples[j % num_samples]\n"," j += 1\n"," if item in seen or item in samples:\n"," continue\n"," samples.append(item)\n"," negative_samples[user] = samples\n","\n"," return seen_samples, negative_samples"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"d3WjrRavI6XQ"},"source":["#### Popular"]},{"cell_type":"code","metadata":{"id":"461MeUO3I57b"},"source":["class PopularNegativeSampler(AbstractNegativeSampler):\n"," @classmethod\n"," def code(cls):\n"," return 'popular'\n","\n"," def generate_negative_samples(self):\n"," assert self.seed is not None, 'Specify seed for random sampling'\n"," np.random.seed(self.seed)\n"," popularity = self.items_by_popularity()\n"," items = list(popularity.keys())\n"," total = 0\n"," for i in range(len(items)):\n"," total += popularity[items[i]]\n"," for i in range(len(items)):\n"," popularity[items[i]] /= total\n"," probs = list(popularity.values())\n"," num_samples = 2 * self.user_count * self.sample_size\n"," all_samples = np.random.choice(items, num_samples, p=probs)\n","\n"," seen_samples = {}\n"," negative_samples = {}\n"," print('Sampling negative items by popularity...')\n"," j = 0\n"," for i in trange(self.user_count):\n"," user = i + 1\n"," seen = set(self.train[user])\n"," seen.update(self.val[user])\n"," seen.update(self.test[user])\n"," seen_samples[user] = seen\n","\n"," samples = []\n"," while len(samples) < self.sample_size:\n"," item = all_samples[j % num_samples]\n"," j += 1\n"," if item in seen or item in samples:\n"," continue\n"," samples.append(item)\n"," negative_samples[user] = samples\n","\n"," return seen_samples, negative_samples\n","\n"," def items_by_popularity(self):\n"," popularity = Counter()\n"," self.users = sorted(self.train.keys())\n"," for user in self.users:\n"," popularity.update(self.train[user])\n"," popularity.update(self.val[user])\n"," popularity.update(self.test[user])\n","\n"," popularity = dict(popularity)\n"," popularity = {k: v for k, v in sorted(popularity.items(), key=lambda item: item[1], reverse=True)}\n"," return popularity"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"KkUM7O88IuxA"},"source":["#### Factory"]},{"cell_type":"code","metadata":{"id":"9DwpDAJvItab"},"source":["NEGATIVE_SAMPLERS = {\n"," PopularNegativeSampler.code(): PopularNegativeSampler,\n"," RandomNegativeSampler.code(): RandomNegativeSampler,\n","}\n","\n","\n","def negative_sampler_factory(code, train, val, test, user_count, item_count, sample_size, seed, flag, save_folder):\n"," negative_sampler = NEGATIVE_SAMPLERS[code]\n"," return negative_sampler(train, val, test, user_count, item_count, sample_size, seed, flag, save_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"fk9fabJWH2Eg"},"source":["### Dataloaders"]},{"cell_type":"markdown","metadata":{"id":"X34jGB6oIoG8"},"source":["#### Base"]},{"cell_type":"code","metadata":{"id":"PfbobsB3IMms"},"source":["class AbstractDataloader(metaclass=ABCMeta):\n"," def __init__(self, args, dataset):\n"," self.args = args\n"," self.rng = random.Random()\n"," self.save_folder = dataset._get_preprocessed_folder_path()\n"," dataset = dataset.load_dataset()\n"," self.train = dataset['train']\n"," self.val = dataset['val']\n"," self.test = dataset['test']\n"," self.umap = dataset['umap']\n"," self.smap = dataset['smap']\n"," self.user_count = len(self.umap)\n"," self.item_count = len(self.smap)\n","\n"," @classmethod\n"," @abstractmethod\n"," def code(cls):\n"," pass\n","\n"," @abstractmethod\n"," def get_pytorch_dataloaders(self):\n"," pass"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"U-WQnwWWImo0"},"source":["#### RNN"]},{"cell_type":"code","metadata":{"id":"LGVIx9WZIa7C"},"source":["class RNNDataloader():\n"," def __init__(self, args, dataset):\n"," self.args = args\n"," self.rng = random.Random()\n"," self.save_folder = dataset._get_preprocessed_folder_path()\n"," dataset = dataset.load_dataset()\n"," self.train = dataset['train']\n"," self.val = dataset['val']\n"," self.test = dataset['test']\n"," self.umap = dataset['umap']\n"," self.smap = dataset['smap']\n"," self.user_count = len(self.umap)\n"," self.item_count = len(self.smap)\n","\n"," args.num_items = len(self.smap)\n"," self.max_len = args.bert_max_len\n","\n"," val_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,\n"," self.train, self.val, self.test,\n"," self.user_count, self.item_count,\n"," args.test_negative_sample_size,\n"," args.test_negative_sampling_seed,\n"," 'val', self.save_folder)\n"," test_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,\n"," self.train, self.val, self.test,\n"," self.user_count, self.item_count,\n"," args.test_negative_sample_size,\n"," args.test_negative_sampling_seed,\n"," 'test', self.save_folder)\n","\n"," self.seen_samples, self.val_negative_samples = val_negative_sampler.get_negative_samples()\n"," self.seen_samples, self.test_negative_samples = test_negative_sampler.get_negative_samples()\n","\n"," @classmethod\n"," def code(cls):\n"," return 'rnn'\n","\n"," def get_pytorch_dataloaders(self):\n"," train_loader = self._get_train_loader()\n"," val_loader = self._get_val_loader()\n"," test_loader = self._get_test_loader()\n"," return train_loader, val_loader, test_loader\n","\n"," def _get_train_loader(self):\n"," dataset = self._get_train_dataset()\n"," dataloader = data_utils.DataLoader(dataset, batch_size=self.args.train_batch_size,\n"," shuffle=True, pin_memory=True)\n"," return dataloader\n","\n"," def _get_train_dataset(self):\n"," dataset = RNNTrainDataset(\n"," self.train, self.max_len)\n"," return dataset\n","\n"," def _get_val_loader(self):\n"," return self._get_eval_loader(mode='val')\n","\n"," def _get_test_loader(self):\n"," return self._get_eval_loader(mode='test')\n","\n"," def _get_eval_loader(self, mode):\n"," batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size\n"," dataset = self._get_eval_dataset(mode)\n"," dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=False, pin_memory=True)\n"," return dataloader\n","\n"," def _get_eval_dataset(self, mode):\n"," if mode == 'val':\n"," dataset = RNNValidDataset(self.train, self.val, self.max_len, self.val_negative_samples)\n"," elif mode == 'test':\n"," dataset = RNNTestDataset(self.train, self.val, self.test, self.max_len, self.test_negative_samples)\n"," return dataset\n","\n","\n","class RNNTrainDataset(data_utils.Dataset):\n"," def __init__(self, u2seq, max_len):\n"," # self.u2seq = u2seq\n"," # self.users = sorted(self.u2seq.keys())\n"," self.max_len = max_len\n"," self.all_seqs = []\n"," self.all_labels = []\n"," for u in sorted(u2seq.keys()):\n"," seq = u2seq[u]\n"," for i in range(1, len(seq)):\n"," self.all_seqs += [seq[:-i]]\n"," self.all_labels += [seq[-i]]\n","\n"," assert len(self.all_seqs) == len(self.all_labels)\n","\n"," def __len__(self):\n"," return len(self.all_seqs)\n","\n"," def __getitem__(self, index):\n"," tokens = self.all_seqs[index][-self.max_len:]\n"," length = len(tokens)\n"," tokens = tokens + [0] * (self.max_len - length)\n"," \n"," return torch.LongTensor(tokens), torch.LongTensor([length]), torch.LongTensor([self.all_labels[index]])\n","\n","\n","class RNNValidDataset(data_utils.Dataset):\n"," def __init__(self, u2seq, u2answer, max_len, negative_samples, valid_users=None):\n"," self.u2seq = u2seq # train\n"," if not valid_users:\n"," self.users = sorted(self.u2seq.keys())\n"," else:\n"," self.users = valid_users\n"," self.users = sorted(self.u2seq.keys())\n"," self.u2answer = u2answer\n"," self.max_len = max_len\n"," self.negative_samples = negative_samples\n"," \n"," def __len__(self):\n"," return len(self.users)\n","\n"," def __getitem__(self, index):\n"," user = self.users[index]\n"," tokens = self.u2seq[user][-self.max_len:]\n"," length = len(tokens)\n"," tokens = tokens + [0] * (self.max_len - length)\n","\n"," answer = self.u2answer[user]\n"," negs = self.negative_samples[user]\n"," candidates = answer + negs\n"," labels = [1] * len(answer) + [0] * len(negs)\n"," \n"," return torch.LongTensor(tokens), torch.LongTensor([length]), torch.LongTensor(candidates), torch.LongTensor(labels)\n","\n","\n","class RNNTestDataset(data_utils.Dataset):\n"," def __init__(self, u2seq, u2val, u2answer, max_len, negative_samples, test_users=None):\n"," self.u2seq = u2seq # train\n"," self.u2val = u2val # val\n"," if not test_users:\n"," self.users = sorted(self.u2seq.keys())\n"," else:\n"," self.users = test_users\n"," self.users = sorted(self.u2seq.keys())\n"," self.u2answer = u2answer # test\n"," self.max_len = max_len\n"," self.negative_samples = negative_samples\n","\n"," def __len__(self):\n"," return len(self.users)\n","\n"," def __getitem__(self, index):\n"," user = self.users[index]\n"," tokens = (self.u2seq[user] + self.u2val[user])[-self.max_len:] # append validation item after train seq\n"," length = len(tokens)\n"," tokens = tokens + [0] * (self.max_len - length)\n"," answer = self.u2answer[user]\n"," negs = self.negative_samples[user]\n"," candidates = answer + negs\n"," labels = [1] * len(answer) + [0] * len(negs)\n","\n"," return torch.LongTensor(tokens), torch.LongTensor([length]), torch.LongTensor(candidates), torch.LongTensor(labels)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lNY2YUbeIkz2"},"source":["#### SAS"]},{"cell_type":"code","metadata":{"id":"KYsm6jy9ISl8"},"source":["class SASDataloader():\n"," def __init__(self, args, dataset):\n"," self.args = args\n"," self.rng = random.Random()\n"," self.save_folder = dataset._get_preprocessed_folder_path()\n"," dataset = dataset.load_dataset()\n"," self.train = dataset['train']\n"," self.val = dataset['val']\n"," self.test = dataset['test']\n"," self.umap = dataset['umap']\n"," self.smap = dataset['smap']\n"," self.user_count = len(self.umap)\n"," self.item_count = len(self.smap)\n","\n"," args.num_items = self.item_count\n"," self.max_len = args.bert_max_len\n"," self.mask_prob = args.bert_mask_prob\n"," self.max_predictions = args.bert_max_predictions\n"," self.sliding_size = args.sliding_window_size\n"," self.CLOZE_MASK_TOKEN = self.item_count + 1\n","\n"," val_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,\n"," self.train, self.val, self.test,\n"," self.user_count, self.item_count,\n"," args.test_negative_sample_size,\n"," args.test_negative_sampling_seed,\n"," 'val', self.save_folder)\n"," test_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,\n"," self.train, self.val, self.test,\n"," self.user_count, self.item_count,\n"," args.test_negative_sample_size,\n"," args.test_negative_sampling_seed,\n"," 'test', self.save_folder)\n","\n"," self.seen_samples, self.val_negative_samples = val_negative_sampler.get_negative_samples()\n"," self.seen_samples, self.test_negative_samples = test_negative_sampler.get_negative_samples()\n","\n"," @classmethod\n"," def code(cls):\n"," return 'sas'\n","\n"," def get_pytorch_dataloaders(self):\n"," train_loader = self._get_train_loader()\n"," val_loader = self._get_val_loader()\n"," test_loader = self._get_test_loader()\n"," return train_loader, val_loader, test_loader\n","\n"," def _get_train_loader(self):\n"," dataset = self._get_train_dataset()\n"," dataloader = data_utils.DataLoader(dataset, batch_size=self.args.train_batch_size,\n"," shuffle=True, pin_memory=True)\n"," return dataloader\n","\n"," def _get_train_dataset(self):\n"," dataset = SASTrainDataset(\n"," self.train, self.max_len, self.sliding_size, self.seen_samples, self.item_count, self.rng)\n"," return dataset\n","\n"," def _get_val_loader(self):\n"," return self._get_eval_loader(mode='val')\n","\n"," def _get_test_loader(self):\n"," return self._get_eval_loader(mode='test')\n","\n"," def _get_eval_loader(self, mode):\n"," batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size\n"," dataset = self._get_eval_dataset(mode)\n"," dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=False, pin_memory=True)\n"," return dataloader\n","\n"," def _get_eval_dataset(self, mode):\n"," if mode == 'val':\n"," dataset = SASValidDataset(self.train, self.val, self.max_len, self.val_negative_samples)\n"," elif mode == 'test':\n"," dataset = SASTestDataset(self.train, self.val, self.test, self.max_len, self.test_negative_samples)\n"," return dataset\n","\n","\n","class SASTrainDataset(data_utils.Dataset):\n"," def __init__(self, u2seq, max_len, sliding_size, seen_samples, num_items, rng):\n"," # self.u2seq = u2seq\n"," # self.users = sorted(self.u2seq.keys())\n"," self.max_len = max_len\n"," self.sliding_step = int(sliding_size * max_len)\n"," self.num_items = num_items\n"," self.rng = rng\n"," \n"," assert self.sliding_step > 0\n"," self.all_seqs = []\n"," self.seen_samples = []\n"," for u in sorted(u2seq.keys()):\n"," seq = u2seq[u]\n"," neg = seen_samples[u]\n"," if len(seq) < self.max_len + self.sliding_step:\n"," self.all_seqs.append(seq)\n"," self.seen_samples.append(neg)\n"," else:\n"," start_idx = range(len(seq) - max_len, -1, -self.sliding_step)\n"," self.all_seqs = self.all_seqs + [seq[i:i + max_len] for i in start_idx]\n"," self.seen_samples = self.seen_samples + [neg for i in start_idx]\n","\n"," def __len__(self):\n"," return len(self.all_seqs)\n","\n"," def __getitem__(self, index):\n"," seq = self.all_seqs[index]\n"," labels = seq[-self.max_len:]\n"," tokens = seq[:-1][-self.max_len:]\n"," neg = []\n","\n"," mask_len = self.max_len - len(tokens)\n"," tokens = [0] * mask_len + tokens\n","\n"," mask_len = self.max_len - len(labels)\n"," while len(neg) < len(labels):\n"," item = self.rng.randint(1, self.num_items)\n"," if item in self.seen_samples[index] or item in neg:\n"," continue\n"," neg.append(item)\n"," \n"," labels = [0] * mask_len + labels\n"," neg = [0] * mask_len + neg\n","\n"," return torch.LongTensor(tokens), torch.LongTensor(labels), torch.LongTensor(neg)\n","\n","\n","class SASValidDataset(data_utils.Dataset):\n"," def __init__(self, u2seq, u2answer, max_len, negative_samples, valid_users=None):\n"," self.u2seq = u2seq # train\n"," if not valid_users:\n"," self.users = sorted(self.u2seq.keys())\n"," else:\n"," self.users = valid_users\n"," self.users = sorted(self.u2seq.keys())\n"," self.u2answer = u2answer\n"," self.max_len = max_len\n"," self.negative_samples = negative_samples\n","\n"," def __len__(self):\n"," return len(self.users)\n","\n"," def __getitem__(self, index):\n"," user = self.users[index]\n"," seq = self.u2seq[user]\n"," answer = self.u2answer[user]\n"," negs = self.negative_samples[user]\n","\n"," candidates = answer + negs\n"," labels = [1] * len(answer) + [0] * len(negs)\n","\n"," # no mask token here\n"," seq = seq[-self.max_len:]\n"," padding_len = self.max_len - len(seq)\n"," seq = [0] * padding_len + seq\n","\n"," return torch.LongTensor(seq), torch.LongTensor(candidates), torch.LongTensor(labels)\n","\n","\n","class SASTestDataset(data_utils.Dataset):\n"," def __init__(self, u2seq, u2val, u2answer, max_len, negative_samples, test_users=None):\n"," self.u2seq = u2seq # train\n"," self.u2val = u2val # val\n"," if not test_users:\n"," self.users = sorted(self.u2seq.keys())\n"," else:\n"," self.users = test_users\n"," self.users = sorted(self.u2seq.keys())\n"," self.u2answer = u2answer # test\n"," self.max_len = max_len\n"," self.negative_samples = negative_samples\n","\n"," def __len__(self):\n"," return len(self.users)\n","\n"," def __getitem__(self, index):\n"," user = self.users[index]\n"," seq = self.u2seq[user] + self.u2val[user] # append validation item after train seq\n"," answer = self.u2answer[user]\n"," negs = self.negative_samples[user]\n","\n"," candidates = answer + negs\n"," labels = [1] * len(answer) + [0] * len(negs)\n","\n"," # no mask token here\n"," seq = seq[-self.max_len:]\n"," padding_len = self.max_len - len(seq)\n"," seq = [0] * padding_len + seq\n","\n"," return torch.LongTensor(seq), torch.LongTensor(candidates), torch.LongTensor(labels)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"96DzPnjLIihJ"},"source":["#### BERT"]},{"cell_type":"code","metadata":{"id":"T4EkWhJ-Ifm6"},"source":["class BERTDataloader():\n"," def __init__(self, args, dataset):\n"," self.args = args\n"," self.rng = random.Random()\n"," self.save_folder = dataset._get_preprocessed_folder_path()\n"," dataset = dataset.load_dataset()\n"," self.train = dataset['train']\n"," self.val = dataset['val']\n"," self.test = dataset['test']\n"," self.umap = dataset['umap']\n"," self.smap = dataset['smap']\n"," self.user_count = len(self.umap)\n"," self.item_count = len(self.smap)\n","\n"," args.num_items = self.item_count\n"," self.max_len = args.bert_max_len\n"," self.mask_prob = args.bert_mask_prob\n"," self.max_predictions = args.bert_max_predictions\n"," self.sliding_size = args.sliding_window_size\n"," self.CLOZE_MASK_TOKEN = self.item_count + 1\n","\n"," val_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,\n"," self.train, self.val, self.test,\n"," self.user_count, self.item_count,\n"," args.test_negative_sample_size,\n"," args.test_negative_sampling_seed,\n"," 'val', self.save_folder)\n"," test_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,\n"," self.train, self.val, self.test,\n"," self.user_count, self.item_count,\n"," args.test_negative_sample_size,\n"," args.test_negative_sampling_seed,\n"," 'test', self.save_folder)\n","\n"," self.seen_samples, self.val_negative_samples = val_negative_sampler.get_negative_samples()\n"," self.seen_samples, self.test_negative_samples = test_negative_sampler.get_negative_samples()\n","\n"," @classmethod\n"," def code(cls):\n"," return 'bert'\n","\n"," def get_pytorch_dataloaders(self):\n"," train_loader = self._get_train_loader()\n"," val_loader = self._get_val_loader()\n"," test_loader = self._get_test_loader()\n"," return train_loader, val_loader, test_loader\n","\n"," def _get_train_loader(self):\n"," dataset = self._get_train_dataset()\n"," dataloader = data_utils.DataLoader(dataset, batch_size=self.args.train_batch_size,\n"," shuffle=True, pin_memory=True)\n"," return dataloader\n","\n"," def _get_train_dataset(self):\n"," dataset = BERTTrainDataset(\n"," self.train, self.max_len, self.mask_prob, self.max_predictions, self.sliding_size, self.CLOZE_MASK_TOKEN, self.item_count, self.rng)\n"," return dataset\n","\n"," def _get_val_loader(self):\n"," return self._get_eval_loader(mode='val')\n","\n"," def _get_test_loader(self):\n"," return self._get_eval_loader(mode='test')\n","\n"," def _get_eval_loader(self, mode):\n"," batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size\n"," dataset = self._get_eval_dataset(mode)\n"," dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=False, pin_memory=True)\n"," return dataloader\n","\n"," def _get_eval_dataset(self, mode):\n"," if mode == 'val':\n"," dataset = BERTValidDataset(self.train, self.val, self.max_len, self.CLOZE_MASK_TOKEN, self.val_negative_samples)\n"," elif mode == 'test':\n"," dataset = BERTTestDataset(self.train, self.val, self.test, self.max_len, self.CLOZE_MASK_TOKEN, self.test_negative_samples)\n"," return dataset\n","\n","\n","class BERTTrainDataset(data_utils.Dataset):\n"," def __init__(self, u2seq, max_len, mask_prob, max_predictions, sliding_size, mask_token, num_items, rng):\n"," # self.u2seq = u2seq\n"," # self.users = sorted(self.u2seq.keys())\n"," self.max_len = max_len\n"," self.mask_prob = mask_prob\n"," self.max_predictions = max_predictions\n"," self.sliding_step = int(sliding_size * max_len)\n"," self.mask_token = mask_token\n"," self.num_items = num_items\n"," self.rng = rng\n"," \n"," assert self.sliding_step > 0\n"," self.all_seqs = []\n"," for u in sorted(u2seq.keys()):\n"," seq = u2seq[u]\n"," if len(seq) < self.max_len + self.sliding_step:\n"," self.all_seqs.append(seq)\n"," else:\n"," start_idx = range(len(seq) - max_len, -1, -self.sliding_step)\n"," self.all_seqs = self.all_seqs + [seq[i:i + max_len] for i in start_idx]\n","\n"," def __len__(self):\n"," return len(self.all_seqs)\n"," # return len(self.users)\n","\n"," def __getitem__(self, index):\n"," # user = self.users[index]\n"," # seq = self._getseq(user)\n"," seq = self.all_seqs[index]\n","\n"," tokens = []\n"," labels = []\n"," covered_items = set()\n"," for i in range(len(seq)):\n"," s = seq[i]\n"," if (len(covered_items) >= self.max_predictions) or (s in covered_items):\n"," tokens.append(s)\n"," labels.append(0)\n"," continue\n"," \n"," temp_mask_prob = self.mask_prob\n"," if i == (len(seq) - 1):\n"," temp_mask_prob += 0.1 * (1 - self.mask_prob)\n","\n"," prob = self.rng.random()\n"," if prob < temp_mask_prob:\n"," covered_items.add(s)\n"," prob /= temp_mask_prob\n"," if prob < 0.8:\n"," tokens.append(self.mask_token)\n"," elif prob < 0.9:\n"," tokens.append(self.rng.randint(1, self.num_items))\n"," else:\n"," tokens.append(s)\n","\n"," labels.append(s)\n"," else:\n"," tokens.append(s)\n"," labels.append(0)\n","\n"," tokens = tokens[-self.max_len:]\n"," labels = labels[-self.max_len:]\n","\n"," mask_len = self.max_len - len(tokens)\n","\n"," tokens = [0] * mask_len + tokens\n"," labels = [0] * mask_len + labels\n","\n"," return torch.LongTensor(tokens), torch.LongTensor(labels)\n","\n"," def _getseq(self, user):\n"," return self.u2seq[user]\n","\n","\n","class BERTValidDataset(data_utils.Dataset):\n"," def __init__(self, u2seq, u2answer, max_len, mask_token, negative_samples, valid_users=None):\n"," self.u2seq = u2seq # train\n"," if not valid_users:\n"," self.users = sorted(self.u2seq.keys())\n"," else:\n"," self.users = valid_users\n"," self.u2answer = u2answer\n"," self.max_len = max_len\n"," self.mask_token = mask_token\n"," self.negative_samples = negative_samples\n","\n"," def __len__(self):\n"," return len(self.users)\n","\n"," def __getitem__(self, index):\n"," user = self.users[index]\n"," seq = self.u2seq[user]\n"," answer = self.u2answer[user]\n"," negs = self.negative_samples[user]\n","\n"," candidates = answer + negs\n"," labels = [1] * len(answer) + [0] * len(negs)\n","\n"," seq = seq + [self.mask_token]\n"," seq = seq[-self.max_len:]\n"," padding_len = self.max_len - len(seq)\n"," seq = [0] * padding_len + seq\n","\n"," return torch.LongTensor(seq), torch.LongTensor(candidates), torch.LongTensor(labels)\n","\n","\n","class BERTTestDataset(data_utils.Dataset):\n"," def __init__(self, u2seq, u2val, u2answer, max_len, mask_token, negative_samples, test_users=None):\n"," self.u2seq = u2seq # train\n"," self.u2val = u2val # val\n"," if not test_users:\n"," self.users = sorted(self.u2seq.keys())\n"," else:\n"," self.users = test_users\n"," self.users = sorted(self.u2seq.keys())\n"," self.u2answer = u2answer # test\n"," self.max_len = max_len\n"," self.mask_token = mask_token\n"," self.negative_samples = negative_samples\n","\n"," def __len__(self):\n"," return len(self.users)\n","\n"," def __getitem__(self, index):\n"," user = self.users[index]\n"," seq = self.u2seq[user] + self.u2val[user] # append validation item after train seq\n"," answer = self.u2answer[user]\n"," negs = self.negative_samples[user]\n","\n"," candidates = answer + negs\n"," labels = [1] * len(answer) + [0] * len(negs)\n","\n"," seq = seq + [self.mask_token]\n"," seq = seq[-self.max_len:]\n"," padding_len = self.max_len - len(seq)\n"," seq = [0] * padding_len + seq\n","\n"," return torch.LongTensor(seq), torch.LongTensor(candidates), torch.LongTensor(labels)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1m1KLGaKI2f2"},"source":["#### Factory"]},{"cell_type":"code","metadata":{"id":"ihAxE7dMH4Eo"},"source":["def dataloader_factory(args):\n"," dataset = dataset_factory(args)\n"," if args.model_code == 'bert':\n"," dataloader = BERTDataloader(args, dataset)\n"," elif args.model_code == 'sas':\n"," dataloader = SASDataloader(args, dataset)\n"," else:\n"," dataloader = RNNDataloader(args, dataset)\n"," train, val, test = dataloader.get_pytorch_dataloaders()\n"," return train, val, test"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dzGFohyfKt7N"},"source":["#### Distillation"]},{"cell_type":"code","metadata":{"id":"tOzhrI8VKvb2"},"source":["DIS_DATASETS = {\n"," ML1MDistillationDataset.code(): ML1MDistillationDataset\n","}\n","\n","\n","def dis_dataset_factory(args, bb_model_code, mode='random'):\n"," dataset = DIS_DATASETS[args.dataset_code]\n"," return dataset(args, bb_model_code, mode)\n","\n","\n","def dis_train_loader_factory(args, bb_model_code, mode='random'):\n"," dataset = dis_dataset_factory(args, bb_model_code, mode)\n"," if dataset.check_data_present():\n"," dataloader = DistillationLoader(args, dataset)\n"," train, val = dataloader.get_loaders()\n"," return train, val\n"," else:\n"," return None\n","\n","\n","class DistillationLoader():\n"," def __init__(self, args, dataset):\n"," self.args = args\n"," dataset = dataset.load_dataset()\n"," self.tokens = dataset['seqs']\n"," self.logits = dataset['logits']\n"," self.candidates = dataset['candidates']\n","\n"," @classmethod\n"," def code(cls):\n"," return 'distillation_loader'\n","\n"," def get_loaders(self):\n"," train, val = self._get_datasets()\n"," train_loader = data_utils.DataLoader(train, batch_size=self.args.train_batch_size,\n"," shuffle=True, pin_memory=True)\n"," val_loader = data_utils.DataLoader(val, batch_size=self.args.train_batch_size,\n"," shuffle=True, pin_memory=True)\n"," return train_loader, val_loader\n","\n"," def _get_datasets(self):\n"," if self.args.model_code == 'bert':\n"," train_dataset = BERTDistillationTrainingDataset(self.args, self.tokens, self.logits, self.candidates)\n"," valid_dataset = BERTDistillationValidationDataset(self.args, self.tokens, self.logits, self.candidates)\n"," elif self.args.model_code == 'sas':\n"," train_dataset = SASDistillationTrainingDataset(self.args, self.tokens, self.logits, self.candidates)\n"," valid_dataset = SASDistillationValidationDataset(self.args, self.tokens, self.logits, self.candidates)\n"," elif self.args.model_code == 'narm':\n"," train_dataset = NARMDistillationTrainingDataset(self.args, self.tokens, self.logits, self.candidates)\n"," valid_dataset = NARMDistillationValidationDataset(self.args, self.tokens, self.logits, self.candidates)\n"," \n"," return train_dataset, valid_dataset\n","\n","\n","class BERTDistillationTrainingDataset(data_utils.Dataset):\n"," def __init__(self, args, tokens, labels, candidates):\n"," self.max_len = args.bert_max_len\n"," self.mask_prob = args.bert_mask_prob\n"," self.max_predictions = args.bert_max_predictions\n"," self.num_items = args.num_items\n"," self.mask_token = args.num_items + 1\n","\n"," self.all_seqs = []\n"," self.all_labels = []\n"," self.all_candidates = []\n"," for i in range(len(tokens)):\n"," seq = tokens[i]\n"," label = labels[i]\n"," candidate = candidates[i]\n","\n"," for j in range(0, len(seq)-1):\n"," masked_seq = seq[:j+1] + [self.mask_token]\n"," self.all_seqs += [masked_seq]\n"," self.all_labels += [label[j]]\n"," self.all_candidates += [candidate[j]]\n","\n"," assert len(self.all_seqs) == len(self.all_labels) == len(self.all_candidates)\n","\n"," def __len__(self):\n"," return len(self.all_seqs)\n","\n"," def __getitem__(self, index):\n"," masked_seq = self.all_seqs[index]\n"," masked_seq = masked_seq[-self.max_len:]\n"," mask_len = self.max_len - len(masked_seq)\n"," masked_seq = [0] * mask_len + masked_seq\n","\n"," return torch.LongTensor(masked_seq), torch.LongTensor(self.all_candidates[index]), torch.tensor(self.all_labels[index])\n","\n","\n","class BERTDistillationValidationDataset(data_utils.Dataset):\n"," def __init__(self, args, tokens, labels, candidates):\n"," self.max_len = args.bert_max_len\n"," self.mask_prob = args.bert_mask_prob\n"," self.max_predictions = args.bert_max_predictions\n"," self.num_items = args.num_items\n"," self.mask_token = args.num_items + 1\n","\n"," self.all_seqs = []\n"," self.all_labels = []\n"," self.all_candidates = []\n"," for i in range(len(tokens)):\n"," seq = tokens[i]\n"," label = labels[i]\n"," candidate = candidates[i]\n"," self.all_seqs += [seq + [self.mask_token]]\n"," self.all_labels += [[1] + [0] * (len(label[-1]) - 1)]\n"," self.all_candidates += [candidate[-1]]\n","\n"," assert len(self.all_seqs) == len(self.all_labels) == len(self.all_candidates)\n","\n"," def __len__(self):\n"," return len(self.all_seqs)\n","\n"," def __getitem__(self, index):\n"," masked_seq = self.all_seqs[index]\n"," masked_seq = masked_seq[-self.max_len:]\n"," mask_len = self.max_len - len(masked_seq)\n"," masked_seq = [0] * mask_len + masked_seq\n","\n"," return torch.LongTensor(masked_seq), torch.LongTensor(self.all_candidates[index]), torch.tensor(self.all_labels[index])\n","\n","\n","class SASDistillationTrainingDataset(data_utils.Dataset):\n"," def __init__(self, args, tokens, labels, candidates):\n"," self.max_len = args.bert_max_len\n"," self.all_seqs = []\n"," self.all_labels = []\n"," self.all_candidates = []\n"," for i in range(len(tokens)):\n"," seq = tokens[i]\n"," label = labels[i]\n"," candidate = candidates[i]\n"," \n"," for j in range(1, len(seq)):\n"," self.all_seqs += [seq[:-j]]\n"," self.all_labels += [label[-j-1]]\n"," self.all_candidates += [candidate[-j-1]]\n","\n"," assert len(self.all_seqs) == len(self.all_labels) == len(self.all_candidates)\n","\n"," def __len__(self):\n"," return len(self.all_seqs)\n","\n"," def __getitem__(self, index):\n"," tokens = self.all_seqs[index][-self.max_len:]\n"," mask_len = self.max_len - len(tokens)\n"," tokens = [0] * mask_len + tokens\n","\n"," return torch.LongTensor(tokens), torch.LongTensor(self.all_candidates[index]), torch.tensor(self.all_labels[index])\n","\n","\n","class SASDistillationValidationDataset(data_utils.Dataset):\n"," def __init__(self, args, tokens, labels, candidates):\n"," self.max_len = args.bert_max_len\n"," self.all_seqs = []\n"," self.all_labels = []\n"," self.all_candidates = []\n"," for i in range(len(tokens)):\n"," seq = tokens[i]\n"," label = labels[i]\n"," candidate = candidates[i]\n"," \n"," self.all_seqs += [seq]\n"," self.all_labels += [[1] + [0] * (len(label[-1]) - 1)]\n"," self.all_candidates += [candidate[-1]]\n","\n"," assert len(self.all_seqs) == len(self.all_labels) == len(self.all_candidates)\n","\n"," def __len__(self):\n"," return len(self.all_seqs)\n","\n"," def __getitem__(self, index):\n"," tokens = self.all_seqs[index][-self.max_len:]\n"," mask_len = self.max_len - len(tokens)\n"," tokens = [0] * mask_len + tokens\n","\n"," return torch.LongTensor(tokens), torch.LongTensor(self.all_candidates[index]), torch.tensor(self.all_labels[index])\n","\n","\n","class NARMDistillationTrainingDataset(data_utils.Dataset):\n"," def __init__(self, args, tokens, labels, candidates):\n"," self.max_len = args.bert_max_len\n"," self.all_seqs = []\n"," self.all_labels = []\n"," self.all_candidates = []\n"," for i in range(len(tokens)):\n"," seq = tokens[i]\n"," label = labels[i]\n"," candidate = candidates[i]\n"," \n"," for j in range(1, len(seq)):\n"," self.all_seqs += [seq[:-j]]\n"," self.all_labels += [label[-j-1]]\n"," self.all_candidates += [candidate[-j-1]]\n","\n"," assert len(self.all_seqs) == len(self.all_labels) == len(self.all_candidates)\n","\n"," def __len__(self):\n"," return len(self.all_seqs)\n","\n"," def __getitem__(self, index):\n"," tokens = self.all_seqs[index][-self.max_len:]\n"," length = len(tokens)\n"," tokens = tokens + [0] * (self.max_len - length)\n","\n"," return torch.LongTensor(tokens), torch.LongTensor([length]), torch.LongTensor(self.all_candidates[index]), torch.tensor(self.all_labels[index])\n","\n","\n","class NARMDistillationValidationDataset(data_utils.Dataset):\n"," def __init__(self, args, tokens, labels, candidates):\n"," self.max_len = args.bert_max_len\n"," self.all_seqs = []\n"," self.all_labels = []\n"," self.all_candidates = []\n"," for i in range(len(tokens)):\n"," seq = tokens[i]\n"," label = labels[i]\n"," candidate = candidates[i]\n"," \n"," self.all_seqs += [seq]\n"," self.all_labels += [[1] + [0] * (len(label[-1]) - 1)]\n"," self.all_candidates += [candidate[-1]]\n","\n"," assert len(self.all_seqs) == len(self.all_labels) == len(self.all_candidates)\n","\n"," def __len__(self):\n"," return len(self.all_seqs)\n","\n"," def __getitem__(self, index):\n"," tokens = self.all_seqs[index][-self.max_len:]\n"," length = len(tokens)\n"," tokens = tokens + [0] * (self.max_len - length)\n","\n"," return torch.LongTensor(tokens), torch.LongTensor([length]), torch.LongTensor(self.all_candidates[index]), torch.tensor(self.all_labels[index])"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"U7buZr1cHAt5"},"source":["## Models"]},{"cell_type":"markdown","metadata":{"id":"xX1F6KrSHm4S"},"source":["### Attention"]},{"cell_type":"code","metadata":{"id":"YiGO7VUPHQEf"},"source":["class TokenEmbedding(nn.Embedding):\n"," def __init__(self, vocab_size, embed_size=512):\n"," super().__init__(vocab_size, embed_size, padding_idx=0)\n","\n","\n","class PositionalEmbedding(nn.Module):\n"," def __init__(self, max_len, d_model):\n"," super().__init__()\n"," self.d_model = d_model\n"," self.pe = nn.Embedding(max_len+1, d_model)\n","\n"," def forward(self, x):\n"," pose = (x > 0) * (x > 0).sum(dim=-1).unsqueeze(1).repeat(1, x.size(-1))\n"," pose += torch.arange(start=-(x.size(1)-1), end=1, step=1, device=x.device)\n"," pose = pose * (x > 0)\n","\n"," return self.pe(pose)\n","\n","\n","class GELU(nn.Module):\n"," def forward(self, x):\n"," return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))\n","\n","\n","class PositionwiseFeedForward(nn.Module):\n"," def __init__(self, d_model, d_ff):\n"," super().__init__()\n"," self.w_1 = nn.Linear(d_model, d_ff)\n"," self.w_2 = nn.Linear(d_ff, d_model)\n"," self.activation = GELU()\n","\n"," def forward(self, x):\n"," return self.w_2(self.activation(self.w_1(x)))\n","\n","\n","# layer norm\n","class LayerNorm(nn.Module):\n"," def __init__(self, features, eps=1e-6):\n"," super().__init__()\n"," self.weight = nn.Parameter(torch.ones(features))\n"," self.bias = nn.Parameter(torch.zeros(features))\n"," self.eps = eps\n","\n"," def forward(self, x):\n"," mean = x.mean(-1, keepdim=True)\n"," std = x.std(-1, keepdim=True)\n"," return self.weight * (x - mean) / (std + self.eps) + self.bias\n","\n","\n","# layer norm and dropout (dropout and then layer norm)\n","class SublayerConnection(nn.Module):\n"," def __init__(self, size, dropout):\n"," super().__init__()\n"," self.layer_norm = LayerNorm(size)\n"," self.dropout = nn.Dropout(dropout)\n","\n"," def forward(self, x, sublayer):\n"," # return x + self.dropout(sublayer(self.norm(x))) # original implementation\n"," return self.layer_norm(x + self.dropout(sublayer(x))) # BERT4Rec implementation\n","\n","\n","class Attention(nn.Module):\n"," def forward(self, query, key, value, mask=None, dropout=None, sas=False):\n"," scores = torch.matmul(query, key.transpose(-2, -1)) \\\n"," / math.sqrt(query.size(-1))\n","\n"," if mask is not None:\n"," scores = scores.masked_fill(mask == 0, -1e9)\n","\n"," if sas:\n"," direction_mask = torch.ones_like(scores)\n"," direction_mask = torch.tril(direction_mask)\n"," scores = scores.masked_fill(direction_mask == 0, -1e9)\n","\n"," p_attn = F.softmax(scores, dim=-1)\n","\n"," if dropout is not None:\n"," p_attn = dropout(p_attn)\n","\n"," return torch.matmul(p_attn, value), p_attn\n","\n","\n","class MultiHeadedAttention(nn.Module):\n"," def __init__(self, h, d_model, head_size=None, dropout=0.1):\n"," super().__init__()\n"," assert d_model % h == 0\n","\n"," self.h = h\n"," self.d_k = d_model // h\n"," if head_size is not None:\n"," self.head_size = head_size\n"," else:\n"," self.head_size = d_model // h\n","\n"," self.linear_layers = nn.ModuleList(\n"," [nn.Linear(d_model, self.h * self.head_size) for _ in range(3)])\n"," self.attention = Attention()\n"," self.dropout = nn.Dropout(p=dropout)\n"," self.output_linear = nn.Linear(self.h * self.head_size, d_model)\n","\n"," def forward(self, query, key, value, mask=None):\n"," batch_size = query.size(0)\n","\n"," # 1) do all the linear projections in batch from d_model => h x d_k\n"," query, key, value = [l(x).view(batch_size, -1, self.h, self.head_size).transpose(1, 2)\n"," for l, x in zip(self.linear_layers, (query, key, value))]\n"," \n"," # 2) apply attention on all the projected vectors in batch.\n"," x, attn = self.attention(\n"," query, key, value, mask=mask, dropout=self.dropout)\n","\n"," # 3) \"concat\" using a view and apply a final linear.\n"," x = x.transpose(1, 2).contiguous().view(\n"," batch_size, -1, self.h * self.head_size)\n"," return self.output_linear(x)\n","\n","\n","class TransformerBlock(nn.Module):\n"," def __init__(self, hidden, attn_heads, head_size, feed_forward_hidden, dropout, attn_dropout=0.1):\n"," super().__init__()\n"," self.attention = MultiHeadedAttention(\n"," h=attn_heads, d_model=hidden, head_size=head_size, dropout=attn_dropout)\n"," self.feed_forward = PositionwiseFeedForward(\n"," d_model=hidden, d_ff=feed_forward_hidden)\n"," self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)\n"," self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)\n","\n"," def forward(self, x, mask):\n"," x = self.input_sublayer(\n"," x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))\n"," x = self.output_sublayer(x, self.feed_forward)\n"," return x\n","\n","\n","class SASMultiHeadedAttention(nn.Module):\n"," def __init__(self, h, d_model, head_size=None, dropout=0.1):\n"," super().__init__()\n"," assert d_model % h == 0\n","\n"," self.h = h\n"," self.d_k = d_model // h\n"," if head_size is not None:\n"," self.head_size = head_size\n"," else:\n"," self.head_size = d_model // h\n","\n"," self.linear_layers = nn.ModuleList(\n"," [nn.Linear(d_model, self.h * self.head_size) for _ in range(3)])\n"," self.attention = Attention()\n"," self.dropout = nn.Dropout(p=dropout)\n"," self.layer_norm = LayerNorm(d_model)\n","\n"," def forward(self, query, key, value, mask=None):\n"," batch_size = query.size(0)\n","\n"," # 1) do all the linear projections in batch from d_model => h x d_k\n"," query_, key_, value_ = [l(x).view(batch_size, -1, self.h, self.head_size).transpose(1, 2)\n"," for l, x in zip(self.linear_layers, (query, key, value))]\n"," \n"," # 2) apply attention on all the projected vectors in batch.\n"," x, attn = self.attention(\n"," query_, key_, value_, mask=mask, dropout=self.dropout, sas=True)\n","\n"," # 3) \"concat\" using a view and apply a final linear.\n"," x = x.transpose(1, 2).contiguous().view(\n"," batch_size, -1, self.h * self.head_size)\n"," \n"," return self.layer_norm(x + query)\n","\n","\n","class SASPositionwiseFeedForward(nn.Module):\n"," def __init__(self, d_model, d_ff, dropout=0.1):\n"," super().__init__()\n"," self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)\n"," self.activation = nn.ReLU()\n"," self.dropout = nn.Dropout(dropout)\n"," self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)\n"," self.layer_norm = LayerNorm(d_model)\n","\n"," def forward(self, x):\n"," x_ = self.dropout(self.activation(self.conv1(x.permute(0, 2, 1))))\n"," return self.layer_norm(self.dropout(self.conv2(x_)).permute(0, 2, 1) + x)\n","\n","\n","class SASTransformerBlock(nn.Module):\n"," def __init__(self, hidden, attn_heads, head_size, feed_forward_hidden, dropout, attn_dropout=0.1):\n"," super().__init__()\n"," self.layer_norm = LayerNorm(hidden)\n"," self.attention = SASMultiHeadedAttention(\n"," h=attn_heads, d_model=hidden, head_size=head_size, dropout=attn_dropout)\n"," self.feed_forward = SASPositionwiseFeedForward(\n"," d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)\n","\n"," def forward(self, x, mask):\n"," x = self.attention(self.layer_norm(x), x, x, mask)\n"," x = self.feed_forward(x)\n"," return x"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"R-Ffm9maHuhy"},"source":["### SASRec"]},{"cell_type":"code","metadata":{"id":"I6teYn-xHv9g"},"source":["class SASRec(nn.Module):\n"," def __init__(self, args):\n"," super().__init__()\n"," self.args = args\n"," self.embedding = SASEmbedding(self.args)\n"," self.model = SASModel(self.args)\n"," self.truncated_normal_init()\n","\n"," def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04):\n"," with torch.no_grad():\n"," l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2.\n"," u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2.\n","\n"," for n, p in self.model.named_parameters():\n"," if not 'layer_norm' in n:\n"," p.uniform_(2 * l - 1, 2 * u - 1)\n"," p.erfinv_()\n"," p.mul_(std * math.sqrt(2.))\n"," p.add_(mean)\n"," \n"," def forward(self, x):\n"," x, mask = self.embedding(x)\n"," scores = self.model(x, self.embedding.token.weight, mask)\n"," return scores\n","\n","\n","class SASEmbedding(nn.Module):\n"," def __init__(self, args):\n"," super().__init__()\n"," vocab_size = args.num_items + 1\n"," hidden = args.bert_hidden_units\n"," max_len = args.bert_max_len\n"," dropout = args.bert_dropout\n","\n"," self.token = TokenEmbedding(\n"," vocab_size=vocab_size, embed_size=hidden)\n"," self.position = PositionalEmbedding(\n"," max_len=max_len, d_model=hidden)\n","\n"," self.dropout = nn.Dropout(p=dropout)\n","\n"," def get_mask(self, x):\n"," if len(x.shape) > 2:\n"," x = torch.ones(x.shape[:2]).to(x.device)\n"," return (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)\n","\n"," def forward(self, x):\n"," mask = self.get_mask(x)\n"," if len(x.shape) > 2:\n"," pos = self.position(torch.ones(x.shape[:2]).to(x.device))\n"," x = torch.matmul(x, self.token.weight) + pos\n"," else:\n"," x = self.token(x) + self.position(x)\n"," return self.dropout(x), mask\n","\n","\n","class SASModel(nn.Module):\n"," def __init__(self, args):\n"," super().__init__()\n"," hidden = args.bert_hidden_units\n"," heads = args.bert_num_heads\n"," head_size = args.bert_head_size\n"," dropout = args.bert_dropout\n"," attn_dropout = args.bert_attn_dropout\n"," layers = args.bert_num_blocks\n","\n"," self.transformer_blocks = nn.ModuleList([SASTransformerBlock(\n"," hidden, heads, head_size, hidden * 4, dropout, attn_dropout) for _ in range(layers)])\n","\n"," def forward(self, x, embedding_weight, mask):\n"," for transformer in self.transformer_blocks:\n"," x = transformer.forward(x, mask)\n"," scores = torch.matmul(x, embedding_weight.permute(1, 0))\n"," return scores"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vWM5QQBNHeDM"},"source":["### BERT"]},{"cell_type":"code","metadata":{"id":"rbcbnnaIHVSr"},"source":["class BERT(nn.Module):\n"," def __init__(self, args):\n"," super().__init__()\n"," self.args = args\n"," self.embedding = BERTEmbedding(self.args)\n"," self.model = BERTModel(self.args)\n"," self.truncated_normal_init()\n","\n"," def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04):\n"," with torch.no_grad():\n"," l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2.\n"," u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2.\n","\n"," for n, p in self.model.named_parameters():\n"," if not 'layer_norm' in n:\n"," p.uniform_(2 * l - 1, 2 * u - 1)\n"," p.erfinv_()\n"," p.mul_(std * math.sqrt(2.))\n"," p.add_(mean)\n"," \n"," def forward(self, x):\n"," x, mask = self.embedding(x)\n"," scores = self.model(x, self.embedding.token.weight, mask)\n"," return scores\n","\n","\n","class BERTEmbedding(nn.Module):\n"," def __init__(self, args):\n"," super().__init__()\n"," vocab_size = args.num_items + 2\n"," hidden = args.bert_hidden_units\n"," max_len = args.bert_max_len\n"," dropout = args.bert_dropout\n","\n"," self.token = TokenEmbedding(\n"," vocab_size=vocab_size, embed_size=hidden)\n"," self.position = PositionalEmbedding(\n"," max_len=max_len, d_model=hidden)\n","\n"," self.layer_norm = LayerNorm(features=hidden)\n"," self.dropout = nn.Dropout(p=dropout)\n","\n"," def get_mask(self, x):\n"," if len(x.shape) > 2:\n"," x = torch.ones(x.shape[:2]).to(x.device)\n"," return (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)\n","\n"," def forward(self, x):\n"," mask = self.get_mask(x)\n"," if len(x.shape) > 2:\n"," pos = self.position(torch.ones(x.shape[:2]).to(x.device))\n"," x = torch.matmul(x, self.token.weight) + pos\n"," else:\n"," x = self.token(x) + self.position(x)\n"," return self.dropout(self.layer_norm(x)), mask\n","\n","\n","class BERTModel(nn.Module):\n"," def __init__(self, args):\n"," super().__init__()\n"," hidden = args.bert_hidden_units\n"," heads = args.bert_num_heads\n"," head_size = args.bert_head_size\n"," dropout = args.bert_dropout\n"," attn_dropout = args.bert_attn_dropout\n"," layers = args.bert_num_blocks\n","\n"," self.transformer_blocks = nn.ModuleList([TransformerBlock(\n"," hidden, heads, head_size, hidden * 4, dropout, attn_dropout) for _ in range(layers)])\n"," self.linear = nn.Linear(hidden, hidden)\n"," self.bias = torch.nn.Parameter(torch.zeros(args.num_items + 2))\n"," self.bias.requires_grad = True\n"," self.activation = GELU()\n","\n"," def forward(self, x, embedding_weight, mask):\n"," for transformer in self.transformer_blocks:\n"," x = transformer.forward(x, mask)\n"," x = self.activation(self.linear(x))\n"," scores = torch.matmul(x, embedding_weight.permute(1, 0)) + self.bias\n"," return scores"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iNd5hIrHHcch"},"source":["### NARM"]},{"cell_type":"code","metadata":{"id":"kLL_5gybHVLy"},"source":["class NARM(nn.Module):\n"," def __init__(self, args):\n"," super(NARM, self).__init__()\n"," self.args = args\n"," self.embedding = NARMEmbedding(self.args)\n"," self.model = NARMModel(self.args)\n"," self.truncated_normal_init()\n","\n"," def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04):\n"," with torch.no_grad():\n"," l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2.\n"," u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2.\n","\n"," for p in self.parameters():\n"," p.uniform_(2 * l - 1, 2 * u - 1)\n"," p.erfinv_()\n"," p.mul_(std * math.sqrt(2.))\n"," p.add_(mean)\n","\n"," def forward(self, x, lengths):\n"," x, mask = self.embedding(x, lengths)\n"," scores = self.model(x, self.embedding.token.weight, lengths, mask)\n"," return scores\n","\n","\n","class NARMEmbedding(nn.Module):\n"," def __init__(self, args):\n"," super().__init__()\n"," vocab_size = args.num_items + 1\n"," embed_size = args.bert_hidden_units\n"," \n"," self.token = nn.Embedding(vocab_size, embed_size)\n"," self.embed_dropout = nn.Dropout(args.bert_dropout)\n","\n"," def get_mask(self, x, lengths):\n"," if len(x.shape) > 2:\n"," return torch.ones(x.shape[:2])[:, :max(lengths)].to(x.device)\n"," else:\n"," return ((x > 0) * 1)[:, :max(lengths)]\n","\n"," def forward(self, x, lengths):\n"," mask = self.get_mask(x, lengths)\n"," if len(x.shape) > 2:\n"," x = torch.matmul(x, self.token.weight)\n"," else:\n"," x = self.token(x)\n","\n"," return self.embed_dropout(x), mask\n","\n","\n","class NARMModel(nn.Module):\n"," def __init__(self, args):\n"," super().__init__()\n"," embed_size = args.bert_hidden_units\n"," hidden_size = 2 * args.bert_hidden_units\n","\n"," self.gru = nn.GRU(embed_size, hidden_size, num_layers=1, batch_first=True)\n"," self.a_global = nn.Linear(hidden_size, hidden_size, bias=False)\n"," self.a_local = nn.Linear(hidden_size, hidden_size, bias=False)\n"," self.act = HardSigmoid()\n"," self.v_vector = nn.Linear(hidden_size, 1, bias=False)\n"," self.proj_dropout = nn.Dropout(args.bert_attn_dropout)\n"," self.b_vetor = nn.Linear(embed_size, 2 * hidden_size, bias=False)\n","\n"," def forward(self, x, embedding_weight, lengths, mask):\n"," x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)\n"," gru_out, hidden = self.gru(x)\n"," gru_out, _ = pad_packed_sequence(gru_out, batch_first=True)\n"," c_global = hidden[-1]\n","\n"," state2 = self.a_local(gru_out)\n"," state1 = self.a_global(c_global).unsqueeze(1).expand_as(state2)\n"," state1 = mask.unsqueeze(2).expand_as(state2) * state1\n"," alpha = self.act(state1 + state2).view(-1, state1.size(-1))\n"," attn = self.v_vector(alpha).view(mask.size())\n"," attn = F.softmax(attn.masked_fill(mask == 0, -1e9), dim=-1)\n"," c_local = torch.sum(attn.unsqueeze(2).expand_as(gru_out) * gru_out, 1)\n","\n"," proj = self.proj_dropout(torch.cat([c_global, c_local], 1))\n"," scores = torch.matmul(proj, self.b_vetor(embedding_weight).permute(1, 0))\n"," return scores\n","\n","\n","class HardSigmoid(nn.Module):\n"," def forward(self, x):\n"," return torch.clamp((x / 6 + 0.5), min=0., max=1.)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SfxkcBfyJU5K"},"source":["## Trainers"]},{"cell_type":"markdown","metadata":{"id":"d0P-b90nKBxR"},"source":["### RNN"]},{"cell_type":"code","metadata":{"id":"-Z7sv4oFKDFS"},"source":["class RNNTrainer(metaclass=ABCMeta):\n"," def __init__(self, args, model, train_loader, val_loader, test_loader, export_root):\n"," self.args = args\n"," self.device = args.device\n"," self.model = model.to(self.device)\n"," self.is_parallel = args.num_gpu > 1\n"," if self.is_parallel:\n"," self.model = nn.DataParallel(self.model)\n","\n"," self.num_epochs = args.num_epochs\n"," self.metric_ks = args.metric_ks\n"," self.best_metric = args.best_metric\n"," self.train_loader = train_loader\n"," self.val_loader = val_loader\n"," self.test_loader = test_loader\n"," self.optimizer = self._create_optimizer()\n"," if args.enable_lr_schedule:\n"," if args.enable_lr_warmup:\n"," self.lr_scheduler = self.get_linear_schedule_with_warmup(\n"," self.optimizer, args.warmup_steps, len(train_loader) * self.num_epochs)\n"," else:\n"," self.lr_scheduler = optim.lr_scheduler.StepLR(\n"," self.optimizer, step_size=args.decay_step, gamma=args.gamma)\n"," \n"," self.export_root = export_root\n"," self.writer, self.train_loggers, self.val_loggers = self._create_loggers()\n"," self.logger_service = LoggerService(\n"," self.train_loggers, self.val_loggers)\n"," self.log_period_as_iter = args.log_period_as_iter\n","\n"," self.ce = nn.CrossEntropyLoss(ignore_index=0)\n","\n"," def train(self):\n"," accum_iter = 0\n"," self.validate(0, accum_iter)\n"," for epoch in range(self.num_epochs):\n"," accum_iter = self.train_one_epoch(epoch, accum_iter) \n"," self.validate(epoch, accum_iter)\n"," self.logger_service.complete({\n"," 'state_dict': (self._create_state_dict()),\n"," })\n"," self.writer.close()\n","\n"," def train_one_epoch(self, epoch, accum_iter):\n"," self.model.train()\n"," average_meter_set = AverageMeterSet()\n"," tqdm_dataloader = tqdm(self.train_loader)\n","\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," batch_size = batch[0].size(0)\n"," seqs, lengths, labels = batch\n"," lengths = lengths.flatten()\n"," seqs, labels = seqs.to(self.device), labels.to(self.device)\n","\n"," self.optimizer.zero_grad()\n"," logits = self.model(seqs, lengths)\n"," loss = self.ce(logits, labels.squeeze())\n"," loss.backward()\n"," self.clip_gradients(5)\n"," self.optimizer.step()\n"," if self.args.enable_lr_schedule:\n"," self.lr_scheduler.step()\n","\n"," average_meter_set.update('loss', loss.item())\n"," tqdm_dataloader.set_description(\n"," 'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))\n","\n"," accum_iter += batch_size\n","\n"," if self._needs_to_log(accum_iter):\n"," tqdm_dataloader.set_description('Logging to Tensorboard')\n"," log_data = {\n"," 'state_dict': (self._create_state_dict()),\n"," 'epoch': epoch + 1,\n"," 'accum_iter': accum_iter,\n"," }\n"," log_data.update(average_meter_set.averages())\n"," self.logger_service.log_train(log_data)\n"," \n"," return accum_iter\n","\n"," def validate(self, epoch, accum_iter):\n"," self.model.eval()\n"," average_meter_set = AverageMeterSet()\n","\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.val_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," metrics = self.calculate_metrics(batch)\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," log_data = {\n"," 'state_dict': (self._create_state_dict()),\n"," 'epoch': epoch+1,\n"," 'accum_iter': accum_iter,\n"," }\n"," log_data.update(average_meter_set.averages())\n"," self.logger_service.log_val(log_data)\n","\n"," def test(self):\n"," best_model_dict = torch.load(os.path.join(\n"," self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)\n"," self.model.load_state_dict(best_model_dict)\n"," self.model.eval()\n"," average_meter_set = AverageMeterSet()\n","\n"," all_scores = []\n"," average_scores = []\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.test_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," batch = [x.to(self.device) for x in batch]\n"," metrics = self.calculate_metrics(batch)\n"," \n"," # seqs, lengths, candidates, labels = batch\n"," # lengths = lengths.flatten()\n"," # seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n","\n"," # scores = self.model(seqs, lengths)\n"," # scores_sorted, indices = torch.sort(scores, dim=-1, descending=True)\n"," # all_scores += scores_sorted[:, :100].cpu().numpy().tolist()\n"," # average_scores += scores_sorted.cpu().numpy().tolist()\n"," # scores = scores.gather(1, candidates)\n"," # metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)\n","\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," average_metrics = average_meter_set.averages()\n"," with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:\n"," json.dump(average_metrics, f, indent=4)\n","\n"," return average_metrics\n","\n"," def calculate_metrics(self, batch):\n"," seqs, lengths, candidates, labels = batch\n"," lengths = lengths.flatten()\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n","\n"," scores = self.model(seqs, lengths) # B x V\n"," scores = scores.gather(1, candidates) # B x C\n","\n"," metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)\n"," return metrics\n","\n"," def clip_gradients(self, limit=5):\n"," for p in self.model.parameters():\n"," nn.utils.clip_grad_norm_(p, 5)\n","\n"," def _update_meter_set(self, meter_set, metrics):\n"," for k, v in metrics.items():\n"," meter_set.update(k, v)\n","\n"," def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):\n"," description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]\n"," ] + ['Recall@%d' % k for k in self.metric_ks[:3]]\n"," description = 'Eval: ' + \\\n"," ', '.join(s + ' {:.3f}' for s in description_metrics)\n"," description = description.replace('NDCG', 'N').replace('Recall', 'R')\n"," description = description.format(\n"," *(meter_set[k].avg for k in description_metrics))\n"," tqdm_dataloader.set_description(description)\n","\n"," def _create_optimizer(self):\n"," args = self.args\n"," param_optimizer = list(self.model.named_parameters())\n"," no_decay = ['bias', 'layer_norm']\n"," optimizer_grouped_parameters = [\n"," {\n"," 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n"," 'weight_decay': args.weight_decay,\n"," },\n"," {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n"," ]\n"," if args.optimizer.lower() == 'adamw':\n"," return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)\n"," elif args.optimizer.lower() == 'adam':\n"," return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)\n"," elif args.optimizer.lower() == 'sgd':\n"," return optim.SGD(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)\n"," else:\n"," raise ValueError\n","\n"," def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):\n"," # based on hugging face get_linear_schedule_with_warmup\n"," def lr_lambda(current_step: int):\n"," if current_step < num_warmup_steps:\n"," return float(current_step) / float(max(1, num_warmup_steps))\n"," return max(\n"," 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))\n"," )\n","\n"," return LambdaLR(optimizer, lr_lambda, last_epoch)\n","\n"," def _create_loggers(self):\n"," root = Path(self.export_root)\n"," writer = SummaryWriter(root.joinpath('logs'))\n"," model_checkpoint = root.joinpath('models')\n","\n"," train_loggers = [\n"," MetricGraphPrinter(writer, key='epoch',\n"," graph_name='Epoch', group_name='Train'),\n"," MetricGraphPrinter(writer, key='loss',\n"," graph_name='Loss', group_name='Train'),\n"," ]\n","\n"," val_loggers = []\n"," for k in self.metric_ks:\n"," val_loggers.append(\n"," MetricGraphPrinter(writer, key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation'))\n"," val_loggers.append(\n"," MetricGraphPrinter(writer, key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation'))\n"," val_loggers.append(RecentModelLogger(model_checkpoint))\n"," val_loggers.append(BestModelLogger(\n"," model_checkpoint, metric_key=self.best_metric))\n"," return writer, train_loggers, val_loggers\n","\n"," def _create_state_dict(self):\n"," return {\n"," STATE_DICT_KEY: self.model.module.state_dict() if self.is_parallel else self.model.state_dict(),\n"," OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),\n"," }\n","\n"," def _needs_to_log(self, accum_iter):\n"," return accum_iter % self.log_period_as_iter < self.args.train_batch_size and accum_iter != 0"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"VDHlTWUwKHI-"},"source":["### SASRec"]},{"cell_type":"code","metadata":{"id":"MkyntK0UKIrg"},"source":["class SASTrainer(metaclass=ABCMeta):\n"," def __init__(self, args, model, train_loader, val_loader, test_loader, export_root):\n"," self.args = args\n"," self.device = args.device\n"," self.model = model.to(self.device)\n"," self.is_parallel = args.num_gpu > 1\n"," if self.is_parallel:\n"," self.model = nn.DataParallel(self.model)\n","\n"," self.num_epochs = args.num_epochs\n"," self.metric_ks = args.metric_ks\n"," self.best_metric = args.best_metric\n"," self.train_loader = train_loader\n"," self.val_loader = val_loader\n"," self.test_loader = test_loader\n"," self.optimizer = self._create_optimizer()\n"," if args.enable_lr_schedule:\n"," if args.enable_lr_warmup:\n"," self.lr_scheduler = self.get_linear_schedule_with_warmup(\n"," self.optimizer, args.warmup_steps, len(train_loader) * self.num_epochs)\n"," else:\n"," self.lr_scheduler = optim.lr_scheduler.StepLR(\n"," self.optimizer, step_size=args.decay_step, gamma=args.gamma)\n"," \n"," self.export_root = export_root\n"," self.writer, self.train_loggers, self.val_loggers = self._create_loggers()\n"," self.logger_service = LoggerService(\n"," self.train_loggers, self.val_loggers)\n"," self.log_period_as_iter = args.log_period_as_iter\n","\n"," self.bce = nn.BCEWithLogitsLoss()\n","\n"," def train(self):\n"," accum_iter = 0\n"," self.validate(0, accum_iter)\n"," for epoch in range(self.num_epochs):\n"," accum_iter = self.train_one_epoch(epoch, accum_iter)\n"," self.validate(epoch, accum_iter)\n"," self.logger_service.complete({\n"," 'state_dict': (self._create_state_dict()),\n"," })\n"," self.writer.close()\n","\n"," def train_one_epoch(self, epoch, accum_iter):\n"," self.model.train()\n"," average_meter_set = AverageMeterSet()\n"," tqdm_dataloader = tqdm(self.train_loader)\n","\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," batch_size = batch[0].size(0)\n"," batch = [x.to(self.device) for x in batch]\n","\n"," self.optimizer.zero_grad()\n"," loss = self.calculate_loss(batch)\n"," loss.backward()\n"," self.clip_gradients(5)\n"," self.optimizer.step()\n"," if self.args.enable_lr_schedule:\n"," self.lr_scheduler.step()\n","\n"," average_meter_set.update('loss', loss.item())\n"," tqdm_dataloader.set_description(\n"," 'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))\n","\n"," accum_iter += batch_size\n","\n"," if self._needs_to_log(accum_iter):\n"," tqdm_dataloader.set_description('Logging to Tensorboard')\n"," log_data = {\n"," 'state_dict': (self._create_state_dict()),\n"," 'epoch': epoch + 1,\n"," 'accum_iter': accum_iter,\n"," }\n"," log_data.update(average_meter_set.averages())\n"," self.logger_service.log_train(log_data)\n","\n"," return accum_iter\n","\n"," def validate(self, epoch, accum_iter):\n"," self.model.eval()\n","\n"," average_meter_set = AverageMeterSet()\n","\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.val_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," batch = [x.to(self.device) for x in batch]\n","\n"," metrics = self.calculate_metrics(batch)\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," log_data = {\n"," 'state_dict': (self._create_state_dict()),\n"," 'epoch': epoch+1,\n"," 'accum_iter': accum_iter,\n"," }\n"," log_data.update(average_meter_set.averages())\n"," self.logger_service.log_val(log_data)\n","\n"," def test(self):\n"," best_model_dict = torch.load(os.path.join(\n"," self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)\n"," self.model.load_state_dict(best_model_dict)\n"," self.model.eval()\n","\n"," average_meter_set = AverageMeterSet()\n","\n"," all_scores = []\n"," average_scores = []\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.test_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," batch = [x.to(self.device) for x in batch]\n"," metrics = self.calculate_metrics(batch)\n"," \n"," # seqs, candidates, labels = batch\n"," # scores = self.model(seqs)\n"," # scores = scores[:, -1, :]\n"," # scores_sorted, indices = torch.sort(scores, dim=-1, descending=True)\n"," # all_scores += scores_sorted[:, :100].cpu().numpy().tolist()\n"," # average_scores += scores_sorted.cpu().numpy().tolist()\n"," # scores = scores.gather(1, candidates)\n"," # metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)\n","\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," average_metrics = average_meter_set.averages()\n"," with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:\n"," json.dump(average_metrics, f, indent=4)\n"," \n"," return average_metrics\n","\n"," def calculate_loss(self, batch):\n"," seqs, labels, negs = batch\n","\n"," logits = self.model(seqs) # F.softmax(self.model(seqs), dim=-1)\n"," pos_logits = logits.gather(-1, labels.unsqueeze(-1))[seqs > 0].squeeze()\n"," pos_targets = torch.ones_like(pos_logits)\n"," neg_logits = logits.gather(-1, negs.unsqueeze(-1))[seqs > 0].squeeze()\n"," neg_targets = torch.zeros_like(neg_logits)\n","\n"," loss = self.bce(torch.cat((pos_logits, neg_logits), 0), torch.cat((pos_targets, neg_targets), 0))\n"," return loss\n","\n"," def calculate_metrics(self, batch):\n"," seqs, candidates, labels = batch\n","\n"," scores = self.model(seqs)\n"," scores = scores[:, -1, :]\n"," scores = scores.gather(1, candidates)\n","\n"," metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)\n"," return metrics\n","\n"," def clip_gradients(self, limit=5):\n"," for p in self.model.parameters():\n"," nn.utils.clip_grad_norm_(p, 5)\n","\n"," def _update_meter_set(self, meter_set, metrics):\n"," for k, v in metrics.items():\n"," meter_set.update(k, v)\n","\n"," def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):\n"," description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]\n"," ] + ['Recall@%d' % k for k in self.metric_ks[:3]]\n"," description = 'Eval: ' + \\\n"," ', '.join(s + ' {:.3f}' for s in description_metrics)\n"," description = description.replace('NDCG', 'N').replace('Recall', 'R')\n"," description = description.format(\n"," *(meter_set[k].avg for k in description_metrics))\n"," tqdm_dataloader.set_description(description)\n","\n"," def _create_optimizer(self):\n"," args = self.args\n"," param_optimizer = list(self.model.named_parameters())\n"," no_decay = ['bias', 'layer_norm']\n"," optimizer_grouped_parameters = [\n"," {\n"," 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n"," 'weight_decay': args.weight_decay,\n"," },\n"," {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n"," ]\n"," if args.optimizer.lower() == 'adamw':\n"," return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)\n"," elif args.optimizer.lower() == 'adam':\n"," return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)\n"," elif args.optimizer.lower() == 'sgd':\n"," return optim.SGD(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)\n"," else:\n"," raise ValueError\n","\n"," def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):\n"," # based on hugging face get_linear_schedule_with_warmup\n"," def lr_lambda(current_step: int):\n"," if current_step < num_warmup_steps:\n"," return float(current_step) / float(max(1, num_warmup_steps))\n"," return max(\n"," 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))\n"," )\n","\n"," return LambdaLR(optimizer, lr_lambda, last_epoch)\n","\n"," def _create_loggers(self):\n"," root = Path(self.export_root)\n"," writer = SummaryWriter(root.joinpath('logs'))\n"," model_checkpoint = root.joinpath('models')\n","\n"," train_loggers = [\n"," MetricGraphPrinter(writer, key='epoch',\n"," graph_name='Epoch', group_name='Train'),\n"," MetricGraphPrinter(writer, key='loss',\n"," graph_name='Loss', group_name='Train'),\n"," ]\n","\n"," val_loggers = []\n"," for k in self.metric_ks:\n"," val_loggers.append(\n"," MetricGraphPrinter(writer, key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation'))\n"," val_loggers.append(\n"," MetricGraphPrinter(writer, key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation'))\n"," val_loggers.append(RecentModelLogger(model_checkpoint))\n"," val_loggers.append(BestModelLogger(\n"," model_checkpoint, metric_key=self.best_metric))\n"," return writer, train_loggers, val_loggers\n","\n"," def _create_state_dict(self):\n"," return {\n"," STATE_DICT_KEY: self.model.module.state_dict() if self.is_parallel else self.model.state_dict(),\n"," OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),\n"," }\n","\n"," def _needs_to_log(self, accum_iter):\n"," return accum_iter % self.log_period_as_iter < self.args.train_batch_size and accum_iter != 0"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"au4jaGK2JWg7"},"source":["### BERT"]},{"cell_type":"code","metadata":{"id":"VUO2X3B-JWcy"},"source":["class BERTTrainer(metaclass=ABCMeta):\n"," def __init__(self, args, model, train_loader, val_loader, test_loader, export_root):\n"," self.args = args\n"," self.device = args.device\n"," self.model = model.to(self.device)\n"," self.is_parallel = args.num_gpu > 1\n"," if self.is_parallel:\n"," self.model = nn.DataParallel(self.model)\n","\n"," self.num_epochs = args.num_epochs\n"," self.metric_ks = args.metric_ks\n"," self.best_metric = args.best_metric\n"," self.train_loader = train_loader\n"," self.val_loader = val_loader\n"," self.test_loader = test_loader\n"," self.optimizer = self._create_optimizer()\n"," if args.enable_lr_schedule:\n"," if args.enable_lr_warmup:\n"," self.lr_scheduler = self.get_linear_schedule_with_warmup(\n"," self.optimizer, args.warmup_steps, len(train_loader) * self.num_epochs)\n"," else:\n"," self.lr_scheduler = optim.lr_scheduler.StepLR(\n"," self.optimizer, step_size=args.decay_step, gamma=args.gamma)\n"," \n"," self.export_root = export_root\n"," self.writer, self.train_loggers, self.val_loggers = self._create_loggers()\n"," self.logger_service = LoggerService(\n"," self.train_loggers, self.val_loggers)\n"," self.log_period_as_iter = args.log_period_as_iter\n","\n"," self.ce = nn.CrossEntropyLoss(ignore_index=0)\n","\n"," def train(self):\n"," accum_iter = 0\n"," self.validate(0, accum_iter)\n"," for epoch in range(self.num_epochs):\n"," accum_iter = self.train_one_epoch(epoch, accum_iter)\n"," self.validate(epoch, accum_iter)\n"," self.logger_service.complete({\n"," 'state_dict': (self._create_state_dict()),\n"," })\n"," self.writer.close()\n","\n"," def train_one_epoch(self, epoch, accum_iter):\n"," self.model.train()\n"," average_meter_set = AverageMeterSet()\n"," tqdm_dataloader = tqdm(self.train_loader)\n","\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," batch_size = batch[0].size(0)\n"," batch = [x.to(self.device) for x in batch]\n","\n"," self.optimizer.zero_grad()\n"," loss = self.calculate_loss(batch)\n"," loss.backward()\n"," self.clip_gradients(5)\n"," self.optimizer.step()\n"," if self.args.enable_lr_schedule:\n"," self.lr_scheduler.step()\n","\n"," average_meter_set.update('loss', loss.item())\n"," tqdm_dataloader.set_description(\n"," 'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))\n","\n"," accum_iter += batch_size\n","\n"," if self._needs_to_log(accum_iter):\n"," tqdm_dataloader.set_description('Logging to Tensorboard')\n"," log_data = {\n"," 'state_dict': (self._create_state_dict()),\n"," 'epoch': epoch + 1,\n"," 'accum_iter': accum_iter,\n"," }\n"," log_data.update(average_meter_set.averages())\n"," self.logger_service.log_train(log_data)\n","\n"," return accum_iter\n","\n"," def validate(self, epoch, accum_iter):\n"," self.model.eval()\n","\n"," average_meter_set = AverageMeterSet()\n","\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.val_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," batch = [x.to(self.device) for x in batch]\n","\n"," metrics = self.calculate_metrics(batch)\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," log_data = {\n"," 'state_dict': (self._create_state_dict()),\n"," 'epoch': epoch+1,\n"," 'accum_iter': accum_iter,\n"," }\n"," log_data.update(average_meter_set.averages())\n"," self.logger_service.log_val(log_data)\n","\n"," def test(self):\n"," best_model_dict = torch.load(os.path.join(\n"," self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)\n"," self.model.load_state_dict(best_model_dict)\n"," self.model.eval()\n","\n"," average_meter_set = AverageMeterSet()\n","\n"," all_scores = []\n"," average_scores = []\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.test_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," batch = [x.to(self.device) for x in batch]\n"," metrics = self.calculate_metrics(batch)\n"," \n"," # seqs, candidates, labels = batch\n"," # scores = self.model(seqs)\n"," # scores = scores[:, -1, :]\n"," # scores_sorted, indices = torch.sort(scores, dim=-1, descending=True)\n"," # all_scores += scores_sorted[:, :100].cpu().numpy().tolist()\n"," # average_scores += scores_sorted.cpu().numpy().tolist()\n"," # scores = scores.gather(1, candidates)\n"," # metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)\n","\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," average_metrics = average_meter_set.averages()\n"," with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:\n"," json.dump(average_metrics, f, indent=4)\n"," \n"," return average_metrics\n","\n"," def calculate_loss(self, batch):\n"," seqs, labels = batch\n"," logits = self.model(seqs)\n","\n"," logits = logits.view(-1, logits.size(-1))\n"," labels = labels.view(-1)\n"," loss = self.ce(logits, labels)\n"," return loss\n","\n"," def calculate_metrics(self, batch):\n"," seqs, candidates, labels = batch\n","\n"," scores = self.model(seqs)\n"," scores = scores[:, -1, :]\n"," scores = scores.gather(1, candidates)\n","\n"," metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)\n"," return metrics\n","\n"," def clip_gradients(self, limit=5):\n"," for p in self.model.parameters():\n"," nn.utils.clip_grad_norm_(p, 5)\n","\n"," def _update_meter_set(self, meter_set, metrics):\n"," for k, v in metrics.items():\n"," meter_set.update(k, v)\n","\n"," def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):\n"," description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]\n"," ] + ['Recall@%d' % k for k in self.metric_ks[:3]]\n"," description = 'Eval: ' + \\\n"," ', '.join(s + ' {:.3f}' for s in description_metrics)\n"," description = description.replace('NDCG', 'N').replace('Recall', 'R')\n"," description = description.format(\n"," *(meter_set[k].avg for k in description_metrics))\n"," tqdm_dataloader.set_description(description)\n","\n"," def _create_optimizer(self):\n"," args = self.args\n"," param_optimizer = list(self.model.named_parameters())\n"," no_decay = ['bias', 'layer_norm']\n"," optimizer_grouped_parameters = [\n"," {\n"," 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n"," 'weight_decay': args.weight_decay,\n"," },\n"," {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n"," ]\n"," if args.optimizer.lower() == 'adamw':\n"," return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)\n"," elif args.optimizer.lower() == 'adam':\n"," return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)\n"," elif args.optimizer.lower() == 'sgd':\n"," return optim.SGD(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)\n"," else:\n"," raise ValueError\n","\n"," def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):\n"," # based on hugging face get_linear_schedule_with_warmup\n"," def lr_lambda(current_step: int):\n"," if current_step < num_warmup_steps:\n"," return float(current_step) / float(max(1, num_warmup_steps))\n"," return max(\n"," 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))\n"," )\n","\n"," return LambdaLR(optimizer, lr_lambda, last_epoch)\n","\n"," def _create_loggers(self):\n"," root = Path(self.export_root)\n"," writer = SummaryWriter(root.joinpath('logs'))\n"," model_checkpoint = root.joinpath('models')\n","\n"," train_loggers = [\n"," MetricGraphPrinter(writer, key='epoch',\n"," graph_name='Epoch', group_name='Train'),\n"," MetricGraphPrinter(writer, key='loss',\n"," graph_name='Loss', group_name='Train'),\n"," ]\n","\n"," val_loggers = []\n"," for k in self.metric_ks:\n"," val_loggers.append(\n"," MetricGraphPrinter(writer, key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation'))\n"," val_loggers.append(\n"," MetricGraphPrinter(writer, key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation'))\n"," val_loggers.append(RecentModelLogger(model_checkpoint))\n"," val_loggers.append(BestModelLogger(\n"," model_checkpoint, metric_key=self.best_metric))\n"," return writer, train_loggers, val_loggers\n","\n"," def _create_state_dict(self):\n"," return {\n"," STATE_DICT_KEY: self.model.module.state_dict() if self.is_parallel else self.model.state_dict(),\n"," OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),\n"," }\n","\n"," def _needs_to_log(self, accum_iter):\n"," return accum_iter % self.log_period_as_iter < self.args.train_batch_size and accum_iter != 0"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"g8gMldjzJWYO"},"source":["### Distiller"]},{"cell_type":"code","metadata":{"id":"MCe9ufY0JWNl"},"source":["class NoDataRankDistillationTrainer(metaclass=ABCMeta):\n"," def __init__(self, args, model_code, model, bb_model, test_loader, export_root, loss='ranking', tau=1., margin_topk=0.5, margin_neg=0.5):\n"," self.args = args\n"," self.device = args.device\n"," self.num_items = args.num_items\n"," self.max_len = args.bert_max_len\n"," self.batch_size = args.train_batch_size\n"," self.mask_prob = args.bert_mask_prob\n"," self.max_predictions = args.bert_max_predictions\n"," self.CLOZE_MASK_TOKEN = self.num_items + 1\n","\n"," self.model = model.to(self.device)\n"," self.model_code = model_code\n"," self.bb_model = bb_model.to(self.device)\n","\n"," self.num_epochs = args.num_epochs\n"," self.metric_ks = args.metric_ks\n"," self.best_metric = args.best_metric\n"," self.export_root = export_root\n"," self.log_period_as_iter = args.log_period_as_iter\n","\n"," self.is_parallel = args.num_gpu > 1\n"," if self.is_parallel:\n"," self.model = nn.DataParallel(self.model)\n","\n"," self.test_loader = test_loader\n"," self.optimizer = self._create_optimizer()\n"," if args.enable_lr_schedule:\n"," if args.enable_lr_warmup:\n"," self.lr_scheduler = self.get_linear_schedule_with_warmup(\n"," self.optimizer, args.warmup_steps, (args.num_generated_seqs // self.batch_size + 1) * self.num_epochs * 2)\n"," else:\n"," self.lr_scheduler = optim.lr_scheduler.StepLR(\n"," self.optimizer, step_size=args.decay_step, gamma=args.gamma)\n","\n"," self.loss = loss\n"," self.tau = tau\n"," self.margin_topk = margin_topk\n"," self.margin_neg = margin_neg\n"," if self.loss == 'kl':\n"," self.loss_func = nn.KLDivLoss(reduction='batchmean')\n"," elif self.loss == 'ranking':\n"," self.loss_func_1 = nn.MarginRankingLoss(margin=self.margin_topk)\n"," self.loss_func_2 = nn.MarginRankingLoss(margin=self.margin_neg)\n"," elif self.loss == 'kl+ct':\n"," self.loss_func_1 = nn.KLDivLoss(reduction='batchmean')\n"," self.loss_func_2 = nn.CrossEntropyLoss(ignore_index=0)\n","\n"," def calculate_loss(self, seqs, labels, candidates, lengths=None):\n"," if isinstance(self.model, BERT) or isinstance(self.model, SASRec):\n"," logits = self.model(seqs)[:, -1, :]\n"," elif isinstance(self.model, NARM):\n"," logits = self.model(seqs, lengths)\n"," \n"," if self.loss == 'kl':\n"," logits = torch.gather(logits, -1, candidates)\n"," logits = logits.view(-1, logits.size(-1))\n"," labels = labels.view(-1, labels.size(-1))\n"," loss = self.loss_func(F.log_softmax(logits/self.tau, dim=-1), F.softmax(labels/self.tau, dim=-1))\n"," \n"," elif self.loss == 'ranking':\n"," # logits = F.softmax(logits/self.tau, dim=-1)\n"," weight = torch.ones_like(logits).to(self.device)\n"," weight[torch.arange(weight.size(0)).unsqueeze(1), candidates] = 0\n"," neg_samples = torch.distributions.Categorical(F.softmax(weight, -1)).sample_n(candidates.size(-1)).permute(1, 0)\n"," # assume candidates are in descending order w.r.t. true label\n"," neg_logits = torch.gather(logits, -1, neg_samples)\n"," logits = torch.gather(logits, -1, candidates)\n"," logits_1 = logits[:, :-1].reshape(-1)\n"," logits_2 = logits[:, 1:].reshape(-1)\n"," loss = self.loss_func_1(logits_1, logits_2, torch.ones(logits_1.shape).to(self.device))\n"," loss += self.loss_func_2(logits, neg_logits, torch.ones(logits.shape).to(self.device))\n"," \n"," elif self.loss == 'kl+ct':\n"," logits = torch.gather(logits, -1, candidates)\n"," logits = logits.view(-1, logits.size(-1))\n"," labels = labels.view(-1, labels.size(-1))\n"," loss = self.loss_func_1(F.log_softmax(logits/self.tau, dim=-1), F.softmax(labels/self.tau, dim=-1))\n"," loss += self.loss_func_2(F.softmax(logits), torch.argmax(labels, dim=-1))\n"," return loss\n","\n"," def calculate_metrics(self, batch, similarity=False):\n"," self.model.eval()\n"," self.bb_model.eval()\n","\n"," if isinstance(self.model, BERT) or isinstance(self.model, SASRec):\n"," seqs, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," scores = self.model(seqs)[:, -1, :]\n"," metrics = recalls_and_ndcgs_for_ks(scores.gather(1, candidates), labels, self.metric_ks)\n"," elif isinstance(self.model, NARM):\n"," seqs, lengths, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," lengths = lengths.flatten()\n"," scores = self.model(seqs, lengths)\n"," metrics = recalls_and_ndcgs_for_ks(scores.gather(1, candidates), labels, self.metric_ks)\n","\n"," if similarity:\n"," if isinstance(self.model, BERT) and isinstance(self.bb_model, BERT):\n"," soft_labels = self.bb_model(seqs)[:, -1, :]\n"," elif isinstance(self.model, BERT) and isinstance(self.bb_model, SASRec):\n"," temp_seqs = torch.cat((torch.zeros(seqs.size(0)).long().unsqueeze(1).to(self.device), seqs[:, :-1]), dim=1)\n"," soft_labels = self.bb_model(temp_seqs)[:, -1, :]\n"," elif isinstance(self.model, BERT) and isinstance(self.bb_model, NARM):\n"," temp_seqs = torch.cat((torch.zeros(seqs.size(0)).long().unsqueeze(1).to(self.device), seqs[:, :-1]), dim=1)\n"," temp_seqs = self.pre2post_padding(temp_seqs)\n"," temp_lengths = (temp_seqs > 0).sum(-1).cpu().flatten()\n"," soft_labels = self.bb_model(temp_seqs, temp_lengths)\n"," elif isinstance(self.model, SASRec) and isinstance(self.bb_model, SASRec):\n"," soft_labels = self.bb_model(seqs)[:, -1, :]\n"," elif isinstance(self.model, SASRec) and isinstance(self.bb_model, BERT):\n"," temp_seqs = torch.cat((seqs[:, 1:], torch.tensor([self.CLOZE_MASK_TOKEN] * seqs.size(0)).unsqueeze(1).to(self.device)), dim=1)\n"," soft_labels = self.bb_model(temp_seqs)[:, -1, :]\n"," elif isinstance(self.model, SASRec) and isinstance(self.bb_model, NARM):\n"," temp_seqs = self.pre2post_padding(seqs)\n"," temp_lengths = (temp_seqs > 0).sum(-1).cpu().flatten()\n"," soft_labels = self.bb_model(temp_seqs, temp_lengths)\n"," elif isinstance(self.model, NARM) and isinstance(self.bb_model, NARM):\n"," soft_labels = self.bb_model(seqs, lengths)\n"," elif isinstance(self.model, NARM) and isinstance(self.bb_model, BERT):\n"," temp_seqs = self.post2pre_padding(seqs)\n"," temp_seqs = torch.cat((temp_seqs[:, 1:], torch.tensor([self.CLOZE_MASK_TOKEN] * seqs.size(0)).unsqueeze(1).to(self.device)), dim=1)\n"," soft_labels = self.bb_model(temp_seqs)[:, -1, :]\n"," elif isinstance(self.model, NARM) and isinstance(self.bb_model, SASRec):\n"," temp_seqs = self.post2pre_padding(seqs)\n"," soft_labels = self.bb_model(temp_seqs)[:, -1, :]\n","\n"," similarity = kl_agreements_and_intersctions_for_ks(scores, soft_labels, self.metric_ks)\n"," metrics = {**metrics, **similarity} \n"," \n"," return metrics\n","\n"," def generate_autoregressive_data(self, k=100, batch_size=50):\n"," dataset = dis_dataset_factory(self.args, self.model_code, 'autoregressive')\n"," # if dataset.check_data_present():\n"," # print('Dataset already exists. Skip generation')\n"," # return\n"," \n"," batch_num = batch_size // self.args.num_generated_seqs\n"," print('Generating dataset...')\n"," for i in tqdm(range(batch_num)):\n"," seqs = torch.randint(1, self.num_items + 1, (batch_size, 1)).to(self.device)\n"," logits = None\n"," candidates = None\n"," \n"," self.bb_model.eval()\n"," with torch.no_grad():\n"," if isinstance(self.bb_model, BERT):\n"," mask_items = torch.tensor([self.CLOZE_MASK_TOKEN] * seqs.size(0)).to(self.device)\n"," for j in range(self.max_len - 1):\n"," input_seqs = torch.zeros((seqs.size(0), self.max_len)).to(self.device)\n"," input_seqs[:, (self.max_len-2-j):-1] = seqs\n"," input_seqs[:, -1] = mask_items\n"," labels = self.bb_model(input_seqs.long())[:, -1, :]\n","\n"," _, sorted_items = torch.sort(labels[:, 1:-1], dim=-1, descending=True)\n"," sorted_items = sorted_items[:, :k] + 1\n"," randomized_label = torch.rand(sorted_items.shape).to(self.device)\n"," randomized_label = randomized_label / randomized_label.sum(dim=-1).unsqueeze(-1)\n"," randomized_label, _ = torch.sort(randomized_label, dim=-1, descending=True)\n","\n"," selected_indices = torch.distributions.Categorical(F.softmax(torch.ones_like(randomized_label), -1).to(randomized_label.device)).sample()\n"," row_indices = torch.arange(sorted_items.size(0))\n"," seqs = torch.cat((seqs, sorted_items[row_indices, selected_indices].unsqueeze(1)), 1)\n","\n"," try:\n"," logits = torch.cat((logits, randomized_label.unsqueeze(1)), 1)\n"," candidates = torch.cat((candidates, sorted_items.unsqueeze(1)), 1)\n"," except:\n"," logits = randomized_label.unsqueeze(1)\n"," candidates = sorted_items.unsqueeze(1)\n"," \n"," input_seqs = torch.zeros((seqs.size(0), self.max_len)).to(self.device)\n"," input_seqs[:, :-1] = seqs[:, 1:]\n"," input_seqs[:, -1] = mask_items\n"," labels = self.bb_model(input_seqs.long())[:, -1, :]\n"," _, sorted_items = torch.sort(labels[:, 1:-1], dim=-1, descending=True)\n"," sorted_items = sorted_items[:, :k] + 1\n"," randomized_label = torch.rand(sorted_items.shape).to(self.device)\n"," randomized_label = randomized_label / randomized_label.sum(dim=-1).unsqueeze(-1)\n"," randomized_label, _ = torch.sort(randomized_label, dim=-1, descending=True)\n"," \n"," logits = torch.cat((logits, randomized_label.unsqueeze(1)), 1)\n"," candidates = torch.cat((candidates, sorted_items.unsqueeze(1)), 1)\n","\n"," elif isinstance(self.bb_model, SASRec):\n"," for j in range(self.max_len - 1):\n"," input_seqs = torch.zeros((seqs.size(0), self.max_len)).to(self.device)\n"," input_seqs[:, (self.max_len-1-j):] = seqs\n"," labels = self.bb_model(input_seqs.long())[:, -1, :]\n","\n"," _, sorted_items = torch.sort(labels[:, 1:], dim=-1, descending=True)\n"," sorted_items = sorted_items[:, :k] + 1\n"," randomized_label = torch.rand(sorted_items.shape).to(self.device)\n"," randomized_label = randomized_label / randomized_label.sum(dim=-1).unsqueeze(-1)\n"," randomized_label, _ = torch.sort(randomized_label, dim=-1, descending=True)\n"," \n"," selected_indices = torch.distributions.Categorical(F.softmax(torch.ones_like(randomized_label), -1).to(randomized_label.device)).sample()\n"," row_indices = torch.arange(sorted_items.size(0))\n"," seqs = torch.cat((seqs, sorted_items[row_indices, selected_indices].unsqueeze(1)), 1)\n","\n"," try:\n"," logits = torch.cat((logits, randomized_label.unsqueeze(1)), 1)\n"," candidates = torch.cat((candidates, sorted_items.unsqueeze(1)), 1)\n"," except:\n"," logits = randomized_label.unsqueeze(1)\n"," candidates = sorted_items.unsqueeze(1)\n","\n"," labels = self.bb_model(seqs.long())[:, -1, :]\n"," _, sorted_items = torch.sort(labels[:, 1:], dim=-1, descending=True)\n"," sorted_items = sorted_items[:, :k] + 1\n"," randomized_label = torch.rand(sorted_items.shape).to(self.device)\n"," randomized_label = randomized_label / randomized_label.sum(dim=-1).unsqueeze(-1)\n"," randomized_label, _ = torch.sort(randomized_label, dim=-1, descending=True)\n"," \n"," logits = torch.cat((logits, randomized_label.unsqueeze(1)), 1)\n"," candidates = torch.cat((candidates, sorted_items.unsqueeze(1)), 1)\n","\n"," elif isinstance(self.bb_model, NARM):\n"," for j in range(self.max_len - 1):\n"," lengths = torch.tensor([j + 1] * seqs.size(0))\n"," labels = self.bb_model(seqs.long(), lengths)\n","\n"," _, sorted_items = torch.sort(labels[:, 1:], dim=-1, descending=True)\n"," sorted_items = sorted_items[:, :k] + 1\n"," randomized_label = torch.rand(sorted_items.shape).to(self.device)\n"," randomized_label = randomized_label / randomized_label.sum(dim=-1).unsqueeze(-1)\n"," randomized_label, _ = torch.sort(randomized_label, dim=-1, descending=True) \n","\n"," selected_indices = torch.distributions.Categorical(F.softmax(torch.ones_like(randomized_label), -1).to(randomized_label.device)).sample()\n"," row_indices = torch.arange(sorted_items.size(0))\n"," seqs = torch.cat((seqs, sorted_items[row_indices, selected_indices].unsqueeze(1)), 1)\n"," \n"," try:\n"," logits = torch.cat((logits, randomized_label.unsqueeze(1)), 1)\n"," candidates = torch.cat((candidates, sorted_items.unsqueeze(1)), 1)\n"," except:\n"," logits = randomized_label.unsqueeze(1)\n"," candidates = sorted_items.unsqueeze(1)\n","\n"," lengths = torch.tensor([self.max_len] * seqs.size(0))\n"," labels = self.bb_model(seqs.long(), lengths)\n"," _, sorted_items = torch.sort(labels[:, 1:], dim=-1, descending=True)\n"," sorted_items = sorted_items[:, :k] + 1\n"," randomized_label = torch.rand(sorted_items.shape).to(self.device)\n"," randomized_label = randomized_label / randomized_label.sum(dim=-1).unsqueeze(-1)\n"," randomized_label, _ = torch.sort(randomized_label, dim=-1, descending=True)\n"," \n"," logits = torch.cat((logits, randomized_label.unsqueeze(1)), 1)\n"," candidates = torch.cat((candidates, sorted_items.unsqueeze(1)), 1)\n","\n"," if i == 0:\n"," batch_tokens = seqs.cpu().numpy()\n"," batch_logits = logits.cpu().numpy()\n"," batch_candidates = candidates.cpu().numpy()\n"," else:\n"," batch_tokens = np.concatenate((batch_tokens, seqs.cpu().numpy()))\n"," batch_logits = np.concatenate((batch_logits, logits.cpu().numpy()))\n"," batch_candidates = np.concatenate((batch_candidates, candidates.cpu().numpy()))\n","\n"," dataset.save_dataset(batch_tokens.tolist(), batch_logits.tolist(), batch_candidates.tolist())\n","\n"," def train_autoregressive(self): \n"," accum_iter = 0\n"," self.writer, self.train_loggers, self.val_loggers = self._create_loggers()\n"," self.logger_service = LoggerService(\n"," self.train_loggers, self.val_loggers)\n"," self.generate_autoregressive_data()\n"," dis_train_loader, dis_val_loader = dis_train_loader_factory(self.args, self.model_code, 'autoregressive')\n"," print('## Distilling model via autoregressive data... ##')\n"," self.validate(dis_val_loader, 0, accum_iter)\n"," for epoch in range(self.num_epochs):\n"," accum_iter = self.train_one_epoch(epoch, accum_iter, dis_train_loader, dis_val_loader, stage=1)\n"," \n"," metrics = self.test()\n"," \n"," self.logger_service.complete({\n"," 'state_dict': (self._create_state_dict()),\n"," })\n"," self.writer.close()\n","\n"," return metrics\n","\n"," def train_one_epoch(self, epoch, accum_iter, train_loader, val_loader, stage=0):\n"," self.model.train()\n"," self.bb_model.train()\n"," average_meter_set = AverageMeterSet()\n"," \n"," tqdm_dataloader = tqdm(train_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," self.optimizer.zero_grad()\n"," if isinstance(self.model, BERT) or isinstance(self.model, SASRec):\n"," seqs, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," loss = self.calculate_loss(seqs, labels, candidates)\n"," elif isinstance(self.model, NARM):\n"," seqs, lengths, candidates, labels = batch\n"," lengths = lengths.flatten()\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," loss = self.calculate_loss(seqs, labels, candidates, lengths=lengths)\n"," \n"," loss.backward()\n"," self.clip_gradients(5)\n"," self.optimizer.step()\n"," accum_iter += int(seqs.size(0))\n"," average_meter_set.update('loss', loss.item())\n"," tqdm_dataloader.set_description(\n"," 'Epoch {} Stage {}, loss {:.3f} '.format(epoch+1, stage, average_meter_set['loss'].avg))\n","\n"," if self._needs_to_log(accum_iter):\n"," log_data = {\n"," 'state_dict': (self._create_state_dict()),\n"," 'epoch': epoch+1,\n"," 'accum_iter': accum_iter,\n"," }\n"," log_data.update(average_meter_set.averages())\n"," self.logger_service.log_train(log_data)\n"," \n"," if self.args.enable_lr_schedule:\n"," self.lr_scheduler.step()\n"," \n"," self.validate(val_loader, epoch, accum_iter)\n"," return accum_iter\n","\n"," def validate(self, val_loader, epoch, accum_iter):\n"," self.model.eval()\n"," average_meter_set = AverageMeterSet()\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(val_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," metrics = self.calculate_metrics(batch)\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," log_data = {\n"," 'state_dict': (self._create_state_dict()),\n"," 'epoch': epoch+1,\n"," 'accum_iter': accum_iter,\n"," }\n"," log_data.update(average_meter_set.averages())\n"," self.logger_service.log_val(log_data)\n","\n"," def test(self):\n"," wb_model = torch.load(os.path.join(\n"," self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)\n"," self.model.load_state_dict(wb_model)\n"," \n"," self.model.eval()\n"," self.bb_model.eval()\n"," average_meter_set = AverageMeterSet()\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.test_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," metrics = self.calculate_metrics(batch, similarity=True)\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," average_metrics = average_meter_set.averages()\n"," with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:\n"," json.dump(average_metrics, f, indent=4)\n"," \n"," return average_metrics\n","\n"," def bb_model_test(self):\n"," self.bb_model.eval()\n"," average_meter_set = AverageMeterSet()\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.test_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," if isinstance(self.model, BERT) or isinstance(self.model, SASRec):\n"," seqs, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," scores = self.bb_model(seqs)[:, -1, :]\n"," metrics = recalls_and_ndcgs_for_ks(scores.gather(1, candidates), labels, self.metric_ks)\n"," elif isinstance(self.model, NARM):\n"," seqs, lengths, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," lengths = lengths.flatten()\n"," scores = self.bb_model(seqs, lengths)\n"," metrics = recalls_and_ndcgs_for_ks(scores.gather(1, candidates), labels, self.metric_ks)\n","\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," average_metrics = average_meter_set.averages()\n"," with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:\n"," json.dump(average_metrics, f, indent=4)\n"," \n"," return average_metrics\n","\n"," def pre2post_padding(self, seqs):\n"," processed = torch.zeros_like(seqs)\n"," lengths = (seqs > 0).sum(-1).squeeze()\n"," for i in range(seqs.size(0)):\n"," processed[i, :lengths[i]] = seqs[i, seqs.size(1)-lengths[i]:]\n"," return processed\n","\n"," def post2pre_padding(self, seqs):\n"," processed = torch.zeros_like(seqs)\n"," lengths = (seqs > 0).sum(-1).squeeze()\n"," for i in range(seqs.size(0)):\n"," processed[i, seqs.size(1)-lengths[i]:] = seqs[i, :lengths[i]]\n"," return processed\n","\n"," def clip_gradients(self, limit=5):\n"," for p in self.model.parameters():\n"," nn.utils.clip_grad_norm_(p, 5)\n","\n"," def _update_meter_set(self, meter_set, metrics):\n"," for k, v in metrics.items():\n"," meter_set.update(k, v)\n","\n"," def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):\n"," description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]\n"," ] + ['Recall@%d' % k for k in self.metric_ks[:3]]\n"," description = 'Eval: ' + \\\n"," ', '.join(s + ' {:.3f}' for s in description_metrics)\n"," description = description.replace('NDCG', 'N').replace('Recall', 'R')\n"," description = description.format(\n"," *(meter_set[k].avg for k in description_metrics))\n"," tqdm_dataloader.set_description(description)\n","\n"," def _create_optimizer(self):\n"," args = self.args\n"," param_optimizer = list(self.model.named_parameters())\n"," no_decay = ['bias', 'layer_norm']\n"," optimizer_grouped_parameters = [\n"," {\n"," 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n"," 'weight_decay': args.weight_decay,\n"," },\n"," {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n"," ]\n"," if args.optimizer.lower() == 'adamw':\n"," return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)\n"," elif args.optimizer.lower() == 'adam':\n"," return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)\n"," elif args.optimizer.lower() == 'sgd':\n"," return optim.SGD(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)\n"," else:\n"," raise ValueError\n","\n"," def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):\n"," # based on hugging face get_linear_schedule_with_warmup\n"," def lr_lambda(current_step: int):\n"," if current_step < num_warmup_steps:\n"," return float(current_step) / float(max(1, num_warmup_steps))\n"," return max(\n"," 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))\n"," )\n","\n"," return LambdaLR(optimizer, lr_lambda, last_epoch)\n","\n"," def _create_loggers(self):\n"," root = Path(self.export_root)\n"," writer = SummaryWriter(root.joinpath('logs'))\n"," model_checkpoint = root.joinpath('models')\n","\n"," train_loggers = [\n"," MetricGraphPrinter(writer, key='epoch',\n"," graph_name='Epoch', group_name='Train'),\n"," MetricGraphPrinter(writer, key='loss',\n"," graph_name='Loss', group_name='Train'),\n"," ]\n","\n"," val_loggers = []\n"," for k in self.metric_ks:\n"," val_loggers.append(\n"," MetricGraphPrinter(writer, key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation'))\n"," val_loggers.append(\n"," MetricGraphPrinter(writer, key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation'))\n"," val_loggers.append(RecentModelLogger(model_checkpoint))\n"," val_loggers.append(BestModelLogger(\n"," model_checkpoint, metric_key=self.best_metric))\n"," return writer, train_loggers, val_loggers\n","\n"," def _create_state_dict(self):\n"," return {\n"," STATE_DICT_KEY: self.model.module.state_dict() if self.is_parallel else self.model.state_dict(),\n"," OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),\n"," }\n","\n"," def _needs_to_log(self, accum_iter):\n"," return accum_iter % self.log_period_as_iter < self.args.train_batch_size and accum_iter != 0"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0X7a90BfvltQ"},"source":["## Train Black-Box Recommender Models"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"wZLBMAk71hUT","executionInfo":{"status":"ok","timestamp":1631626837875,"user_tz":-330,"elapsed":201702,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"6fcc12c9-3e1e-450b-b5fe-9a4a81297f71"},"source":["def train(args, export_root=None, resume=False):\n"," args.lr = 0.001\n"," fix_random_seed_as(args.model_init_seed)\n"," train_loader, val_loader, test_loader = dataloader_factory(args)\n","\n"," if args.model_code == 'bert':\n"," model = BERT(args)\n"," elif args.model_code == 'sas':\n"," model = SASRec(args)\n"," elif args.model_code == 'narm':\n"," model = NARM(args)\n","\n"," if export_root == None:\n"," export_root = 'experiments/' + args.model_code + '/' + args.dataset_code\n"," \n"," if resume:\n"," try: \n"," model.load_state_dict(torch.load(os.path.join(export_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))\n"," except FileNotFoundError:\n"," print('Failed to load old model, continue training new model...')\n","\n"," if args.model_code == 'bert':\n"," args.num_epochs = 10\n"," trainer = BERTTrainer(args, model, train_loader, val_loader, test_loader, export_root)\n"," if args.model_code == 'sas':\n"," trainer = SASTrainer(args, model, train_loader, val_loader, test_loader, export_root)\n"," elif args.model_code == 'narm':\n"," args.num_epochs = 100\n"," trainer = RNNTrainer(args, model, train_loader, val_loader, test_loader, export_root)\n","\n"," trainer.train()\n"," trainer.test()\n","\n","\n","if __name__ == \"__main__\":\n"," set_template(args)\n","\n"," # when use k-core beauty and k is not 5 (beauty-dense)\n"," # args.min_uc = k\n"," # args.min_sc = k\n","\n","\n"," train(args, resume=True)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Raw file doesn't exist. Downloading...\n","\n","Extracting data...\n","\n"]},{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:70: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.\n"]},{"output_type":"stream","name":"stdout","text":["Filtering triplets\n","Densifying index\n","Splitting\n"]},{"output_type":"stream","name":"stderr","text":["100%|██████████| 6040/6040 [00:10<00:00, 595.20it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Negative samples don't exist. Generating.\n","Sampling negative items randomly...\n"]},{"output_type":"stream","name":"stderr","text":["100%|██████████| 6040/6040 [00:01<00:00, 5155.44it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Negative samples don't exist. Generating.\n","Sampling negative items randomly...\n"]},{"output_type":"stream","name":"stderr","text":["100%|██████████| 6040/6040 [00:01<00:00, 5303.09it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Failed to load old model, continue training new model...\n"]},{"output_type":"stream","name":"stderr","text":["Eval: N@1 0.009, N@5 0.025, N@10 0.040, R@1 0.009, R@5 0.043, R@10 0.091: 100%|██████████| 48/48 [00:03<00:00, 13.01it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 1\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1, loss 8.047 : 100%|██████████| 68/68 [00:12<00:00, 5.43it/s]\n","Eval: N@1 0.026, N@5 0.066, N@10 0.095, R@1 0.026, R@5 0.108, R@10 0.196: 100%|██████████| 48/48 [00:03<00:00, 14.29it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 1\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 2, loss 7.721 : 100%|██████████| 68/68 [00:12<00:00, 5.44it/s]\n","Eval: N@1 0.041, N@5 0.103, N@10 0.142, R@1 0.041, R@5 0.166, R@10 0.287: 100%|██████████| 48/48 [00:03<00:00, 14.39it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 2\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 3, loss 7.419 : 100%|██████████| 68/68 [00:12<00:00, 5.42it/s]\n","Eval: N@1 0.055, N@5 0.130, N@10 0.172, R@1 0.055, R@5 0.205, R@10 0.337: 100%|██████████| 48/48 [00:03<00:00, 14.23it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 3\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 4, loss 7.230 : 100%|██████████| 68/68 [00:12<00:00, 5.46it/s]\n","Eval: N@1 0.070, N@5 0.153, N@10 0.200, R@1 0.070, R@5 0.235, R@10 0.380: 100%|██████████| 48/48 [00:03<00:00, 14.27it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 4\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 5, loss 7.126 : 100%|██████████| 68/68 [00:12<00:00, 5.46it/s]\n","Eval: N@1 0.076, N@5 0.167, N@10 0.215, R@1 0.076, R@5 0.256, R@10 0.405: 100%|██████████| 48/48 [00:03<00:00, 14.48it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 5\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 6, loss 7.071 : 100%|██████████| 68/68 [00:12<00:00, 5.43it/s]\n","Eval: N@1 0.076, N@5 0.169, N@10 0.218, R@1 0.076, R@5 0.260, R@10 0.412: 100%|██████████| 48/48 [00:03<00:00, 14.16it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 6\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 7, loss 7.043 : 100%|██████████| 68/68 [00:12<00:00, 5.48it/s]\n","Eval: N@1 0.077, N@5 0.174, N@10 0.223, R@1 0.077, R@5 0.269, R@10 0.424: 100%|██████████| 48/48 [00:03<00:00, 14.32it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 7\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 8, loss 7.026 : 100%|██████████| 68/68 [00:12<00:00, 5.46it/s]\n","Eval: N@1 0.080, N@5 0.177, N@10 0.226, R@1 0.080, R@5 0.272, R@10 0.425: 100%|██████████| 48/48 [00:03<00:00, 14.37it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 8\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 9, loss 7.007 : 100%|██████████| 68/68 [00:12<00:00, 5.49it/s]\n","Eval: N@1 0.080, N@5 0.180, N@10 0.227, R@1 0.080, R@5 0.277, R@10 0.423: 100%|██████████| 48/48 [00:03<00:00, 14.37it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 9\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 10, loss 7.004 : 100%|██████████| 68/68 [00:12<00:00, 5.48it/s]\n","Eval: N@1 0.080, N@5 0.178, N@10 0.226, R@1 0.080, R@5 0.274, R@10 0.423: 100%|██████████| 48/48 [00:03<00:00, 14.33it/s]\n","Eval: N@1 0.080, N@5 0.174, N@10 0.218, R@1 0.080, R@5 0.266, R@10 0.405: 100%|██████████| 48/48 [00:03<00:00, 14.35it/s]\n"]}]},{"cell_type":"markdown","metadata":{"id":"E3TYYGafwLMR"},"source":["## Extract a White-Box Recommender Model"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Z4ffcQ00pFiw","executionInfo":{"status":"ok","timestamp":1631628441513,"user_tz":-330,"elapsed":681488,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"d3d09ce2-c204-4697-b7be-e66dc5f3ab67"},"source":["def distill(args, bb_model_root=None, export_root=None, resume=False):\n"," args.lr = 0.001\n"," args.enable_lr_warmup = False\n"," fix_random_seed_as(args.model_init_seed)\n"," _, _, test_loader = dataloader_factory(args)\n","\n"," if args.model_code == 'bert':\n"," model = BERT(args)\n"," elif args.model_code == 'sas':\n"," model = SASRec(args)\n"," elif args.model_code == 'narm':\n"," model = NARM(args)\n"," \n"," # model_codes = {'b': 'bert', 's':'sas', 'n':'narm'}\n"," # bb_model_code = model_codes[input('Input black box model code, b for BERT, s for SASRec and n for NARM: ')]\n"," # args.num_generated_seqs = int(input('Input integer number of seqs budget: '))\n"," args.num_generated_seqs = 5\n","\n"," bb_model_code = 'bert'\n","\n"," if bb_model_code == 'bert':\n"," bb_model = BERT(args)\n"," elif bb_model_code == 'sas':\n"," bb_model = SASRec(args)\n"," elif bb_model_code == 'narm':\n"," bb_model = NARM(args)\n"," \n"," if bb_model_root == None:\n"," bb_model_root = 'experiments/' + bb_model_code + '/' + args.dataset_code\n"," if export_root == None:\n"," folder_name = bb_model_code + '2' + args.model_code + '_autoregressive' + str(args.num_generated_seqs)\n"," export_root = 'experiments/distillation_rank/' + folder_name + '/' + args.dataset_code\n","\n"," bb_model.load_state_dict(torch.load(os.path.join(bb_model_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))\n"," if resume:\n"," try:\n"," model.load_state_dict(torch.load(os.path.join(export_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))\n"," except FileNotFoundError:\n"," print('Failed to load old model, continue training new model...')\n"," trainer = NoDataRankDistillationTrainer(args, args.model_code, model, bb_model, test_loader, export_root)\n","\n"," trainer.train_autoregressive()\n","\n","\n","if __name__ == \"__main__\":\n"," set_template(args)\n","\n"," # when use k-core beauty and k is not 5 (beauty-dense)\n"," # args.min_uc = k\n"," # args.min_sc = k\n"," args.num_epochs = 5\n"," distill(args=args, resume=False)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Already preprocessed. Skip preprocessing\n","Negatives samples exist. Loading.\n","Negatives samples exist. Loading.\n","Generating dataset...\n"]},{"output_type":"stream","name":"stderr","text":["100%|██████████| 10/10 [00:43<00:00, 4.36s/it]\n"]},{"output_type":"stream","name":"stdout","text":["## Distilling model via autoregressive data... ##\n"]},{"output_type":"stream","name":"stderr","text":["Eval: N@1 0.000, N@5 0.000, N@10 0.000, R@1 0.000, R@5 0.000, R@10 0.000: 100%|██████████| 4/4 [00:00<00:00, 11.28it/s]\n"," 0%| | 0/778 [00:00 0).sum(-1).cpu().flatten()\n"," wb_embedding, mask = self.wb_model.embedding(perturbed_seqs.long(), lengths)\n","\n"," self.wb_model.train()\n"," wb_embedding = wb_embedding.detach().clone()\n"," wb_embedding.requires_grad = True\n"," zero_gradients(wb_embedding)\n","\n"," if isinstance(self.wb_model, BERT) or isinstance(self.wb_model, SASRec):\n"," wb_scores = self.wb_model.model(wb_embedding, self.wb_model.embedding.token.weight, mask)[:, -1, :]\n"," elif isinstance(self.wb_model, NARM):\n"," wb_scores = self.wb_model.model(wb_embedding, self.wb_model.embedding.token.weight, lengths, mask)\n","\n"," loss = self.adv_ce(wb_scores, torch.tensor([target] * perturbed_seqs.size(0)).to(self.device))\n"," self.wb_model.zero_grad()\n"," loss.backward()\n"," wb_embedding_grad = wb_embedding.grad.data\n","\n"," self.wb_model.eval()\n"," with torch.no_grad():\n"," appended_indicies = (perturbed_seqs != self.CLOZE_MASK_TOKEN)\n"," appended_indicies = (perturbed_seqs != 0) * appended_indicies\n"," appended_indicies = torch.arange(perturbed_seqs.shape[1]).to(self.device) * appended_indicies\n"," _, appended_indicies = torch.sort(appended_indicies, -1, descending=True)\n"," appended_indicies = appended_indicies[:, :num_attack]\n"," \n"," best_seqs = perturbed_seqs.clone().detach()\n"," for num in range(num_attack):\n"," row_indices = torch.arange(seqs.size(0))\n"," col_indices = appended_indicies[:, num]\n","\n"," current_embedding = wb_embedding[row_indices, col_indices]\n"," current_embedding_grad = wb_embedding_grad[row_indices, col_indices]\n"," all_embeddings = self.item_embeddings.unsqueeze(1).repeat_interleave(current_embedding.size(0), 1)\n"," cos = nn.CosineSimilarity(dim=-1, eps=1e-6)\n"," multipication_results = torch.t(cos(current_embedding-current_embedding_grad.sign(), all_embeddings))\n"," _, candidate_indicies = torch.sort(multipication_results, dim=1, descending=True)\n","\n"," if num == 0:\n"," multipication_results[:, target-1] = multipication_results[:, target-1] - 100000000\n"," _, candidate_indicies = torch.sort(multipication_results, dim=1, descending=True)\n"," best_seqs[row_indices, col_indices] = candidate_indicies[:, 0] + 1\n","\n"," if isinstance(self.wb_model, BERT) or isinstance(self.wb_model, SASRec):\n"," logits = F.softmax(self.wb_model(best_seqs)[:, -1, :], dim=-1)\n"," elif isinstance(self.wb_model, NARM):\n"," logits = F.softmax(self.wb_model(best_seqs, lengths), dim=-1)\n"," best_scores = torch.gather(logits, -1, torch.tensor([target] * best_seqs.size(0)).unsqueeze(1).to(self.device)).squeeze()\n","\n"," elif num > 0:\n"," prev_col_indices = appended_indicies[:, num-1]\n"," if_prev_target = (best_seqs[row_indices, prev_col_indices] == target)\n"," multipication_results[:, target-1] = multipication_results[:, target-1] + (if_prev_target * -100000000)\n"," _, candidate_indicies = torch.sort(multipication_results, dim=1, descending=True)\n"," best_seqs[row_indices, col_indices] = best_seqs[row_indices, col_indices] * ~if_prev_target + \\\n"," (candidate_indicies[:, 0] + 1) * if_prev_target\n"," \n"," if isinstance(self.wb_model, BERT) or isinstance(self.wb_model, SASRec):\n"," logits = F.softmax(self.wb_model(best_seqs)[:, -1, :], dim=-1)\n"," elif isinstance(self.wb_model, NARM):\n"," logits = F.softmax(self.wb_model(best_seqs, lengths), dim=-1)\n"," best_scores = torch.gather(logits, -1, torch.tensor([target] * best_seqs.size(0)).unsqueeze(1).to(self.device)).squeeze()\n","\n"," for time in range(repeated_search):\n"," temp_seqs = best_seqs.clone().detach()\n"," temp_seqs[row_indices, col_indices] = candidate_indicies[:, time] + 1\n","\n"," if isinstance(self.wb_model, BERT) or isinstance(self.wb_model, SASRec):\n"," logits = F.softmax(self.wb_model(temp_seqs)[:, -1, :], dim=-1)\n"," elif isinstance(self.wb_model, NARM):\n"," logits = F.softmax(self.wb_model(temp_seqs, lengths), dim=-1)\n"," temp_scores = torch.gather(logits, -1, torch.tensor([target] * temp_seqs.size(0)).unsqueeze(1).to(self.device)).squeeze()\n","\n"," best_seqs[row_indices, col_indices] = temp_seqs[row_indices, col_indices] * (temp_scores >= best_scores) + best_seqs[row_indices, col_indices] * (temp_scores < best_scores)\n"," best_scores = temp_scores * (temp_scores >= best_scores) + best_scores * (temp_scores < best_scores)\n"," best_seqs = best_seqs.detach()\n"," best_scores = best_scores.detach()\n"," del temp_scores\n"," \n"," perturbed_seqs = best_seqs.detach()\n"," if isinstance(self.wb_model, BERT) and isinstance(self.bb_model, BERT):\n"," perturbed_scores = self.bb_model(perturbed_seqs)[:, -1, :]\n"," elif isinstance(self.wb_model, BERT) and isinstance(self.bb_model, SASRec):\n"," temp_seqs = torch.cat((torch.zeros(perturbed_seqs.size(0)).long().unsqueeze(1).to(self.device), perturbed_seqs[:, :-1]), dim=1)\n"," perturbed_scores = self.bb_model(temp_seqs)[:, -1, :]\n"," elif isinstance(self.wb_model, BERT) and isinstance(self.bb_model, NARM):\n"," temp_seqs = torch.cat((torch.zeros(perturbed_seqs.size(0)).long().unsqueeze(1).to(self.device), perturbed_seqs[:, :-1]), dim=1)\n"," temp_seqs = self.pre2post_padding(temp_seqs)\n"," temp_lengths = (temp_seqs > 0).sum(-1).cpu().flatten()\n"," perturbed_scores = self.bb_model(temp_seqs, temp_lengths)\n"," elif isinstance(self.wb_model, SASRec) and isinstance(self.bb_model, SASRec):\n"," perturbed_scores = self.bb_model(perturbed_seqs)[:, -1, :]\n"," elif isinstance(self.wb_model, SASRec) and isinstance(self.bb_model, BERT):\n"," temp_seqs = torch.cat((perturbed_seqs[:, 1:], torch.tensor([self.CLOZE_MASK_TOKEN] * perturbed_seqs.size(0)).unsqueeze(1).to(self.device)), dim=1)\n"," perturbed_scores = self.bb_model(temp_seqs)[:, -1, :]\n"," elif isinstance(self.wb_model, SASRec) and isinstance(self.bb_model, NARM):\n"," temp_seqs = self.pre2post_padding(perturbed_seqs)\n"," temp_lengths = (temp_seqs > 0).sum(-1).cpu().flatten()\n"," perturbed_scores = self.bb_model(temp_seqs, temp_lengths)\n"," elif isinstance(self.wb_model, NARM) and isinstance(self.bb_model, NARM):\n"," perturbed_scores = self.bb_model(perturbed_seqs, lengths)\n"," elif isinstance(self.wb_model, NARM) and isinstance(self.bb_model, BERT):\n"," temp_seqs = self.post2pre_padding(perturbed_seqs)\n"," temp_seqs = torch.cat((temp_seqs[:, 1:], torch.tensor([self.CLOZE_MASK_TOKEN] * perturbed_seqs.size(0)).unsqueeze(1).to(self.device)), dim=1)\n"," perturbed_scores = self.bb_model(temp_seqs)[:, -1, :]\n"," elif isinstance(self.wb_model, NARM) and isinstance(self.bb_model, SASRec):\n"," temp_seqs = self.post2pre_padding(perturbed_seqs)\n"," perturbed_scores = self.bb_model(temp_seqs)[:, -1, :]\n"," \n"," candidates[:, 0] = torch.tensor([target] * candidates.size(0)).to(self.device) \n"," perturbed_scores = perturbed_scores.gather(1, candidates)\n"," metrics = recalls_and_ndcgs_for_ks(perturbed_scores, labels, self.metric_ks)\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(tqdm_dataloader, average_meter_set)\n","\n"," average_metrics = average_meter_set.averages()\n"," return average_metrics\n"," \n","\n"," def test(self, target=None):\n"," if target is not None:\n"," print('## Black-Box Targeted Test on Item {} ##'.format(str(target)))\n"," else:\n"," print('## Black-Box Untargeted Test on Item Level ##')\n"," \n"," self.bb_model.eval()\n"," average_meter_set = AverageMeterSet()\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.test_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, SASRec):\n"," seqs, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," scores = self.bb_model(seqs)[:, -1, :]\n"," elif isinstance(self.bb_model, NARM):\n"," seqs, lengths, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," lengths = lengths.flatten()\n"," scores = self.bb_model(seqs, lengths)\n"," \n"," if target is not None:\n"," candidates[:, 0] = torch.tensor([target] * seqs.size(0)).to(self.device)\n"," scores = scores.gather(1, candidates)\n"," metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," average_metrics = average_meter_set.averages()\n"," return average_metrics\n","\n"," def calculate_metrics(self, batch):\n"," self.bb_model.eval()\n","\n"," if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, SASRec):\n"," seqs, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," scores = self.bb_model(seqs)[:, -1, :]\n"," elif isinstance(self.bb_model, NARM):\n"," seqs, lengths, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," lengths = lengths.flatten()\n"," scores = self.bb_model(seqs, lengths)\n","\n"," scores = scores.gather(1, candidates) # B x C\n"," metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)\n"," return metrics\n","\n"," def pre2post_padding(self, seqs):\n"," processed = torch.zeros_like(seqs)\n"," lengths = (seqs > 0).sum(-1).squeeze()\n"," for i in range(seqs.size(0)):\n"," processed[i, :lengths[i]] = seqs[i, seqs.size(1)-lengths[i]:]\n"," return processed\n","\n"," def post2pre_padding(self, seqs):\n"," processed = torch.zeros_like(seqs)\n"," lengths = (seqs > 0).sum(-1).squeeze()\n"," for i in range(seqs.size(0)):\n"," processed[i, seqs.size(1)-lengths[i]:] = seqs[i, :lengths[i]]\n"," return processed\n","\n"," def _update_meter_set(self, meter_set, metrics):\n"," for k, v in metrics.items():\n"," meter_set.update(k, v)\n","\n"," def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):\n"," description_metrics = ['Recall@%d' % k for k in self.metric_ks[:3]] + ['NDCG@%d' % k for k in self.metric_ks[1:3]]\n"," description = 'Val: ' + ', '.join(s + ' {:.3f}' for s in description_metrics)\n"," description = description.replace('NDCG', 'N').replace('Recall', 'R')\n"," description = description.format(*(meter_set[k].avg for k in description_metrics))\n"," tqdm_dataloader.set_description(description)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1fvIwzj4pkXJ"},"source":["#### Datasets"]},{"cell_type":"code","metadata":{"id":"lru97jRiplie"},"source":["import pickle\n","import shutil\n","import tempfile\n","import os\n","from pathlib import Path\n","import numpy as np\n","from abc import *\n","\n","\n","class AbstractPoisonedDataset(metaclass=ABCMeta):\n"," def __init__(self, args, target, method_code, num_poisoned_seqs=0, num_original_seqs=0):\n"," self.args = args\n"," if isinstance(target, list):\n"," self.target = target_spec = '_'.join([str(t) for t in target])\n"," else:\n"," self.target = target\n"," self.method_code = method_code\n"," self.num_poisoned_seqs = num_poisoned_seqs\n"," self.num_original_seqs = num_original_seqs\n","\n"," @classmethod\n"," @abstractmethod\n"," def code(cls):\n"," pass\n","\n"," @classmethod\n"," def raw_code(cls):\n"," return cls.code()\n","\n"," def check_data_present(self):\n"," dataset_path = self._get_poisoned_dataset_path()\n"," return dataset_path.is_file()\n","\n"," def load_dataset(self):\n"," dataset_path = self._get_poisoned_dataset_path()\n"," if not dataset_path.is_file():\n"," print('Dataset not found, please generate distillation dataset first')\n"," return\n"," dataset = pickle.load(dataset_path.open('rb'))\n"," return dataset\n","\n"," def save_dataset(self, tokens, original_dataset_size=0, valid_all=False):\n"," original_dataset = dataset_factory(self.args)\n"," original_dataset = original_dataset.load_dataset()\n"," train = original_dataset['train']\n"," val = original_dataset['val']\n"," test = original_dataset['test']\n"," self.num_poisoned_seqs = len(tokens)\n"," self.num_original_seqs = len(train)\n"," start_index = len(train) + 1\n"," \n"," if original_dataset_size > 0:\n"," sampled_users = np.random.choice(list(train.keys()), original_dataset_size)\n"," train_ = {idx + 1: train[user] for idx, user in enumerate(sampled_users)}\n"," val_ = {idx + 1: val[user] for idx, user in enumerate(sampled_users)}\n"," test_ = {idx + 1: test[user] for idx, user in enumerate(sampled_users)}\n"," train, val, test = train_, val_, test_\n"," self.num_original_seqs = original_dataset_size\n"," start_index = original_dataset_size + 1\n"," \n"," self.poisoning_users = []\n"," for i in range(len(tokens)):\n"," items = tokens[i]\n"," user = start_index + i\n"," self.poisoning_users.append(user)\n"," train[user], val[user], test[user] = items[:-2], items[-2:-1], items[-1:]\n","\n"," dataset_path = self._get_poisoned_dataset_path()\n"," if not dataset_path.parent.is_dir():\n"," dataset_path.parent.mkdir(parents=True)\n","\n"," dataset = {'train': train,\n"," 'val': val,\n"," 'test': test}\n","\n"," with dataset_path.open('wb') as f:\n"," pickle.dump(dataset, f)\n","\n"," return self.num_poisoned_seqs, self.num_original_seqs, self.poisoning_users\n","\n"," def _get_rawdata_root_path(self):\n"," return Path(GEN_DATASET_ROOT_FOLDER)\n","\n"," def _get_folder_path(self):\n"," root = self._get_rawdata_root_path()\n"," return root.joinpath(self.raw_code())\n","\n"," def _get_subfolder_path(self):\n"," root = self._get_folder_path()\n"," folder = 'poisoned' + str(self.num_poisoned_seqs) + '_' + 'original' + str(self.num_original_seqs)\n"," return root.joinpath(self.method_code + '_target_' + str(self.target) + '_' + folder)\n","\n"," def _get_poisoned_dataset_path(self):\n"," folder = self._get_subfolder_path()\n"," return folder.joinpath('poisoned_dataset.pkl')\n","\n","\n","class ML1MPoisonedDataset(AbstractPoisonedDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'ml-1m'\n","\n","\n","class ML20MPoisonedDataset(AbstractPoisonedDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'ml-20m'\n","\n","\n","class BeautyPoisonedDataset(AbstractPoisonedDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'beauty'\n","\n","\n","class SteamPoisonedDataset(AbstractPoisonedDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'steam'\n","\n","\n","class YooChoosePoisonedDataset(AbstractPoisonedDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'yoochoose'"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"m07w0vtro-I8"},"source":["#### Dataloader"]},{"cell_type":"code","metadata":{"id":"WKos5BMzoWW4"},"source":["import torch\n","import torch.utils.data as data_utils\n","\n","import random\n","\n","POI_DATASETS = {\n"," ML1MPoisonedDataset.code(): ML1MPoisonedDataset,\n"," ML20MPoisonedDataset.code(): ML20MPoisonedDataset,\n"," BeautyPoisonedDataset.code(): BeautyPoisonedDataset,\n"," SteamPoisonedDataset.code(): SteamPoisonedDataset,\n"," YooChoosePoisonedDataset.code(): YooChoosePoisonedDataset,\n","}\n","\n","\n","def poi_dataset_factory(args, target, method_code, num_poisoned_seqs=0, num_original_seqs=0):\n"," dataset = POI_DATASETS[args.dataset_code]\n"," return dataset(args, target, method_code, num_poisoned_seqs, num_original_seqs)\n","\n","\n","def poi_train_loader_factory(args, target, method_code, num_poisoned_seqs, num_original_seqs, poisoning_users=None):\n"," dataset = poi_dataset_factory(args, target, method_code, num_poisoned_seqs, num_original_seqs)\n"," if dataset.check_data_present():\n"," dataloader = PoisonedDataLoader(args, dataset)\n"," train, val, test = dataloader.get_loaders(poisoning_users)\n"," return train, val, test\n"," else:\n"," return None\n","\n","\n","class PoisonedDataLoader():\n"," def __init__(self, args, dataset):\n"," self.args = args\n"," self.rng = random.Random()\n"," self.save_folder = dataset._get_subfolder_path()\n"," dataset = dataset.load_dataset()\n"," self.train = dataset['train']\n"," self.val = dataset['val']\n"," self.test = dataset['test']\n"," \n"," self.user_count = len(self.train)\n"," self.item_count = self.args.num_items\n"," self.max_len = args.bert_max_len\n"," self.mask_prob = args.bert_mask_prob\n"," self.max_predictions = args.bert_max_predictions\n"," self.sliding_size = args.sliding_window_size\n"," self.CLOZE_MASK_TOKEN = self.args.num_items + 1\n","\n"," val_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,\n"," self.train, self.val, self.test,\n"," self.user_count, self.item_count,\n"," args.test_negative_sample_size,\n"," args.test_negative_sampling_seed,\n"," 'poisoned_val', self.save_folder)\n"," test_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,\n"," self.train, self.val, self.test,\n"," self.user_count, self.item_count,\n"," args.test_negative_sample_size,\n"," args.test_negative_sampling_seed,\n"," 'poisoned_test', self.save_folder)\n","\n"," self.seen_samples, self.val_negative_samples = val_negative_sampler.get_negative_samples()\n"," self.seen_samples, self.test_negative_samples = test_negative_sampler.get_negative_samples()\n","\n"," @classmethod\n"," def code(cls):\n"," return 'distillation_loader'\n","\n"," def get_loaders(self, poisoning_users=None):\n"," train, val, test = self._get_datasets(poisoning_users)\n"," train_loader = data_utils.DataLoader(train, batch_size=self.args.train_batch_size,\n"," shuffle=True, pin_memory=True)\n"," val_loader = data_utils.DataLoader(val, batch_size=self.args.train_batch_size,\n"," shuffle=True, pin_memory=True)\n"," test_loader = data_utils.DataLoader(test, batch_size=self.args.train_batch_size,\n"," shuffle=True, pin_memory=True)\n"," \n"," return train_loader, val_loader, test_loader\n","\n"," def _get_datasets(self, poisoning_users=None):\n"," if self.args.model_code == 'bert':\n"," train = BERTTrainDataset(self.train, self.max_len, self.mask_prob, self.max_predictions, self.sliding_size, self.CLOZE_MASK_TOKEN, self.item_count, self.rng)\n"," val = BERTValidDataset(self.train, self.val, self.max_len, self.CLOZE_MASK_TOKEN, self.val_negative_samples, poisoning_users)\n"," test = BERTTestDataset(self.train, self.val, self.test, self.max_len, self.CLOZE_MASK_TOKEN, self.test_negative_samples, poisoning_users)\n"," elif self.args.model_code == 'sas':\n"," train = SASTrainDataset(self.train, self.max_len, self.sliding_size, self.seen_samples, self.item_count, self.rng)\n"," val = SASValidDataset(self.train, self.val, self.max_len, self.val_negative_samples, poisoning_users)\n"," test = SASTestDataset(self.train, self.val, self.test, self.max_len, self.test_negative_samples, poisoning_users)\n"," elif self.args.model_code == 'narm':\n"," train = RNNTrainDataset(self.train, self.max_len)\n"," val = RNNValidDataset(self.train, self.val, self.max_len, self.val_negative_samples, poisoning_users)\n"," test = RNNTestDataset(self.train, self.val, self.test, self.max_len, self.test_negative_samples, poisoning_users)\n"," \n"," return train, val, test"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kNQ0WQb4oWQl"},"source":["#### Logger"]},{"cell_type":"code","metadata":{"id":"jlBS_vzZoWLo"},"source":["import os\n","import torch\n","from abc import ABCMeta, abstractmethod\n","\n","\n","def save_state_dict(state_dict, path, filename):\n"," torch.save(state_dict, os.path.join(path, filename))\n","\n","\n","class LoggerService(object):\n"," def __init__(self, train_loggers=None, val_loggers=None):\n"," self.train_loggers = train_loggers if train_loggers else []\n"," self.val_loggers = val_loggers if val_loggers else []\n","\n"," def complete(self, log_data):\n"," for logger in self.train_loggers:\n"," logger.complete(**log_data)\n"," for logger in self.val_loggers:\n"," logger.complete(**log_data)\n","\n"," def log_train(self, log_data):\n"," for logger in self.train_loggers:\n"," logger.log(**log_data)\n","\n"," def log_val(self, log_data):\n"," for logger in self.val_loggers:\n"," logger.log(**log_data)\n","\n","\n","class AbstractBaseLogger(metaclass=ABCMeta):\n"," @abstractmethod\n"," def log(self, *args, **kwargs):\n"," raise NotImplementedError\n","\n"," def complete(self, *args, **kwargs):\n"," pass\n","\n","\n","class RecentModelLogger(AbstractBaseLogger):\n"," def __init__(self, checkpoint_path, filename='checkpoint-recent.pth'):\n"," self.checkpoint_path = checkpoint_path\n"," if not os.path.exists(self.checkpoint_path):\n"," os.mkdir(self.checkpoint_path)\n"," self.recent_epoch = None\n"," self.filename = filename\n","\n"," def log(self, *args, **kwargs):\n"," epoch = kwargs['epoch']\n","\n"," if self.recent_epoch != epoch:\n"," self.recent_epoch = epoch\n"," state_dict = kwargs['state_dict']\n"," state_dict['epoch'] = kwargs['epoch']\n"," save_state_dict(state_dict, self.checkpoint_path, self.filename)\n","\n"," def complete(self, *args, **kwargs):\n"," save_state_dict(kwargs['state_dict'],\n"," self.checkpoint_path, self.filename + '.final')\n","\n","\n","class BestModelLogger(AbstractBaseLogger):\n"," def __init__(self, checkpoint_path, metric_key='mean_iou', filename='best_acc_model.pth'):\n"," self.checkpoint_path = checkpoint_path\n"," if not os.path.exists(self.checkpoint_path):\n"," os.mkdir(self.checkpoint_path)\n","\n"," self.best_metric = 0.\n"," self.metric_key = metric_key\n"," self.filename = filename\n","\n"," def log(self, *args, **kwargs):\n"," current_metric = kwargs[self.metric_key]\n"," if self.best_metric < current_metric:\n"," print(\"Update Best {} Model at {}\".format(\n"," self.metric_key, kwargs['epoch']))\n"," self.best_metric = current_metric\n"," save_state_dict(kwargs['state_dict'],\n"," self.checkpoint_path, self.filename)\n","\n","\n","class MetricGraphPrinter(AbstractBaseLogger):\n"," def __init__(self, writer, key='train_loss', graph_name='Train Loss', group_name='metric'):\n"," self.key = key\n"," self.graph_label = graph_name\n"," self.group_name = group_name\n"," self.writer = writer\n","\n"," def log(self, *args, **kwargs):\n"," if self.key in kwargs:\n"," self.writer.add_scalar(\n"," self.group_name + '/' + self.graph_label, kwargs[self.key], kwargs['accum_iter'])\n"," else:\n"," self.writer.add_scalar(\n"," self.group_name + '/' + self.graph_label, 0, kwargs['accum_iter'])\n","\n"," def complete(self, *args, **kwargs):\n"," self.writer.close()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DUM7N4K9oWGb"},"source":["#### Retrainer"]},{"cell_type":"code","metadata":{"id":"h-7WBljGoWAG"},"source":["import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import torch.nn.functional as F\n","from torch.optim.lr_scheduler import LambdaLR\n","# from torch.autograd.gradcheck import zero_gradients\n","from torch.utils.tensorboard import SummaryWriter\n","from tqdm import tqdm\n","\n","import json\n","import math\n","import faiss\n","import numpy as np\n","from abc import *\n","from pathlib import Path\n","\n","\n","class PoisonedGroupRetrainer(metaclass=ABCMeta):\n"," def __init__(self, args, wb_model_spec, wb_model, bb_model, original_test_loader, bb_model_root=None):\n"," self.args = args\n"," self.device = args.device\n"," self.num_items = args.num_items\n"," self.max_len = args.bert_max_len\n"," self.wb_model_spec = wb_model_spec\n"," self.wb_model = wb_model.to(self.device)\n"," self.bb_model = bb_model.to(self.device)\n"," self.is_parallel = args.num_gpu > 1\n"," if self.is_parallel:\n"," self.bb_model = nn.DataParallel(self.bb_model)\n","\n"," self.num_epochs = args.num_epochs\n"," self.metric_ks = args.metric_ks\n"," self.best_metric = args.best_metric\n"," self.original_test_loader = original_test_loader\n"," if bb_model_root == None:\n"," self.bb_model_root = 'experiments/' + args.model_code + '/' + args.dataset_code\n"," else:\n"," self.bb_model_root = bb_model_root\n"," \n"," if isinstance(self.wb_model, BERT):\n"," self.item_embeddings = self.wb_model.embedding.token.weight.detach().cpu().numpy()[1:-1]\n"," else:\n"," self.item_embeddings = self.wb_model.embedding.token.weight.detach().cpu().numpy()[1:]\n"," \n"," self.faiss_index = faiss.IndexFlatL2(self.item_embeddings.shape[-1])\n"," self.faiss_index.add(self.item_embeddings)\n"," self.item_embeddings = torch.tensor(self.item_embeddings).to(self.device)\n","\n"," self.CLOZE_MASK_TOKEN = args.num_items + 1\n"," self.adv_ce = nn.CrossEntropyLoss(ignore_index=0)\n"," if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, NARM):\n"," self.ce = nn.CrossEntropyLoss(ignore_index=0)\n"," elif isinstance(self.bb_model, SASRec):\n"," self.ce = nn.BCEWithLogitsLoss()\n","\n","\n"," def train_ours(self, targets, ratio, popular_items, num_items):\n"," num_poisoned, num_original, poisoning_users = self.generate_poisoned_data(targets, popular_items, num_items)\n"," target_spec = '_'.join([str(target) for target in targets])\n"," self.train_loader, self.val_loader, self.test_loader = poi_train_loader_factory(self.args, target_spec, self.wb_model_spec, num_poisoned, num_original)\n"," self.bb_model.load_state_dict(torch.load(os.path.join(self.bb_model_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))\n"," self.export_root = 'experiments/retrained/' + self.wb_model_spec + '/' + self.args.dataset_code + '/ratio_' + str(ratio) + '_target_' + target_spec\n"," self.writer, self.train_loggers, self.val_loggers = self._create_loggers()\n"," self.logger_service = LoggerService(\n"," self.train_loggers, self.val_loggers)\n"," self.log_period_as_iter = self.args.log_period_as_iter\n"," metrics_before, metrics_after = self.train(targets)\n"," \n"," return metrics_before, metrics_after\n","\n","\n"," def generate_poisoned_data(self, targets, popular_items, num_items, batch_size=50, sample_prob=0.0):\n"," print('## Generate Biased Data with Target {} ##'.format(targets))\n"," target_spec = '_'.join([str(target) for target in targets])\n"," dataset = poi_dataset_factory(self.args, target_spec, self.wb_model_spec)\n"," # if dataset.check_data_present():\n"," # print('Dataset already exists. Skip generation')\n"," # return\n","\n"," if isinstance(self.wb_model, BERT):\n"," self.item_embeddings = self.wb_model.embedding.token.weight.detach().cpu().numpy()[1:-1]\n"," else:\n"," self.item_embeddings = self.wb_model.embedding.token.weight.detach().cpu().numpy()[1:]\n"," self.item_embeddings = torch.tensor(self.item_embeddings).to(self.device)\n"," \n"," batch_num = math.ceil(self.args.num_poisoned_seqs / batch_size)\n"," print('Generating poisoned dataset...')\n"," for i in tqdm(range(batch_num)):\n"," if i == batch_num - 1 and self.args.num_poisoned_seqs % batch_size != 0:\n"," batch_size = self.args.num_poisoned_seqs % batch_size\n"," seqs = torch.tensor(np.random.choice(targets, size=batch_size)).reshape(batch_size, 1).to(self.device)\n","\n"," for j in range(self.max_len - 1):\n"," self.wb_model.eval()\n"," \n"," if j % 2 == 0:\n"," selected_targets = torch.tensor(np.random.choice(targets, size=batch_size)).to(self.device)\n"," rand_items = torch.tensor(np.random.choice(self.num_items, size=seqs.size(0))+1).to(self.device)\n"," seqs = torch.cat((seqs, rand_items.unsqueeze(1)), 1)\n","\n"," if isinstance(self.wb_model, BERT):\n"," mask_items = torch.tensor([self.CLOZE_MASK_TOKEN] * seqs.size(0)).to(self.device)\n"," input_seqs = torch.zeros((seqs.size(0), self.max_len)).to(self.device)\n"," if j < self.max_len - 2:\n"," input_seqs[:, (self.max_len-3-j):-1] = seqs\n"," elif j == self.max_len - 2:\n"," input_seqs[:, :-1] = seqs[:, 1:]\n"," input_seqs[:, -1] = mask_items\n"," wb_embedding, mask = self.wb_model.embedding(input_seqs.long())\n"," elif isinstance(self.wb_model, SASRec):\n"," input_seqs = torch.zeros((seqs.size(0), self.max_len)).to(self.device)\n"," input_seqs[:, (self.max_len-2-j):] = seqs\n"," wb_embedding, mask = self.wb_model.embedding(input_seqs.long())\n"," elif isinstance(self.wb_model, NARM):\n"," input_seqs = seqs\n"," lengths = torch.tensor([j + 2] * seqs.size(0))\n"," wb_embedding, mask = self.wb_model.embedding(input_seqs, lengths)\n","\n"," self.wb_model.train()\n"," wb_embedding = wb_embedding.detach().clone()\n"," wb_embedding.requires_grad = True\n"," zero_gradients(wb_embedding)\n","\n"," if isinstance(self.wb_model, BERT) or isinstance(self.wb_model, SASRec):\n"," wb_scores = self.wb_model.model(wb_embedding, self.wb_model.embedding.token.weight, mask)[:, -1, :]\n"," elif isinstance(self.wb_model, NARM):\n"," wb_scores = self.wb_model.model(wb_embedding, self.wb_model.embedding.token.weight, lengths, mask)\n","\n"," loss = self.adv_ce(wb_scores, selected_targets)\n"," self.wb_model.zero_grad()\n"," loss.backward()\n"," wb_embedding_grad = wb_embedding.grad.data\n"," \n"," self.wb_model.eval()\n"," with torch.no_grad():\n"," if isinstance(self.wb_model, BERT):\n"," current_embedding = wb_embedding[:, -2]\n"," current_embedding_grad = wb_embedding_grad[:, -2]\n"," else:\n"," current_embedding = wb_embedding[:, -1]\n"," current_embedding_grad = wb_embedding_grad[:, -1]\n"," \n"," all_embeddings = self.item_embeddings.unsqueeze(1).repeat_interleave(current_embedding.size(0), 1)\n"," cos = nn.CosineSimilarity(dim=-1, eps=1e-6)\n"," multipication_results = torch.t(cos(current_embedding-current_embedding_grad.sign(), all_embeddings))\n"," multipication_results[torch.arange(seqs.size(0)), selected_targets-1] = multipication_results[torch.arange(seqs.size(0)), selected_targets-1] + 2\n"," \n"," _, candidate_indicies = torch.sort(multipication_results, dim=1, descending=False)\n"," sample_indices = torch.randint(0, 10, [seqs.size(0)])\n"," seqs[:, -1] = candidate_indicies[torch.arange(seqs.size(0)), sample_indices] + 1\n"," seqs = torch.cat((seqs, selected_targets.unsqueeze(1)), 1)\n"," \n"," seqs = seqs[:, :self.max_len]\n"," try:\n"," batch_tokens = np.concatenate((batch_tokens, seqs.cpu().numpy()))\n"," except:\n"," batch_tokens = seqs.cpu().numpy()\n","\n"," num_poisoned, num_original, poisoning_users = dataset.save_dataset(batch_tokens.tolist(), original_dataset_size=self.args.num_original_seqs)\n"," return num_poisoned, num_original, poisoning_users\n","\n"," def train(self, targets):\n"," self.optimizer = self._create_optimizer()\n"," if self.args.enable_lr_schedule:\n"," if self.args.enable_lr_warmup:\n"," self.lr_scheduler = self.get_linear_schedule_with_warmup(\n"," self.optimizer, self.args.warmup_steps, len(train_loader) * self.num_epochs)\n"," else:\n"," self.lr_scheduler = optim.lr_scheduler.StepLR(\n"," self.optimizer, step_size=self.args.decay_step, gamma=self.args.gamma)\n","\n"," print('## Biased Retrain on Item {} ##'.format(targets))\n"," accum_iter = 0\n"," for epoch in range(self.num_epochs):\n"," accum_iter = self.train_one_epoch(epoch, accum_iter) \n"," \n"," print('## Clean Black-Box Model Targeted Test on Item {} ##'.format(targets))\n"," metrics_before = self.targeted_test(targets, load_retrained=False)\n"," print('## Retrained Black-Box Model Targeted Test on Item {} ##'.format(targets))\n"," metrics_after = self.targeted_test(targets, load_retrained=True)\n"," \n"," self.logger_service.complete({\n"," 'state_dict': (self._create_state_dict()),\n"," })\n"," self.writer.close()\n","\n"," return metrics_before, metrics_after\n","\n"," def train_one_epoch(self, epoch, accum_iter):\n"," self.bb_model.train()\n"," average_meter_set = AverageMeterSet()\n"," tqdm_dataloader = tqdm(self.train_loader)\n","\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," self.optimizer.zero_grad()\n"," if isinstance(self.bb_model, BERT):\n"," seqs, labels = batch\n"," seqs, labels = seqs.to(self.device), labels.to(self.device)\n"," logits = self.bb_model(seqs)\n"," logits = logits.view(-1, logits.size(-1))\n"," labels = labels.view(-1)\n"," loss = self.ce(logits, labels)\n"," elif isinstance(self.bb_model, SASRec):\n"," seqs, labels, negs = batch\n"," seqs, labels, negs = seqs.to(self.device), labels.to(self.device), negs.to(self.device)\n"," logits = self.bb_model(seqs) # F.softmax(self.bb_model(seqs), dim=-1)\n"," pos_logits = logits.gather(-1, labels.unsqueeze(-1))[seqs > 0].squeeze()\n"," pos_targets = torch.ones_like(pos_logits)\n"," neg_logits = logits.gather(-1, negs.unsqueeze(-1))[seqs > 0].squeeze()\n"," neg_targets = torch.zeros_like(neg_logits)\n"," loss = self.ce(torch.cat((pos_logits, neg_logits), 0), torch.cat((pos_targets, neg_targets), 0))\n"," elif isinstance(self.bb_model, NARM):\n"," seqs, lengths, labels = batch\n"," lengths = lengths.flatten()\n"," seqs, labels = seqs.to(self.device), labels.to(self.device)\n"," logits = self.bb_model(seqs, lengths)\n"," loss = self.ce(logits, labels.squeeze())\n","\n"," loss.backward()\n"," self.clip_gradients(5)\n"," self.optimizer.step()\n"," if self.args.enable_lr_schedule:\n"," self.lr_scheduler.step()\n","\n"," average_meter_set.update('loss', loss.item())\n"," tqdm_dataloader.set_description(\n"," 'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))\n","\n"," accum_iter += seqs.size(0)\n","\n"," if self._needs_to_log(accum_iter):\n"," tqdm_dataloader.set_description('Logging to Tensorboard')\n"," log_data = {\n"," 'state_dict': (self._create_state_dict()),\n"," 'epoch': epoch + 1,\n"," 'accum_iter': accum_iter,\n"," }\n"," log_data.update(average_meter_set.averages())\n"," self.logger_service.log_train(log_data)\n"," \n"," self.validate(epoch, accum_iter)\n"," return accum_iter\n","\n"," def validate(self, epoch, accum_iter):\n"," self.bb_model.eval()\n"," average_meter_set = AverageMeterSet()\n","\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.val_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," metrics = self.calculate_metrics(batch)\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," log_data = {\n"," 'state_dict': (self._create_state_dict()),\n"," 'epoch': epoch+1,\n"," 'accum_iter': accum_iter,\n"," }\n"," log_data.update(average_meter_set.averages())\n"," # self.log_extra_val_info(log_data)\n"," self.logger_service.log_val(log_data)\n","\n"," def test(self, load_retrained=False):\n"," if load_retrained:\n"," best_model_dict = torch.load(os.path.join(\n"," self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)\n"," self.bb_model.load_state_dict(best_model_dict)\n"," else:\n"," bb_model_dict = torch.load(os.path.join(\n"," self.bb_model_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)\n"," self.bb_model.load_state_dict(bb_model_dict)\n","\n"," self.bb_model.eval()\n"," average_meter_set = AverageMeterSet()\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.original_test_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," metrics = self.calculate_metrics(batch)\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," average_metrics = average_meter_set.averages()\n"," with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:\n"," json.dump(average_metrics, f, indent=4)\n"," return average_metrics\n","\n"," def targeted_test(self, targets, load_retrained=False):\n"," if load_retrained:\n"," best_model_dict = torch.load(os.path.join(\n"," self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)\n"," self.bb_model.load_state_dict(best_model_dict)\n"," else:\n"," bb_model_dict = torch.load(os.path.join(\n"," self.bb_model_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)\n"," self.bb_model.load_state_dict(bb_model_dict)\n"," \n"," self.bb_model.eval()\n"," average_meter_set = AverageMeterSet()\n","\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.original_test_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, SASRec):\n"," seqs, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," scores = self.bb_model(seqs)[:, -1, :]\n"," elif isinstance(self.bb_model, NARM):\n"," seqs, lengths, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," lengths = lengths.flatten()\n"," scores = self.bb_model(seqs, lengths)\n"," \n"," for target in targets:\n"," candidates[:, 0] = torch.tensor([target] * seqs.size(0)).to(self.device)\n"," metrics = recalls_and_ndcgs_for_ks(scores.gather(1, candidates), labels, self.metric_ks)\n"," self._update_meter_set(average_meter_set, metrics)\n"," \n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," average_metrics = average_meter_set.averages()\n"," return average_metrics\n"," \n"," def targeted_test_item(self, targets, load_retrained=False):\n"," if load_retrained:\n"," best_model_dict = torch.load(os.path.join(\n"," self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)\n"," self.bb_model.load_state_dict(best_model_dict)\n"," else:\n"," bb_model_dict = torch.load(os.path.join(\n"," self.bb_model_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)\n"," self.bb_model.load_state_dict(bb_model_dict)\n"," \n"," self.bb_model.eval()\n"," average_meter_set = AverageMeterSet()\n"," item_average_meter_set = {target: AverageMeterSet() for target in targets}\n"," \n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.original_test_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, SASRec):\n"," seqs, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," scores = self.bb_model(seqs)[:, -1, :]\n"," elif isinstance(self.bb_model, NARM):\n"," seqs, lengths, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," lengths = lengths.flatten()\n"," scores = self.bb_model(seqs, lengths)\n"," \n"," for target in targets:\n"," candidates[:, 0] = torch.tensor([target] * seqs.size(0)).to(self.device)\n"," metrics = recalls_and_ndcgs_for_ks(scores.gather(1, candidates), labels, self.metric_ks)\n"," self._update_meter_set(average_meter_set, metrics)\n"," self._update_meter_set(item_average_meter_set[target], metrics)\n"," \n"," self._update_dataloader_metrics(\n"," tqdm_dataloader, average_meter_set)\n","\n"," average_metrics = average_meter_set.averages()\n"," for target in targets:\n"," item_average_meter_set[target] = item_average_meter_set[target].averages()\n"," return average_metrics, item_average_meter_set\n","\n"," def calculate_metrics(self, batch):\n"," self.bb_model.eval()\n","\n"," if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, SASRec):\n"," seqs, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," scores = self.bb_model(seqs)[:, -1, :]\n"," elif isinstance(self.bb_model, NARM):\n"," seqs, lengths, candidates, labels = batch\n"," seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)\n"," lengths = lengths.flatten()\n"," scores = self.bb_model(seqs, lengths)\n","\n"," scores = scores.gather(1, candidates) # B x C\n"," metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)\n"," return metrics\n","\n"," def clip_gradients(self, limit=5):\n"," for p in self.bb_model.parameters():\n"," nn.utils.clip_grad_norm_(p, 5)\n","\n"," def _update_meter_set(self, meter_set, metrics):\n"," for k, v in metrics.items():\n"," meter_set.update(k, v)\n","\n"," def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):\n"," description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]\n"," ] + ['Recall@%d' % k for k in self.metric_ks[:3]]\n"," description = 'Eval: ' + \\\n"," ', '.join(s + ' {:.3f}' for s in description_metrics)\n"," description = description.replace('NDCG', 'N').replace('Recall', 'R')\n"," description = description.format(\n"," *(meter_set[k].avg for k in description_metrics))\n"," tqdm_dataloader.set_description(description)\n","\n"," def _create_optimizer(self):\n"," args = self.args\n"," param_optimizer = list(self.bb_model.named_parameters())\n"," no_decay = ['bias', 'layer_norm']\n"," optimizer_grouped_parameters = [\n"," {\n"," 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n"," 'weight_decay': args.weight_decay,\n"," },\n"," {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n"," ]\n"," if args.optimizer.lower() == 'adamw':\n"," return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)\n"," elif args.optimizer.lower() == 'adam':\n"," return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)\n"," elif args.optimizer.lower() == 'sgd':\n"," return optim.SGD(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)\n"," else:\n"," raise ValueError\n","\n"," def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):\n"," # based on hugging face get_linear_schedule_with_warmup\n"," def lr_lambda(current_step: int):\n"," if current_step < num_warmup_steps:\n"," return float(current_step) / float(max(1, num_warmup_steps))\n"," return max(\n"," 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))\n"," )\n","\n"," return LambdaLR(optimizer, lr_lambda, last_epoch)\n","\n"," def _create_loggers(self):\n"," root = Path(self.export_root)\n"," writer = SummaryWriter(root.joinpath('logs'))\n"," model_checkpoint = root.joinpath('models')\n","\n"," train_loggers = [\n"," MetricGraphPrinter(writer, key='epoch',\n"," graph_name='Epoch', group_name='Train'),\n"," MetricGraphPrinter(writer, key='loss',\n"," graph_name='Loss', group_name='Train'),\n"," ]\n","\n"," val_loggers = []\n"," for k in self.metric_ks:\n"," val_loggers.append(\n"," MetricGraphPrinter(writer, key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation'))\n"," val_loggers.append(\n"," MetricGraphPrinter(writer, key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation'))\n"," val_loggers.append(RecentModelLogger(model_checkpoint))\n"," val_loggers.append(BestModelLogger(\n"," model_checkpoint, metric_key=self.best_metric))\n"," return writer, train_loggers, val_loggers\n","\n"," def _create_state_dict(self):\n"," return {\n"," STATE_DICT_KEY: self.bb_model.module.state_dict() if self.is_parallel else self.bb_model.state_dict(),\n"," OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),\n"," }\n","\n"," def _needs_to_log(self, accum_iter):\n"," return accum_iter % self.log_period_as_iter < self.args.train_batch_size and accum_iter != 0"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_FAMzmuYp6_j"},"source":["#### Utils"]},{"cell_type":"code","metadata":{"id":"h1pWOazIoV5Q"},"source":["import json\n","import os\n","import pprint as pp\n","import random\n","from datetime import date\n","from pathlib import Path\n","\n","import numpy as np\n","import torch\n","import torch.backends.cudnn as cudnn\n","from torch import optim as optim\n","\n","\n","def recall(scores, labels, k):\n"," scores = scores\n"," labels = labels\n"," rank = (-scores).argsort(dim=1)\n"," cut = rank[:, :k]\n"," hit = labels.gather(1, cut)\n"," return (hit.sum(1).float() / torch.min(torch.Tensor([k]).to(hit.device), labels.sum(1).float())).mean().cpu().item()\n","\n","\n","def ndcg(scores, labels, k):\n"," scores = scores.cpu()\n"," labels = labels.cpu()\n"," rank = (-scores).argsort(dim=1)\n"," cut = rank[:, :k]\n"," hits = labels.gather(1, cut)\n"," position = torch.arange(2, 2+k)\n"," weights = 1 / torch.log2(position.float())\n"," dcg = (hits.float() * weights).sum(1)\n"," idcg = torch.Tensor([weights[:min(int(n), k)].sum()\n"," for n in labels.sum(1)])\n"," ndcg = dcg / idcg\n"," return ndcg.mean()\n","\n","\n","def recalls_and_ndcgs_for_ks(scores, labels, ks):\n"," metrics = {}\n","\n"," scores = scores\n"," labels = labels\n"," answer_count = labels.sum(1)\n","\n"," labels_float = labels.float()\n"," rank = (-scores).argsort(dim=1)\n","\n"," cut = rank\n"," for k in sorted(ks, reverse=True):\n"," cut = cut[:, :k]\n"," hits = labels_float.gather(1, cut)\n"," metrics['Recall@%d' % k] = \\\n"," (hits.sum(1) / torch.min(torch.Tensor([k]).to(\n"," labels.device), labels.sum(1).float())).mean().cpu().item()\n","\n"," position = torch.arange(2, 2+k)\n"," weights = 1 / torch.log2(position.float())\n"," dcg = (hits * weights.to(hits.device)).sum(1)\n"," idcg = torch.Tensor([weights[:min(int(n), k)].sum()\n"," for n in answer_count]).to(dcg.device)\n"," ndcg = (dcg / idcg).mean()\n"," metrics['NDCG@%d' % k] = ndcg.cpu().item()\n","\n"," return metrics\n","\n","\n","def setup_train(args):\n"," set_up_gpu(args)\n","\n"," export_root = create_experiment_export_folder(args)\n"," export_experiments_config_as_json(args, export_root)\n","\n"," pp.pprint({k: v for k, v in vars(args).items() if v is not None}, width=1)\n"," return export_root\n","\n","\n","def create_experiment_export_folder(args):\n"," experiment_dir, experiment_description = args.experiment_dir, args.experiment_description\n"," if not os.path.exists(experiment_dir):\n"," os.mkdir(experiment_dir)\n"," experiment_path = get_name_of_experiment_path(\n"," experiment_dir, experiment_description)\n"," os.mkdir(experiment_path)\n"," print('Folder created: ' + os.path.abspath(experiment_path))\n"," return experiment_path\n","\n","\n","def get_name_of_experiment_path(experiment_dir, experiment_description):\n"," experiment_path = os.path.join(\n"," experiment_dir, (experiment_description + \"_\" + str(date.today())))\n"," idx = _get_experiment_index(experiment_path)\n"," experiment_path = experiment_path + \"_\" + str(idx)\n"," return experiment_path\n","\n","\n","def _get_experiment_index(experiment_path):\n"," idx = 0\n"," while os.path.exists(experiment_path + \"_\" + str(idx)):\n"," idx += 1\n"," return idx\n","\n","\n","def load_weights(model, path):\n"," pass\n","\n","\n","def save_test_result(export_root, result):\n"," filepath = Path(export_root).joinpath('test_result.txt')\n"," with filepath.open('w') as f:\n"," json.dump(result, f, indent=2)\n","\n","\n","def export_experiments_config_as_json(args, experiment_path):\n"," with open(os.path.join(experiment_path, 'config.json'), 'w') as outfile:\n"," json.dump(vars(args), outfile, indent=2)\n","\n","\n","def fix_random_seed_as(random_seed):\n"," random.seed(random_seed)\n"," torch.manual_seed(random_seed)\n"," torch.cuda.manual_seed_all(random_seed)\n"," np.random.seed(random_seed)\n"," cudnn.deterministic = True\n"," cudnn.benchmark = False\n","\n","\n","def set_up_gpu(args):\n"," os.environ['CUDA_VISIBLE_DEVICES'] = args.device_idx\n"," args.num_gpu = len(args.device_idx.split(\",\"))\n","\n","\n","def load_pretrained_weights(model, path):\n"," chk_dict = torch.load(os.path.abspath(path))\n"," model_state_dict = chk_dict[STATE_DICT_KEY] if STATE_DICT_KEY in chk_dict else chk_dict['state_dict']\n"," model.load_state_dict(model_state_dict)\n","\n","\n","def setup_to_resume(args, model, optimizer):\n"," chk_dict = torch.load(os.path.join(os.path.abspath(\n"," args.resume_training), 'models/checkpoint-recent.pth'))\n"," model.load_state_dict(chk_dict[STATE_DICT_KEY])\n"," optimizer.load_state_dict(chk_dict[OPTIMIZER_STATE_DICT_KEY])\n","\n","\n","def create_optimizer(model, args):\n"," if args.optimizer == 'Adam':\n"," return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n","\n"," return optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)\n","\n","\n","class AverageMeterSet(object):\n"," def __init__(self, meters=None):\n"," self.meters = meters if meters else {}\n","\n"," def __getitem__(self, key):\n"," if key not in self.meters:\n"," meter = AverageMeter()\n"," meter.update(0)\n"," return meter\n"," return self.meters[key]\n","\n"," def update(self, name, value, n=1):\n"," if name not in self.meters:\n"," self.meters[name] = AverageMeter()\n"," self.meters[name].update(value, n)\n","\n"," def reset(self):\n"," for meter in self.meters.values():\n"," meter.reset()\n","\n"," def values(self, format_string='{}'):\n"," return {format_string.format(name): meter.val for name, meter in self.meters.items()}\n","\n"," def averages(self, format_string='{}'):\n"," return {format_string.format(name): meter.avg for name, meter in self.meters.items()}\n","\n"," def sums(self, format_string='{}'):\n"," return {format_string.format(name): meter.sum for name, meter in self.meters.items()}\n","\n"," def counts(self, format_string='{}'):\n"," return {format_string.format(name): meter.count for name, meter in self.meters.items()}\n","\n","\n","class AverageMeter(object):\n"," \"\"\"Computes and stores the average and current value\"\"\"\n","\n"," def __init__(self):\n"," self.val = 0\n"," self.avg = 0\n"," self.sum = 0\n"," self.count = 0\n","\n"," def reset(self):\n"," self.val = 0\n"," self.avg = 0\n"," self.sum = 0\n"," self.count = 0\n","\n"," def update(self, val, n=1):\n"," self.val = val\n"," self.sum += val\n"," self.count += n\n"," self.avg = self.sum / self.count\n","\n"," def __format__(self, format):\n"," return \"{self.val:{format}} ({self.avg:{format}})\".format(self=self, format=format)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"o1MEoBe8qr9_"},"source":["### Attacker"]},{"cell_type":"code","metadata":{"id":"qPkfedrZoVTl"},"source":["import argparse\n","import torch\n","from pathlib import Path\n","from collections import defaultdict\n","\n","\n","def attack(args, attack_item_num=2, bb_model_root=None):\n"," fix_random_seed_as(args.model_init_seed)\n"," _, _, test_loader = dataloader_factory(args)\n","\n"," model_codes = {'b': 'bert', 's':'sas', 'n':'narm'}\n"," wb_model_code = model_codes[input('Input white box model code, b for BERT, s for SASRec and n for NARM: ')]\n","\n"," wb_model_folder = {}\n"," folder_list = [item for item in os.listdir('experiments/distillation_rank/') if (args.model_code + '2' + wb_model_code in item)]\n"," for idx, folder_name in enumerate(folder_list):\n"," wb_model_folder[idx + 1] = folder_name\n"," wb_model_folder[idx + 2] = args.model_code + '_black_box'\n"," print(wb_model_folder)\n"," wb_model_spec = wb_model_folder[int(input('Input index of desired white box model: '))]\n","\n"," wb_model_root = 'experiments/distillation_rank/' + wb_model_spec + '/' + args.dataset_code\n"," if wb_model_spec == args.model_code + '_black_box':\n"," wb_model_root = 'experiments/' + args.model_code + '/' + args.dataset_code\n","\n"," if bb_model_root == None:\n"," bb_model_root = 'experiments/' + args.model_code + '/' + args.dataset_code\n","\n"," if wb_model_code == 'bert':\n"," wb_model = BERT(args)\n"," elif wb_model_code == 'sas':\n"," wb_model = SASRec(args)\n"," elif wb_model_code == 'narm':\n"," wb_model = NARM(args)\n","\n"," if args.model_code == 'bert':\n"," bb_model = BERT(args)\n"," elif args.model_code == 'sas':\n"," bb_model = SASRec(args)\n"," elif args.model_code == 'narm':\n"," bb_model = NARM(args)\n"," \n"," bb_model.load_state_dict(torch.load(os.path.join(bb_model_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))\n"," wb_model.load_state_dict(torch.load(os.path.join(wb_model_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))\n","\n"," item_counter = defaultdict(int)\n"," dataset = dataset_factory(args)\n"," dataset = dataset.load_dataset()\n"," train = dataset['train']\n"," val = dataset['val']\n"," test = dataset['test']\n"," for user in train.keys():\n"," seqs = train[user] + val[user] + test[user]\n"," for i in seqs:\n"," item_counter[i] += 1\n","\n"," item_popularity = []\n"," for i in item_counter.keys():\n"," item_popularity.append((item_counter[i], i))\n"," item_popularity.sort(reverse=True)\n"," \n"," attacker = AdversarialRankAttacker(args, wb_model, bb_model, test_loader)\n"," \n"," item_id = []\n"," item_rank = []\n"," item_R1_before, item_R5_before, item_R10_before, item_N5_before, item_N10_before = [], [], [], [], []\n"," item_R1_ours, item_R5_ours, item_R10_ours, item_N5_ours, item_N10_ours = [], [], [], [], []\n"," \n"," step = len(item_popularity) // 25\n"," attack_ranks = list(range(0, len(item_popularity), step))[:25]\n"," for i in attack_ranks:\n"," item = item_popularity[i][1]\n"," metrics_before = attacker.test(target=item)\n"," metrics_ours = attacker.attack(target=item, num_attack=attack_item_num)\n"," \n"," item_id.append(item)\n"," item_rank.append(i)\n"," item_R1_before.append(metrics_before['Recall@1'])\n"," item_R5_before.append(metrics_before['Recall@5'])\n"," item_R10_before.append(metrics_before['Recall@10'])\n"," item_N5_before.append(metrics_before['NDCG@5'])\n"," item_N10_before.append(metrics_before['NDCG@10'])\n","\n"," item_R1_ours.append(metrics_ours['Recall@1'])\n"," item_R5_ours.append(metrics_ours['Recall@5'])\n"," item_R10_ours.append(metrics_ours['Recall@10'])\n"," item_N5_ours.append(metrics_ours['NDCG@5'])\n"," item_N10_ours.append(metrics_ours['NDCG@10'])\n","\n"," attack_metrics = {\n"," 'item_id': item_id,\n"," 'item_rank': item_rank,\n"," 'item_R1_before': item_R1_before,\n"," 'item_R5_before': item_R5_before,\n"," 'item_R10_before': item_R10_before,\n"," 'item_N5_before': item_N5_before,\n"," 'item_N10_before': item_N10_before,\n","\n"," 'item_R1_ours': item_R1_ours,\n"," 'item_R5_ours': item_R5_ours,\n"," 'item_R10_ours': item_R10_ours,\n"," 'item_N5_ours': item_N5_ours,\n"," 'item_N10_ours': item_N10_ours,\n"," }\n"," \n"," metrics_root = 'experiments/attack_rank/' + wb_model_spec + '/' + args.dataset_code\n"," if not Path(metrics_root).is_dir():\n"," Path(metrics_root).mkdir(parents=True)\n"," \n"," with open(os.path.join(metrics_root, 'attack_bb_metrics.json'), 'w') as f:\n"," json.dump(attack_metrics, f, indent=4)\n","\n","\n","if __name__ == \"__main__\":\n"," set_template(args)\n","\n"," # when use k-core beauty and k is not 5 (beauty-dense)\n"," # args.min_uc = k\n"," # args.min_sc = k\n","\n"," if args.dataset_code == 'ml-1m':\n"," args.num_epochs = 5\n"," attack(args=args, attack_item_num=1)\n"," else:\n"," attack(args=args, attack_item_num=2)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OZL5lYRlwhYb"},"source":["## Data Poisoning Attack"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Ka-49FnUwUoc","executionInfo":{"status":"ok","timestamp":1631631025244,"user_tz":-330,"elapsed":140114,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"7954869e-cda3-48a9-d704-8fb98a433afb"},"source":["import argparse\n","import torch\n","import copy\n","from pathlib import Path\n","from collections import defaultdict\n","\n","\n","def retrain(args, bb_model_root=None):\n"," fix_random_seed_as(args.model_init_seed)\n"," _, _, test_loader = dataloader_factory(args)\n","\n"," model_codes = {'b': 'bert', 's':'sas', 'n':'narm'}\n"," wb_model_code = model_codes[input('Input white box model code, b for BERT, s for SASRec and n for NARM: ')]\n","\n"," wb_model_folder = {}\n"," folder_list = [item for item in os.listdir('experiments/distillation_rank/') if (args.model_code + '2' + wb_model_code in item)]\n"," for idx, folder_name in enumerate(folder_list):\n"," wb_model_folder[idx + 1] = folder_name\n"," wb_model_folder[idx + 2] = args.model_code + '_black_box'\n"," print(wb_model_folder)\n"," wb_model_spec = wb_model_folder[int(input('Input index of desired white box model: '))]\n","\n"," wb_model_root = 'experiments/distillation_rank/' + wb_model_spec + '/' + args.dataset_code\n"," if wb_model_spec == args.model_code + '_black_box':\n"," wb_model_root = 'experiments/' + args.model_code + '/' + args.dataset_code\n","\n"," if bb_model_root == None:\n"," bb_model_root = 'experiments/' + args.model_code + '/' + args.dataset_code\n","\n"," if args.model_code == 'bert':\n"," bb_model = BERT(args)\n"," elif args.model_code == 'sas':\n"," bb_model = SASRec(args)\n"," elif args.model_code == 'narm':\n"," bb_model = NARM(args)\n"," \n"," if wb_model_code == 'bert':\n"," wb_model = BERT(args)\n"," elif wb_model_code == 'sas':\n"," wb_model = SASRec(args)\n"," elif wb_model_code == 'narm':\n"," wb_model = NARM(args)\n"," \n"," item_counter = defaultdict(int)\n"," dataset = dataset_factory(args)\n"," dataset = dataset.load_dataset()\n"," train = dataset['train']\n"," val = dataset['val']\n"," test = dataset['test']\n"," lengths = []\n"," for user in train.keys():\n"," seqs = train[user] + val[user] + test[user]\n"," lengths.append(len(seqs))\n"," for i in seqs:\n"," item_counter[i] += 1\n","\n"," item_popularity = []\n"," for i in item_counter.keys():\n"," item_popularity.append((item_counter[i], i))\n"," item_popularity.sort(reverse=True)\n","\n"," wb_model.load_state_dict(torch.load(os.path.join(wb_model_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))\n"," \n"," step = len(item_popularity) // 25\n"," popular_items = [item_popularity[i][1] for i in range(int(0.05*len(item_popularity)))]\n"," attack_ranks = list(range(0, len(item_popularity), step))[:25]\n"," targets = [item_popularity[i][1] for i in attack_ranks]\n","\n"," bb_poisoned_metrics = {}\n"," all_ratios = [0.01]\n"," for ratio in all_ratios: \n"," args.num_poisoned_seqs = int(len(train) * ratio)\n"," retrainer = PoisonedGroupRetrainer(args, wb_model_spec, wb_model, bb_model, test_loader)\n"," metrics_before, metrics_bb_after = retrainer.train_ours(targets, ratio, popular_items, int(0.05*len(item_popularity)))\n","\n"," bb_poisoned_metrics[ratio] = {\n"," 'before': metrics_before,\n"," 'ours': metrics_bb_after, \n"," }\n"," \n"," metrics_root = 'experiments/retrained/' + wb_model_spec + '/' + args.dataset_code\n"," if not Path(metrics_root).is_dir():\n"," Path(metrics_root).mkdir(parents=True)\n","\n"," with open(os.path.join(metrics_root, 'retrained_bb_metrics.json'), 'w') as f:\n"," json.dump(bb_poisoned_metrics, f, indent=4)\n","\n","\n","if __name__ == \"__main__\":\n"," set_template(args)\n","\n"," # when use k-core beauty and k is not 5 (beauty-dense)\n"," # args.min_uc = k\n"," # args.min_sc = k\n"," args.num_epochs = 5\n"," retrain(args=args)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Already preprocessed. Skip preprocessing\n","Negatives samples exist. Loading.\n","Negatives samples exist. Loading.\n","Input white box model code, b for BERT, s for SASRec and n for NARM: b\n","{1: 'bert2bert_autoregressive5', 2: 'bert_black_box'}\n","Input index of desired white box model: 1\n","Already preprocessed. Skip preprocessing\n","## Generate Biased Data with Target [2459, 1009, 2135, 918, 3233, 1226, 498, 2917, 1332, 3184, 264, 2490, 1696, 1448, 144, 365, 1368, 2714, 1874, 3285, 2235, 3406, 3155, 1322, 2928] ##\n","Generating poisoned dataset...\n"]},{"output_type":"stream","name":"stderr","text":["100%|██████████| 2/2 [00:07<00:00, 3.96s/it]\n"]},{"output_type":"stream","name":"stdout","text":["Already preprocessed. Skip preprocessing\n","Negatives samples exist. Loading.\n","Negatives samples exist. Loading.\n","## Biased Retrain on Item [2459, 1009, 2135, 918, 3233, 1226, 498, 2917, 1332, 3184, 264, 2490, 1696, 1448, 144, 365, 1368, 2714, 1874, 3285, 2235, 3406, 3155, 1322, 2928] ##\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 1, loss 7.003 : 100%|██████████| 69/69 [00:12<00:00, 5.48it/s]\n","Eval: N@1 0.078, N@5 0.175, N@10 0.225, R@1 0.078, R@5 0.270, R@10 0.423: 100%|██████████| 48/48 [00:03<00:00, 13.98it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 1\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 2, loss 6.993 : 100%|██████████| 69/69 [00:12<00:00, 5.44it/s]\n","Eval: N@1 0.080, N@5 0.178, N@10 0.225, R@1 0.080, R@5 0.273, R@10 0.420: 100%|██████████| 48/48 [00:03<00:00, 13.96it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 2\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 3, loss 6.988 : 100%|██████████| 69/69 [00:12<00:00, 5.48it/s]\n","Eval: N@1 0.085, N@5 0.179, N@10 0.230, R@1 0.085, R@5 0.269, R@10 0.426: 100%|██████████| 48/48 [00:03<00:00, 14.07it/s]\n"]},{"output_type":"stream","name":"stdout","text":["Update Best NDCG@10 Model at 3\n"]},{"output_type":"stream","name":"stderr","text":["Epoch 4, loss 6.986 : 100%|██████████| 69/69 [00:12<00:00, 5.49it/s]\n","Eval: N@1 0.083, N@5 0.180, N@10 0.229, R@1 0.083, R@5 0.275, R@10 0.427: 100%|██████████| 48/48 [00:03<00:00, 13.97it/s]\n","Epoch 5, loss 6.981 : 100%|██████████| 69/69 [00:12<00:00, 5.51it/s]\n","Eval: N@1 0.078, N@5 0.179, N@10 0.227, R@1 0.078, R@5 0.274, R@10 0.423: 100%|██████████| 48/48 [00:03<00:00, 13.97it/s]\n"]},{"output_type":"stream","name":"stdout","text":["## Clean Black-Box Model Targeted Test on Item [2459, 1009, 2135, 918, 3233, 1226, 498, 2917, 1332, 3184, 264, 2490, 1696, 1448, 144, 365, 1368, 2714, 1874, 3285, 2235, 3406, 3155, 1322, 2928] ##\n"]},{"output_type":"stream","name":"stderr","text":["Eval: N@1 0.041, N@5 0.062, N@10 0.078, R@1 0.041, R@5 0.086, R@10 0.136: 100%|██████████| 48/48 [00:19<00:00, 2.45it/s]\n"]},{"output_type":"stream","name":"stdout","text":["## Retrained Black-Box Model Targeted Test on Item [2459, 1009, 2135, 918, 3233, 1226, 498, 2917, 1332, 3184, 264, 2490, 1696, 1448, 144, 365, 1368, 2714, 1874, 3285, 2235, 3406, 3155, 1322, 2928] ##\n"]},{"output_type":"stream","name":"stderr","text":["Eval: N@1 0.044, N@5 0.066, N@10 0.080, R@1 0.044, R@5 0.089, R@10 0.134: 100%|██████████| 48/48 [00:19<00:00, 2.43it/s]\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"uu7aVgC-yAmU","executionInfo":{"status":"ok","timestamp":1631631031374,"user_tz":-330,"elapsed":4447,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"a41324de-57ec-43ca-d2ad-0448b8292b2a"},"source":["!apt-get install tree"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Reading package lists... Done\n","Building dependency tree \n","Reading state information... Done\n","The following NEW packages will be installed:\n"," tree\n","0 upgraded, 1 newly installed, 0 to remove and 40 not upgraded.\n","Need to get 40.7 kB of archives.\n","After this operation, 105 kB of additional disk space will be used.\n","Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 tree amd64 1.7.0-5 [40.7 kB]\n","Fetched 40.7 kB in 0s (110 kB/s)\n","Selecting previously unselected package tree.\n","(Reading database ... 148560 files and directories currently installed.)\n","Preparing to unpack .../tree_1.7.0-5_amd64.deb ...\n","Unpacking tree (1.7.0-5) ...\n","Setting up tree (1.7.0-5) ...\n","Processing triggers for man-db (2.8.3-2ubuntu0.1) ...\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"BhQMt_2EyC5_","executionInfo":{"status":"ok","timestamp":1631631031375,"user_tz":-330,"elapsed":21,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"07645304-878f-4814-dcd6-97b901b04146"},"source":["!tree ."],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":[".\n","├── data\n","│   ├── ml-1m\n","│   │   ├── movies.dat\n","│   │   ├── ratings.dat\n","│   │   ├── README\n","│   │   └── users.dat\n","│   └── preprocessed\n","│   └── ml-1m_min_rating0-min_uc5-min_sc5-splitleave_one_out\n","│   ├── dataset.pkl\n","│   ├── random-sample_size100-seed98765-test.pkl\n","│   └── random-sample_size100-seed98765-val.pkl\n","├── experiments\n","│   ├── bert\n","│   │   └── ml-1m\n","│   │   ├── logs\n","│   │   │   ├── events.out.tfevents.1631626670.6d6b3d5241b9.74.0\n","│   │   │   └── test_metrics.json\n","│   │   └── models\n","│   │   ├── best_acc_model.pth\n","│   │   ├── checkpoint-recent.pth\n","│   │   └── checkpoint-recent.pth.final\n","│   ├── distillation_rank\n","│   │   └── bert2bert_autoregressive5\n","│   │   └── ml-1m\n","│   │   ├── logs\n","│   │   │   ├── events.out.tfevents.1631626837.6d6b3d5241b9.74.1\n","│   │   │   ├── events.out.tfevents.1631627214.6d6b3d5241b9.74.2\n","│   │   │   ├── events.out.tfevents.1631627292.6d6b3d5241b9.74.3\n","│   │   │   ├── events.out.tfevents.1631627333.6d6b3d5241b9.74.4\n","│   │   │   ├── events.out.tfevents.1631627518.6d6b3d5241b9.74.5\n","│   │   │   ├── events.out.tfevents.1631627759.6d6b3d5241b9.74.6\n","│   │   │   └── test_metrics.json\n","│   │   └── models\n","│   │   ├── best_acc_model.pth\n","│   │   ├── checkpoint-recent.pth\n","│   │   └── checkpoint-recent.pth.final\n","│   └── retrained\n","│   └── bert2bert_autoregressive5\n","│   └── ml-1m\n","│   ├── ratio_0.01_target_2459_1009_2135_918_3233_1226_498_2917_1332_3184_264_2490_1696_1448_144_365_1368_2714_1874_3285_2235_3406_3155_1322_2928\n","│   │   ├── logs\n","│   │   │   ├── events.out.tfevents.1631630835.6d6b3d5241b9.74.7\n","│   │   │   └── events.out.tfevents.1631630902.6d6b3d5241b9.74.8\n","│   │   └── models\n","│   │   ├── best_acc_model.pth\n","│   │   ├── checkpoint-recent.pth\n","│   │   └── checkpoint-recent.pth.final\n","│   └── retrained_bb_metrics.json\n","├── gen_data\n","│   └── ml-1m\n","│   ├── bert2bert_autoregressive5_target_2459_1009_2135_918_3233_1226_498_2917_1332_3184_264_2490_1696_1448_144_365_1368_2714_1874_3285_2235_3406_3155_1322_2928_poisoned60_original6040\n","│   │   ├── poisoned_dataset.pkl\n","│   │   ├── random-sample_size100-seed98765-poisoned_test.pkl\n","│   │   └── random-sample_size100-seed98765-poisoned_val.pkl\n","│   └── bert_5\n","│   └── autoregressive_dataset.pkl\n","└── sample_data\n"," ├── anscombe.json\n"," ├── california_housing_test.csv\n"," ├── california_housing_train.csv\n"," ├── mnist_test.csv\n"," ├── mnist_train_small.csv\n"," └── README.md\n","\n","25 directories, 38 files\n"]}]},{"cell_type":"markdown","metadata":{"id":"T37XQPBbwoFo"},"source":["## Performance Evaluation"]},{"cell_type":"markdown","metadata":{"id":"2q2av0161B8G"},"source":["### Black-Box and Extracted Models"]},{"cell_type":"markdown","metadata":{"id":"hXr-ammPy3fm"},"source":["![](https://github.com/recohut/recsys-attacks/raw/d7472b7296515249c1bd1bbb8ea0afa9b07f6d9d/docs/_images/T355514_1.png)"]},{"cell_type":"markdown","metadata":{"id":"LeKtV6HV1AJc"},"source":["### Profile Pollution Performance"]},{"cell_type":"markdown","metadata":{"id":"943FB_f-y89T"},"source":["![](https://github.com/recohut/recsys-attacks/raw/d7472b7296515249c1bd1bbb8ea0afa9b07f6d9d/docs/_images/T355514_2.png)"]},{"cell_type":"markdown","metadata":{"id":"MQ3E1dlf1EkF"},"source":["### Data Poisoning Performance"]},{"cell_type":"markdown","metadata":{"id":"ix3whZbby9Wq"},"source":["![](https://github.com/recohut/recsys-attacks/raw/d7472b7296515249c1bd1bbb8ea0afa9b07f6d9d/docs/_images/T355514_3.png)"]}]}