{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-26-topk-reinforce.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T373316%20%7C%20Top-K%20Off-Policy%20Correction%20for%20a%20REINFORCE%20Recommender%20System.ipynb","timestamp":1644674011650}],"collapsed_sections":[],"toc_visible":true,"authorship_tag":"ABX9TyO0eAfCFX4cpoPXr6QSJJk0"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"Skc9_mUcfxTx"},"source":["# Top-K Off-Policy Correction for a REINFORCE Recommender System"]},{"cell_type":"markdown","metadata":{"id":"cXgEkn6wbXSb"},"source":["## CLI run"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"yGjiMX5nVGr2","executionInfo":{"status":"ok","timestamp":1634811527051,"user_tz":-330,"elapsed":68164,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"365fde77-8a2a-4d0a-f5fe-c6b59d2da8e1"},"source":["!gdown --id 1erBjYEOa7IuOIGpI8pGPn1WNBAC4Rv0-\n","!git clone https://github.com/massquantity/DBRL.git\n","!unzip /content/ECommAI_EUIR_round2_train_20190821.zip\n","!mv ECommAI_EUIR_round2_train_20190816/*.csv DBRL/dbrl/resources\n","%cd DBRL/dbrl"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Downloading...\n","From: https://drive.google.com/uc?id=1erBjYEOa7IuOIGpI8pGPn1WNBAC4Rv0-\n","To: /content/ECommAI_EUIR_round2_train_20190821.zip\n","100% 894M/894M [00:06<00:00, 146MB/s]\n","Cloning into 'DBRL'...\n","remote: Enumerating objects: 118, done.\u001b[K\n","remote: Counting objects: 100% (118/118), done.\u001b[K\n","remote: Compressing objects: 100% (83/83), done.\u001b[K\n","remote: Total 118 (delta 29), reused 114 (delta 25), pack-reused 0\u001b[K\n","Receiving objects: 100% (118/118), 203.89 KiB | 2.87 MiB/s, done.\n","Resolving deltas: 100% (29/29), done.\n","Archive: /content/ECommAI_EUIR_round2_train_20190821.zip\n"," creating: ECommAI_EUIR_round2_train_20190816/\n"," inflating: ECommAI_EUIR_round2_train_20190816/user_behavior.csv \n"," inflating: ECommAI_EUIR_round2_train_20190816/item.csv \n"," inflating: ECommAI_EUIR_round2_train_20190816/user.csv \n","/content/DBRL/dbrl\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"-SiMsatqVp6a","executionInfo":{"status":"ok","timestamp":1634811700903,"user_tz":-330,"elapsed":173863,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"041071ae-c7c0-4437-f5f7-5585f3291bbd"},"source":["!python run_prepare_data.py"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["{'seed': 0}\n","tcmalloc: large alloc 1931411456 bytes == 0x559fe1e9e000 @ 0x7fbeb7c811e7 0x7fbeb580146e 0x7fbeb5851c7b 0x7fbeb585235f 0x7fbeb58f4103 0x559fd33e4544 0x559fd33e4240 0x559fd3458627 0x559fd33e5afa 0x559fd3453915 0x559fd34529ee 0x559fd33e5bda 0x559fd3453915 0x559fd33e5afa 0x559fd3453915 0x559fd33e5afa 0x559fd3453915 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e648c 0x559fd3427159 0x559fd34240a4 0x559fd33e4d49 0x559fd345894f 0x559fd34529ee 0x559fd33e5bda\n","tcmalloc: large alloc 1931411456 bytes == 0x55a1625c8000 @ 0x7fbeb7c811e7 0x7fbeb580146e 0x7fbeb5851c7b 0x7fbeb585235f 0x7fbeb58f4103 0x559fd33e4544 0x559fd33e4240 0x559fd3458627 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3453915 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd34526f3 0x559fd351c4c2 0x559fd351c83d 0x559fd351c6e6\n","tcmalloc: large alloc 1931411456 bytes == 0x55a1d57b8000 @ 0x7fbeb7c811e7 0x7fbeb580146e 0x7fbeb5851c7b 0x7fbeb5851d97 0x7fbeb584b4a5 0x7fbeb58e8eab 0x559fd33e44b0 0x559fd34d5e1d 0x559fd3457e99 0x559fd34529ee 0x559fd33e648c 0x559fd33e6698 0x559fd3454fe4 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3457d00 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd34526f3 0x559fd351c4c2 0x559fd351c83d 0x559fd351c6e6 0x559fd34f4163\n","tcmalloc: large alloc 1548664832 bytes == 0x559fe1e9e000 @ 0x7fbeb7c811e7 0x7fbeb580146e 0x7fbeb5851c7b 0x7fbeb585235f 0x7fbeb58f4103 0x559fd33e4544 0x559fd33e4240 0x559fd3458627 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd3452ced 0x559fd33e5bda 0x559fd3453915 0x559fd3452ced 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd34529ee 0x559fd33e5bda 0x559fd3454737 0x559fd33e5afa 0x559fd3457d00\n","n_users: 80000, n_items: 1047166, behavior length: 3234367\n","prepare data done!, time elapsed: 173.06\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"2Bhb2NHZV7pH","executionInfo":{"status":"ok","timestamp":1634812382531,"user_tz":-330,"elapsed":681644,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"4a80ba3f-46ed-4a07-cf05-9b2d118bfab9"},"source":["!python run_pretrain_embeddings.py --lr 0.001 --n_epochs 4"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["A list all args: \n","======================\n","{'batch_size': 2048,\n"," 'data': 'tianchi.csv',\n"," 'embed_size': 32,\n"," 'loss': 'cosine',\n"," 'lr': 0.001,\n"," 'n_epochs': 4,\n"," 'neg_item': 1,\n"," 'seed': 0}\n","\n","n_users: 80000, n_items: 912114, train_shape: (2587518, 10), eval_shape: (646849, 10)\n","100% 2527/2527 [02:22<00:00, 17.73it/s]\n","100% 632/632 [00:11<00:00, 56.31it/s]\n","epoch 1, train_loss: 0.3253, eval loss: 0.3537, eval roc: 0.7370\n","100% 2527/2527 [02:21<00:00, 17.85it/s]\n","100% 632/632 [00:11<00:00, 55.87it/s]\n","epoch 2, train_loss: 0.2568, eval loss: 0.3351, eval roc: 0.7697\n","100% 2527/2527 [02:21<00:00, 17.84it/s]\n","100% 632/632 [00:11<00:00, 56.34it/s]\n","epoch 3, train_loss: 0.2260, eval loss: 0.3309, eval roc: 0.7772\n","100% 2527/2527 [02:20<00:00, 17.96it/s]\n","100% 632/632 [00:11<00:00, 57.03it/s]\n","epoch 4, train_loss: 0.2036, eval loss: 0.3296, eval roc: 0.7829\n","user_embeds shape: (80000, 32), item_embeds shape: (912115, 32)\n","pretrain embeddings done!\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"825CG5pQWous","executionInfo":{"status":"ok","timestamp":1634823180654,"user_tz":-330,"elapsed":10050064,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"a3f0e9d5-0e8a-4bac-a9af-bd6fd1af29df"},"source":["!python run_reinforce.py --n_epochs 1 --lr 1e-5"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["A list all args: \n","======================\n","{'batch_size': 128,\n"," 'data': 'tianchi.csv',\n"," 'gamma': 0.99,\n"," 'hidden_size': 64,\n"," 'hist_num': 10,\n"," 'item_embeds': 'tianchi_item_embeddings.npy',\n"," 'lr': 1e-05,\n"," 'n_epochs': 1,\n"," 'n_rec': 10,\n"," 'seed': 0,\n"," 'sess_mode': 'interval',\n"," 'user_embeds': 'tianchi_user_embeddings.npy',\n"," 'weight_decay': 0.0}\n","\n","Number of parameters: policy: 118628454, beta: 59310067\n","Caution: Will compute loss every 10 step(s)\n","\n","Epoch 1 start-time: 2021-10-21 10:46:33\n","\n","train: 100% 19590/19590 [2:24:05<00:00, 2.27it/s]\n","last_eval: 100% 625/625 [00:47<00:00, 13.12it/s]\n","\n","policy_loss: 665.2355, beta_loss: 13.6349, importance_weight: 0.8856, lambda_k: 9.9993, \n","reward: 455, ndcg_next_item: 0.000999, ndcg_all_item: 0.039790, ndcg: 0.027800\n","\n","******************** EVAL ********************\n","eval: 100% 10516/10516 [20:26<00:00, 8.58it/s]\n","last_eval: 100% 625/625 [00:47<00:00, 13.12it/s]\n","\n","policy_loss: 1333.3558, beta_loss: 13.5444, importance_weight: 0.9655, lambda_k: 9.9992, \n","reward: 290, ndcg_next_item: 0.001165, ndcg_all_item: 0.008780, ndcg: 0.007617\n","================================================================================\n","train and save done!\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"3_d3DNySXFhc","executionInfo":{"status":"ok","timestamp":1634823421135,"user_tz":-330,"elapsed":7393,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"5347b11c-e1ce-44dd-84d4-2bd7005f724f"},"source":["!apt-get -qq install tree\n","!tree --du -h ."],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Selecting previously unselected package tree.\n","(Reading database ... 155047 files and directories currently installed.)\n","Preparing to unpack .../tree_1.7.0-5_amd64.deb ...\n","Unpacking tree (1.7.0-5) ...\n","Setting up tree (1.7.0-5) ...\n","Processing triggers for man-db (2.8.3-2ubuntu0.1) ...\n",".\n","├── [ 56K] data\n","│   ├── [7.7K] dataset.py\n","│   ├── [ 126] __init__.py\n","│   ├── [8.4K] process.py\n","│   ├── [ 24K] __pycache__\n","│   │   ├── [5.5K] dataset.cpython-37.pyc\n","│   │   ├── [ 284] __init__.cpython-37.pyc\n","│   │   ├── [5.6K] process.cpython-37.pyc\n","│   │   ├── [6.3K] session.cpython-37.pyc\n","│   │   └── [2.2K] split.cpython-37.pyc\n","│   ├── [9.6K] session.py\n","│   └── [2.5K] split.py\n","├── [ 17K] evaluate\n","│   ├── [4.1K] evaluate.py\n","│   ├── [ 45] __init__.py\n","│   ├── [1.6K] metrics.py\n","│   └── [7.3K] __pycache__\n","│   ├── [1.9K] evaluate.cpython-37.pyc\n","│   ├── [ 178] __init__.cpython-37.pyc\n","│   └── [1.2K] metrics.cpython-37.pyc\n","├── [ 0] __init__.py\n","├── [ 44K] models\n","│   ├── [7.6K] bcq.py\n","│   ├── [4.2K] ddpg.py\n","│   ├── [3.3K] dssm.py\n","│   ├── [ 103] __init__.py\n","│   ├── [ 19K] __pycache__\n","│   │   ├── [5.1K] bcq.cpython-37.pyc\n","│   │   ├── [3.3K] ddpg.cpython-37.pyc\n","│   │   ├── [2.4K] dssm.cpython-37.pyc\n","│   │   ├── [ 256] __init__.cpython-37.pyc\n","│   │   └── [3.9K] youtube_topk.cpython-37.pyc\n","│   └── [5.7K] youtube_topk.py\n","├── [ 24K] network\n","│   ├── [ 20] __init__.py\n","│   ├── [9.1K] net.py\n","│   └── [ 11K] __pycache__\n","│   ├── [ 134] __init__.cpython-37.pyc\n","│   └── [7.2K] net.cpython-37.pyc\n","├── [4.1K] __pycache__\n","│   └── [ 106] __init__.cpython-37.pyc\n","├── [3.0G] resources\n","│   ├── [ 0] aa\n","│   ├── [114M] item.csv\n","│   ├── [ 15M] item_map.json\n","│   ├── [574M] model_reinforce.pt\n","│   ├── [160M] tianchi.csv\n","│   ├── [111M] tianchi_item_embeddings.npy\n","│   ├── [9.8M] tianchi_user_embeddings.npy\n","│   ├── [2.0G] user_behavior.csv\n","│   ├── [ 19M] user.csv\n","│   └── [1.2M] user_map.json\n","├── [5.9K] run_bcq.py\n","├── [5.1K] run_ddpg.py\n","├── [2.5K] run_prepare_data.py\n","├── [4.4K] run_pretrain_embeddings.py\n","├── [5.1K] run_reinforce.py\n","├── [9.9K] serialization\n","│   ├── [ 43] __init__.py\n","│   ├── [5.0K] __pycache__\n","│   │   ├── [ 182] __init__.cpython-37.pyc\n","│   │   └── [ 845] serialize.cpython-37.pyc\n","│   └── [ 889] serialize.py\n","├── [ 18K] trainer\n","│   ├── [ 68] __init__.py\n","│   ├── [1.9K] pretrain.py\n","│   ├── [8.2K] __pycache__\n","│   │   ├── [ 202] __init__.cpython-37.pyc\n","│   │   ├── [1.5K] pretrain.cpython-37.pyc\n","│   │   └── [2.5K] train.cpython-37.pyc\n","│   └── [3.9K] train.py\n","└── [ 21K] utils\n"," ├── [1.7K] info.py\n"," ├── [ 156] __init__.py\n"," ├── [2.2K] misc.py\n"," ├── [1.7K] params.py\n"," ├── [ 10K] __pycache__\n"," │   ├── [1.5K] info.cpython-37.pyc\n"," │   ├── [ 325] __init__.cpython-37.pyc\n"," │   ├── [2.0K] misc.cpython-37.pyc\n"," │   ├── [1.5K] params.cpython-37.pyc\n"," │   └── [ 920] sampling.cpython-37.pyc\n"," └── [ 908] sampling.py\n","\n"," 3.0G used in 16 directories, 67 files\n"]}]},{"cell_type":"markdown","metadata":{"id":"kxN_Fb0fbVCX"},"source":["## Code analysis"]},{"cell_type":"markdown","metadata":{"id":"pRGngazAbWdJ"},"source":["### Data preparation"]},{"cell_type":"code","metadata":{"id":"zecpI005bbOk"},"source":["import os\n","import sys\n","sys.path.append(os.pardir)\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","import argparse\n","import time\n","import numpy as np\n","import pandas as pd\n","\n","\n","def parse_args():\n"," parser = argparse.ArgumentParser(description=\"run_prepare_data\")\n"," parser.add_argument(\"--seed\", type=int, default=0)\n"," return parser.parse_args(args={})\n","\n","\n","def bucket_age(age):\n"," if age < 30:\n"," return 1\n"," elif age < 40:\n"," return 2\n"," elif age < 50:\n"," return 3\n"," else:\n"," return 4\n","\n","\n","if __name__ == \"__main__\":\n"," args = parse_args()\n"," print(vars(args))\n"," np.random.seed(args.seed)\n"," start_time = time.perf_counter()\n","\n"," # 1. loading the data into memory\n","\n"," user_feat = pd.read_csv(\"resources/user.csv\", header=None,\n"," names=[\"user\", \"sex\", \"age\", \"pur_power\"])\n"," item_feat = pd.read_csv(\"resources/item.csv\", header=None,\n"," names=[\"item\", \"category\", \"shop\", \"brand\"])\n"," behavior = pd.read_csv(\"resources/user_behavior.csv\", header=None,\n"," names=[\"user\", \"item\", \"behavior\", \"time\"])\n"," \n"," # 2. sorting values chronologically and dropping duplicate records\n","\n"," behavior = behavior.sort_values(by=\"time\").reset_index(drop=True)\n"," behavior = behavior.drop_duplicates(subset=[\"user\", \"item\", \"behavior\"])\n","\n"," # 3. Choosing 60K random users with short journey and 20K with long journey\n"," user_counts = behavior.groupby(\"user\")[[\"user\"]].count().rename(\n"," columns={\"user\": \"count_user\"}\n"," ).sort_values(\"count_user\", ascending=False)\n","\n"," short_users = np.array(\n"," user_counts[\n"," (user_counts.count_user > 5) & (user_counts.count_user <= 50)\n"," ].index\n"," )\n"," long_users = np.array(\n"," user_counts[\n"," (user_counts.count_user > 50) & (user_counts.count_user <= 200)\n"," ].index\n"," )\n"," short_chosen_users = np.random.choice(short_users, 60000, replace=False)\n"," long_chosen_users = np.random.choice(long_users, 20000, replace=False)\n"," chosen_users = np.concatenate([short_chosen_users, long_chosen_users])\n","\n"," behavior = behavior[behavior.user.isin(chosen_users)]\n"," print(f\"n_users: {behavior.user.nunique()}, \"\n"," f\"n_items: {behavior.item.nunique()}, \"\n"," f\"behavior length: {len(behavior)}\")\n","\n"," # 4. merge with all features, bucketizing the age and saving the processed data\n"," behavior = behavior.merge(user_feat, on=\"user\")\n"," behavior = behavior.merge(item_feat, on=\"item\")\n"," behavior[\"age\"] = behavior[\"age\"].apply(bucket_age)\n"," behavior = behavior.sort_values(by=\"time\").reset_index(drop=True)\n"," behavior.to_csv(\"resources/tianchi.csv\", header=None, index=False)\n"," print(f\"prepare data done!, \"\n"," f\"time elapsed: {(time.perf_counter() - start_time):.2f}\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n1C4mcuIdRaW"},"source":["### Embeddings"]},{"cell_type":"code","metadata":{"id":"lH2RLv5ddSK-"},"source":["import os\n","import sys\n","sys.path.append(os.pardir)\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","import argparse\n","from pprint import pprint\n","import numpy as np\n","import torch\n","import torch.nn as nn\n","from torch.optim import Adam\n","from torch.utils.data import DataLoader\n","from dbrl.data import process_feat_data, FeatDataset\n","from dbrl.models import DSSM\n","from dbrl.utils import sample_items_random, init_param_dssm, generate_embeddings\n","from dbrl.trainer import pretrain_model\n","from dbrl.serialization import save_npy, save_json\n","\n","\n","def parse_args():\n"," parser = argparse.ArgumentParser(description=\"run_pretrain_embeddings\")\n"," parser.add_argument(\"--data\", type=str, default=\"tianchi.csv\")\n"," parser.add_argument(\"--n_epochs\", type=int, default=10)\n"," parser.add_argument(\"--batch_size\", type=int, default=2048)\n"," parser.add_argument(\"--lr\", type=float, default=5e-4)\n"," parser.add_argument(\"--embed_size\", type=int, default=32)\n"," parser.add_argument(\"--loss\", type=str, default=\"cosine\",\n"," help=\"cosine or bce loss\")\n"," parser.add_argument(\"--neg_item\", type=int, default=1)\n"," parser.add_argument(\"--seed\", type=int, default=0)\n"," return parser.parse_args()\n","\n","\n","if __name__ == \"__main__\":\n"," args = parse_args()\n"," print(\"A list all args: \\n======================\")\n"," pprint(vars(args))\n"," print()\n","\n"," # 1. Setting arguments/params\n","\n"," torch.manual_seed(args.seed)\n"," np.random.seed(args.seed)\n"," PATH = os.path.join(\"resources\", args.data)\n"," EMBEDDING_PATH = \"resources/\"\n"," static_feat = [\"sex\", \"age\", \"pur_power\"]\n"," dynamic_feat = [\"category\", \"shop\", \"brand\"]\n"," device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"," n_epochs = args.n_epochs\n"," batch_size = args.batch_size\n"," lr = args.lr\n"," item_embed_size = args.embed_size\n"," feat_embed_size = args.embed_size\n"," hidden_size = (256, 128)\n"," criterion = (\n"," nn.CosineEmbeddingLoss()\n"," if args.loss == \"cosine\"\n"," else nn.BCEWithLogitsLoss()\n"," )\n"," criterion_type = (\n"," \"cosine\"\n"," if \"cosine\" in criterion.__class__.__name__.lower()\n"," else \"bce\"\n"," )\n"," neg_label = -1. if criterion_type == \"cosine\" else 0.\n"," neg_item = args.neg_item\n","\n"," # 2. Preprocessing\n","\n"," columns = [\"user\", \"item\", \"label\", \"time\", \"sex\", \"age\", \"pur_power\",\n"," \"category\", \"shop\", \"brand\"]\n","\n"," (\n"," n_users,\n"," n_items,\n"," train_user_consumed,\n"," eval_user_consumed,\n"," train_data,\n"," eval_data,\n"," user_map,\n"," item_map,\n"," feat_map\n"," ) = process_feat_data(\n"," PATH, columns, test_size=0.2, time_col=\"time\",\n"," static_feat=static_feat, dynamic_feat=dynamic_feat\n"," )\n"," print(f\"n_users: {n_users}, n_items: {n_items}, \"\n"," f\"train_shape: {train_data.shape}, eval_shape: {eval_data.shape}\")\n"," \n"," # 3. Random negative sampling\n","\n"," train_user, train_item, train_label = sample_items_random(\n"," train_data, n_items, train_user_consumed, neg_label, neg_item\n"," )\n"," eval_user, eval_item, eval_label = sample_items_random(\n"," eval_data, n_items, eval_user_consumed, neg_label, neg_item\n"," )\n","\n"," # 4. Putting data into torch dataset format and dataloader\n","\n"," train_dataset = FeatDataset(\n"," train_user,\n"," train_item,\n"," train_label,\n"," feat_map,\n"," static_feat,\n"," dynamic_feat\n"," )\n"," eval_dataset = FeatDataset(\n"," eval_user,\n"," eval_item,\n"," eval_label,\n"," feat_map,\n"," static_feat,\n"," dynamic_feat\n"," )\n"," train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,\n"," shuffle=True, num_workers=0)\n"," eval_loader = DataLoader(dataset=eval_dataset, batch_size=batch_size,\n"," shuffle=False, num_workers=0)\n","\n"," # 5. DSSM embedding model training\n","\n"," model = DSSM(\n"," item_embed_size,\n"," feat_embed_size,\n"," n_users,\n"," n_items,\n"," hidden_size,\n"," feat_map,\n"," static_feat,\n"," dynamic_feat,\n"," use_bn=True\n"," ).to(device)\n"," init_param_dssm(model)\n"," optimizer = Adam(model.parameters(), lr=lr) # weight_decay\n","\n"," pretrain_model(model, train_loader, eval_loader, n_epochs, criterion,\n"," criterion_type, optimizer, device)\n"," \n"," # 6. Generate and save embeddings\n"," \n"," user_embeddings, item_embeddings = generate_embeddings(\n"," model, n_users, n_items, feat_map, static_feat, dynamic_feat, device\n"," )\n"," print(f\"user_embeds shape: {user_embeddings.shape},\"\n"," f\" item_embeds shape: {item_embeddings.shape}\")\n","\n"," save_npy(user_embeddings, item_embeddings, EMBEDDING_PATH)\n"," save_json(\n"," user_map, item_map, user_embeddings, item_embeddings, EMBEDDING_PATH\n"," )\n"," print(\"pretrain embeddings done!\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"oNVi23z3etT_"},"source":["## REINFORCE model"]},{"cell_type":"code","metadata":{"id":"s7upYlxygH08"},"source":["import torch\n","import torch.nn as nn\n","from torch.distributions import Categorical\n","\n","\n","class Reinforce(nn.Module):\n"," def __init__(\n"," self,\n"," policy,\n"," policy_optim,\n"," beta,\n"," beta_optim,\n"," hidden_size,\n"," gamma=0.99,\n"," k=10,\n"," weight_clip=2.0,\n"," offpolicy_correction=True,\n"," topk=True,\n"," adaptive_softmax=True,\n"," cutoffs=None,\n"," device=torch.device(\"cpu\"),\n"," ):\n"," super(Reinforce, self).__init__()\n"," self.policy = policy\n"," self.policy_optim = policy_optim\n"," self.beta = beta\n"," self.beta_optim = beta_optim\n"," self.beta_criterion = nn.CrossEntropyLoss()\n"," self.gamma = gamma\n"," self.k = k\n"," self.weight_clip = weight_clip\n"," self.offpolicy_correction = offpolicy_correction\n"," self.topk = topk\n"," self.adaptive_softmax = adaptive_softmax\n"," if adaptive_softmax:\n"," assert cutoffs is not None, (\n"," \"must provide cutoffs when using adaptive_softmax\"\n"," )\n"," self.softmax_loss = nn.AdaptiveLogSoftmaxWithLoss(\n"," in_features=hidden_size,\n"," n_classes=policy.item_embeds.weight.size(0),\n"," cutoffs=cutoffs,\n"," div_value=4.\n"," ).to(device)\n"," self.device = device\n","\n"," def update(self, data):\n"," (\n"," policy_loss,\n"," beta_loss,\n"," action,\n"," importance_weight,\n"," lambda_k\n"," ) = self._compute_loss(data)\n","\n"," self.policy_optim.zero_grad()\n"," policy_loss.backward()\n"," self.policy_optim.step()\n","\n"," self.beta_optim.zero_grad()\n"," beta_loss.backward()\n"," self.beta_optim.step()\n","\n"," info = {'policy_loss': policy_loss.cpu().detach().item(),\n"," 'beta_loss': beta_loss.cpu().detach().item(),\n"," 'importance_weight': importance_weight.cpu().mean().item(),\n"," 'lambda_k': lambda_k.cpu().mean().item(),\n"," 'action': action}\n"," return info\n","\n"," def _compute_weight(self, policy_logp, beta_logp):\n"," if self.offpolicy_correction:\n"," importance_weight = torch.exp(policy_logp - beta_logp).detach()\n"," wc = torch.tensor([self.weight_clip]).to(self.device)\n"," importance_weight = torch.min(importance_weight, wc)\n"," # importance_weight = torch.clamp(\n"," # importance_weight, self.weight_clip[0], self.weight_clip[1]\n"," # )\n"," else:\n"," importance_weight = torch.tensor([1.]).float().to(self.device)\n"," return importance_weight\n","\n"," def _compute_lambda_k(self, policy_logp):\n"," lam = (\n"," self.k * ((1. - policy_logp.exp()).pow(self.k - 1)).detach()\n"," if self.topk\n"," else torch.tensor([1.]).float().to(self.device)\n"," )\n"," return lam\n","\n"," def _compute_loss(self, data):\n"," if self.adaptive_softmax:\n"," state, action = self.policy(data)\n"," policy_out = self.softmax_loss(action, data[\"action\"])\n"," policy_logp = policy_out.output\n","\n"," beta_action = self.beta(state.detach())\n"," beta_out = self.softmax_loss(beta_action, data[\"action\"])\n"," beta_logp = beta_out.output\n"," else:\n"," state, all_logp, action = self.policy.get_log_probs(data)\n"," policy_logp = all_logp[:, data[\"action\"]]\n","\n"," b_logp, beta_logits = self.beta.get_log_probs(state.detach())\n"," beta_logp = (b_logp[:, data[\"action\"]]).detach()\n","\n"," importance_weight = self._compute_weight(policy_logp, beta_logp)\n"," lambda_k = self._compute_lambda_k(policy_logp)\n","\n"," policy_loss = -(\n"," importance_weight * lambda_k * data[\"return\"] * policy_logp\n"," ).mean()\n","\n"," if self.adaptive_softmax:\n"," if \"beta_label\" in data:\n"," b_state = self.policy.get_beta_state(data)\n"," b_action = self.beta(b_state.detach())\n"," b_out = self.softmax_loss(b_action, data[\"beta_label\"])\n"," beta_loss = b_out.loss\n"," else:\n"," beta_loss = beta_out.loss\n"," else:\n"," if \"beta_label\" in data:\n"," b_state = self.policy.get_beta_state(data)\n"," _, b_logits = self.beta.get_log_probs(b_state.detach())\n"," beta_loss = self.beta_criterion(b_logits, data[\"beta_label\"])\n"," else:\n"," beta_loss = self.beta_criterion(beta_logits, data[\"action\"])\n"," return policy_loss, beta_loss, action, importance_weight, lambda_k\n","\n"," def compute_loss(self, data):\n"," (\n"," policy_loss,\n"," beta_loss,\n"," action,\n"," importance_weight,\n"," lambda_k\n"," ) = self._compute_loss(data)\n","\n"," info = {'policy_loss': policy_loss.cpu().detach().item(),\n"," 'beta_loss': beta_loss.cpu().detach().item(),\n"," 'importance_weight': importance_weight.cpu().mean().item(),\n"," 'lambda_k': lambda_k.cpu().mean().item(),\n"," 'action': action}\n"," return info\n","\n"," def get_log_probs(self, data=None, action=None):\n"," with torch.no_grad():\n"," if self.adaptive_softmax:\n"," if action is None:\n"," _, action = self.policy.forward(data)\n"," log_probs = self.softmax_loss.log_prob(action)\n"," else:\n"," # _, log_probs = self.policy.get_log_probs(data)\n"," if action is None:\n"," _, action = self.policy.forward(data)\n"," log_probs = self.policy.softmax_fc(action)\n"," return log_probs\n","\n"," def forward(self, state):\n"," policy_logits = self.policy.get_action(state)\n"," policy_dist = Categorical(logits=policy_logits)\n"," _, rec_idxs = torch.topk(policy_dist.probs, 10, dim=1)\n"," return rec_idxs"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"E7fBLQGCf7gE"},"source":["## Trainer"]},{"cell_type":"code","metadata":{"id":"L0AxMTlueutE"},"source":["import os\n","import sys\n","sys.path.append(os.pardir)\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","import argparse\n","from pprint import pprint\n","import numpy as np\n","import torch\n","from torch.optim import Adam\n","from dbrl.data import process_data, build_dataloader\n","from dbrl.models import Reinforce\n","from dbrl.network import PolicyPi, Beta\n","from dbrl.trainer import train_model\n","from dbrl.utils import count_vars, init_param\n","\n","\n","def parse_args():\n"," parser = argparse.ArgumentParser(description=\"run_reinforce\")\n"," parser.add_argument(\"--data\", type=str, default=\"tianchi.csv\")\n"," parser.add_argument(\"--user_embeds\", type=str,\n"," default=\"tianchi_user_embeddings.npy\")\n"," parser.add_argument(\"--item_embeds\", type=str,\n"," default=\"tianchi_item_embeddings.npy\")\n"," parser.add_argument(\"--n_epochs\", type=int, default=100)\n"," parser.add_argument(\"--hist_num\", type=int, default=10,\n"," help=\"num of history items to consider\")\n"," parser.add_argument(\"--n_rec\", type=int, default=10,\n"," help=\"num of items to recommend\")\n"," parser.add_argument(\"--batch_size\", type=int, default=128)\n"," parser.add_argument(\"--hidden_size\", type=int, default=64)\n"," parser.add_argument(\"--lr\", type=float, default=1e-5)\n"," parser.add_argument(\"--weight_decay\", type=float, default=0.)\n"," parser.add_argument(\"--gamma\", type=float, default=0.99)\n"," parser.add_argument(\"--sess_mode\", type=str, default=\"interval\",\n"," help=\"Specify when to end a session\")\n"," parser.add_argument(\"--seed\", type=int, default=0)\n"," return parser.parse_args()\n","\n","\n","if __name__ == \"__main__\":\n"," args = parse_args()\n"," print(\"A list all args: \\n======================\")\n"," pprint(vars(args))\n"," print()\n","\n"," # 1. Loading user and item embeddings\n","\n"," torch.manual_seed(args.seed)\n"," np.random.seed(args.seed)\n"," device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"," PATH = os.path.join(\"resources\", args.data)\n"," with open(os.path.join(\"resources\", args.user_embeds), \"rb\") as f:\n"," user_embeddings = np.load(f)\n"," with open(os.path.join(\"resources\", args.item_embeds), \"rb\") as f:\n"," item_embeddings = np.load(f)\n"," item_embeddings[-1] = 0. # last item is used for padding\n","\n"," # 2. Setting model arguments/params\n","\n"," n_epochs = args.n_epochs\n"," hist_num = args.hist_num\n"," batch_size = eval_batch_size = args.batch_size\n"," embed_size = item_embeddings.shape[1]\n"," hidden_size = args.hidden_size\n"," input_dim = embed_size * (hist_num + 1)\n"," action_dim = len(item_embeddings)\n"," policy_lr = args.lr\n"," beta_lr = args.lr\n"," weight_decay = args.weight_decay\n"," gamma = args.gamma\n"," n_rec = args.n_rec\n"," pad_val = len(item_embeddings) - 1\n"," sess_mode = args.sess_mode\n"," debug = True\n"," one_hour = int(60 * 60)\n"," reward_map = {\"pv\": 1., \"cart\": 2., \"fav\": 2., \"buy\": 3.}\n"," columns = [\"user\", \"item\", \"label\", \"time\", \"sex\", \"age\", \"pur_power\",\n"," \"category\", \"shop\", \"brand\"]\n","\n"," cutoffs = [\n"," len(item_embeddings) // 20,\n"," len(item_embeddings) // 10,\n"," len(item_embeddings) // 3\n"," ]\n","\n"," # 3. Building the data loader\n","\n"," (\n"," n_users,\n"," n_items,\n"," train_user_consumed,\n"," test_user_consumed,\n"," train_sess_end,\n"," test_sess_end,\n"," train_rewards,\n"," test_rewards\n"," ) = process_data(PATH, columns, 0.2, time_col=\"time\", sess_mode=sess_mode,\n"," interval=one_hour, reward_shape=reward_map)\n","\n"," train_loader, eval_loader = build_dataloader(\n"," n_users,\n"," n_items,\n"," hist_num,\n"," train_user_consumed,\n"," test_user_consumed,\n"," batch_size,\n"," sess_mode=sess_mode,\n"," train_sess_end=train_sess_end,\n"," test_sess_end=test_sess_end,\n"," n_workers=0,\n"," compute_return=True,\n"," neg_sample=False,\n"," train_rewards=train_rewards,\n"," test_rewards=test_rewards,\n"," reward_shape=reward_map\n"," )\n","\n"," # 4. Building the model\n","\n"," policy = PolicyPi(\n"," input_dim, action_dim, hidden_size, user_embeddings,\n"," item_embeddings, None, pad_val, 1, device\n"," ).to(device)\n"," beta = Beta(input_dim, action_dim, hidden_size).to(device)\n"," init_param(policy, beta)\n","\n"," policy_optim = Adam(policy.parameters(), policy_lr, weight_decay=weight_decay)\n"," beta_optim = Adam(beta.parameters(), beta_lr, weight_decay=weight_decay)\n","\n"," model = Reinforce(\n"," policy,\n"," policy_optim,\n"," beta,\n"," beta_optim,\n"," hidden_size,\n"," gamma,\n"," k=10,\n"," weight_clip=2.0,\n"," offpolicy_correction=True,\n"," topk=True,\n"," adaptive_softmax=False,\n"," cutoffs=cutoffs,\n"," device=device,\n"," )\n","\n"," var_counts = tuple(count_vars(module) for module in [policy, beta])\n"," print(f'Number of parameters: policy: {var_counts[0]}, '\n"," f' beta: {var_counts[1]}')\n"," \n"," # 5. Training the model\n","\n"," train_model(\n"," model,\n"," n_epochs,\n"," n_rec,\n"," n_users,\n"," train_user_consumed,\n"," test_user_consumed,\n"," hist_num,\n"," train_loader,\n"," eval_loader,\n"," item_embeddings,\n"," eval_batch_size,\n"," pad_val,\n"," device,\n"," debug=debug,\n"," eval_interval=10\n"," )\n","\n"," torch.save(policy.state_dict(), \"resources/model_reinforce.pt\")\n"," print(\"train and save done!\")"],"execution_count":null,"outputs":[]}]}