{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-24-simplex-mf.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T279677%20%7C%20SimpleX%20MF%20Model%20on%20ML-100k%20Dataset.ipynb","timestamp":1644665914220}],"collapsed_sections":[],"authorship_tag":"ABX9TyPVzsui/vwvXI0btF8U9T7z"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# SimpleX MF Model on ML-100k Dataset"],"metadata":{"id":"19bXzza1v9P0"}},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Xr1pHrKo6fRY","executionInfo":{"status":"ok","timestamp":1633243786239,"user_tz":-330,"elapsed":13929,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"b31f099d-0834-491f-ed80-aa566f78e0af"},"source":["!pip install -q git+https://github.com/sparsh-ai/recochef"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":[" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n"," Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n"," Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n","\u001b[K |████████████████████████████████| 4.3 MB 5.1 MB/s \n","\u001b[?25h Building wheel for recochef (PEP 517) ... \u001b[?25l\u001b[?25hdone\n"]}]},{"cell_type":"code","metadata":{"id":"hX4kYizs6ghj"},"source":["import os\n","import csv \n","import argparse\n","import numpy as np\n","import pandas as pd\n","import random as rd\n","from time import time\n","from pathlib import Path\n","import scipy.sparse as sp\n","from datetime import datetime\n","\n","import torch\n","from torch import nn\n","import torch.nn.functional as F\n","\n","from recochef.preprocessing.split import chrono_split"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"qHeoRZag6iWP","executionInfo":{"status":"ok","timestamp":1633243793602,"user_tz":-330,"elapsed":1239,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"ed596e48-a3d1-4914-93a4-d5c726fdd3f6"},"source":["!wget -q --show-progress http://files.grouplens.org/datasets/movielens/ml-100k.zip\n","!unzip ml-100k.zip"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["ml-100k.zip 100%[===================>] 4.70M 16.6MB/s in 0.3s \n","Archive: ml-100k.zip\n"," creating: ml-100k/\n"," inflating: ml-100k/allbut.pl \n"," inflating: ml-100k/mku.sh \n"," inflating: ml-100k/README \n"," inflating: ml-100k/u.data \n"," inflating: ml-100k/u.genre \n"," inflating: ml-100k/u.info \n"," inflating: ml-100k/u.item \n"," inflating: ml-100k/u.occupation \n"," inflating: ml-100k/u.user \n"," inflating: ml-100k/u1.base \n"," inflating: ml-100k/u1.test \n"," inflating: ml-100k/u2.base \n"," inflating: ml-100k/u2.test \n"," inflating: ml-100k/u3.base \n"," inflating: ml-100k/u3.test \n"," inflating: ml-100k/u4.base \n"," inflating: ml-100k/u4.test \n"," inflating: ml-100k/u5.base \n"," inflating: ml-100k/u5.test \n"," inflating: ml-100k/ua.base \n"," inflating: ml-100k/ua.test \n"," inflating: ml-100k/ub.base \n"," inflating: ml-100k/ub.test \n"]}]},{"cell_type":"code","metadata":{"id":"IBkgDR_z6oEh"},"source":["df = pd.read_csv('ml-100k/u.data', sep='\\t', header=None, names=['USERID','ITEMID','RATING','TIMESTAMP'])\n","df_train, df_test = chrono_split(df, ratio=0.8)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"dhfH2qbs6uZe"},"source":["def preprocess(data):\n"," data = data.copy()\n"," data = data.sort_values(by=['USERID','TIMESTAMP'])\n"," data['USERID'] = data['USERID'] - 1\n"," data['ITEMID'] = data['ITEMID'] - 1\n"," data.drop(['TIMESTAMP','RATING'], axis=1, inplace=True)\n"," data = data.groupby('USERID')['ITEMID'].apply(list).reset_index(name='ITEMID')\n"," return data"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"SurjuR956uqZ"},"source":["def store(data, target_file='./data/movielens/train.txt'):\n"," Path(target_file).parent.mkdir(parents=True, exist_ok=True)\n"," with open(target_file, 'w+') as f:\n"," writer = csv.writer(f, delimiter=' ')\n"," for USERID, row in zip(data.USERID.values,data.ITEMID.values):\n"," row = [USERID] + row\n"," writer.writerow(row)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Ad9JZ_Wa65Hl"},"source":["store(preprocess(df_train), '/content/data/ml-100k/train.txt')\n","store(preprocess(df_test), '/content/data/ml-100k/test.txt')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"nDESlr897C9v","executionInfo":{"status":"ok","timestamp":1633243908823,"user_tz":-330,"elapsed":515,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"c7acf4cd-8ff3-4c06-a766-52c4aa8121c1"},"source":["!head /content/data/ml-100k/train.txt"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["0 167 171 164 155 165 195 186 13 249 126 180 116 108 0 245 256 247 49 248 252 261 92 223 123 18 122 136 145 6 234 14 244 259 23 263 125 236 12 24 120 250 235 239 117 129 64 189 46 30 27 113 38 51 237 198 182 10 68 160 94 59 82 178 21 97 63 134 162 25 201 88 7 213 181 47 98 159 174 191 179 127 142 184 67 54 203 55 95 80 78 150 211 22 69 83 93 196 190 183 133 206 144 187 185 96 84 35 143 158 16 173 251 104 147 107 146 219 105 242 121 106 103 246 119 44 267 266 258 260 262 9 149 233 91 70 41 175 90 192 216 176 215 193 72 58 132 40 194 217 169 212 156 222 26 226 79 230 66 118 199 3 214 163 1 205 76 52 135 45 39 152 268 253 114 172 210 228 154 202 61 89 218 166 229 34 161 60 264 111 56 48 29 232 130 151 81 140 71 32 157 197 224 112 20 148 87 100 109 102 238 33 28 42 131 209 204 115 124\r\n","1 285 257 304 306 287 311 300 305 291 302 268 298 314 295 0 18 296 292 274 256 294 276 286 254 297 289 279 273 275 272 290 277 293 24 278 13 110 9 281 12 236 283 99 126 312 284 301 282 250 310\r\n","2 301 332 343 299 267 336 302 344 353 257 287 318 340 351 271 349 352 333 342 338 341 335 298 325 293 306 331 270 244 354 323 348 322 321 334 263 324 337 329 350 346 339 328\r\n","3 257 287 299 327 270 358 361 302 326 328 359 360 353 300 323 209 355 356 49\r\n","4 266 454 221 120 404 362 256 249 24 20 99 108 368 234 411 406 410 104 367 224 150 0 180 49 405 423 412 78 396 372 230 398 228 225 175 449 182 434 88 1 227 229 226 448 209 430 173 171 143 402 397 390 384 16 371 385 392 395 166 366 89 400 389 41 152 185 455 69 383 109 79 380 363 208 450 381 427 382 429 210 432 238 172 207 203 413 167 153 421 431 422 418 142 416 414 373 28 433 364 365 379 391 386 428 424 213 134 61 374 97 447 184 233 435 199 442 444 446 218 443 378 369 440 144 445 401 240 370 215 65 420 426 377 94 419 101 415 417 98 403\r\n","5 285 241 301 268 305 257 339 302 303 320 309 258 267 308 537 260 181 247 407 274 6 296 126 275 99 8 458 123 13 514 14 136 292 533 535 12 116 284 220 474 0 256 110 476 245 507 470 150 297 283 124 409 532 236 457 293 471 459 534 404 531 472 20 475 300 307 63 523 7 513 97 426 164 222 134 78 530 509 177 486 176 135 88 49 526 204 512 480 461 186 191 46 168 173 519 317 483 488 497 142 70 11 478 190 496 479 491 503 511 495 210 468 524 196 481 198 529 55 536 499 473 68 520 203 506 31 179 182 518 489 188 22 193 463 494 510 460 184 165 174 487 485 69 132 528 482 215 521 522 434 192 462 431 492 237 493 58 208 130 21 94 502 527 86 490 316 467 505 155\r\n","6 268 677 681 258 680 306 265 285 267 682 299 287 263 679 308 63 173 186 602 514 175 179 85 366 264 227 522 434 617 185 418 215 446 529 177 642 31 650 171 649 181 100 473 172 615 428 97 233 487 49 92 525 196 88 99 481 495 513 21 611 203 610 652 658 96 143 483 190 402 656 131 170 180 633 7 494 222 95 22 645 654 635 55 8 490 195 435 81 200 498 655 632 167 422 603 134 272 430 67 204 384 165 614 660 510 182 214 155 607 647 643 197 212 670 68 126 189 43 595 355 542 236 526 512 3 284 135 163 497 237 151 484 592 193 482 612 91 478 491 191 317 392 156 381 583 479 509 420 496 202 429 486 590 6 426 662 152 207 501 567 587 631 78 460 178 629 504 480 27 506 130 228 213 503 433 657 10 588 24 646 160 613 549 469 628 206 210 98 69 626 601 194 528 80 555 673 527 651 70 26 46 547 608 150 536 187 508 216 667 618 274 518 431 627 606 634 9 470 176 120 401 605 209 403 674 201 50 630 89 621 377 211 659 415 454 609 548 153 139 520 678 464 124 462 132 419 125 648 442 543 669 404 604 76 378 229 471 565 117 500 162 161 545 140 580 488 199 636 639 505 644 38 225 591 638 280 383 546 600 596 672 364 231 594 597 51 28 572 447 451 598 90 105 623 619 450 570 586 502 663 71 240 379 561 141 577 599 388 593 77 622 440 400 395 443 571 398 79 414 664 563 676\r\n","7 257 293 300 258 335 259 687 242 357 456 340 686 337 688 650 171 186 126 49 384 88 21 55 189 181 180 95 173 510 509 567 176 10 434 175 182 402 143 54 227 78 272 209 6 194 228 685\r\n","8 339 241 478 520 401 506 614 526 689 275 293 6 370 49 384 5 297 200\r\n","9 301 285 268 288 318 244 333 332 653 526 429 55 512 662 63 31 173 126 492 193 152 557 58 185 517 701 610 628 417 473 384 692 602 706 155 504 655 587 485 663 22 11 403 530 685 370 81 222 123 474 49 708 190 272 609 487 220 705 174 274 696 10 177 21 710 700 167 204 650 184 605 181 233 15 510 697 0 479 196 460 159 115 477 134 178 156 8 508 495 197 601 47 194 210 651 175 3 98 133 68 136 284 356 154 462 97 434 501 496 217 691 199 481 482 215 179 273 163 169 497 69 446 461 99 469 275 132 466 654 483 588 191 478 413 128 202 704 12 198 703 160 694 518 702 656 520 603\r\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"MmPQCcpe7EYn","executionInfo":{"status":"ok","timestamp":1633244429809,"user_tz":-330,"elapsed":3355,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"f7b1054f-064b-402b-c3c1-ad77f9328145"},"source":["import pandas as pd\n","\n","user_history_dict = dict()\n","train_data = []\n","item_corpus = []\n","corpus_index = dict()\n","\n","max_hist = 0\n","\n","with open(\"/content/data/ml-100k/train.txt\", \"r\") as fid:\n"," for line in fid:\n"," splits = line.strip().split()\n"," user_id = splits[0]\n"," items = splits[1:]\n"," if len(items)>max_hist: max_hist = len(items)\n"," user_history_dict[user_id] = items\n"," for item in items:\n"," if item not in corpus_index:\n"," corpus_index[item] = len(corpus_index)\n"," item_corpus.append([corpus_index[item], item])\n"," history = user_history_dict[user_id].copy()\n"," history.remove(item)\n"," train_data.append([user_id, corpus_index[item], 1, user_id, \"^\".join(history)])\n","train = pd.DataFrame(train_data, columns=[\"query_index\", \"corpus_index\", \"label\", \"user_id\", \"user_history\"])\n","print(\"train samples:\", len(train))\n","train.to_csv(\"train.csv\", index=False)\n","\n","test_data = []\n","with open(\"/content/data/ml-100k/test.txt\", \"r\") as fid:\n"," for line in fid:\n"," splits = line.strip().split()\n"," user_id = splits[0]\n"," items = splits[1:]\n"," for item in items:\n"," if item not in corpus_index:\n"," corpus_index[item] = len(corpus_index)\n"," item_corpus.append([corpus_index[item], item])\n"," history = user_history_dict[user_id].copy()\n"," test_data.append([user_id, corpus_index[item], 1, user_id, \"^\".join(history)])\n","test = pd.DataFrame(test_data, columns=[\"query_index\", \"corpus_index\", \"label\", \"user_id\", \"user_history\"])\n","print(\"test samples:\", len(test))\n","test.to_csv(\"test.csv\", index=False)\n","\n","corpus = pd.DataFrame(item_corpus, columns=[\"corpus_index\", \"item_id\"])\n","print(\"number of items:\", len(item_corpus))\n","corpus = corpus.set_index(\"corpus_index\")\n","corpus.to_csv(\"item_corpus.csv\", index=False)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["train samples: 80000\n","test samples: 20000\n","number of items: 1682\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"h3QtzxFv9Dgf","executionInfo":{"status":"ok","timestamp":1633244434859,"user_tz":-330,"elapsed":453,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"5d7b66e6-dad9-404d-9713-33da33166da3"},"source":["max_hist"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["590"]},"metadata":{},"execution_count":36}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":204},"id":"KNeeWl-i7cGn","executionInfo":{"status":"ok","timestamp":1633244049646,"user_tz":-330,"elapsed":11,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"8cad85ff-c5a5-4bd6-c908-561ef9c8f0b6"},"source":["train.head()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
query_indexcorpus_indexlabeluser_iduser_history
00010171^164^155^165^195^186^13^249^126^180^116^108...
10110167^164^155^165^195^186^13^249^126^180^116^108...
20210167^171^155^165^195^186^13^249^126^180^116^108...
30310167^171^164^165^195^186^13^249^126^180^116^108...
40410167^171^164^155^195^186^13^249^126^180^116^108...
\n","
"],"text/plain":[" query_index ... user_history\n","0 0 ... 171^164^155^165^195^186^13^249^126^180^116^108...\n","1 0 ... 167^164^155^165^195^186^13^249^126^180^116^108...\n","2 0 ... 167^171^155^165^195^186^13^249^126^180^116^108...\n","3 0 ... 167^171^164^165^195^186^13^249^126^180^116^108...\n","4 0 ... 167^171^164^155^195^186^13^249^126^180^116^108...\n","\n","[5 rows x 5 columns]"]},"metadata":{},"execution_count":12}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":235},"id":"fApOv9qr7mrD","executionInfo":{"status":"ok","timestamp":1633244062078,"user_tz":-330,"elapsed":450,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"a5fd1efe-60a4-47e8-fed1-6758ed9fa2d4"},"source":["corpus.head()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
item_id
corpus_index
0167
1171
2164
3155
4165
\n","
"],"text/plain":[" item_id\n","corpus_index \n","0 167\n","1 171\n","2 164\n","3 155\n","4 165"]},"metadata":{},"execution_count":13}]},{"cell_type":"code","metadata":{"id":"FTxBApzp7plM"},"source":[""],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"9wplOS7TrcdP"},"source":["import gc\n","import glob\n","import h5py\n","import hashlib\n","import heapq\n","import itertools\n","import json\n","import logging\n","import logging.config\n","import multiprocessing as mp\n","import numpy as np\n","import os\n","import pickle\n","import random\n","import shutil\n","import subprocess\n","import sys\n","import time\n","import yaml\n","from tqdm import tqdm\n","from collections import OrderedDict, defaultdict, Counter\n","from concurrent.futures import ProcessPoolExecutor, as_completed\n","import sklearn.preprocessing as sklearn_preprocess\n","\n","from tensorflow.keras.preprocessing.sequence import pad_sequences\n","\n","import torch\n","from torch import nn\n","from torch.utils.data import Dataset, DataLoader\n","from torch.utils.data.dataloader import default_collate\n","import torch.nn.functional as F"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Wt5GCtTQ8Fnr"},"source":["!mkdir checkpoints"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"7QV05GvFusUc"},"source":["class Args:\n"," dataset = 'ml100k'\n"," data_root = '/content'\n"," data_format = 'csv'\n"," train_data = '/content/train.csv'\n"," valid_data = '/content/test.csv'\n"," item_corpus = '/content/item_corpus.csv'\n"," nrows = None\n"," data_block_size = -1\n"," min_categr_count = 1\n"," query_index = 'query_index'\n"," corpus_index = 'corpus_index'\n"," feature_cols = [\n"," {'name': 'query_index', 'active': True, 'dtype': int, 'type': 'index'},\n"," {'name': 'corpus_index', 'active': True, 'dtype': int, 'type': 'index'},\n"," {'name': 'user_id', 'active': True, 'dtype': str, 'type': 'categorical', 'source': 'user'},\n"," {'name': 'user_history', 'active': True, 'dtype': str, 'type': 'sequence', 'source': 'user', 'splitter': '^',\n"," 'max_len': 100, 'padding': 'pre', 'embedding_callback': None},\n"," {'name': 'item_id', 'active': True, 'dtype': str, 'type': 'categorical', 'source': 'item'},\n"," ]\n"," label_col = {'name': 'label', 'dtype': float}\n","\n"," model_root = 'checkpoints/'\n"," num_workers = 2\n"," verbose = 1\n"," patience = 3\n"," save_best_only = True\n"," eval_interval_epochs = 1\n"," debug_mode = False\n"," model = 'SimpleX'\n"," dataset_id = ''\n"," version = 'pytorch'\n"," metrics = ['Recall(k=20)', 'Recall(k=50)', 'NDCG(k=20)', 'NDCG(k=50)', 'HitRate(k=20)', 'HitRate(k=50)']\n"," optimizer = 'adam'\n"," learning_rate = 1.0e-3\n"," batch_size = 256\n"," num_negs = 20\n"," embedding_dim = 64\n"," aggregator = 'mean'\n"," gamma = 0.5\n"," user_id_field = 'user_id'\n"," item_id_field = 'item_id'\n"," user_history_field = 'user_history'\n"," embedding_regularizer = 0\n"," net_regularizer = 0\n"," net_dropout = 0\n"," attention_dropout = 0\n"," enable_bias = False\n"," similarity_score = 'dot'\n"," # loss = 'PairwiseLogisticLoss'\n"," loss = 'CosineContrastiveLoss'\n"," margin = 0\n"," negative_weight = None\n"," sampling_num_process = 1\n"," fix_sampling_seeds = True\n"," ignore_pos_items = False\n"," epochs = 100\n"," shuffle = True\n"," seed = 2019\n"," monitor = 'Recall(k=20)'\n"," monitor_mode = 'max'\n","\n","\n","args = Args()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"mG_wVwEAtBla","cellView":"form"},"source":["#@markdown utils\n","def load_config(config_dir, experiment_id):\n"," params = dict()\n"," model_configs = glob.glob(os.path.join(config_dir, 'model_config.yaml'))\n"," if not model_configs:\n"," model_configs = glob.glob(os.path.join(config_dir, 'model_config/*.yaml'))\n"," if not model_configs:\n"," raise RuntimeError('config_dir={} is not valid!'.format(config_dir))\n"," found_params = dict()\n"," for config in model_configs:\n"," with open(config, 'r') as cfg:\n"," config_dict = yaml.load(cfg)\n"," if 'Base' in config_dict:\n"," found_params['Base'] = config_dict['Base']\n"," if experiment_id in config_dict:\n"," found_params[experiment_id] = config_dict[experiment_id]\n"," if len(found_params) == 2:\n"," break\n"," # Update base setting first so that values can be overrided when conflict \n"," # with experiment_id settings\n"," params.update(found_params.get('Base', {}))\n"," params.update(found_params.get(experiment_id))\n"," if 'dataset_id' not in params:\n"," raise RuntimeError('experiment_id={} is not valid in config.'.format(experiment_id))\n"," params['model_id'] = experiment_id\n"," dataset_id = params['dataset_id']\n"," dataset_configs = glob.glob(os.path.join(config_dir, 'dataset_config.yaml'))\n"," if not dataset_configs:\n"," dataset_configs = glob.glob(os.path.join(config_dir, 'dataset_config/*.yaml'))\n"," for config in dataset_configs:\n"," with open(config, 'r') as cfg:\n"," config_dict = yaml.load(cfg)\n"," if dataset_id in config_dict:\n"," params.update(config_dict[dataset_id])\n"," break\n"," return params\n","\n","\n","def set_logger(params):\n"," dataset_id = params['dataset_id']\n"," model_id = params['model_id']\n"," log_dir = os.path.join(params['model_root'], dataset_id)\n"," if not os.path.exists(log_dir):\n"," os.makedirs(log_dir)\n"," log_file = os.path.join(log_dir, model_id + '.log')\n","\n"," # logs will not show in the file without the two lines.\n"," for handler in logging.root.handlers[:]: \n"," logging.root.removeHandler(handler)\n"," \n"," logging.basicConfig(level=logging.INFO,\n"," format='%(asctime)s P%(process)d %(levelname)s %(message)s',\n"," handlers=[logging.FileHandler(log_file, mode='w'),\n"," logging.StreamHandler()])\n","\n","\n","def print_to_json(data, sort_keys=True):\n"," new_data = dict((k, str(v)) for k, v in data.items())\n"," if sort_keys:\n"," new_data = OrderedDict(sorted(new_data.items(), key=lambda x: x[0]))\n"," return json.dumps(new_data, indent=4)\n","\n","\n","def print_to_list(data):\n"," return ' - '.join('{}: {:.6f}'.format(k, v) for k, v in data.items())\n","\n","\n","class Monitor(object):\n"," def __init__(self, kv):\n"," if isinstance(kv, str):\n"," kv = {kv: 1}\n"," self.kv_pairs = kv\n","\n"," def get_value(self, logs):\n"," value = 0\n"," for k, v in self.kv_pairs.items():\n"," value += logs[k] * v\n"," return value\n","\n","\n","def seed_everything(seed=1029):\n"," random.seed(seed)\n"," os.environ[\"PYTHONHASHSEED\"] = str(seed)\n"," np.random.seed(seed)\n"," torch.manual_seed(seed)\n"," torch.cuda.manual_seed(seed)\n"," torch.backends.cudnn.deterministic = True\n","\n","def set_device(gpu=-1):\n"," if gpu >= 0 and torch.cuda.is_available():\n"," device = torch.device(\"cuda: \" + str(gpu))\n"," else:\n"," device = torch.device(\"cpu\") \n"," return device\n","\n","def set_optimizer(optimizer):\n"," if isinstance(optimizer, str):\n"," if optimizer.lower() == \"adam\":\n"," optimizer = \"Adam\"\n"," elif optimizer.lower() == \"rmsprop\":\n"," optimizer = \"RMSprop\"\n"," elif optimizer.lower() == \"sgd\":\n"," optimizer = \"SGD\"\n"," return getattr(torch.optim, optimizer)\n","\n","def set_loss(loss):\n"," if isinstance(loss, str):\n"," if loss in [\"bce\", \"binary_crossentropy\", \"binary_cross_entropy\"]:\n"," loss = \"binary_cross_entropy\"\n"," else:\n"," raise NotImplementedError(\"loss={} is not supported.\".format(loss))\n"," return loss\n","\n","def set_regularizer(reg):\n"," reg_pair = [] # of tuples (p_norm, weight)\n"," if isinstance(reg, float):\n"," reg_pair.append((2, reg))\n"," elif isinstance(reg, str):\n"," try:\n"," if reg.startswith(\"l1(\") or reg.startswith(\"l2(\"):\n"," reg_pair.append((int(reg[1]), float(reg.rstrip(\")\").split(\"(\")[-1])))\n"," elif reg.startswith(\"l1_l2\"):\n"," l1_reg, l2_reg = reg.rstrip(\")\").split(\"(\")[-1].split(\",\")\n"," reg_pair.append((1, float(l1_reg)))\n"," reg_pair.append((2, float(l2_reg)))\n"," else:\n"," raise NotImplementedError\n"," except:\n"," raise NotImplementedError(\"regularizer={} is not supported.\".format(reg))\n"," return reg_pair\n","\n","def set_activation(activation):\n"," if isinstance(activation, str):\n"," if activation.lower() == \"relu\":\n"," return nn.ReLU()\n"," elif activation.lower() == \"sigmoid\":\n"," return nn.Sigmoid()\n"," elif activation.lower() == \"tanh\":\n"," return nn.Tanh()\n"," else:\n"," return getattr(nn, activation)()\n"," else:\n"," return activation\n","\n","def pad_sequences(sequences, maxlen=None, dtype='int32',\n"," padding='pre', truncating='pre', value=0.):\n"," \"\"\" Pads sequences (list of list) to the ndarray of same length \n"," This is an equivalent implementation of tf.keras.preprocessing.sequence.pad_sequences\n"," for Pytorch\n"," \"\"\"\n","\n"," assert padding in [\"pre\", \"post\"], \"Invalid padding={}.\".format(padding)\n"," assert truncating in [\"pre\", \"post\"], \"Invalid truncating={}.\".format(truncating)\n"," \n"," if maxlen is None:\n"," maxlen = max(len(x) for x in sequences)\n"," arr = np.full((len(sequences), maxlen), value, dtype=dtype)\n"," for idx, x in enumerate(sequences):\n"," if len(x) == 0:\n"," continue # empty list\n"," if truncating == 'pre':\n"," trunc = x[-maxlen:]\n"," else:\n"," trunc = x[:maxlen]\n"," trunc = np.asarray(trunc, dtype=dtype)\n","\n"," if padding == 'pre':\n"," arr[idx, -len(trunc):] = trunc\n"," else:\n"," arr[idx, :len(trunc)] = trunc\n"," return arr\n","\n","\n","def save_h5(darray_dict, data_path):\n"," logging.info(\"Saving data to h5: \" + data_path)\n"," if not os.path.exists(os.path.dirname(data_path)):\n"," try:\n"," os.makedirs(os.path.dirname(data_path))\n"," except:\n"," pass\n"," with h5py.File(data_path, 'w') as hf:\n"," hf.attrs[\"num_samples\"] = len(list(darray_dict.values())[0])\n"," for key, arr in darray_dict.items():\n"," hf.create_dataset(key, data=arr)\n","\n","\n","def load_h5(data_path, verbose=True):\n"," if verbose:\n"," logging.info('Loading data from h5: ' + data_path)\n"," data_dict = dict()\n"," with h5py.File(data_path, 'r') as hf:\n"," num_samples = hf.attrs[\"num_samples\"]\n"," for key in hf.keys():\n"," data_dict[key] = hf[key][:]\n"," return data_dict, num_samples\n","\n","\n","def split_train_test(train_ddf=None, valid_ddf=None, test_ddf=None, valid_size=0, \n"," test_size=0, split_type=\"sequential\"):\n"," num_samples = len(train_ddf)\n"," train_size = num_samples\n"," instance_IDs = np.arange(num_samples)\n"," if split_type == \"random\":\n"," np.random.shuffle(instance_IDs)\n"," if test_size > 0:\n"," if test_size < 1:\n"," test_size = int(num_samples * test_size)\n"," train_size = train_size - test_size\n"," test_ddf = train_ddf.loc[instance_IDs[train_size:], :].reset_index()\n"," instance_IDs = instance_IDs[0:train_size]\n"," if valid_size > 0:\n"," if valid_size < 1:\n"," valid_size = int(num_samples * valid_size)\n"," train_size = train_size - valid_size\n"," valid_ddf = train_ddf.loc[instance_IDs[train_size:], :].reset_index()\n"," instance_IDs = instance_IDs[0:train_size]\n"," if valid_size > 0 or test_size > 0:\n"," train_ddf = train_ddf.loc[instance_IDs, :].reset_index()\n"," return train_ddf, valid_ddf, test_ddf\n","\n","\n","def transform_h5(feature_encoder, ddf, filename, preprocess=False, block_size=0):\n"," def _transform_block(feature_encoder, df_block, filename, preprocess):\n"," if preprocess:\n"," df_block = feature_encoder.preprocess(df_block)\n"," darray_dict = feature_encoder.transform(df_block)\n"," save_h5(darray_dict, os.path.join(feature_encoder.data_dir, filename))\n","\n"," if block_size > 0:\n"," pool = mp.Pool(mp.cpu_count() // 2)\n"," block_id = 0\n"," for idx in range(0, len(ddf), block_size):\n"," df_block = ddf[idx: (idx + block_size)]\n"," pool.apply_async(_transform_block, args=(feature_encoder, \n"," df_block, \n"," filename.replace('.h5', '_part_{}.h5'.format(block_id)),\n"," preprocess))\n"," block_id += 1\n"," pool.close()\n"," pool.join()\n"," else:\n"," _transform_block(feature_encoder, ddf, filename, preprocess)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"pZjzdjby6rSA"},"source":["#@title\n","class Tokenizer(object):\n"," def __init__(self, topk_words=None, na_value=None, min_freq=1, splitter=None, \n"," lower=False, oov_token=0, max_len=0, padding=\"pre\"):\n"," self._topk_words = topk_words\n"," self._na_value = na_value\n"," self._min_freq = min_freq\n"," self._lower = lower\n"," self._splitter = splitter\n"," self.oov_token = oov_token # use 0 for __OOV__\n"," self.vocab = dict()\n"," self.vocab_size = 0 # include oov and padding\n"," self.max_len = max_len\n"," self.padding = padding\n"," self.use_padding = None\n","\n"," def fit(self, texts, use_padding=False):\n"," self.use_padding = use_padding\n"," word_counts = Counter()\n"," if self._splitter is not None: # for sequence\n"," max_len = 0\n"," for text in texts:\n"," if not pd.isnull(text):\n"," text_split = text.split(self._splitter)\n"," max_len = max(max_len, len(text_split))\n"," for text in text_split:\n"," word_counts[text] += 1\n"," if self.max_len == 0:\n"," self.max_len = max_len # use pre-set max_len otherwise\n"," else:\n"," tokens = list(texts)\n"," word_counts = Counter(tokens)\n"," self.build_vocab(word_counts)\n","\n"," def build_vocab(self, word_counts):\n"," # sort to guarantee the determinism of index order\n"," word_counts = sorted(word_counts.items(), key=lambda x: (-x[1], x[0]))\n"," words = []\n"," for token, count in word_counts:\n"," if count >= self._min_freq:\n"," if self._na_value is None or token != self._na_value:\n"," words.append(token.lower() if self._lower else token)\n"," if self._topk_words:\n"," words = words[0:self._topk_words]\n"," self.vocab = dict((token, idx) for idx, token in enumerate(words, 1 + self.oov_token))\n"," self.vocab[\"__OOV__\"] = self.oov_token\n"," if self.use_padding:\n"," self.vocab[\"__PAD__\"] = len(words) + self.oov_token + 1 # use the last index for __PAD__\n"," self.vocab_size = len(self.vocab) + self.oov_token\n","\n"," def encode_category(self, categories):\n"," category_indices = [self.vocab.get(x, self.oov_token) for x in categories]\n"," return np.array(category_indices)\n","\n"," def encode_sequence(self, texts):\n"," sequence_list = []\n"," for text in texts:\n"," if pd.isnull(text) or text == '':\n"," sequence_list.append([])\n"," else:\n"," sequence_list.append([self.vocab.get(x, self.oov_token) for x in text.split(self._splitter)])\n"," sequence_list = pad_sequences(sequence_list, maxlen=self.max_len, value=self.vocab_size - 1,\n"," padding=self.padding, truncating=self.padding)\n"," return np.array(sequence_list)\n"," \n"," def load_pretrained_embedding(self, feature_name, key_dtype, pretrain_path, embedding_dim, output_path):\n"," with h5py.File(pretrain_path, 'r') as hf:\n"," keys = hf[\"key\"][:]\n"," if issubclass(keys.dtype.type, key_dtype): # in case mismatch between int and str\n"," keys = keys.astype(key_dtype)\n"," pretrained_vocab = dict(zip(keys, range(len(keys))))\n"," pretrained_emb = hf[\"value\"][:]\n"," # update vocab with pretrained keys, in case new token ids appear in validation or test set\n"," num_new_words = 0\n"," for word in pretrained_vocab.keys():\n"," if word not in self.vocab:\n"," self.vocab[word] = self.vocab.get(\"__PAD__\", self.vocab_size) + num_new_words\n"," num_new_words += 1\n"," self.vocab_size += num_new_words\n"," embedding_matrix = np.random.normal(loc=0, scale=1.e-4, size=(self.vocab_size, embedding_dim))\n"," if \"__PAD__\" in self.vocab:\n"," self.vocab[\"__PAD__\"] = self.vocab_size - 1\n"," embedding_matrix[-1, :] = 0 # set as zero vector for PAD\n"," for word in pretrained_vocab.keys():\n"," embedding_matrix[self.vocab[word]] = pretrained_emb[pretrained_vocab[word]]\n"," os.makedirs(os.path.dirname(output_path), exist_ok=True)\n"," with h5py.File(output_path, 'a') as hf:\n"," hf.create_dataset(feature_name, data=embedding_matrix)\n","\n"," def load_vocab_from_file(self, vocab_file):\n"," with open(vocab_file, 'r') as fid:\n"," word_counts = json.load(fid)\n"," self.build_vocab(word_counts)\n","\n"," def set_vocab(self, vocab):\n"," self.vocab = vocab\n"," self.vocab_size = len(self.vocab) + self.oov_token\n"," \n"," \n","class Normalizer(object):\n"," def __init__(self, normalizer_name):\n"," if normalizer_name in ['StandardScaler', 'MinMaxScaler']:\n"," self.normalizer = getattr(sklearn_preprocess, normalizer_name)()\n"," else:\n"," raise NotImplementedError('normalizer={}'.format(normalizer_name))\n","\n"," def fit(self, X):\n"," null_index = np.isnan(X)\n"," self.normalizer.fit(X[~null_index].reshape(-1, 1))\n","\n"," def transform(self, X):\n"," return self.normalizer.transform(X.reshape(-1, 1)).flatten()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"JP7NWKskr92V"},"source":["class FeatureMap(object):\n"," def __init__(self, dataset_id, data_dir, query_index, corpus_index, label_name, version=\"pytorch\"):\n"," self.data_dir = data_dir\n"," self.dataset_id = dataset_id\n"," self.version = version\n"," self.num_fields = 0\n"," self.num_features = 0\n"," self.num_items = 0\n"," self.query_index = query_index\n"," self.corpus_index = corpus_index\n"," self.label_name = label_name\n"," self.feature_specs = OrderedDict()\n","\n"," def load(self, json_file):\n"," logging.info(\"Load feature_map from json: \" + json_file)\n"," with open(json_file, \"r\", encoding=\"utf-8\") as fd:\n"," feature_map = json.load(fd, object_pairs_hook=OrderedDict)\n"," if feature_map[\"dataset_id\"] != self.dataset_id:\n"," raise RuntimeError(\"dataset_id={} does not match to feature_map!\".format(self.dataset_id))\n"," self.num_fields = feature_map[\"num_fields\"]\n"," self.num_features = feature_map.get(\"num_features\", None)\n"," self.label_name = feature_map.get(\"label_name\", None)\n"," self.feature_specs = OrderedDict(feature_map[\"feature_specs\"])\n","\n"," def save(self, json_file):\n"," logging.info(\"Save feature_map to json: \" + json_file)\n"," os.makedirs(os.path.dirname(json_file), exist_ok=True)\n"," feature_map = OrderedDict()\n"," feature_map[\"dataset_id\"] = self.dataset_id\n"," feature_map[\"num_fields\"] = self.num_fields\n"," feature_map[\"num_features\"] = self.num_features\n"," feature_map[\"num_items\"] = self.num_items\n"," feature_map[\"query_index\"] = self.query_index\n"," feature_map[\"corpus_index\"] = self.corpus_index\n"," feature_map[\"label_name\"] = self.label_name\n"," feature_map[\"feature_specs\"] = self.feature_specs\n"," with open(json_file, \"w\", encoding=\"utf-8\") as fd:\n"," json.dump(feature_map, fd, indent=4)\n","\n"," def get_num_fields(self, feature_source=[]):\n"," if type(feature_source) != list:\n"," feature_source = [feature_source]\n"," num_fields = 0\n"," for feature, feature_spec in self.feature_specs.items():\n"," if not feature_source or feature_spec[\"source\"] in feature_source:\n"," num_fields += 1\n"," return num_fields"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"TQEDFSqrsj1w"},"source":["class FeatureEncoder(object):\n"," def __init__(self,\n"," feature_cols=[], \n"," label_col={}, \n"," dataset_id=None, \n"," data_root=\"../data/\", \n"," version=\"pytorch\", \n"," **kwargs):\n"," logging.info(\"Set up feature encoder...\")\n"," self.data_dir = os.path.join(data_root, dataset_id)\n"," self.pickle_file = os.path.join(self.data_dir, \"feature_encoder.pkl\")\n"," self.json_file = os.path.join(self.data_dir, \"feature_map.json\")\n"," self.feature_cols = self._complete_feature_cols(feature_cols)\n"," self.label_col = label_col\n"," self.version = version\n"," self.feature_map = FeatureMap(dataset_id, self.data_dir, kwargs[\"query_index\"], \n"," kwargs[\"corpus_index\"], self.label_col[\"name\"], version)\n"," self.dtype_dict = dict((feat[\"name\"], eval(feat[\"dtype\"]) if type(feat[\"dtype\"]) == str else feat[\"dtype\"]) \n"," for feat in self.feature_cols + [self.label_col])\n"," self.encoders = dict()\n","\n"," def _complete_feature_cols(self, feature_cols):\n"," full_feature_cols = []\n"," for col in feature_cols:\n"," name_or_namelist = col[\"name\"]\n"," if isinstance(name_or_namelist, list):\n"," for _name in name_or_namelist:\n"," _col = col.copy()\n"," _col[\"name\"] = _name\n"," full_feature_cols.append(_col)\n"," else:\n"," full_feature_cols.append(col)\n"," return full_feature_cols\n","\n"," def read_csv(self, data_path, sep=\",\", nrows=None, **kwargs):\n"," if data_path is not None:\n"," logging.info(\"Reading file: \" + data_path)\n"," usecols_fn = lambda x: x in self.dtype_dict\n"," ddf = pd.read_csv(data_path, sep=sep, usecols=usecols_fn, \n"," dtype=object, memory_map=True, nrows=nrows)\n"," return ddf\n"," else:\n"," return None\n","\n"," def preprocess(self, ddf):\n"," logging.info(\"Preprocess feature columns...\")\n"," if self.feature_map.query_index in ddf.columns: # for train/val/test ddf\n"," all_cols = [self.label_col] + [col for col in self.feature_cols[::-1] if col.get(\"source\") != \"item\"]\n"," else: # for item_corpus ddf\n"," all_cols = [col for col in self.feature_cols[::-1] if col.get(\"source\") == \"item\"]\n"," for col in all_cols:\n"," name = col[\"name\"]\n"," if name in ddf.columns and ddf[name].isnull().values.any():\n"," ddf[name] = self._fill_na_(col, ddf[name])\n"," if \"preprocess\" in col and col[\"preprocess\"] != \"\":\n"," preprocess_fn = getattr(self, col[\"preprocess\"])\n"," ddf[name] = preprocess_fn(ddf, name)\n"," ddf[name] = ddf[name].astype(self.dtype_dict[name])\n"," active_cols = [col[\"name\"] for col in all_cols if col.get(\"active\") != False]\n"," ddf = ddf.loc[:, active_cols]\n"," return ddf\n","\n"," def _fill_na_(self, col, series):\n"," na_value = col.get(\"na_value\")\n"," if na_value is not None:\n"," return series.fillna(na_value)\n"," elif col[\"dtype\"] in [\"str\", str]:\n"," return series.fillna(\"\")\n"," else:\n"," raise RuntimeError(\"Feature column={} requires to assign na_value!\".format(col[\"name\"]))\n","\n"," def fit(self, train_ddf, corpus_ddf, min_categr_count=1, num_buckets=10, **kwargs): \n"," logging.info(\"Fit feature encoder...\") \n"," self.feature_map.num_items = len(corpus_ddf)\n"," train_ddf = train_ddf.join(corpus_ddf, on=self.feature_map.corpus_index)\n"," for col in self.feature_cols:\n"," name = col[\"name\"]\n"," if col[\"active\"]:\n"," self.feature_map.num_fields += 1\n"," logging.info(\"Processing column: {}\".format(col))\n"," if col[\"type\"] == \"index\":\n"," self.fit_index_col(col)\n"," elif col[\"type\"] == \"numeric\":\n"," self.fit_numeric_col(col, train_ddf[name].values)\n"," elif col[\"type\"] == \"categorical\":\n"," self.fit_categorical_col(col, train_ddf[name].values, \n"," min_categr_count=min_categr_count,\n"," num_buckets=num_buckets)\n"," elif col[\"type\"] == \"sequence\":\n"," self.fit_sequence_col(col, train_ddf[name].values, \n"," min_categr_count=min_categr_count)\n"," else:\n"," raise NotImplementedError(\"feature_col={}\".format(feature_col))\n"," self.save_pickle(self.pickle_file)\n"," self.feature_map.save(self.json_file)\n"," logging.info(\"Set feature encoder done.\")\n","\n"," def fit_index_col(self, feature_col):\n"," name = feature_col[\"name\"]\n"," feature_type = feature_col[\"type\"]\n"," feature_source = feature_col.get(\"source\", \"\")\n"," self.feature_map.feature_specs[name] = {\"source\": feature_source,\n"," \"type\": feature_type} \n","\n"," def fit_numeric_col(self, feature_col, data_vector):\n"," name = feature_col[\"name\"]\n"," feature_type = feature_col[\"type\"]\n"," feature_source = feature_col.get(\"source\", \"\")\n"," self.feature_map.feature_specs[name] = {\"source\": feature_source,\n"," \"type\": feature_type}\n"," if \"embedding_callback\" in feature_col:\n"," self.feature_map.feature_specs[name][\"embedding_callback\"] = feature_col[\"embedding_callback\"]\n"," if \"normalizer\" in feature_col:\n"," normalizer = Normalizer(feature_col[\"normalizer\"])\n"," normalizer.fit(data_vector)\n"," self.encoders[name + \"_normalizer\"] = normalizer\n"," self.feature_map.num_features += 1\n"," \n"," def fit_categorical_col(self, feature_col, data_vector, min_categr_count=1, num_buckets=10):\n"," name = feature_col[\"name\"]\n"," feature_type = feature_col[\"type\"]\n"," feature_source = feature_col.get(\"source\", \"\")\n"," min_categr_count = feature_col.get(\"min_categr_count\", min_categr_count)\n"," self.feature_map.feature_specs[name] = {\"source\": feature_source,\n"," \"type\": feature_type,\n"," \"min_categr_count\": min_categr_count}\n"," if \"embedding_callback\" in feature_col:\n"," self.feature_map.feature_specs[name][\"embedding_callback\"] = feature_col[\"embedding_callback\"]\n"," if \"embedding_dim\" in feature_col:\n"," self.feature_map.feature_specs[name][\"embedding_dim\"] = feature_col[\"embedding_dim\"]\n"," if \"category_encoder\" not in feature_col:\n"," tokenizer = Tokenizer(min_freq=min_categr_count, \n"," na_value=feature_col.get(\"na_value\", \"\"))\n"," if \"share_embedding\" in feature_col:\n"," self.feature_map.feature_specs[name][\"share_embedding\"] = feature_col[\"share_embedding\"]\n"," tokenizer.set_vocab(self.encoders[\"{}_tokenizer\".format(feature_col[\"share_embedding\"])].vocab)\n"," else:\n"," if self._whether_share_emb_with_sequence(name):\n"," tokenizer.fit(data_vector, use_padding=True)\n"," if \"pretrained_emb\" not in feature_col:\n"," self.feature_map.feature_specs[name][\"padding_idx\"] = tokenizer.vocab_size - 1\n"," else:\n"," tokenizer.fit(data_vector, use_padding=False)\n"," if \"pretrained_emb\" in feature_col:\n"," logging.info(\"Loading pretrained embedding: \" + name)\n"," self.feature_map.feature_specs[name][\"pretrained_emb\"] = \"pretrained_{}.h5\".format(name)\n"," self.feature_map.feature_specs[name][\"freeze_emb\"] = feature_col.get(\"freeze_emb\", True)\n"," tokenizer.load_pretrained_embedding(name,\n"," self.dtype_dict[name],\n"," feature_col[\"pretrained_emb\"], \n"," feature_col[\"embedding_dim\"],\n"," os.path.join(self.data_dir, \"pretrained_{}.h5\".format(name)))\n"," if tokenizer.use_padding: # update to account pretrained keys\n"," self.feature_map.feature_specs[name][\"padding_idx\"] = tokenizer.vocab_size - 1\n"," self.encoders[name + \"_tokenizer\"] = tokenizer\n"," self.feature_map.feature_specs[name][\"vocab_size\"] = tokenizer.vocab_size\n"," self.feature_map.num_features += tokenizer.vocab_size\n"," else:\n"," category_encoder = feature_col[\"category_encoder\"]\n"," self.feature_map.feature_specs[name][\"category_encoder\"] = category_encoder\n"," if category_encoder == \"quantile_bucket\": # transform numeric value to bucket\n"," num_buckets = feature_col.get(\"num_buckets\", num_buckets)\n"," qtf = sklearn_preprocess.QuantileTransformer(n_quantiles=num_buckets + 1)\n"," qtf.fit(data_vector)\n"," boundaries = qtf.quantiles_[1:-1]\n"," self.feature_map.feature_specs[name][\"vocab_size\"] = num_buckets\n"," self.feature_map.num_features += num_buckets\n"," self.encoders[name + \"_boundaries\"] = boundaries\n"," elif category_encoder == \"hash_bucket\":\n"," num_buckets = feature_col.get(\"num_buckets\", num_buckets)\n"," uniques = Counter(data_vector)\n"," num_buckets = min(num_buckets, len(uniques))\n"," self.feature_map.feature_specs[name][\"vocab_size\"] = num_buckets\n"," self.encoders[name + \"_num_buckets\"] = num_buckets\n"," self.feature_map.num_features += num_buckets\n"," else:\n"," raise NotImplementedError(\"category_encoder={} not supported.\".format(category_encoder))\n","\n"," def fit_sequence_col(self, feature_col, data_vector, min_categr_count=1):\n"," name = feature_col[\"name\"]\n"," feature_type = feature_col[\"type\"]\n"," feature_source = feature_col.get(\"source\", \"\")\n"," min_categr_count = feature_col.get(\"min_categr_count\", min_categr_count)\n"," self.feature_map.feature_specs[name] = {\"source\": feature_source,\n"," \"type\": feature_type,\n"," \"min_categr_count\": min_categr_count}\n"," embedding_callback = feature_col.get(\"embedding_callback\", \"layers.MaskedAveragePooling()\")\n"," if embedding_callback not in [None, \"null\", \"None\", \"none\"]:\n"," self.feature_map.feature_specs[name][\"embedding_callback\"] = embedding_callback\n"," splitter = feature_col.get(\"splitter\", \" \")\n"," na_value = feature_col.get(\"na_value\", \"\")\n"," max_len = feature_col.get(\"max_len\", 0)\n"," padding = feature_col.get(\"padding\", \"post\") # \"post\" or \"pre\"\n"," tokenizer = Tokenizer(min_freq=min_categr_count, splitter=splitter, \n"," na_value=na_value, max_len=max_len, padding=padding)\n"," if \"share_embedding\" in feature_col:\n"," self.feature_map.feature_specs[name][\"share_embedding\"] = feature_col[\"share_embedding\"]\n"," tokenizer.set_vocab(self.encoders[\"{}_tokenizer\".format(feature_col[\"share_embedding\"])].vocab)\n"," else:\n"," tokenizer.fit(data_vector, use_padding=True)\n"," if \"pretrained_emb\" in feature_col:\n"," logging.info(\"Loading pretrained embedding: \" + name)\n"," self.feature_map.feature_specs[name][\"pretrained_emb\"] = \"pretrained_{}.h5\".format(name)\n"," self.feature_map.feature_specs[name][\"freeze_emb\"] = feature_col.get(\"freeze_emb\", True)\n"," tokenizer.load_pretrained_embedding(name,\n"," self.dtype_dict[name],\n"," feature_col[\"pretrained_emb\"], \n"," feature_col[\"embedding_dim\"],\n"," os.path.join(self.data_dir, \"pretrained_{}.h5\".format(name)))\n"," self.encoders[name + \"_tokenizer\"] = tokenizer\n"," self.feature_map.feature_specs[name].update({\"padding_idx\": tokenizer.vocab_size - 1,\n"," \"vocab_size\": tokenizer.vocab_size,\n"," \"max_len\": tokenizer.max_len})\n"," self.feature_map.num_features += tokenizer.vocab_size\n","\n"," def transform(self, ddf):\n"," logging.info(\"Transform feature columns...\")\n"," data_dict = dict()\n"," for feature, feature_spec in self.feature_map.feature_specs.items():\n"," if feature in ddf.columns:\n"," feature_type = feature_spec[\"type\"]\n"," data_vector = ddf.loc[:, feature].values\n"," if feature_type == \"index\":\n"," data_dict[feature] = data_vector\n"," elif feature_type == \"numeric\":\n"," data_vector = data_vector.astype(float)\n"," normalizer = self.encoders.get(feature + \"_normalizer\")\n"," if normalizer:\n"," data_vector = normalizer.transform(data_vector)\n"," data_dict[feature] = data_vector\n"," elif feature_type == \"categorical\":\n"," category_encoder = feature_spec.get(\"category_encoder\")\n"," if category_encoder is None:\n"," data_dict[feature] = self.encoders.get(feature + \"_tokenizer\").encode_category(data_vector)\n"," elif encoder == \"numeric_bucket\":\n"," raise NotImplementedError\n"," elif encoder == \"hash_bucket\":\n"," raise NotImplementedError\n"," elif feature_type == \"sequence\":\n"," data_dict[feature] = self.encoders.get(feature + \"_tokenizer\").encode_sequence(data_vector)\n"," label = self.label_col[\"name\"]\n"," if label in ddf.columns:\n"," data_dict[label] = ddf.loc[:, label].values.astype(float)\n"," return data_dict\n","\n"," def _whether_share_emb_with_sequence(self, feature):\n"," for col in self.feature_cols:\n"," if col.get(\"share_embedding\", None) == feature and col[\"type\"] == \"sequence\":\n"," return True\n"," return False\n","\n"," def load_pickle(self, pickle_file=None):\n"," \"\"\" Load feature encoder from cache \"\"\"\n"," if pickle_file is None:\n"," pickle_file = self.pickle_file\n"," logging.info(\"Load feature_encoder from pickle: \" + pickle_file)\n"," if os.path.exists(pickle_file):\n"," pickled_feature_encoder = pickle.load(open(pickle_file, \"rb\"))\n"," if pickled_feature_encoder.feature_map.dataset_id == self.feature_map.dataset_id:\n"," pickled_feature_encoder.version = self.version\n"," return pickled_feature_encoder\n"," raise IOError(\"pickle_file={} not valid.\".format(pickle_file))\n","\n"," def save_pickle(self, pickle_file):\n"," logging.info(\"Pickle feature_encode: \" + pickle_file)\n"," if not os.path.exists(os.path.dirname(pickle_file)):\n"," os.makedirs(os.path.dirname(pickle_file))\n"," pickle.dump(self, open(pickle_file, \"wb\"))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"CanPADhy51eJ"},"source":["feature_encoder = FeatureEncoder(**Args.__dict__)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"c64_Frz5s4df"},"source":["def build_dataset(feature_encoder, item_corpus=None, train_data=None, valid_data=None, \n"," test_data=None, valid_size=0, test_size=0, split_type=\"sequential\", **kwargs):\n"," \"\"\" Build feature_map and transform h5 data \"\"\"\n"," \n"," # Load csv data\n"," train_ddf = feature_encoder.read_csv(train_data, **kwargs)\n"," valid_ddf = None\n"," test_ddf = None\n","\n"," # Split data for train/validation/test\n"," if valid_size > 0 or test_size > 0:\n"," valid_ddf = feature_encoder.read_csv(valid_data, **kwargs)\n"," test_ddf = feature_encoder.read_csv(test_data, **kwargs)\n"," train_ddf, valid_ddf, test_ddf = split_train_test(train_ddf, valid_ddf, test_ddf, \n"," valid_size, test_size, split_type)\n","\n"," # fit feature_encoder\n"," corpus_ddf = feature_encoder.read_csv(item_corpus, **kwargs)\n"," corpus_ddf = feature_encoder.preprocess(corpus_ddf)\n"," train_ddf = feature_encoder.preprocess(train_ddf)\n"," feature_encoder.fit(train_ddf, corpus_ddf, **kwargs)\n","\n"," # transform corpus_ddf\n"," item_corpus_dict = feature_encoder.transform(corpus_ddf)\n"," save_h5(item_corpus_dict, os.path.join(feature_encoder.data_dir, 'item_corpus.h5'))\n"," del item_corpus_dict, corpus_ddf\n"," gc.collect()\n","\n"," # transform train_ddf\n"," block_size = int(kwargs.get(\"data_block_size\", 0)) # Num of samples in a data block\n"," transform_h5(feature_encoder, train_ddf, 'train.h5', preprocess=False, block_size=block_size)\n"," del train_ddf\n"," gc.collect()\n","\n"," # Transfrom valid_ddf\n"," if valid_ddf is None and (valid_data is not None):\n"," valid_ddf = feature_encoder.read_csv(valid_data, **kwargs)\n"," if valid_ddf is not None:\n"," transform_h5(feature_encoder, valid_ddf, 'valid.h5', preprocess=True, block_size=block_size)\n"," del valid_ddf\n"," gc.collect()\n","\n"," # Transfrom test_ddf\n"," if test_ddf is None and (test_data is not None):\n"," test_ddf = feature_encoder.read_csv(test_data, **kwargs)\n"," if test_ddf is not None:\n"," transform_h5(feature_encoder, test_ddf, 'test.h5', preprocess=True, block_size=block_size)\n"," del test_ddf\n"," gc.collect()\n"," logging.info(\"Transform csv data to h5 done.\")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"AOK0I1-A6Xrs"},"source":["build_dataset(feature_encoder, **Args.__dict__)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ONlQCu5R66-B"},"source":["#@title\n","class TrainDataset(Dataset):\n"," def __init__(self, feature_map, data_path, item_corpus):\n"," self.data_dict, self.num_samples = load_h5(data_path)\n"," self.item_corpus_dict, self.num_items = load_h5(item_corpus)\n"," self.labels = self.data_dict[feature_map.label_name]\n"," self.pos_item_indexes = self.data_dict[feature_map.corpus_index]\n"," self.all_item_indexes = self.data_dict[feature_map.corpus_index]\n"," \n"," def __getitem__(self, index):\n"," user_dict = self.slice_array_dict(self.data_dict, index)\n"," item_indexes = self.all_item_indexes[index, :]\n"," item_dict = self.slice_array_dict(self.item_corpus_dict, item_indexes)\n"," label = self.labels[index]\n"," return user_dict, item_dict, label, item_indexes\n"," \n"," def __len__(self):\n"," return self.num_samples\n","\n"," def slice_array_dict(self, array_dict, slice_index):\n"," return dict((k, v[slice_index]) for k, v in array_dict.items())\n","\n","\n","def get_user2items_dict(data_dict, feature_map):\n"," user2items_dict = defaultdict(list)\n"," for query_index, corpus_index in zip(data_dict[feature_map.query_index], \n"," data_dict[feature_map.corpus_index]):\n"," user2items_dict[query_index].append(corpus_index)\n"," return user2items_dict\n","\n","\n","def collate_fn_unique(batch): \n"," # TODO: check correctness\n"," user_dict, item_dict, labels, item_indexes = default_collate(batch)\n"," num_negs = item_indexes.size(1) - 1\n"," unique, inverse_indexes = torch.unique(item_indexes.flatten(), return_inverse=True, sorted=True)\n"," perm = torch.arange(inverse_indexes.size(0), dtype=inverse_indexes.dtype, device=inverse_indexes.device)\n"," inverse_indexes, perm = inverse_indexes.flip([0]), perm.flip([0])\n"," unique_indexes = inverse_indexes.new_empty(unique.size(0)).scatter_(0, inverse_indexes, perm) # obtain return_indicies in np.unique\n"," # reshape item data with (b*(num_neg + 1) x input_dim)\n"," for k, v in item_dict.items():\n"," item_dict[k] = v.flatten(end_dim=1)[unique_indexes]\n"," # add negative labels\n"," labels = torch.cat([labels.view(-1, 1).float(), torch.zeros((labels.size(0), num_negs))], dim=1)\n"," return user_dict, item_dict, labels, inverse_indexes\n","\n","\n","def collate_fn(batch):\n"," user_dict, item_dict, labels, item_indexes = default_collate(batch)\n"," num_negs = item_indexes.size(1) - 1\n"," # reshape item data with (b*(num_neg + 1) x input_dim)\n"," for k, v in item_dict.items():\n"," item_dict[k] = v.flatten(end_dim=1)\n"," # add negative labels\n"," labels = torch.cat([labels.view(-1, 1).float(), torch.zeros((labels.size(0), num_negs))], dim=1)\n"," return user_dict, item_dict, labels, None\n","\n","\n","def sampling_block(num_items, block_query_indexes, num_negs, user2items_dict, \n"," sampling_probs=None, ignore_pos_items=False, seed=None, dump_path=None):\n"," if seed is not None:\n"," np.random.seed(seed) # used in multiprocessing\n"," if sampling_probs is None:\n"," sampling_probs = np.ones(num_items) / num_items # uniform sampling\n"," if ignore_pos_items:\n"," sampled_items = []\n"," for query_index in block_query_indexes:\n"," pos_items = user2items_dict[query_index]\n"," probs = np.array(sampling_probs)\n"," probs[pos_items] = 0\n"," probs = probs / np.sum(probs) # renomalize to sum 1\n"," sampled_items.append(np.random.choice(num_items, size=num_negs, replace=True, p=probs))\n"," sampled_array = np.array(sampled_items)\n"," else:\n"," sampled_array = np.random.choice(num_items,\n"," size=(len(block_query_indexes), num_negs), \n"," replace=True)\n"," if dump_path is not None:\n"," # To fix bug in multiprocessing: https://github.com/xue-pai/Open-CF-Benchmarks/issues/1\n"," pickle_array(sampled_array, dump_path)\n"," else:\n"," return sampled_array\n","\n","\n","def pickle_array(array, path):\n"," with open(path, \"wb\") as fout:\n"," pickle.dump(array, fout, pickle.HIGHEST_PROTOCOL)\n","\n","\n","def load_pickled_array(path):\n"," with open(path, \"rb\") as fin:\n"," return pickle.load(fin)\n","\n","\n","class TrainGenerator(DataLoader):\n"," # reference https://cloud.tencent.com/developer/article/1010247\n"," def __init__(self, feature_map, data_path, item_corpus, batch_size=32, shuffle=True, \n"," num_workers=1, num_negs=0, compress_duplicate_items=False, **kwargs):\n"," if type(data_path) == list:\n"," data_path = data_path[0]\n"," self.num_blocks = 1\n"," self.num_negs = num_negs\n"," self.dataset = TrainDataset(feature_map, data_path, item_corpus)\n"," super(TrainGenerator, self).__init__(dataset=self.dataset, batch_size=batch_size,\n"," shuffle=shuffle, num_workers=num_workers,\n"," collate_fn=collate_fn_unique if compress_duplicate_items else collate_fn)\n"," self.user2items_dict = get_user2items_dict(self.dataset.data_dict, feature_map)\n"," self.query_indexes = self.dataset.data_dict[feature_map.query_index]\n"," # delete some columns to speed up batch generator\n"," del self.dataset.data_dict[feature_map.query_index]\n"," del self.dataset.data_dict[feature_map.corpus_index]\n"," del self.dataset.data_dict[feature_map.label_name]\n"," self.num_samples = len(self.dataset)\n"," self.num_batches = int(np.ceil(self.num_samples * 1.0 / batch_size))\n"," self.sampling_num_process = kwargs.get(\"sampling_num_process\", 1)\n"," self.ignore_pos_items = kwargs.get(\"ignore_pos_items\", False)\n"," self.fix_sampling_seeds = kwargs.get(\"fix_sampling_seeds\", True)\n","\n"," def __iter__(self):\n"," self.negative_sampling()\n"," iter = super(TrainGenerator, self).__iter__()\n"," while True:\n"," yield next(iter) # a batch iterator\n","\n"," def __len__(self):\n"," return self.num_batches\n","\n"," def negative_sampling(self):\n"," if self.num_negs > 0:\n"," logging.info(\"Negative sampling num_negs={}\".format(self.num_negs))\n"," sampling_probs = None # set it to item popularity when using importance sampling\n"," if self.sampling_num_process > 1:\n"," chunked_query_indexes = np.array_split(self.query_indexes, self.sampling_num_process)\n"," if self.fix_sampling_seeds:\n"," seeds = np.random.randint(1000000, size=self.sampling_num_process)\n"," else:\n"," seeds = [None] * self.sampling_num_process\n"," pool = mp.Pool(self.sampling_num_process)\n"," block_result = []\n"," os.makedirs(\"./tmp/pid_{}/\".format(os.getpid()), exist_ok=True)\n"," dump_paths = [\"./tmp/pid_{}/part_{}.pkl\".format(os.getpid(), idx) for idx in range(len(chunked_query_indexes))]\n"," for idx, block_query_indexes in enumerate(chunked_query_indexes):\n"," pool.apply_async(sampling_block, args=(self.dataset.num_items, \n"," block_query_indexes, \n"," self.num_negs, \n"," self.user2items_dict, \n"," sampling_probs, \n"," self.ignore_pos_items,\n"," seeds[idx],\n"," dump_paths[idx]))\n"," pool.close()\n"," pool.join()\n"," block_result = [load_pickled_array(dump_paths[idx]) for idx in range(len(chunked_query_indexes))]\n"," shutil.rmtree(\"./tmp/pid_{}/\".format(os.getpid()))\n"," neg_item_indexes = np.vstack(block_result)\n"," else:\n"," neg_item_indexes = sampling_block(self.dataset.num_items, \n"," self.query_indexes, \n"," self.num_negs, \n"," self.user2items_dict, \n"," sampling_probs,\n"," self.ignore_pos_items)\n"," self.dataset.all_item_indexes = np.hstack([self.dataset.pos_item_indexes.reshape(-1, 1), \n"," neg_item_indexes])\n"," logging.info(\"Negative sampling done\")\n","\n","\n","class TestDataset(Dataset):\n"," def __init__(self, data_path):\n"," self.data_dict, self.num_samples = load_h5(data_path)\n","\n"," def __getitem__(self, index):\n"," batch_dict = self.slice_array_dict(index)\n"," return batch_dict\n"," \n"," def __len__(self):\n"," return self.num_samples\n","\n"," def slice_array_dict(self, slice_index):\n"," return dict((k, v[slice_index]) for k, v in self.data_dict.items())\n","\n","\n","class TestGenerator(object):\n"," def __init__(self, feature_map, data_path, item_corpus, batch_size=32, shuffle=False, \n"," num_workers=1, **kwargs):\n"," if type(data_path) == list:\n"," data_path = data_path[0]\n"," self.num_blocks = 1\n"," user_dataset = TestDataset(data_path)\n"," self.user2items_dict = get_user2items_dict(user_dataset.data_dict, feature_map)\n"," # pick users of unique query_index\n"," self.query_indexes, unique_rows = np.unique(user_dataset.data_dict[feature_map.query_index], \n"," return_index=True)\n"," user_dataset.num_samples = len(unique_rows)\n"," self.num_samples = len(user_dataset)\n"," # delete some columns to speed up batch generator\n"," del user_dataset.data_dict[feature_map.query_index]\n"," del user_dataset.data_dict[feature_map.corpus_index]\n"," del user_dataset.data_dict[feature_map.label_name]\n"," for k, v in user_dataset.data_dict.items():\n"," user_dataset.data_dict[k] = v[unique_rows]\n"," item_dataset = TestDataset(item_corpus)\n"," self.user_loader = DataLoader(dataset=user_dataset, batch_size=batch_size,\n"," shuffle=shuffle, num_workers=num_workers)\n"," self.item_loader = DataLoader(dataset=item_dataset, batch_size=batch_size,\n"," shuffle=shuffle, num_workers=num_workers)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_-HmZAEs8D7v"},"source":["### h5_generator"]},{"cell_type":"code","metadata":{"id":"FtvWNHRJs4aE"},"source":["def h5_generator(feature_map, stage=\"both\", train_data=None, valid_data=None, test_data=None,\n"," item_corpus=None, batch_size=32, num_negs=10, shuffle=True, **kwargs):\n"," logging.info(\"Loading data...\")\n"," train_gen = None\n"," valid_gen = None\n"," test_gen = None\n"," if stage in [\"both\", \"train\"]:\n"," train_blocks = glob.glob(train_data)\n"," valid_blocks = glob.glob(valid_data)\n"," assert len(train_blocks) > 0 and len(valid_blocks) > 0, \"invalid data files or paths.\"\n"," train_gen = TrainGenerator(feature_map, train_blocks, item_corpus, batch_size=batch_size, \n"," num_negs=num_negs, shuffle=shuffle, **kwargs)\n"," valid_gen = TestGenerator(feature_map, valid_blocks, item_corpus, batch_size=batch_size, \n"," shuffle=False, **kwargs)\n"," logging.info(\"Train samples: total/{:d}, blocks/{:.0f}\".format(train_gen.num_samples, train_gen.num_blocks))\n"," logging.info(\"Validation samples: total/{:d}, blocks/{:.0f}\".format(valid_gen.num_samples, valid_gen.num_blocks))\n"," if stage == \"train\":\n"," logging.info(\"Loading train data done.\")\n"," return train_gen, valid_gen\n","\n"," if stage in [\"both\", \"test\"]:\n"," test_blocks = glob.glob(test_data)\n"," test_gen = TestGenerator(feature_map, test_blocks, item_corpus, batch_size=batch_size, \n"," shuffle=False, **kwargs)\n"," logging.info(\"Test samples: total/{:d}, blocks/{:.0f}\".format(test_gen.num_samples, test_gen.num_blocks))\n"," if stage == \"test\":\n"," logging.info(\"Loading test data done.\")\n"," return test_gen\n","\n"," logging.info(\"Loading data done.\")\n"," return train_gen, valid_gen, test_gen"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"abIbck5i7GmK"},"source":["args = Args()\n","args.train_data = '/content/train.h5'\n","args.valid_data = '/content/valid.h5'\n","# args.test_data = '/content/valid.h5'\n","args.item_corpus = '/content/item_corpus.h5'\n","# train_gen, valid_gen, test_gen = h5_generator(feature_encoder.feature_map, **args.__dict__)\n","train_gen, valid_gen = h5_generator(feature_encoder.feature_map, stage='train', **args.__dict__)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"9qXVmVj8c8Fu","cellView":"form"},"source":["#@title\n","class SoftmaxCrossEntropyLoss(nn.Module):\n"," def __init__(self):\n"," \"\"\"\n"," :param num_negs: number of negative instances in bpr loss.\n"," \"\"\"\n"," super(SoftmaxCrossEntropyLoss, self).__init__()\n","\n"," def forward(self, y_pred, y_true):\n"," \"\"\"\n"," :param y_true: Labels\n"," :param y_pred: Predicted result.\n"," \"\"\"\n"," probs = F.softmax(y_pred, dim=1)\n"," hit_probs = probs[:, 0]\n"," loss = -torch.log(hit_probs).mean()\n"," return loss\n","\n","\n","class CosineContrastiveLoss(nn.Module):\n"," def __init__(self, margin=0, negative_weight=None):\n"," \"\"\"\n"," :param margin: float, margin in CosineContrastiveLoss\n"," :param num_negs: int, number of negative samples\n"," :param negative_weight:, float, the weight set to the negative samples. When negative_weight=None, it\n"," equals to num_negs\n"," \"\"\"\n"," super(CosineContrastiveLoss, self).__init__()\n"," self._margin = margin\n"," self._negative_weight = negative_weight\n","\n"," def forward(self, y_pred, y_true):\n"," \"\"\"\n"," :param y_pred: prdicted values of shape (batch_size, 1 + num_negs) \n"," :param y_true: true labels of shape (batch_size, 1 + num_negs)\n"," \"\"\"\n"," pos_logits = y_pred[:, 0]\n"," pos_loss = torch.relu(1 - pos_logits)\n"," neg_logits = y_pred[:, 1:]\n"," neg_loss = torch.relu(neg_logits - self._margin)\n"," if self._negative_weight:\n"," loss = pos_loss + neg_loss.mean(dim=-1) * self._negative_weight\n"," else:\n"," loss = pos_loss + neg_loss.sum(dim=-1)\n"," return loss.mean()\n","\n","\n","class MSELoss(nn.Module):\n"," def __init__(self):\n"," super(MSELoss, self).__init__()\n","\n"," def forward(self, y_pred, y_true):\n"," \"\"\"\n"," :param y_pred: prdicted values of shape (batch_size, 1 + num_negs) \n"," :param y_true: true labels of shape (batch_size, 1 + num_negs)\n"," \"\"\"\n"," pos_logits = y_pred[:, 0]\n"," pos_loss = torch.pow(pos_logits - 1, 2) / 2\n"," neg_logits = y_pred[:, 1:]\n"," neg_loss = torch.pow(neg_logits, 2).sum(dim=-1) / 2\n"," loss = pos_loss + neg_loss\n"," return loss.mean()\n","\n","\n","class PairwiseLogisticLoss(nn.Module):\n"," def __init__(self):\n"," super(PairwiseLogisticLoss, self).__init__()\n","\n"," def forward(self, y_pred, y_true):\n"," \"\"\"\n"," :param y_true: Labels\n"," :param y_pred: Predicted result.\n"," \"\"\"\n"," pos_logits = y_pred[:, 0].unsqueeze(-1)\n"," neg_logits = y_pred[:, 1:]\n"," logits_diff = pos_logits - neg_logits\n"," loss = -torch.log(torch.sigmoid(logits_diff)).mean()\n"," return loss\n","\n","\n","class PairwiseMarginLoss(nn.Module):\n"," def __init__(self, margin=1.0):\n"," \"\"\"\n"," :param num_negs: number of negative instances in bpr loss.\n"," \"\"\"\n"," super(PairwiseMarginLoss, self).__init__()\n"," self._margin = margin\n","\n"," def forward(self, y_pred, y_true):\n"," \"\"\"\n"," :param y_true: Labels\n"," :param y_pred: Predicted result.\n"," \"\"\"\n"," pos_logits = y_pred[:, 0].unsqueeze(-1)\n"," neg_logits = y_pred[:, 1:]\n"," loss = torch.relu(self._margin + neg_logits - pos_logits).mean()\n"," return loss\n","\n","\n","class SigmoidCrossEntropyLoss(nn.Module):\n"," def __init__(self):\n"," \"\"\"\n"," :param num_negs: number of negative instances in bpr loss.\n"," \"\"\"\n"," super(SigmoidCrossEntropyLoss, self).__init__()\n","\n"," def forward(self, y_pred, y_true):\n"," \"\"\"\n"," :param y_true: Labels\n"," :param y_pred: Predicted result\n"," \"\"\"\n"," logits = y_pred.flatten()\n"," labels = y_true.flatten()\n"," loss = F.binary_cross_entropy_with_logits(logits, labels, reduction=\"sum\")\n"," return loss"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"hOlH9BWUuLqV","cellView":"form"},"source":["#@title\n","class BaseModel(nn.Module):\n"," def __init__(self, \n"," feature_map, \n"," model_id=\"BaseModel\", \n"," gpu=-1, \n"," monitor=\"AUC\", \n"," save_best_only=True, \n"," monitor_mode=\"max\", \n"," patience=2, \n"," eval_interval_epochs=1, \n"," embedding_regularizer=None, \n"," net_regularizer=None, \n"," reduce_lr_on_plateau=True, \n"," embedding_initializer=\"lambda w: nn.init.normal_(w, std=1e-4)\", \n"," num_negs=0,\n"," **kwargs):\n"," super(BaseModel, self).__init__()\n"," self.device = set_device(gpu)\n"," self.feature_map = feature_map\n"," self._monitor = Monitor(kv=monitor)\n"," self._monitor_mode = monitor_mode\n"," self._patience = patience\n"," self._eval_interval_epochs = eval_interval_epochs # float acceptable\n"," self._save_best_only = save_best_only\n"," self._embedding_regularizer = embedding_regularizer\n"," self._net_regularizer = net_regularizer\n"," self._reduce_lr_on_plateau = reduce_lr_on_plateau\n"," self._embedding_initializer = embedding_initializer\n"," self.model_id = model_id\n"," self.model_dir = os.path.join(kwargs[\"model_root\"], feature_map.dataset_id)\n"," self.checkpoint = os.path.abspath(os.path.join(self.model_dir, self.model_id + \".model\"))\n"," self._validation_metrics = kwargs[\"metrics\"]\n"," self._verbose = kwargs[\"verbose\"]\n"," self.num_negs = num_negs\n","\n"," def compile(self, lr=1e-3, optimizer=None, loss=None, **kwargs):\n"," try:\n"," self.optimizer = set_optimizer(optimizer)(self.parameters(), lr=lr)\n"," except:\n"," raise NotImplementedError(\"optimizer={} is not supported.\".format(optimizer))\n"," if loss == \"SigmoidCrossEntropyLoss\":\n"," self.loss_fn = SigmoidCrossEntropyLoss()\n"," elif loss == \"PairwiseLogisticLoss\":\n"," self.loss_fn = PairwiseLogisticLoss()\n"," elif loss == \"SoftmaxCrossEntropyLoss\":\n"," self.loss_fn = SoftmaxCrossEntropyLoss()\n"," elif loss == \"PairwiseMarginLoss\":\n"," self.loss_fn = PairwiseMarginLoss(margin=kwargs.get(\"margin\", 1))\n"," elif loss == \"MSELoss\":\n"," self.loss_fn = MSELoss()\n"," elif loss == \"CosineContrastiveLoss\":\n"," self.loss_fn = CosineContrastiveLoss(margin=kwargs.get(\"margin\", 0),\n"," negative_weight=kwargs.get(\"negative_weight\"))\n"," else:\n"," raise NotImplementedError(\"loss={} is not supported.\".format(loss))\n"," self.apply(self.init_weights)\n"," self.to(device=self.device)\n","\n"," def get_total_loss(self, y_pred, y_true):\n"," # y_pred: N x (1 + num_negs) \n"," # y_true: N x (1 + num_negs) \n"," y_true = y_true.float().to(self.device)\n"," total_loss = self.loss_fn(y_pred, y_true)\n"," if self._embedding_regularizer or self._net_regularizer:\n"," emb_reg = set_regularizer(self._embedding_regularizer)\n"," net_reg = set_regularizer(self._net_regularizer)\n"," for name, param in self.named_parameters():\n"," if param.requires_grad:\n"," if \"embedding_layer\" in name:\n"," if self._embedding_regularizer:\n"," for emb_p, emb_lambda in emb_reg:\n"," total_loss += (emb_lambda / emb_p) * torch.norm(param, emb_p) ** emb_p\n"," else:\n"," if self._net_regularizer:\n"," for net_p, net_lambda in net_reg:\n"," total_loss += (net_lambda / net_p) * torch.norm(param, net_p) ** net_p\n"," return total_loss\n","\n"," def init_weights(self, m):\n"," if type(m) == nn.ModuleDict:\n"," for k, v in m.items():\n"," if type(v) == nn.Embedding:\n"," if \"pretrained_emb\" in self.feature_map.feature_specs[k]: # skip pretrained\n"," continue\n"," try:\n"," initialize_emb = eval(self._embedding_initializer)\n"," if v.padding_idx is not None:\n"," # using the last index as padding_idx\n"," initialize_emb(v.weight[0:-1, :])\n"," else:\n"," initialize_emb(v.weight)\n"," except:\n"," raise NotImplementedError(\"embedding_initializer={} is not supported.\"\\\n"," .format(self._embedding_initializer))\n"," elif type(v) == nn.Linear:\n"," nn.init.xavier_normal_(v.weight)\n"," if v.bias is not None:\n"," v.bias.data.fill_(0)\n"," elif type(m) == nn.Linear:\n"," nn.init.xavier_normal_(m.weight)\n"," if m.bias is not None:\n"," m.bias.data.fill_(0)\n"," \n"," def to_device(self, inputs):\n"," self.batch_size = 0\n"," for k in inputs.keys():\n"," inputs[k] = inputs[k].to(self.device)\n"," if self.batch_size < 1:\n"," self.batch_size = inputs[k].size(0)\n"," return inputs\n","\n"," def on_batch_end(self, train_generator, batch_index, logs={}):\n"," self._total_batches += 1\n"," if (batch_index + 1) % self._eval_interval_batches == 0 or (batch_index + 1) % self._batches_per_epoch == 0:\n"," val_logs = self.evaluate(train_generator, self.valid_gen)\n"," epoch = round(float(self._total_batches) / self._batches_per_epoch, 2)\n"," self.checkpoint_and_earlystop(epoch, val_logs)\n"," logging.info(\"--- {}/{} batches finished ---\".format(batch_index + 1, self._batches_per_epoch))\n","\n"," def reduce_learning_rate(self, factor=0.1, min_lr=1e-6):\n"," for param_group in self.optimizer.param_groups:\n"," reduced_lr = max(param_group[\"lr\"] * factor, min_lr)\n"," param_group[\"lr\"] = reduced_lr\n"," return reduced_lr\n","\n"," def checkpoint_and_earlystop(self, epoch, logs, min_delta=1e-6):\n"," monitor_value = self._monitor.get_value(logs)\n"," if (self._monitor_mode == \"min\" and monitor_value > self._best_metric - min_delta) or \\\n"," (self._monitor_mode == \"max\" and monitor_value < self._best_metric + min_delta):\n"," self._stopping_steps += 1\n"," logging.info(\"Monitor({}) STOP: {:.6f} !\".format(self._monitor_mode, monitor_value))\n"," if self._reduce_lr_on_plateau:\n"," current_lr = self.reduce_learning_rate()\n"," logging.info(\"Reduce learning rate on plateau: {:.6f}\".format(current_lr))\n"," logging.info(\"Load best model: {}\".format(self.checkpoint))\n"," self.load_weights(self.checkpoint)\n"," else:\n"," self._stopping_steps = 0\n"," self._best_metric = monitor_value\n"," if self._save_best_only:\n"," logging.info(\"Save best model: monitor({}): {:.6f}\"\\\n"," .format(self._monitor_mode, monitor_value))\n"," self.save_weights(self.checkpoint)\n"," if self._stopping_steps * self._eval_interval_epochs >= self._patience:\n"," self._stop_training = True\n"," logging.info(\"Early stopping at epoch={:g}\".format(epoch))\n"," if not self._save_best_only:\n"," self.save_weights(self.checkpoint)\n"," \n"," def fit(self, train_generator, epochs=1, valid_generator=None,\n"," verbose=0, max_gradient_norm=10., **kwargs):\n"," self.valid_gen = valid_generator\n"," self._max_gradient_norm = max_gradient_norm\n"," self._best_metric = np.Inf if self._monitor_mode == \"min\" else -np.Inf\n"," self._stopping_steps = 0\n"," self._total_batches = 0\n"," self._batches_per_epoch = len(train_generator)\n"," self._eval_interval_batches = int(np.ceil(self._eval_interval_epochs * self._batches_per_epoch))\n"," self._stop_training = False\n"," self._verbose = verbose\n"," \n"," logging.info(\"**** Start training: {} batches/epoch ****\".format(self._batches_per_epoch))\n"," for epoch in range(epochs):\n"," epoch_loss = self.train_on_epoch(train_generator, epoch)\n"," logging.info(\"Train loss: {:.6f}\".format(epoch_loss))\n"," if self._stop_training:\n"," break\n"," else:\n"," logging.info(\"************ Epoch={} end ************\".format(epoch + 1))\n"," logging.info(\"Training finished.\")\n"," logging.info(\"Load best model: {}\".format(self.checkpoint))\n"," self.load_weights(self.checkpoint)\n","\n"," def train_on_epoch(self, train_generator, epoch):\n"," epoch_loss = 0\n"," model = self.train()\n"," batch_generator = train_generator\n"," if self._verbose > 0:\n"," batch_generator = tqdm(train_generator, disable=False)#, file=sys.stdout)\n"," for batch_index, batch_data in enumerate(batch_generator):\n"," self.optimizer.zero_grad()\n"," return_dict = model.forward(batch_data)\n"," loss = return_dict[\"loss\"]\n"," loss.backward()\n"," nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)\n"," self.optimizer.step()\n"," epoch_loss += loss.item()\n"," self.on_batch_end(train_generator, batch_index)\n"," if self._stop_training:\n"," break\n"," return epoch_loss / self._batches_per_epoch\n","\n"," def evaluate(self, train_generator, valid_generator):\n"," logging.info(\"--- Start evaluation ---\")\n"," self.eval() # set to evaluation mode\n"," with torch.no_grad():\n"," user_vecs = []\n"," item_vecs = []\n"," for user_batch in valid_generator.user_loader:\n"," user_vec = self.user_tower(user_batch)\n"," user_vecs.extend(user_vec.data.cpu().numpy())\n"," for item_batch in valid_generator.item_loader:\n"," item_vec = self.item_tower(item_batch)\n"," item_vecs.extend(item_vec.data.cpu().numpy())\n"," user_vecs = np.array(user_vecs, np.float64)\n"," item_vecs = np.array(item_vecs, np.float64)\n"," val_logs = evaluate_metrics(user_vecs,\n"," item_vecs,\n"," train_generator.user2items_dict,\n"," valid_generator.user2items_dict,\n"," valid_generator.query_indexes,\n"," self._validation_metrics)\n"," return val_logs\n"," \n"," def save_weights(self, checkpoint):\n"," torch.save(self.state_dict(), checkpoint)\n"," \n"," def load_weights(self, checkpoint):\n"," self.load_state_dict(torch.load(checkpoint, map_location=self.device))\n","\n"," def count_parameters(self, count_embedding=True):\n"," total_params = 0\n"," for name, param in self.named_parameters(): \n"," if not count_embedding and \"embedding\" in name:\n"," continue\n"," if param.requires_grad:\n"," total_params += param.numel()\n"," logging.info(\"Total number of parameters: {}.\".format(total_params))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"y09Ty_uWz0Ma"},"source":["#@title\n","class EmbeddingLayer(nn.Module):\n"," def __init__(self, \n"," feature_map,\n"," embedding_dim,\n"," disable_sharing_pretrain=False,\n"," required_feature_columns=[],\n"," not_required_feature_columns=[]):\n"," super(EmbeddingLayer, self).__init__()\n"," self.embedding_layer = EmbeddingDictLayer(feature_map, \n"," embedding_dim,\n"," disable_sharing_pretrain=disable_sharing_pretrain,\n"," required_feature_columns=required_feature_columns,\n"," not_required_feature_columns=not_required_feature_columns)\n","\n"," def forward(self, X, feature_source=None):\n"," feature_emb_dict = self.embedding_layer(X, feature_source=feature_source)\n"," feature_emb = self.embedding_layer.dict2tensor(feature_emb_dict)\n"," return feature_emb\n","\n","\n","class EmbeddingDictLayer(nn.Module):\n"," def __init__(self, \n"," feature_map, \n"," embedding_dim,\n"," disable_sharing_pretrain=False,\n"," required_feature_columns=None,\n"," not_required_feature_columns=None):\n"," super(EmbeddingDictLayer, self).__init__()\n"," self._feature_map = feature_map\n"," self.required_feature_columns = required_feature_columns\n"," self.not_required_feature_columns = not_required_feature_columns\n"," self.embedding_layers = nn.ModuleDict()\n"," self.embedding_callbacks = nn.ModuleDict()\n"," for feature, feature_spec in self._feature_map.feature_specs.items():\n"," if self.is_required(feature):\n"," if disable_sharing_pretrain: # in case for LR\n"," assert embedding_dim == 1\n"," feat_emb_dim = embedding_dim\n"," else:\n"," feat_emb_dim = feature_spec.get(\"embedding_dim\", embedding_dim)\n"," if (not disable_sharing_pretrain) and \"embedding_callback\" in feature_spec:\n"," self.embedding_callbacks[feature] = eval(feature_spec[\"embedding_callback\"])\n"," # Set embedding_layer according to share_embedding\n"," if (not disable_sharing_pretrain) and \"share_embedding\" in feature_spec:\n"," self.embedding_layers[feature] = self.embedding_layers[feature_spec[\"share_embedding\"]]\n"," continue\n"," \n"," if feature_spec[\"type\"] == \"numeric\":\n"," self.embedding_layers[feature] = nn.Linear(1, feat_emb_dim, bias=False)\n"," elif feature_spec[\"type\"] == \"categorical\":\n"," padding_idx = feature_spec.get(\"padding_idx\", None)\n"," embedding_matrix = nn.Embedding(feature_spec[\"vocab_size\"], \n"," feat_emb_dim,\n"," padding_idx=padding_idx)\n"," if (not disable_sharing_pretrain) and \"pretrained_emb\" in feature_spec:\n"," embedding_matrix = self.load_pretrained_embedding(embedding_matrix,\n"," feature_map, \n"," feature_name, \n"," freeze=feature_spec[\"freeze_emb\"],\n"," padding_idx=padding_idx)\n"," self.embedding_layers[feature] = embedding_matrix\n"," elif feature_spec[\"type\"] == \"sequence\":\n"," padding_idx = feature_spec.get(\"padding_idx\", None)\n"," embedding_matrix = nn.Embedding(feature_spec[\"vocab_size\"], \n"," feat_emb_dim, \n"," padding_idx=padding_idx)\n"," if (not disable_sharing_pretrain) and \"pretrained_emb\" in feature_spec:\n"," embedding_matrix = self.load_pretrained_embedding(embedding_matrix, \n"," feature_map, \n"," feature_name,\n"," freeze=feature_spec[\"freeze_emb\"],\n"," padding_idx=padding_idx)\n"," self.embedding_layers[feature] = embedding_matrix\n","\n"," def is_required(self, feature):\n"," \"\"\" Check whether feature is required for embedding \"\"\"\n"," feature_spec = self._feature_map.feature_specs[feature]\n"," if self.required_feature_columns and (feature not in self.required_feature_columns):\n"," return False\n"," if self.not_required_feature_columns and (feature in self.not_required_feature_columns):\n"," return False\n"," return True\n","\n"," def get_pretrained_embedding(self, pretrained_path, feature_name):\n"," with h5py.File(pretrained_path, 'r') as hf:\n"," embeddings = hf[feature_name][:]\n"," return embeddings\n","\n"," def load_pretrained_embedding(self, embedding_matrix, feature_map, feature_name, freeze=False, padding_idx=None):\n"," pretrained_path = os.path.join(feature_map.data_dir, feature_map.feature_specs[feature_name][\"pretrained_emb\"])\n"," embeddings = self.get_pretrained_embedding(pretrained_path, feature_name)\n"," if padding_idx is not None:\n"," embeddings[padding_idx] = np.zeros(embeddings.shape[-1])\n"," embeddings = torch.from_numpy(embeddings).float()\n"," embedding_matrix.weight = torch.nn.Parameter(embeddings)\n"," if freeze:\n"," embedding_matrix.weight.requires_grad = False\n"," return embedding_matrix\n","\n"," def dict2tensor(self, embedding_dict):\n"," if len(embedding_dict) == 1:\n"," feature_emb = list(embedding_dict.values())[0]\n"," else:\n"," feature_emb = torch.stack(list(embedding_dict.values()), dim=1)\n"," return feature_emb\n","\n"," def forward(self, inputs, feature_source=None, feature_type=None):\n"," feature_emb_dict = OrderedDict()\n"," for feature, feature_spec in self._feature_map.feature_specs.items():\n"," if feature_source and feature_spec[\"source\"] != feature_source:\n"," continue\n"," if feature_type and feature_spec[\"type\"] != feature_type:\n"," continue\n"," if feature in self.embedding_layers:\n"," if feature_spec[\"type\"] == \"numeric\":\n"," inp = inputs[feature].float().view(-1, 1)\n"," embeddings = self.embedding_layers[feature](inp)\n"," elif feature_spec[\"type\"] == \"categorical\":\n"," inp = inputs[feature].long()\n"," embeddings = self.embedding_layers[feature](inp)\n"," elif feature_spec[\"type\"] == \"sequence\":\n"," inp = inputs[feature].long()\n"," embeddings = self.embedding_layers[feature](inp)\n"," else:\n"," raise NotImplementedError\n"," if feature in self.embedding_callbacks:\n"," embeddings = self.embedding_callbacks[feature](embeddings) \n"," feature_emb_dict[feature] = embeddings\n"," return feature_emb_dict"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"tG3l6RfEuG1_","cellView":"form"},"source":["#@title\n","class SimpleX(BaseModel):\n"," def __init__(self, \n"," feature_map, \n"," model_id=\"SimpleX\", \n"," gpu=-1, \n"," learning_rate=1e-3, \n"," embedding_initializer=\"lambda w: nn.init.normal_(w, std=1e-4)\", \n"," embedding_dim=10, \n"," user_id_field=\"user_id\",\n"," item_id_field=\"item_id\",\n"," user_history_field=\"user_history\",\n"," enable_bias=False,\n"," num_negs=1,\n"," net_dropout=0,\n"," aggregator=\"mean\",\n"," gamma=0.5,\n"," attention_dropout=0,\n"," batch_norm=False,\n"," net_regularizer=None,\n"," embedding_regularizer=None,\n"," similarity_score=\"dot\",\n"," **kwargs):\n"," super(SimpleX, self).__init__(feature_map, \n"," model_id=model_id, \n"," gpu=gpu, \n"," embedding_regularizer=embedding_regularizer,\n"," net_regularizer=net_regularizer,\n"," num_negs=num_negs,\n"," embedding_initializer=embedding_initializer,\n"," **kwargs)\n"," self.similarity_score = similarity_score\n"," self.embedding_dim = embedding_dim\n"," self.user_id_field = user_id_field\n"," self.user_history_field = user_history_field\n"," self.embedding_layer = EmbeddingDictLayer(feature_map, embedding_dim)\n"," self.behavior_aggregation = BehaviorAggregator(embedding_dim, \n"," gamma=gamma,\n"," aggregator=aggregator, \n"," dropout_rate=attention_dropout)\n"," self.enable_bias = enable_bias\n"," if self.enable_bias:\n"," self.user_bias = EmbeddingLayer(feature_map, 1,\n"," disable_sharing_pretrain=True, \n"," required_feature_columns=[user_id_field])\n"," self.item_bias = EmbeddingLayer(feature_map, 1, \n"," disable_sharing_pretrain=True, \n"," required_feature_columns=[item_id_field])\n"," self.global_bias = nn.Parameter(torch.zeros(1))\n"," self.dropout = nn.Dropout(net_dropout)\n"," self.compile(lr=learning_rate, **kwargs)\n"," \n"," def forward(self, inputs):\n"," \"\"\"\n"," Inputs: [user_dict, item_dict, label]\n"," \"\"\"\n"," user_dict, item_dict, labels = inputs[0:3]\n"," user_vecs = self.user_tower(user_dict)\n"," user_vecs = self.dropout(user_vecs)\n"," item_vecs = self.item_tower(item_dict)\n"," y_pred = torch.bmm(item_vecs.view(user_vecs.size(0), self.num_negs + 1, -1), \n"," user_vecs.unsqueeze(-1)).squeeze(-1)\n"," if self.enable_bias: # user_bias and global_bias only influence training, but not inference for ranking\n"," y_pred += self.user_bias(self.to_device(user_dict)) + self.global_bias\n"," loss = self.get_total_loss(y_pred, labels)\n"," return_dict = {\"loss\": loss, \"y_pred\": y_pred}\n"," return return_dict\n","\n"," def user_tower(self, inputs):\n"," user_inputs = self.to_device(inputs)\n"," user_emb_dict = self.embedding_layer(user_inputs, feature_source=\"user\")\n"," user_id_emb = user_emb_dict[self.user_id_field]\n"," user_history_emb = user_emb_dict[self.user_history_field]\n"," user_vec = self.behavior_aggregation(user_id_emb, user_history_emb)\n"," if self.similarity_score == \"cosine\":\n"," user_vec = F.normalize(user_vec)\n"," if self.enable_bias: \n"," user_vec = torch.cat([user_vec, torch.ones(user_vec.size(0), 1).to(self.device)], dim=-1)\n"," return user_vec\n","\n"," def item_tower(self, inputs):\n"," item_inputs = self.to_device(inputs)\n"," item_vec_dict = self.embedding_layer(item_inputs, feature_source=\"item\")\n"," item_vec = self.embedding_layer.dict2tensor(item_vec_dict)\n"," if self.similarity_score == \"cosine\":\n"," item_vec = F.normalize(item_vec)\n"," if self.enable_bias:\n"," item_vec = torch.cat([item_vec, self.item_bias(item_inputs)], dim=-1)\n"," return item_vec\n","\n","\n","class BehaviorAggregator(nn.Module):\n"," def __init__(self, embedding_dim, gamma=0.5, aggregator=\"mean\", dropout_rate=0.):\n"," super(BehaviorAggregator, self).__init__()\n"," self.aggregator = aggregator\n"," self.gamma = gamma\n"," self.W_v = nn.Linear(embedding_dim, embedding_dim, bias=False)\n"," if self.aggregator in [\"cross_attention\", \"self_attention\"]:\n"," self.W_k = nn.Sequential(nn.Linear(embedding_dim, embedding_dim),\n"," nn.Tanh())\n"," self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None\n"," if self.aggregator == \"self_attention\":\n"," self.W_q = nn.Parameter(torch.Tensor(embedding_dim, 1))\n"," nn.init.xavier_normal_(self.W_q)\n","\n"," def forward(self, id_emb, sequence_emb):\n"," out = id_emb\n"," if self.aggregator == \"mean\":\n"," out = self.average_pooling(sequence_emb)\n"," elif self.aggregator == \"cross_attention\":\n"," out = self.cross_attention(id_emb, sequence_emb)\n"," elif self.aggregator == \"self_attention\":\n"," out = self.self_attention(sequence_emb)\n"," return self.gamma * id_emb + (1 - self.gamma) * out\n","\n"," def cross_attention(self, id_emb, sequence_emb):\n"," key = self.W_k(sequence_emb) # b x seq_len x attention_dim\n"," mask = sequence_emb.sum(dim=-1) == 0\n"," attention = torch.bmm(key, id_emb.unsqueeze(-1)).squeeze(-1) # b x seq_len\n"," attention = self.masked_softmax(attention, mask)\n"," if self.dropout is not None:\n"," attention = self.dropout(attention)\n"," output = torch.bmm(attention.unsqueeze(1), sequence_emb).squeeze(1)\n"," return self.W_v(output)\n","\n"," def self_attention(self, sequence_emb):\n"," key = self.W_k(sequence_emb) # b x seq_len x attention_dim\n"," mask = sequence_emb.sum(dim=-1) == 0\n"," attention = torch.matmul(key, self.W_q).squeeze(-1) # b x seq_len\n"," attention = self.masked_softmax(attention, mask)\n"," if self.dropout is not None:\n"," attention = self.dropout(attention)\n"," output = torch.bmm(attention.unsqueeze(1), sequence_emb).squeeze(1)\n"," return self.W_v(output)\n","\n"," def average_pooling(self, sequence_emb):\n"," mask = sequence_emb.sum(dim=-1) != 0\n"," mean = sequence_emb.sum(dim=1) / (mask.float().sum(dim=-1, keepdim=True) + 1.e-12)\n"," return self.W_v(mean)\n","\n"," def masked_softmax(self, X, mask):\n"," # use the following softmax to avoid nans when a sequence is entirely masked\n"," X = X.masked_fill_(mask, 0)\n"," e_X = torch.exp(X)\n"," return e_X / (e_X.sum(dim=1, keepdim=True) + 1.e-12)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kc__hGpwymEw"},"source":["model = SimpleX(feature_encoder.feature_map, **Args.__dict__)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Rf-jVMar1bBA","executionInfo":{"status":"ok","timestamp":1633244631475,"user_tz":-330,"elapsed":44,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"1560d296-77ac-4f06-fa15-f8f8e9deab92"},"source":["model"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["SimpleX(\n"," (embedding_layer): EmbeddingDictLayer(\n"," (embedding_layers): ModuleDict(\n"," (user_id): Embedding(944, 64)\n"," (user_history): Embedding(1615, 64, padding_idx=1614)\n"," (item_id): Embedding(1614, 64)\n"," )\n"," (embedding_callbacks): ModuleDict()\n"," )\n"," (behavior_aggregation): BehaviorAggregator(\n"," (W_v): Linear(in_features=64, out_features=64, bias=False)\n"," )\n"," (dropout): Dropout(p=0, inplace=False)\n"," (loss_fn): CosineContrastiveLoss()\n",")"]},"metadata":{},"execution_count":53}]},{"cell_type":"code","metadata":{"id":"nn1BpF1t1ZRS"},"source":["model.count_parameters() # print number of parameters used in model"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"8eoTXaKDDyM4","executionInfo":{"status":"ok","timestamp":1633246426354,"user_tz":-330,"elapsed":453,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"ce72c4c8-40ce-45fc-8c8c-31a378b87dd2"},"source":["for batch_index, batch_data in enumerate(train_gen):\n"," print(batch_data[0]['user_history'].shape)\n"," break"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["torch.Size([32, 100])\n"]}]},{"cell_type":"markdown","source":["> Danger: Didn't worked!"],"metadata":{"id":"u9nT4dzmv4h1"}},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":578},"id":"v9Rjv_1AzDGh","executionInfo":{"status":"error","timestamp":1633244631479,"user_tz":-330,"elapsed":36,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"3dd5aac9-bb49-449a-f565-23e2ad88164a"},"source":["model.fit(train_generator=train_gen, valid_generator=valid_gen, **Args.__dict__)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":[" 0%| | 0/2500 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_generator\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain_gen\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_generator\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mvalid_gen\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mArgs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, train_generator, epochs, valid_generator, verbose, max_gradient_norm, **kwargs)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"**** Start training: {} batches/epoch ****\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batches_per_epoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0mepoch_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_on_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_generator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Train loss: {:.6f}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stop_training\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m\u001b[0m in \u001b[0;36mtrain_on_epoch\u001b[0;34m(self, train_generator, epoch)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_data\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_generator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 180\u001b[0;31m \u001b[0mreturn_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 181\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"loss\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 182\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0muser_vecs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0muser_vecs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0mitem_vecs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem_tower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem_dict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m y_pred = torch.bmm(item_vecs.view(user_vecs.size(0), self.num_negs + 1, -1), \n\u001b[0m\u001b[1;32m 62\u001b[0m user_vecs.unsqueeze(-1)).squeeze(-1)\n\u001b[1;32m 63\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# user_bias and global_bias only influence training, but not inference for ranking\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mRuntimeError\u001b[0m: shape '[32, 21, -1]' is invalid for input of size 22528"]}]}]}