{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-19-bert4rec-movie.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T034923%20%7C%20BERT4Rec%20on%20ML-1m%20in%20PyTorch.ipynb","timestamp":1644651906723}],"collapsed_sections":[],"toc_visible":true,"authorship_tag":"ABX9TyPOMSLwE0ZgkxqH/UfoSvnk"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"xyVv1tr3ZLt5"},"source":["# BERT4Rec on ML-1m in PyTorch\n","\n","In this tutorial, we are building BERT4Rec model in PyTorch and then training it on the movielens 1m dataset."]},{"cell_type":"markdown","metadata":{"id":"aTzyO6mfZ5cQ"},"source":["## Setup"]},{"cell_type":"markdown","metadata":{"id":"YiFQCtLZ63rc"},"source":["### Installations"]},{"cell_type":"code","metadata":{"id":"pwdnWWZnZygU"},"source":["!pip install -q wget"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WauAGrcH65HI"},"source":["### Imports"]},{"cell_type":"code","metadata":{"id":"w2zR_C419Hmy"},"source":["import os\n","import sys\n","import wget\n","import math\n","import json\n","import random\n","import zipfile\n","import shutil\n","import pickle\n","import tempfile\n","from abc import *\n","\n","import numpy as np\n","import pandas as pd\n","import pprint as pp\n","from pathlib import Path\n","from datetime import date\n","from tqdm import tqdm, trange\n","\n","import torch\n","import torch.backends.cudnn as cudnn\n","from torch import optim as optim\n","import torch.utils.data as data_utils\n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torch.utils.tensorboard import SummaryWriter\n","\n","tqdm.pandas()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xmTE2oIn66bq"},"source":["### Params"]},{"cell_type":"code","metadata":{"id":"nfjPTfJaKcUc"},"source":["STATE_DICT_KEY = 'model_state_dict'\n","OPTIMIZER_STATE_DICT_KEY = 'optimizer_state_dict'\n","RAW_DATASET_ROOT_FOLDER = '/content/ml-1m'"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"pEDDwB7JBMG7"},"source":["class Args:\n"," mode = 'train'\n"," test_model_path = '/content/models'\n"," # Dataset\n"," dataset_code = 'ml-1m'\n"," min_rating = 0\n"," min_uc = 5\n"," min_sc = 0\n"," split = 'leave_one_out'\n"," dataset_split_seed = 42\n"," eval_set_size = 500\n"," # Dataloader\n"," dataloader_code = 'bert'\n"," dataloader_random_seed = 0.0\n"," train_batch_size = 128\n"," val_batch_size = 128\n"," test_batch_size = 128\n"," # NegativeSampler\n"," train_negative_sampler_code = 'random'\n"," train_negative_sample_size = 0\n"," train_negative_sampling_seed = 0\n"," test_negative_sampler_code = 'random'\n"," test_negative_sample_size = 100\n"," test_negative_sampling_seed = 42\n"," # Trainer\n"," trainer_code = 'bert'\n"," device = 'cuda'\n"," num_gpu = 1\n"," device_idx = '0'\n"," optimizer='Adam'\n"," lr=0.001\n"," weight_decay=0\n"," momentum=None\n"," enable_lr_schedule = True\n"," decay_step=25\n"," gamma=1.0\n"," num_epochs=10\n"," log_period_as_iter=12800\n"," metric_ks=[1, 5, 10, 20, 50, 100]\n"," best_metric='NDCG@10'\n"," find_best_beta=False\n"," total_anneal_steps=2000\n"," anneal_cap=0.2\n"," # Model\n"," model_code='bert'\n"," model_init_seed=0\n"," bert_max_len=100\n"," bert_num_items=None\n"," bert_hidden_units=256\n"," bert_num_blocks=2\n"," bert_num_heads=4\n"," bert_dropout=0.1\n"," bert_mask_prob=0.15\n"," # Experiment\n"," experiment_dir='experiments'\n"," experiment_description='test'\n","\n","args = Args()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"pQ0_HSBbZ_AL"},"source":["## Utils"]},{"cell_type":"markdown","metadata":{"id":"AGjLR7B27VmM"},"source":["### Basic"]},{"cell_type":"code","metadata":{"id":"66tGWsEfFuJ1"},"source":["def download(url, savepath):\n"," wget.download(url, str(savepath))\n","\n","\n","def unzip(zippath, savepath):\n"," zip = zipfile.ZipFile(zippath)\n"," zip.extractall(savepath)\n"," zip.close()\n","\n","\n","def get_count(tp, id):\n"," groups = tp[[id]].groupby(id, as_index=False)\n"," count = groups.size()\n"," return count"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZxL_DVsw7FKi"},"source":["### Metrics"]},{"cell_type":"code","metadata":{"id":"74kcufQKKpX6"},"source":["def recall(scores, labels, k):\n"," scores = scores\n"," labels = labels\n"," rank = (-scores).argsort(dim=1)\n"," cut = rank[:, :k]\n"," hit = labels.gather(1, cut)\n"," return (hit.sum(1).float() / torch.min(torch.Tensor([k]).to(hit.device), labels.sum(1).float())).mean().cpu().item()\n","\n","\n","def ndcg(scores, labels, k):\n"," scores = scores.cpu()\n"," labels = labels.cpu()\n"," rank = (-scores).argsort(dim=1)\n"," cut = rank[:, :k]\n"," hits = labels.gather(1, cut)\n"," position = torch.arange(2, 2+k)\n"," weights = 1 / torch.log2(position.float())\n"," dcg = (hits.float() * weights).sum(1)\n"," idcg = torch.Tensor([weights[:min(int(n), k)].sum() for n in labels.sum(1)])\n"," ndcg = dcg / idcg\n"," return ndcg.mean()\n","\n","\n","def recalls_and_ndcgs_for_ks(scores, labels, ks):\n"," metrics = {}\n","\n"," scores = scores\n"," labels = labels\n"," answer_count = labels.sum(1)\n","\n"," labels_float = labels.float()\n"," rank = (-scores).argsort(dim=1)\n"," cut = rank\n"," for k in sorted(ks, reverse=True):\n"," cut = cut[:, :k]\n"," hits = labels_float.gather(1, cut)\n"," metrics['Recall@%d' % k] = \\\n"," (hits.sum(1) / torch.min(torch.Tensor([k]).to(labels.device), labels.sum(1).float())).mean().cpu().item()\n","\n"," position = torch.arange(2, 2+k)\n"," weights = 1 / torch.log2(position.float())\n"," dcg = (hits * weights.to(hits.device)).sum(1)\n"," idcg = torch.Tensor([weights[:min(int(n), k)].sum() for n in answer_count]).to(dcg.device)\n"," ndcg = (dcg / idcg).mean()\n"," metrics['NDCG@%d' % k] = ndcg.cpu().item()\n","\n"," return metrics"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QogDOi3E7SQW"},"source":["### Experiment setup"]},{"cell_type":"code","metadata":{"id":"dqWM-UoAK_43"},"source":["def setup_train(args):\n"," set_up_gpu(args)\n","\n"," export_root = create_experiment_export_folder(args)\n"," export_experiments_config_as_json(args, export_root)\n","\n"," pp.pprint({k: v for k, v in vars(args).items() if v is not None}, width=1)\n"," return export_root\n","\n","\n","def create_experiment_export_folder(args):\n"," experiment_dir, experiment_description = args.experiment_dir, args.experiment_description\n"," if not os.path.exists(experiment_dir):\n"," os.mkdir(experiment_dir)\n"," experiment_path = get_name_of_experiment_path(experiment_dir, experiment_description)\n"," os.mkdir(experiment_path)\n"," print('Folder created: ' + os.path.abspath(experiment_path))\n"," return experiment_path\n","\n","\n","def get_name_of_experiment_path(experiment_dir, experiment_description):\n"," experiment_path = os.path.join(experiment_dir, (experiment_description + \"_\" + str(date.today())))\n"," idx = _get_experiment_index(experiment_path)\n"," experiment_path = experiment_path + \"_\" + str(idx)\n"," return experiment_path\n","\n","\n","def _get_experiment_index(experiment_path):\n"," idx = 0\n"," while os.path.exists(experiment_path + \"_\" + str(idx)):\n"," idx += 1\n"," return idx\n","\n","\n","def load_weights(model, path):\n"," pass\n","\n","\n","def save_test_result(export_root, result):\n"," filepath = Path(export_root).joinpath('test_result.txt')\n"," with filepath.open('w') as f:\n"," json.dump(result, f, indent=2)\n","\n","\n","def export_experiments_config_as_json(args, experiment_path):\n"," with open(os.path.join(experiment_path, 'config.json'), 'w') as outfile:\n"," json.dump(vars(args), outfile, indent=2)\n","\n","\n","def fix_random_seed_as(random_seed):\n"," random.seed(random_seed)\n"," torch.manual_seed(random_seed)\n"," torch.cuda.manual_seed_all(random_seed)\n"," np.random.seed(random_seed)\n"," cudnn.deterministic = True\n"," cudnn.benchmark = False\n","\n","\n","def set_up_gpu(args):\n"," os.environ['CUDA_VISIBLE_DEVICES'] = args.device_idx\n"," args.num_gpu = len(args.device_idx.split(\",\"))\n","\n","\n","def load_pretrained_weights(model, path):\n"," chk_dict = torch.load(os.path.abspath(path))\n"," model_state_dict = chk_dict[STATE_DICT_KEY] if STATE_DICT_KEY in chk_dict else chk_dict['state_dict']\n"," model.load_state_dict(model_state_dict)\n","\n","\n","def setup_to_resume(args, model, optimizer):\n"," chk_dict = torch.load(os.path.join(os.path.abspath(args.resume_training), 'models/checkpoint-recent.pth'))\n"," model.load_state_dict(chk_dict[STATE_DICT_KEY])\n"," optimizer.load_state_dict(chk_dict[OPTIMIZER_STATE_DICT_KEY])\n","\n","\n","def create_optimizer(model, args):\n"," if args.optimizer == 'Adam':\n"," return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n","\n"," return optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)\n","\n","\n","class AverageMeterSet(object):\n"," def __init__(self, meters=None):\n"," self.meters = meters if meters else {}\n","\n"," def __getitem__(self, key):\n"," if key not in self.meters:\n"," meter = AverageMeter()\n"," meter.update(0)\n"," return meter\n"," return self.meters[key]\n","\n"," def update(self, name, value, n=1):\n"," if name not in self.meters:\n"," self.meters[name] = AverageMeter()\n"," self.meters[name].update(value, n)\n","\n"," def reset(self):\n"," for meter in self.meters.values():\n"," meter.reset()\n","\n"," def values(self, format_string='{}'):\n"," return {format_string.format(name): meter.val for name, meter in self.meters.items()}\n","\n"," def averages(self, format_string='{}'):\n"," return {format_string.format(name): meter.avg for name, meter in self.meters.items()}\n","\n"," def sums(self, format_string='{}'):\n"," return {format_string.format(name): meter.sum for name, meter in self.meters.items()}\n","\n"," def counts(self, format_string='{}'):\n"," return {format_string.format(name): meter.count for name, meter in self.meters.items()}\n","\n","\n","class AverageMeter(object):\n"," \"\"\"Computes and stores the average and current value\"\"\"\n","\n"," def __init__(self):\n"," self.val = 0\n"," self.avg = 0\n"," self.sum = 0\n"," self.count = 0\n","\n"," def reset(self):\n"," self.val = 0\n"," self.avg = 0\n"," self.sum = 0\n"," self.count = 0\n","\n"," def update(self, val, n=1):\n"," self.val = val\n"," self.sum += val\n"," self.count += n\n"," self.avg = self.sum / self.count\n","\n"," def __format__(self, format):\n"," return \"{self.val:{format}} ({self.avg:{format}})\".format(self=self, format=format)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7K7DrpeL7Z1f"},"source":["### Logging"]},{"cell_type":"code","metadata":{"id":"5rfOwq6nLLEI"},"source":["def save_state_dict(state_dict, path, filename):\n"," torch.save(state_dict, os.path.join(path, filename))\n","\n","\n","class LoggerService(object):\n"," def __init__(self, train_loggers=None, val_loggers=None):\n"," self.train_loggers = train_loggers if train_loggers else []\n"," self.val_loggers = val_loggers if val_loggers else []\n","\n"," def complete(self, log_data):\n"," for logger in self.train_loggers:\n"," logger.complete(**log_data)\n"," for logger in self.val_loggers:\n"," logger.complete(**log_data)\n","\n"," def log_train(self, log_data):\n"," for logger in self.train_loggers:\n"," logger.log(**log_data)\n","\n"," def log_val(self, log_data):\n"," for logger in self.val_loggers:\n"," logger.log(**log_data)\n","\n","\n","class AbstractBaseLogger(metaclass=ABCMeta):\n"," @abstractmethod\n"," def log(self, *args, **kwargs):\n"," raise NotImplementedError\n","\n"," def complete(self, *args, **kwargs):\n"," pass\n","\n","\n","class RecentModelLogger(AbstractBaseLogger):\n"," def __init__(self, checkpoint_path, filename='checkpoint-recent.pth'):\n"," self.checkpoint_path = checkpoint_path\n"," if not os.path.exists(self.checkpoint_path):\n"," os.mkdir(self.checkpoint_path)\n"," self.recent_epoch = None\n"," self.filename = filename\n","\n"," def log(self, *args, **kwargs):\n"," epoch = kwargs['epoch']\n","\n"," if self.recent_epoch != epoch:\n"," self.recent_epoch = epoch\n"," state_dict = kwargs['state_dict']\n"," state_dict['epoch'] = kwargs['epoch']\n"," save_state_dict(state_dict, self.checkpoint_path, self.filename)\n","\n"," def complete(self, *args, **kwargs):\n"," save_state_dict(kwargs['state_dict'], self.checkpoint_path, self.filename + '.final')\n","\n","\n","class BestModelLogger(AbstractBaseLogger):\n"," def __init__(self, checkpoint_path, metric_key='mean_iou', filename='best_acc_model.pth'):\n"," self.checkpoint_path = checkpoint_path\n"," if not os.path.exists(self.checkpoint_path):\n"," os.mkdir(self.checkpoint_path)\n","\n"," self.best_metric = 0.\n"," self.metric_key = metric_key\n"," self.filename = filename\n","\n"," def log(self, *args, **kwargs):\n"," current_metric = kwargs[self.metric_key]\n"," if self.best_metric < current_metric:\n"," print(\"Update Best {} Model at {}\".format(self.metric_key, kwargs['epoch']))\n"," self.best_metric = current_metric\n"," save_state_dict(kwargs['state_dict'], self.checkpoint_path, self.filename)\n","\n","\n","class MetricGraphPrinter(AbstractBaseLogger):\n"," def __init__(self, writer, key='train_loss', graph_name='Train Loss', group_name='metric'):\n"," self.key = key\n"," self.graph_label = graph_name\n"," self.group_name = group_name\n"," self.writer = writer\n","\n"," def log(self, *args, **kwargs):\n"," if self.key in kwargs:\n"," self.writer.add_scalar(self.group_name + '/' + self.graph_label, kwargs[self.key], kwargs['accum_iter'])\n"," else:\n"," self.writer.add_scalar(self.group_name + '/' + self.graph_label, 0, kwargs['accum_iter'])\n","\n"," def complete(self, *args, **kwargs):\n"," self.writer.close()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XcQfTrxmaIFS"},"source":["## Dataset"]},{"cell_type":"markdown","metadata":{"id":"KeeF9O8d7dmq"},"source":["### Abstract class"]},{"cell_type":"code","metadata":{"id":"Bg1ChmjtF9Vi"},"source":["class AbstractDataset(metaclass=ABCMeta):\n"," def __init__(self, args):\n"," self.args = args\n"," self.min_rating = args.min_rating\n"," self.min_uc = args.min_uc\n"," self.min_sc = args.min_sc\n"," self.split = args.split\n","\n"," assert self.min_uc >= 2, 'Need at least 2 ratings per user for validation and test'\n","\n"," @classmethod\n"," @abstractmethod\n"," def code(cls):\n"," pass\n","\n"," @classmethod\n"," def raw_code(cls):\n"," return cls.code()\n","\n"," @classmethod\n"," @abstractmethod\n"," def url(cls):\n"," pass\n","\n"," @classmethod\n"," def is_zipfile(cls):\n"," return True\n","\n"," @classmethod\n"," def zip_file_content_is_folder(cls):\n"," return True\n","\n"," @classmethod\n"," def all_raw_file_names(cls):\n"," return []\n","\n"," @abstractmethod\n"," def load_ratings_df(self):\n"," pass\n","\n"," def load_dataset(self):\n"," self.preprocess()\n"," dataset_path = self._get_preprocessed_dataset_path()\n"," dataset = pickle.load(dataset_path.open('rb'))\n"," return dataset\n","\n"," def preprocess(self):\n"," dataset_path = self._get_preprocessed_dataset_path()\n"," if dataset_path.is_file():\n"," print('Already preprocessed. Skip preprocessing')\n"," return\n"," if not dataset_path.parent.is_dir():\n"," dataset_path.parent.mkdir(parents=True)\n"," self.maybe_download_raw_dataset()\n"," df = self.load_ratings_df()\n"," df = self.make_implicit(df)\n"," df = self.filter_triplets(df)\n"," df, umap, smap = self.densify_index(df)\n"," train, val, test = self.split_df(df, len(umap))\n"," dataset = {'train': train,\n"," 'val': val,\n"," 'test': test,\n"," 'umap': umap,\n"," 'smap': smap}\n"," with dataset_path.open('wb') as f:\n"," pickle.dump(dataset, f)\n","\n"," def maybe_download_raw_dataset(self):\n"," folder_path = self._get_rawdata_folder_path()\n"," if folder_path.is_dir() and\\\n"," all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):\n"," print('Raw data already exists. Skip downloading')\n"," return\n"," print(\"Raw file doesn't exist. Downloading...\")\n"," if self.is_zipfile():\n"," tmproot = Path(tempfile.mkdtemp())\n"," tmpzip = tmproot.joinpath('file.zip')\n"," tmpfolder = tmproot.joinpath('folder')\n"," download(self.url(), tmpzip)\n"," unzip(tmpzip, tmpfolder)\n"," if self.zip_file_content_is_folder():\n"," tmpfolder = tmpfolder.joinpath(os.listdir(tmpfolder)[0])\n"," shutil.move(tmpfolder, folder_path)\n"," shutil.rmtree(tmproot)\n"," print()\n"," else:\n"," tmproot = Path(tempfile.mkdtemp())\n"," tmpfile = tmproot.joinpath('file')\n"," download(self.url(), tmpfile)\n"," folder_path.mkdir(parents=True)\n"," shutil.move(tmpfile, folder_path.joinpath('ratings.csv'))\n"," shutil.rmtree(tmproot)\n"," print()\n","\n"," def make_implicit(self, df):\n"," print('Turning into implicit ratings')\n"," df = df[df['rating'] >= self.min_rating]\n"," # return df[['uid', 'sid', 'timestamp']]\n"," return df\n","\n"," def filter_triplets(self, df):\n"," print('Filtering triplets')\n"," if self.min_sc > 0:\n"," item_sizes = df.groupby('sid').size()\n"," good_items = item_sizes.index[item_sizes >= self.min_sc]\n"," df = df[df['sid'].isin(good_items)]\n","\n"," if self.min_uc > 0:\n"," user_sizes = df.groupby('uid').size()\n"," good_users = user_sizes.index[user_sizes >= self.min_uc]\n"," df = df[df['uid'].isin(good_users)]\n","\n"," return df\n","\n"," def densify_index(self, df):\n"," print('Densifying index')\n"," umap = {u: i for i, u in enumerate(set(df['uid']))}\n"," smap = {s: i for i, s in enumerate(set(df['sid']))}\n"," df['uid'] = df['uid'].map(umap)\n"," df['sid'] = df['sid'].map(smap)\n"," return df, umap, smap\n","\n"," def split_df(self, df, user_count):\n"," if self.args.split == 'leave_one_out':\n"," print('Splitting')\n"," user_group = df.groupby('uid')\n"," user2items = user_group.progress_apply(lambda d: list(d.sort_values(by='timestamp')['sid']))\n"," train, val, test = {}, {}, {}\n"," for user in range(user_count):\n"," items = user2items[user]\n"," train[user], val[user], test[user] = items[:-2], items[-2:-1], items[-1:]\n"," return train, val, test\n"," elif self.args.split == 'holdout':\n"," print('Splitting')\n"," np.random.seed(self.args.dataset_split_seed)\n"," eval_set_size = self.args.eval_set_size\n","\n"," # Generate user indices\n"," permuted_index = np.random.permutation(user_count)\n"," train_user_index = permuted_index[ :-2*eval_set_size]\n"," val_user_index = permuted_index[-2*eval_set_size: -eval_set_size]\n"," test_user_index = permuted_index[ -eval_set_size: ]\n","\n"," # Split DataFrames\n"," train_df = df.loc[df['uid'].isin(train_user_index)]\n"," val_df = df.loc[df['uid'].isin(val_user_index)]\n"," test_df = df.loc[df['uid'].isin(test_user_index)]\n","\n"," # DataFrame to dict => {uid : list of sid's}\n"," train = dict(train_df.groupby('uid').progress_apply(lambda d: list(d['sid'])))\n"," val = dict(val_df.groupby('uid').progress_apply(lambda d: list(d['sid'])))\n"," test = dict(test_df.groupby('uid').progress_apply(lambda d: list(d['sid'])))\n"," return train, val, test\n"," else:\n"," raise NotImplementedError\n","\n"," def _get_rawdata_root_path(self):\n"," return Path(RAW_DATASET_ROOT_FOLDER)\n","\n"," def _get_rawdata_folder_path(self):\n"," root = self._get_rawdata_root_path()\n"," return root.joinpath(self.raw_code())\n","\n"," def _get_preprocessed_root_path(self):\n"," root = self._get_rawdata_root_path()\n"," return root.joinpath('preprocessed')\n","\n"," def _get_preprocessed_folder_path(self):\n"," preprocessed_root = self._get_preprocessed_root_path()\n"," folder_name = '{}_min_rating{}-min_uc{}-min_sc{}-split{}' \\\n"," .format(self.code(), self.min_rating, self.min_uc, self.min_sc, self.split)\n"," return preprocessed_root.joinpath(folder_name)\n","\n"," def _get_preprocessed_dataset_path(self):\n"," folder = self._get_preprocessed_folder_path()\n"," return folder.joinpath('dataset.pkl')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WO4WVLTg7f8V"},"source":["### ML1M class"]},{"cell_type":"code","metadata":{"id":"ft3ctaOXGSrx"},"source":["class ML1MDataset(AbstractDataset):\n"," @classmethod\n"," def code(cls):\n"," return 'ml-1m'\n","\n"," @classmethod\n"," def url(cls):\n"," return 'http://files.grouplens.org/datasets/movielens/ml-1m.zip'\n","\n"," @classmethod\n"," def zip_file_content_is_folder(cls):\n"," return True\n","\n"," @classmethod\n"," def all_raw_file_names(cls):\n"," return ['README',\n"," 'movies.dat',\n"," 'ratings.dat',\n"," 'users.dat']\n","\n"," def load_ratings_df(self):\n"," folder_path = self._get_rawdata_folder_path()\n"," file_path = folder_path.joinpath('ratings.dat')\n"," df = pd.read_csv(file_path, sep='::', header=None)\n"," df.columns = ['uid', 'sid', 'rating', 'timestamp']\n"," return df"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"uDMosTT57h0M"},"source":["### Manager"]},{"cell_type":"code","metadata":{"id":"UgnbdrGxFh6B"},"source":["DATASETS = {\n"," ML1MDataset.code(): ML1MDataset\n","}\n","\n","\n","def dataset_factory(args):\n"," dataset = DATASETS[args.dataset_code]\n"," return dataset(args)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"FU-1tB3SaK7u"},"source":["## Negative sampling"]},{"cell_type":"markdown","metadata":{"id":"WLQbqies7mEe"},"source":["### Abstract class"]},{"cell_type":"code","metadata":{"id":"77GSVXs7IZgP"},"source":["class AbstractNegativeSampler(metaclass=ABCMeta):\n"," def __init__(self, train, val, test, user_count, item_count, sample_size, seed, save_folder):\n"," self.train = train\n"," self.val = val\n"," self.test = test\n"," self.user_count = user_count\n"," self.item_count = item_count\n"," self.sample_size = sample_size\n"," self.seed = seed\n"," self.save_folder = save_folder\n","\n"," @classmethod\n"," @abstractmethod\n"," def code(cls):\n"," pass\n","\n"," @abstractmethod\n"," def generate_negative_samples(self):\n"," pass\n","\n"," def get_negative_samples(self):\n"," savefile_path = self._get_save_path()\n"," if savefile_path.is_file():\n"," print('Negatives samples exist. Loading.')\n"," negative_samples = pickle.load(savefile_path.open('rb'))\n"," return negative_samples\n"," print(\"Negative samples don't exist. Generating.\")\n"," negative_samples = self.generate_negative_samples()\n"," with savefile_path.open('wb') as f:\n"," pickle.dump(negative_samples, f)\n"," return negative_samples\n","\n"," def _get_save_path(self):\n"," folder = Path(self.save_folder)\n"," filename = '{}-sample_size{}-seed{}.pkl'.format(self.code(), self.sample_size, self.seed)\n"," return folder.joinpath(filename)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sm1cpuBx7pVU"},"source":["### Random negative sampling"]},{"cell_type":"code","metadata":{"id":"vuvvsiHFIgnP"},"source":["class RandomNegativeSampler(AbstractNegativeSampler):\n"," @classmethod\n"," def code(cls):\n"," return 'random'\n","\n"," def generate_negative_samples(self):\n"," assert self.seed is not None, 'Specify seed for random sampling'\n"," np.random.seed(self.seed)\n"," negative_samples = {}\n"," print('Sampling negative items')\n"," for user in trange(self.user_count):\n"," if isinstance(self.train[user][1], tuple):\n"," seen = set(x[0] for x in self.train[user])\n"," seen.update(x[0] for x in self.val[user])\n"," seen.update(x[0] for x in self.test[user])\n"," else:\n"," seen = set(self.train[user])\n"," seen.update(self.val[user])\n"," seen.update(self.test[user])\n","\n"," samples = []\n"," for _ in range(self.sample_size):\n"," item = np.random.choice(self.item_count) + 1\n"," while item in seen or item in samples:\n"," item = np.random.choice(self.item_count) + 1\n"," samples.append(item)\n","\n"," negative_samples[user] = samples\n","\n"," return negative_samples"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZzZe59fr7s-B"},"source":["### Manager"]},{"cell_type":"code","metadata":{"id":"tYSSmJg2IZdu"},"source":["NEGATIVE_SAMPLERS = {\n"," RandomNegativeSampler.code(): RandomNegativeSampler,\n","}\n","\n","def negative_sampler_factory(code, train, val, test, user_count, item_count, sample_size, seed, save_folder):\n"," negative_sampler = NEGATIVE_SAMPLERS[code]\n"," return negative_sampler(train, val, test, user_count, item_count, sample_size, seed, save_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"F-RX320PaV6e"},"source":["## Dataloader"]},{"cell_type":"markdown","metadata":{"id":"f9t02l5u7wu8"},"source":["### Abstract class"]},{"cell_type":"code","metadata":{"id":"z6c630DnIJLm"},"source":["class AbstractDataloader(metaclass=ABCMeta):\n"," def __init__(self, args, dataset):\n"," self.args = args\n"," seed = args.dataloader_random_seed\n"," self.rng = random.Random(seed)\n"," self.save_folder = dataset._get_preprocessed_folder_path()\n"," dataset = dataset.load_dataset()\n"," self.train = dataset['train']\n"," self.val = dataset['val']\n"," self.test = dataset['test']\n"," self.umap = dataset['umap']\n"," self.smap = dataset['smap']\n"," self.user_count = len(self.umap)\n"," self.item_count = len(self.smap)\n","\n"," @classmethod\n"," @abstractmethod\n"," def code(cls):\n"," pass\n","\n"," @abstractmethod\n"," def get_pytorch_dataloaders(self):\n"," pass"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3dqOj8Np71bf"},"source":["### BERT dataloader"]},{"cell_type":"code","metadata":{"id":"fqLUnykjIJIT"},"source":["class BertDataloader(AbstractDataloader):\n"," def __init__(self, args, dataset):\n"," super().__init__(args, dataset)\n"," args.num_items = len(self.smap)\n"," self.max_len = args.bert_max_len\n"," self.mask_prob = args.bert_mask_prob\n"," self.CLOZE_MASK_TOKEN = self.item_count + 1\n","\n"," code = args.train_negative_sampler_code\n"," train_negative_sampler = negative_sampler_factory(code, self.train, self.val, self.test,\n"," self.user_count, self.item_count,\n"," args.train_negative_sample_size,\n"," args.train_negative_sampling_seed,\n"," self.save_folder)\n"," code = args.test_negative_sampler_code\n"," test_negative_sampler = negative_sampler_factory(code, self.train, self.val, self.test,\n"," self.user_count, self.item_count,\n"," args.test_negative_sample_size,\n"," args.test_negative_sampling_seed,\n"," self.save_folder)\n","\n"," self.train_negative_samples = train_negative_sampler.get_negative_samples()\n"," self.test_negative_samples = test_negative_sampler.get_negative_samples()\n","\n"," @classmethod\n"," def code(cls):\n"," return 'bert'\n","\n"," def get_pytorch_dataloaders(self):\n"," train_loader = self._get_train_loader()\n"," val_loader = self._get_val_loader()\n"," test_loader = self._get_test_loader()\n"," return train_loader, val_loader, test_loader\n","\n"," def _get_train_loader(self):\n"," dataset = self._get_train_dataset()\n"," dataloader = data_utils.DataLoader(dataset, batch_size=self.args.train_batch_size,\n"," shuffle=True, pin_memory=True)\n"," return dataloader\n","\n"," def _get_train_dataset(self):\n"," dataset = BertTrainDataset(self.train, self.max_len, self.mask_prob, self.CLOZE_MASK_TOKEN, self.item_count, self.rng)\n"," return dataset\n","\n"," def _get_val_loader(self):\n"," return self._get_eval_loader(mode='val')\n","\n"," def _get_test_loader(self):\n"," return self._get_eval_loader(mode='test')\n","\n"," def _get_eval_loader(self, mode):\n"," batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size\n"," dataset = self._get_eval_dataset(mode)\n"," dataloader = data_utils.DataLoader(dataset, batch_size=batch_size,\n"," shuffle=False, pin_memory=True)\n"," return dataloader\n","\n"," def _get_eval_dataset(self, mode):\n"," answers = self.val if mode == 'val' else self.test\n"," dataset = BertEvalDataset(self.train, answers, self.max_len, self.CLOZE_MASK_TOKEN, self.test_negative_samples)\n"," return dataset\n","\n","\n","class BertTrainDataset(data_utils.Dataset):\n"," def __init__(self, u2seq, max_len, mask_prob, mask_token, num_items, rng):\n"," self.u2seq = u2seq\n"," self.users = sorted(self.u2seq.keys())\n"," self.max_len = max_len\n"," self.mask_prob = mask_prob\n"," self.mask_token = mask_token\n"," self.num_items = num_items\n"," self.rng = rng\n","\n"," def __len__(self):\n"," return len(self.users)\n","\n"," def __getitem__(self, index):\n"," user = self.users[index]\n"," seq = self._getseq(user)\n","\n"," tokens = []\n"," labels = []\n"," for s in seq:\n"," prob = self.rng.random()\n"," if prob < self.mask_prob:\n"," prob /= self.mask_prob\n","\n"," if prob < 0.8:\n"," tokens.append(self.mask_token)\n"," elif prob < 0.9:\n"," tokens.append(self.rng.randint(1, self.num_items))\n"," else:\n"," tokens.append(s)\n","\n"," labels.append(s)\n"," else:\n"," tokens.append(s)\n"," labels.append(0)\n","\n"," tokens = tokens[-self.max_len:]\n"," labels = labels[-self.max_len:]\n","\n"," mask_len = self.max_len - len(tokens)\n","\n"," tokens = [0] * mask_len + tokens\n"," labels = [0] * mask_len + labels\n","\n"," return torch.LongTensor(tokens), torch.LongTensor(labels)\n","\n"," def _getseq(self, user):\n"," return self.u2seq[user]\n","\n","\n","\n","class BertEvalDataset(data_utils.Dataset):\n"," def __init__(self, u2seq, u2answer, max_len, mask_token, negative_samples):\n"," self.u2seq = u2seq\n"," self.users = sorted(self.u2seq.keys())\n"," self.u2answer = u2answer\n"," self.max_len = max_len\n"," self.mask_token = mask_token\n"," self.negative_samples = negative_samples\n","\n"," def __len__(self):\n"," return len(self.users)\n","\n"," def __getitem__(self, index):\n"," user = self.users[index]\n"," seq = self.u2seq[user]\n"," answer = self.u2answer[user]\n"," negs = self.negative_samples[user]\n","\n"," candidates = answer + negs\n"," labels = [1] * len(answer) + [0] * len(negs)\n","\n"," seq = seq + [self.mask_token]\n"," seq = seq[-self.max_len:]\n"," padding_len = self.max_len - len(seq)\n"," seq = [0] * padding_len + seq\n","\n"," return torch.LongTensor(seq), torch.LongTensor(candidates), torch.LongTensor(labels)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UB5TW2DJ745i"},"source":["### Manager"]},{"cell_type":"code","metadata":{"id":"vljL5_uSGmB3"},"source":["DATALOADERS = {\n"," BertDataloader.code(): BertDataloader,\n","}\n","\n","\n","def dataloader_factory(args):\n"," dataset = dataset_factory(args)\n"," dataloader = DATALOADERS[args.dataloader_code]\n"," dataloader = dataloader(args, dataset)\n"," train, val, test = dataloader.get_pytorch_dataloaders()\n"," return train, val, test"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"v3r-1B3Oaiq7"},"source":["## Model"]},{"cell_type":"code","metadata":{"id":"k3NFXrgtJBXj"},"source":["class LayerNorm(nn.Module):\n"," \"Construct a layernorm module (See citation for details).\"\n"," def __init__(self, features, eps=1e-6):\n"," super(LayerNorm, self).__init__()\n"," self.a_2 = nn.Parameter(torch.ones(features))\n"," self.b_2 = nn.Parameter(torch.zeros(features))\n"," self.eps = eps\n","\n"," def forward(self, x):\n"," mean = x.mean(-1, keepdim=True)\n"," std = x.std(-1, keepdim=True)\n"," return self.a_2 * (x - mean) / (std + self.eps) + self.b_2\n","\n","\n","class SublayerConnection(nn.Module):\n"," \"\"\"\n"," A residual connection followed by a layer norm.\n"," Note for code simplicity the norm is first as opposed to last.\n"," \"\"\"\n","\n"," def __init__(self, size, dropout):\n"," super(SublayerConnection, self).__init__()\n"," self.norm = LayerNorm(size)\n"," self.dropout = nn.Dropout(dropout)\n","\n"," def forward(self, x, sublayer):\n"," \"Apply residual connection to any sublayer with the same size.\"\n"," return x + self.dropout(sublayer(self.norm(x)))\n","\n","\n","class GELU(nn.Module):\n"," \"\"\"\n"," Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU\n"," \"\"\"\n"," def forward(self, x):\n"," return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))\n","\n","\n","class PositionwiseFeedForward(nn.Module):\n"," \"Implements FFN equation.\"\n"," def __init__(self, d_model, d_ff, dropout=0.1):\n"," super(PositionwiseFeedForward, self).__init__()\n"," self.w_1 = nn.Linear(d_model, d_ff)\n"," self.w_2 = nn.Linear(d_ff, d_model)\n"," self.dropout = nn.Dropout(dropout)\n"," self.activation = GELU()\n","\n"," def forward(self, x):\n"," return self.w_2(self.dropout(self.activation(self.w_1(x))))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"RLgZ4CpJJBRU"},"source":["class Attention(nn.Module):\n"," \"\"\"\n"," Compute 'Scaled Dot Product Attention\n"," \"\"\"\n"," def forward(self, query, key, value, mask=None, dropout=None):\n"," scores = torch.matmul(query, key.transpose(-2, -1)) \\\n"," / math.sqrt(query.size(-1))\n","\n"," if mask is not None:\n"," scores = scores.masked_fill(mask == 0, -1e9)\n","\n"," p_attn = F.softmax(scores, dim=-1)\n","\n"," if dropout is not None:\n"," p_attn = dropout(p_attn)\n","\n"," return torch.matmul(p_attn, value), p_attn\n","\n","\n","class MultiHeadedAttention(nn.Module):\n"," \"\"\"\n"," Take in model size and number of heads.\n"," \"\"\"\n"," def __init__(self, h, d_model, dropout=0.1):\n"," super().__init__()\n"," assert d_model % h == 0\n","\n"," # We assume d_v always equals d_k\n"," self.d_k = d_model // h\n"," self.h = h\n","\n"," self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])\n"," self.output_linear = nn.Linear(d_model, d_model)\n"," self.attention = Attention()\n","\n"," self.dropout = nn.Dropout(p=dropout)\n","\n"," def forward(self, query, key, value, mask=None):\n"," batch_size = query.size(0)\n","\n"," # 1) Do all the linear projections in batch from d_model => h x d_k\n"," query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)\n"," for l, x in zip(self.linear_layers, (query, key, value))]\n","\n"," # 2) Apply attention on all the projected vectors in batch.\n"," x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)\n","\n"," # 3) \"Concat\" using a view and apply a final linear.\n"," x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)\n","\n"," return self.output_linear(x)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"sQlJoNe-JcX4"},"source":["class PositionalEmbedding(nn.Module):\n"," def __init__(self, max_len, d_model):\n"," super().__init__()\n","\n"," # Compute the positional encodings once in log space.\n"," self.pe = nn.Embedding(max_len, d_model)\n","\n"," def forward(self, x):\n"," batch_size = x.size(0)\n"," return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1)\n","\n"," \n","class SegmentEmbedding(nn.Embedding):\n"," def __init__(self, embed_size=512):\n"," super().__init__(3, embed_size, padding_idx=0)\n","\n","\n","class TokenEmbedding(nn.Embedding):\n"," def __init__(self, vocab_size, embed_size=512):\n"," super().__init__(vocab_size, embed_size, padding_idx=0)\n","\n","\n","class BERTEmbedding(nn.Module):\n"," \"\"\"\n"," BERT Embedding which is consisted with under features\n"," 1. TokenEmbedding : normal embedding matrix\n"," 2. PositionalEmbedding : adding positional information using sin, cos\n"," 2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2)\n"," sum of all these features are output of BERTEmbedding\n"," \"\"\"\n"," def __init__(self, vocab_size, embed_size, max_len, dropout=0.1):\n"," \"\"\"\n"," :param vocab_size: total vocab size\n"," :param embed_size: embedding size of token embedding\n"," :param dropout: dropout rate\n"," \"\"\"\n"," super().__init__()\n"," self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)\n"," self.position = PositionalEmbedding(max_len=max_len, d_model=embed_size)\n"," # self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)\n"," self.dropout = nn.Dropout(p=dropout)\n"," self.embed_size = embed_size\n","\n"," def forward(self, sequence):\n"," x = self.token(sequence) + self.position(sequence) # + self.segment(segment_label)\n"," return self.dropout(x)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"6MUoT1WTKEPF"},"source":["class TransformerBlock(nn.Module):\n"," \"\"\"\n"," Bidirectional Encoder = Transformer (self-attention)\n"," Transformer = MultiHead_Attention + Feed_Forward with sublayer connection\n"," \"\"\"\n"," def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):\n"," \"\"\"\n"," :param hidden: hidden size of transformer\n"," :param attn_heads: head sizes of multi-head attention\n"," :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size\n"," :param dropout: dropout rate\n"," \"\"\"\n","\n"," super().__init__()\n"," self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)\n"," self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)\n"," self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)\n"," self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)\n"," self.dropout = nn.Dropout(p=dropout)\n","\n"," def forward(self, x, mask):\n"," x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))\n"," x = self.output_sublayer(x, self.feed_forward)\n"," return self.dropout(x)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"exnahi0aJ9zn"},"source":["class BERT(nn.Module):\n"," def __init__(self, args):\n"," super().__init__()\n","\n"," fix_random_seed_as(args.model_init_seed)\n"," # self.init_weights()\n","\n"," max_len = args.bert_max_len\n"," num_items = args.num_items\n"," n_layers = args.bert_num_blocks\n"," heads = args.bert_num_heads\n"," vocab_size = num_items + 2\n"," hidden = args.bert_hidden_units\n"," self.hidden = hidden\n"," dropout = args.bert_dropout\n","\n"," # embedding for BERT, sum of positional, segment, token embeddings\n"," self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=self.hidden, max_len=max_len, dropout=dropout)\n","\n"," # multi-layers transformer blocks, deep network\n"," self.transformer_blocks = nn.ModuleList(\n"," [TransformerBlock(hidden, heads, hidden * 4, dropout) for _ in range(n_layers)])\n","\n"," def forward(self, x):\n"," mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)\n","\n"," # embedding the indexed sequence to sequence of vectors\n"," x = self.embedding(x)\n","\n"," # running over multiple transformer blocks\n"," for transformer in self.transformer_blocks:\n"," x = transformer.forward(x, mask)\n","\n"," return x\n","\n"," def init_weights(self):\n"," pass"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"UMR-_ls5Itog"},"source":["class BaseModel(nn.Module, metaclass=ABCMeta):\n"," def __init__(self, args):\n"," super().__init__()\n"," self.args = args\n","\n"," @classmethod\n"," @abstractmethod\n"," def code(cls):\n"," pass"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"iY6_tI_nItmA"},"source":["class BERTModel(BaseModel):\n"," def __init__(self, args):\n"," super().__init__(args)\n"," self.bert = BERT(args)\n"," self.out = nn.Linear(self.bert.hidden, args.num_items + 1)\n","\n"," @classmethod\n"," def code(cls):\n"," return 'bert'\n","\n"," def forward(self, x):\n"," x = self.bert(x)\n"," return self.out(x)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"WWexEPvmIti7"},"source":["MODELS = {\n"," BERTModel.code(): BERTModel,\n","}\n","\n","\n","def model_factory(args):\n"," model = MODELS[args.model_code]\n"," return model(args)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EDi_Pj_1amv4"},"source":["## Training"]},{"cell_type":"code","metadata":{"id":"YD47UKG7KR6M"},"source":["class AbstractTrainer(metaclass=ABCMeta):\n"," def __init__(self, args, model, train_loader, val_loader, test_loader, export_root):\n"," self.args = args\n"," self.device = args.device\n"," self.model = model.to(self.device)\n"," self.is_parallel = args.num_gpu > 1\n"," if self.is_parallel:\n"," self.model = nn.DataParallel(self.model)\n","\n"," self.train_loader = train_loader\n"," self.val_loader = val_loader\n"," self.test_loader = test_loader\n"," self.optimizer = self._create_optimizer()\n"," if args.enable_lr_schedule:\n"," self.lr_scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=args.decay_step, gamma=args.gamma)\n","\n"," self.num_epochs = args.num_epochs\n"," self.metric_ks = args.metric_ks\n"," self.best_metric = args.best_metric\n","\n"," self.export_root = export_root\n"," self.writer, self.train_loggers, self.val_loggers = self._create_loggers()\n"," self.add_extra_loggers()\n"," self.logger_service = LoggerService(self.train_loggers, self.val_loggers)\n"," self.log_period_as_iter = args.log_period_as_iter\n","\n"," @abstractmethod\n"," def add_extra_loggers(self):\n"," pass\n","\n"," @abstractmethod\n"," def log_extra_train_info(self, log_data):\n"," pass\n","\n"," @abstractmethod\n"," def log_extra_val_info(self, log_data):\n"," pass\n","\n"," @classmethod\n"," @abstractmethod\n"," def code(cls):\n"," pass\n","\n"," @abstractmethod\n"," def calculate_loss(self, batch):\n"," pass\n","\n"," @abstractmethod\n"," def calculate_metrics(self, batch):\n"," pass\n","\n"," def train(self):\n"," accum_iter = 0\n"," self.validate(0, accum_iter)\n"," for epoch in range(self.num_epochs):\n"," accum_iter = self.train_one_epoch(epoch, accum_iter)\n"," self.validate(epoch, accum_iter)\n"," self.logger_service.complete({\n"," 'state_dict': (self._create_state_dict()),\n"," })\n"," self.writer.close()\n","\n"," def train_one_epoch(self, epoch, accum_iter):\n"," self.model.train()\n"," if self.args.enable_lr_schedule:\n"," self.lr_scheduler.step()\n","\n"," average_meter_set = AverageMeterSet()\n"," tqdm_dataloader = tqdm(self.train_loader)\n","\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," batch_size = batch[0].size(0)\n"," batch = [x.to(self.device) for x in batch]\n","\n"," self.optimizer.zero_grad()\n"," loss = self.calculate_loss(batch)\n"," loss.backward()\n","\n"," self.optimizer.step()\n","\n"," average_meter_set.update('loss', loss.item())\n"," tqdm_dataloader.set_description(\n"," 'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))\n","\n"," accum_iter += batch_size\n","\n"," if self._needs_to_log(accum_iter):\n"," tqdm_dataloader.set_description('Logging to Tensorboard')\n"," log_data = {\n"," 'state_dict': (self._create_state_dict()),\n"," 'epoch': epoch+1,\n"," 'accum_iter': accum_iter,\n"," }\n"," log_data.update(average_meter_set.averages())\n"," self.log_extra_train_info(log_data)\n"," self.logger_service.log_train(log_data)\n","\n"," return accum_iter\n","\n"," def validate(self, epoch, accum_iter):\n"," self.model.eval()\n","\n"," average_meter_set = AverageMeterSet()\n","\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.val_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," batch = [x.to(self.device) for x in batch]\n","\n"," metrics = self.calculate_metrics(batch)\n","\n"," for k, v in metrics.items():\n"," average_meter_set.update(k, v)\n"," description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\\\n"," ['Recall@%d' % k for k in self.metric_ks[:3]]\n"," description = 'Val: ' + ', '.join(s + ' {:.3f}' for s in description_metrics)\n"," description = description.replace('NDCG', 'N').replace('Recall', 'R')\n"," description = description.format(*(average_meter_set[k].avg for k in description_metrics))\n"," tqdm_dataloader.set_description(description)\n","\n"," log_data = {\n"," 'state_dict': (self._create_state_dict()),\n"," 'epoch': epoch+1,\n"," 'accum_iter': accum_iter,\n"," }\n"," log_data.update(average_meter_set.averages())\n"," self.log_extra_val_info(log_data)\n"," self.logger_service.log_val(log_data)\n","\n"," def test(self):\n"," print('Test best model with test set!')\n","\n"," best_model = torch.load(os.path.join(self.export_root, 'models', 'best_acc_model.pth')).get('model_state_dict')\n"," self.model.load_state_dict(best_model)\n"," self.model.eval()\n","\n"," average_meter_set = AverageMeterSet()\n","\n"," with torch.no_grad():\n"," tqdm_dataloader = tqdm(self.test_loader)\n"," for batch_idx, batch in enumerate(tqdm_dataloader):\n"," batch = [x.to(self.device) for x in batch]\n","\n"," metrics = self.calculate_metrics(batch)\n","\n"," for k, v in metrics.items():\n"," average_meter_set.update(k, v)\n"," description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\\\n"," ['Recall@%d' % k for k in self.metric_ks[:3]]\n"," description = 'Val: ' + ', '.join(s + ' {:.3f}' for s in description_metrics)\n"," description = description.replace('NDCG', 'N').replace('Recall', 'R')\n"," description = description.format(*(average_meter_set[k].avg for k in description_metrics))\n"," tqdm_dataloader.set_description(description)\n","\n"," average_metrics = average_meter_set.averages()\n"," with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:\n"," json.dump(average_metrics, f, indent=4)\n"," print(average_metrics)\n","\n"," def _create_optimizer(self):\n"," args = self.args\n"," if args.optimizer.lower() == 'adam':\n"," return optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n"," elif args.optimizer.lower() == 'sgd':\n"," return optim.SGD(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)\n"," else:\n"," raise ValueError\n","\n"," def _create_loggers(self):\n"," root = Path(self.export_root)\n"," writer = SummaryWriter(root.joinpath('logs'))\n"," model_checkpoint = root.joinpath('models')\n","\n"," train_loggers = [\n"," MetricGraphPrinter(writer, key='epoch', graph_name='Epoch', group_name='Train'),\n"," MetricGraphPrinter(writer, key='loss', graph_name='Loss', group_name='Train'),\n"," ]\n","\n"," val_loggers = []\n"," for k in self.metric_ks:\n"," val_loggers.append(\n"," MetricGraphPrinter(writer, key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation'))\n"," val_loggers.append(\n"," MetricGraphPrinter(writer, key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation'))\n"," val_loggers.append(RecentModelLogger(model_checkpoint))\n"," val_loggers.append(BestModelLogger(model_checkpoint, metric_key=self.best_metric))\n"," return writer, train_loggers, val_loggers\n","\n"," def _create_state_dict(self):\n"," return {\n"," STATE_DICT_KEY: self.model.module.state_dict() if self.is_parallel else self.model.state_dict(),\n"," OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),\n"," }\n","\n"," def _needs_to_log(self, accum_iter):\n"," return accum_iter % self.log_period_as_iter < self.args.train_batch_size and accum_iter != 0"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"q3WwY7cIKR4L"},"source":["class BERTTrainer(AbstractTrainer):\n"," def __init__(self, args, model, train_loader, val_loader, test_loader, export_root):\n"," super().__init__(args, model, train_loader, val_loader, test_loader, export_root)\n"," self.ce = nn.CrossEntropyLoss(ignore_index=0)\n","\n"," @classmethod\n"," def code(cls):\n"," return 'bert'\n","\n"," def add_extra_loggers(self):\n"," pass\n","\n"," def log_extra_train_info(self, log_data):\n"," pass\n","\n"," def log_extra_val_info(self, log_data):\n"," pass\n","\n"," def calculate_loss(self, batch):\n"," seqs, labels = batch\n"," logits = self.model(seqs) # B x T x V\n","\n"," logits = logits.view(-1, logits.size(-1)) # (B*T) x V\n"," labels = labels.view(-1) # B*T\n"," loss = self.ce(logits, labels)\n"," return loss\n","\n"," def calculate_metrics(self, batch):\n"," seqs, candidates, labels = batch\n"," scores = self.model(seqs) # B x T x V\n"," scores = scores[:, -1, :] # B x V\n"," scores = scores.gather(1, candidates) # B x C\n","\n"," metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)\n"," return metrics"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"iM00f0PAKR1T"},"source":["TRAINERS = {\n"," BERTTrainer.code(): BERTTrainer,\n","}\n","\n","\n","def trainer_factory(args, model, train_loader, val_loader, test_loader, export_root):\n"," trainer = TRAINERS[args.trainer_code]\n"," return trainer(args, model, train_loader, val_loader, test_loader, export_root)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"background_save":true,"base_uri":"https://localhost:8080/"},"id":"id-5bt57LmRW","executionInfo":{"elapsed":283689,"status":"ok","timestamp":1632645271496,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"},"user_tz":-330},"outputId":"98a71d76-b161-4f93-a682-806530d3c6ef"},"source":["def train():\n"," export_root = setup_train(args)\n"," train_loader, val_loader, test_loader = dataloader_factory(args)\n"," model = model_factory(args)\n"," trainer = trainer_factory(args, model, train_loader, val_loader, test_loader, export_root)\n"," trainer.train()\n","\n"," test_model = (input('Test model with test dataset? y/[n]: ') == 'y')\n"," if test_model:\n"," trainer.test()\n","\n","\n","if __name__ == '__main__':\n"," if args.mode == 'train':\n"," train()\n"," else:\n"," raise ValueError('Invalid mode')"],"execution_count":null,"outputs":[{"name":"stdout","output_type":"stream","text":["Folder created: /content/experiments/test_2021-09-26_1\n","{'num_gpu': 1}\n","Already preprocessed. Skip preprocessing\n","Negatives samples exist. Loading.\n","Negatives samples exist. Loading.\n"]},{"name":"stderr","output_type":"stream","text":["Val: N@1 0.013, N@5 0.035, N@10 0.051, R@1 0.013, R@5 0.059, R@10 0.111: 100%|██████████| 48/48 [01:30<00:00, 1.88s/it]\n","/usr/local/lib/python3.7/dist-packages/torch/optim/lr_scheduler.py:134: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n"," \"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\", UserWarning)\n"]},{"name":"stdout","output_type":"stream","text":["Update Best NDCG@10 Model at 1\n"]},{"name":"stderr","output_type":"stream","text":["Epoch 1, loss 7.841 : 100%|██████████| 48/48 [05:01<00:00, 6.29s/it]\n","Val: N@1 0.083, N@5 0.185, N@10 0.232, R@1 0.083, R@5 0.285, R@10 0.430: 100%|██████████| 48/48 [01:36<00:00, 2.00s/it]\n"]},{"name":"stdout","output_type":"stream","text":["Update Best NDCG@10 Model at 1\n"]},{"name":"stderr","output_type":"stream","text":["Epoch 2, loss 7.427 : 100%|██████████| 48/48 [04:59<00:00, 6.24s/it]\n","Val: N@1 0.099, N@5 0.208, N@10 0.255, R@1 0.099, R@5 0.315, R@10 0.461: 100%|██████████| 48/48 [01:35<00:00, 2.00s/it]\n"]},{"name":"stdout","output_type":"stream","text":["Update Best NDCG@10 Model at 2\n"]},{"name":"stderr","output_type":"stream","text":["Epoch 3, loss 7.192 : 100%|██████████| 48/48 [04:59<00:00, 6.24s/it]\n","Val: N@1 0.104, N@5 0.223, N@10 0.276, R@1 0.104, R@5 0.336, R@10 0.501: 100%|██████████| 48/48 [01:36<00:00, 2.00s/it]\n"]},{"name":"stdout","output_type":"stream","text":["Update Best NDCG@10 Model at 3\n"]},{"name":"stderr","output_type":"stream","text":["Epoch 4, loss 6.957 : 100%|██████████| 48/48 [05:07<00:00, 6.40s/it]\n","Val: N@1 0.132, N@5 0.263, N@10 0.315, R@1 0.132, R@5 0.390, R@10 0.553: 100%|██████████| 48/48 [01:41<00:00, 2.11s/it]\n"]},{"name":"stdout","output_type":"stream","text":["Update Best NDCG@10 Model at 4\n"]},{"name":"stderr","output_type":"stream","text":["Epoch 5, loss 6.783 : 100%|██████████| 48/48 [05:07<00:00, 6.41s/it]\n","Val: N@1 0.154, N@5 0.302, N@10 0.357, R@1 0.154, R@5 0.444, R@10 0.616: 100%|██████████| 48/48 [01:44<00:00, 2.17s/it]\n"]},{"name":"stdout","output_type":"stream","text":["Update Best NDCG@10 Model at 5\n"]},{"name":"stderr","output_type":"stream","text":["Epoch 6, loss 6.606 : 100%|██████████| 48/48 [05:08<00:00, 6.43s/it]\n","Val: N@1 0.174, N@5 0.335, N@10 0.383, R@1 0.174, R@5 0.484, R@10 0.634: 100%|██████████| 48/48 [01:44<00:00, 2.17s/it]\n"]},{"name":"stdout","output_type":"stream","text":["Update Best NDCG@10 Model at 6\n"]},{"name":"stderr","output_type":"stream","text":["Epoch 7, loss 6.440 : 100%|██████████| 48/48 [05:08<00:00, 6.44s/it]\n","Val: N@1 0.197, N@5 0.367, N@10 0.414, R@1 0.197, R@5 0.521, R@10 0.668: 100%|██████████| 48/48 [01:44<00:00, 2.17s/it]\n"]},{"name":"stdout","output_type":"stream","text":["Update Best NDCG@10 Model at 7\n"]},{"name":"stderr","output_type":"stream","text":["Epoch 8, loss 6.323 : 100%|██████████| 48/48 [05:08<00:00, 6.42s/it]\n","Val: N@1 0.219, N@5 0.388, N@10 0.434, R@1 0.219, R@5 0.545, R@10 0.687: 100%|██████████| 48/48 [01:44<00:00, 2.18s/it]\n"]},{"name":"stdout","output_type":"stream","text":["Update Best NDCG@10 Model at 8\n"]},{"name":"stderr","output_type":"stream","text":["Epoch 9, loss 6.222 : 79%|███████▉ | 38/48 [04:11<01:05, 6.57s/it]"]}]}]}