{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-20-nms.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T063788%20%7C%20Neural%20Memory%20Streaming%20Recommender%20Networks%20with%20Adversarial%20Training.ipynb","timestamp":1644654511689}],"collapsed_sections":[],"authorship_tag":"ABX9TyOGWnAY0gSidRbw9dsOIl+d"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"M2GX2zB_Pl9h"},"source":["# Neural Memory Streaming Recommender Networks with Adversarial Training"]},{"cell_type":"markdown","metadata":{"id":"TuygM6y1WO66"},"source":["NMRN is a streaming recommender model based on neural memory networks with external memories to capture and store both long-term stable interests and short-term dynamic interests in a unified way. An adaptive negative sampling framework based on Generative Adversarial Nets (GAN) is developed to optimize our proposed streaming recommender model, which effectively overcomes the limitations of classical negative sampling approaches and improves the effectiveness of the model parameter inference."]},{"cell_type":"markdown","metadata":{"id":"HPDRZQSvZzNs"},"source":["![](https://github.com/recohut/incremental-learning/raw/a6fdcde2e8af7ebfd9f5efd278c487e0e9560cb3/docs/_images/T063788_1.png)"]},{"cell_type":"markdown","metadata":{"id":"kw5elHdCWrGo"},"source":["Specifically, the external memories of NMRN compose the key memory and the value memory. Given a new user-item interaction pair(u , v) arriving at the system in real time, NMRN first generates a soft address from the key memory, activated by u. The addressing process is inspired by recent advances in attention mechanism. The attention mechanism applied in recommender systems is useful to improve the retrieval accuracy and model interpretability. The fundamental idea of this attention design is to learn a weighted representation across the key memory, which is converted into a probability distribution by the Softmax function as the soft address.Then, NMRN reads from the value memory based on the soft address, resulting in a vector that represents both long-term stable and short-term emerging interests of user u. Inspired by the success of pairwise personalized ranking models (e.g., BPR) in top-k recommendations, we adopt the Hinge-loss in our model optimization. As the number of unobserved examples is very huge, we use the negative sampling method to improve the training efficiency.\n","\n","Most existing negative sampling approaches use either random sampling or popularity-biased sampling strategies. However, the majority of negative examples generated in these sampling strategies can be easily discriminated from observed examples, and will contribute little towards the training, because sampled items could be completely unrelated to the target user. Besides, these sampling approaches are not adaptive enough to generate adversarial negative examples, because (1) they are static and thus do not consider that the estimated similarity or proximity between a user and an item changes during the learning process. For example, the similarity between user u and a sampled noise item v is high at the beginning, but after several gradient descent steps it becomes low; and (2) these samplers are global and do not reflect how informative a noise item is w.r.t. a specific user.\n","\n","In light of this, we use an adaptive noise sampler based on a Generative Adversarial Network (GAN) to optimize the model, which considers both the specific user and the current values of the model parameters to adaptively generate “difficult” and informative negative examples. Moreover, in order to simultaneously capture the first-order similarity between users and items as well as the second-order similarities between users and between items to learn robust representations of users and items, we use the Euclidean distance to measure the similarity between a user and an item instead of the widely adopted dot product."]},{"cell_type":"markdown","metadata":{"id":"u6H8sp5JZ2YT"},"source":["![](https://github.com/recohut/incremental-learning/raw/a6fdcde2e8af7ebfd9f5efd278c487e0e9560cb3/docs/_images/T063788_2.png)"]},{"cell_type":"markdown","metadata":{"id":"uRxvbQ0NWunq"},"source":["\n","\n","The Architecture of the model. The total numbers of users and items are denoted as $N$ and $M$. We denote a user-item interaction pair at time $t$ as $(u_t , v_t)$ while $v_t^-$ is a sampled negative item for a specific user at time $t$."]},{"cell_type":"code","metadata":{"id":"xf6MOr-0PnDH"},"source":["import numpy as np \n","import pandas as pd \n","\n","import torch \n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torch.optim import Adam\n","from torch.nn.parameter import Parameter\n","from torch.utils.data.dataset import Dataset\n","from torch.utils.data.dataloader import DataLoader\n","\n","from math import log"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"qgxBOo0ZVYWO"},"source":["#some parameters of this model\n","N, D_in, H = 64, 20, 10\n","safety_margin_size = 3\n","Max_sampleNum = 50\n","total_item_num = 500"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"sjSGYoItVata"},"source":["def euclidean_distance(a, b):\n"," return (a-b).pow(2).sum(1).sqrt()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"gAFIiaF1VhM8"},"source":["value_memory = torch.randn(H, D_in, requires_grad=True)\n","key_memory = torch.randn(H, D_in, requires_grad=True)\n","\n","x = torch.randn(500, D_in)\n","y = torch.randn(500, D_in)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"8y1HvCc1Vi21"},"source":["class CandidateDataset(Dataset):\n"," def __init__(self, x, y):\n"," self.len = x.shape[0]\n"," self.x_data = torch.as_tensor(x,dtype=torch.float)\n"," self.y_data = torch.as_tensor(y,dtype=torch.float)\n","\n"," def __getitem__(self, index):\n"," return self.x_data[index], self.y_data[index]\n","\n"," def __len__(self):\n"," return self.len"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"DLJ51QrTVlin"},"source":["my_dataset = CandidateDataset(x, y)\n","train_loader = DataLoader(dataset = my_dataset, batch_size = 32, shuffle=True)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"lg7CVMoJVmwz"},"source":["class generator(nn.Module):\n"," def __init__(self, D_in):\n"," super(generator, self).__init__()\n","\n"," self.fc1 = nn.Linear(2*D_in, 32)\n"," self.fc2 = nn.Linear(32, 16)\n"," self.fc3 = nn.Linear(16, 1) # Prob of Left\n","\n"," def forward(self, x, y):\n"," cat_xy = torch.cat([x,y],1)\n"," hidden1 = F.relu(self.fc1(cat_xy))\n"," hidden2 = F.relu(self.fc2(hidden1))\n"," output = torch.sigmoid(self.fc3(hidden2))\n"," return output"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"_1OQDusHVn9E"},"source":["class Discriminator(torch.nn.Module):\n"," def __init__(self, D_in, H, N, key_memory, value_memory):\n"," super(Discriminator, self).__init__()\n"," self.key_memory = Parameter(torch.Tensor(H, D_in))\n"," self.value_memory = Parameter(torch.Tensor(H, D_in))\n"," self.key_memory.data = key_memory\n"," self.value_memory.data = value_memory\n"," #self.value_memory.requires_grad=True\n"," self.fc_erase = nn.Linear(D_in, D_in)\n"," self.fc_update = nn.Linear(D_in, D_in)\n"," self.D_in = D_in\n"," self.N = N\n"," self.H = H\n"," def forward(self, user, item):\n"," output_list = []\n"," e_I = torch.ones(self.H, self.D_in)\n"," \n"," for i in range(self.N):\n"," #Memory Adderssing\n"," Attention_weight = torch.empty(self.H)\n","\n"," for j in range(self.H):\n"," Attention_weight[j] = -euclidean_distance(user[i].unsqueeze(0), self.key_memory[j,:])\n","\n","\n"," #select value memory by attention\n","\n"," Attention_weight = Attention_weight.softmax(0)\n"," s = Attention_weight.matmul(self.value_memory)\n","\n"," output = euclidean_distance(s, item[i].unsqueeze(0))\n"," output_list.append(output)\n"," \n"," #update value memory by item vector\n"," e_t = self.fc_erase(item[i].unsqueeze(0)).sigmoid()\n"," self.value_memory.data = self.value_memory * (e_I - Attention_weight.unsqueeze(0).t().matmul(e_t))\n","\n"," a_t = self.fc_update(item[i].unsqueeze(0)).tanh()\n"," self.value_memory.data = self.value_memory + Attention_weight.unsqueeze(0).t().matmul(a_t)\n"," \n"," return output_list"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"lldToo4eVpv9"},"source":["def ngtv_spl_by_unifm_dist(user):\n"," ngtv_item = torch.randn(1, D_in)\n"," return ngtv_item\n","\n","def ngtv_spl_by_generator(user):\n"," ngtv_item_id = list(torch.multinomial(prob_ngtv_item, 1, replacement=True))[0]\n"," ngtv_item = dict_item_id2vec[candidate_ngtv_item[ngtv_item_id].item()]\n"," return ngtv_item_id, ngtv_item\n","\n","def G_pretrain_per_datapoint(user, generator, model):\n"," ngtv_item_id, ngtv_item = generator(user)\n"," neg_rslt = model(user.unsqueeze(0), ngtv_item)[0]\n"," return -neg_rslt, ngtv_item_id\n","\n","def D_pretrain_per_datapoint(user, item, generator):\n"," ps_rslt = model(user.unsqueeze(0),item.unsqueeze(0))[0]\n"," for i in range(Max_sampleNum):\n"," weight = log((total_item_num-1)//(i+1) + 1)\n"," neg_rslt = model(user.unsqueeze(0), generator(user))[0]\n"," loss = weighted_hinge_loss(ps_rslt, neg_rslt, safety_margin_size, weight)\n"," if(loss > 0):\n"," break;\n"," return loss\n"," \n","def weighted_hinge_loss(ps_rslt, neg_rslt, safety_margin_size, weight):\n"," return weight * F.relu(safety_margin_size + ps_rslt - neg_rslt)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"1uksFi-TVvaw","executionInfo":{"status":"ok","timestamp":1635230723030,"user_tz":-330,"elapsed":95726,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"d8566b76-ab5c-4012-c812-a83630cfc64b"},"source":["# pre-training of discriminator\n","model = Discriminator(20, 10, 1, key_memory, value_memory)\n","total_loss = []\n","optimizer = Adam(model.parameters(), lr= 0.01)\n","\n","for epoch in range(50):\n"," epoch_loss = 0.\n"," for batch_id, (x, y) in enumerate(train_loader):\n"," losses = torch.empty(x.shape[0])\n"," for i in range(x.shape[0]):\n"," losses[i] = D_pretrain_per_datapoint(x[i],y[i], ngtv_spl_by_unifm_dist)\n","\n"," mini_batch_loss = torch.sum(losses)\n"," epoch_loss += float(mini_batch_loss.data)\n"," \n"," model.zero_grad()\n"," mini_batch_loss.backward()\n"," optimizer.step()\n"," \n"," print('epoch_' + str(epoch) + ': loss=' + str(epoch_loss))\n"," total_loss.append(epoch_loss) "],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["epoch_0: loss=9162.51040649414\n","epoch_1: loss=9512.62255859375\n","epoch_2: loss=9214.800323486328\n","epoch_3: loss=9504.176055908203\n","epoch_4: loss=9284.231536865234\n","epoch_5: loss=9422.251586914062\n","epoch_6: loss=9421.282775878906\n","epoch_7: loss=9396.276489257812\n","epoch_8: loss=9390.662292480469\n","epoch_9: loss=9425.124694824219\n","epoch_10: loss=9205.153289794922\n","epoch_11: loss=9452.185852050781\n","epoch_12: loss=9485.987365722656\n","epoch_13: loss=9410.585327148438\n","epoch_14: loss=9127.744140625\n","epoch_15: loss=9379.030639648438\n","epoch_16: loss=9347.898803710938\n","epoch_17: loss=9386.432373046875\n","epoch_18: loss=9456.192657470703\n","epoch_19: loss=9354.690643310547\n","epoch_20: loss=9413.049743652344\n","epoch_21: loss=9403.752807617188\n","epoch_22: loss=9283.259490966797\n","epoch_23: loss=9420.226654052734\n","epoch_24: loss=9239.539581298828\n","epoch_25: loss=9273.154418945312\n","epoch_26: loss=9407.462158203125\n","epoch_27: loss=9153.6669921875\n","epoch_28: loss=9260.94091796875\n","epoch_29: loss=9324.113006591797\n","epoch_30: loss=9295.09634399414\n","epoch_31: loss=9115.524719238281\n","epoch_32: loss=9340.50634765625\n","epoch_33: loss=9337.981170654297\n","epoch_34: loss=9454.823944091797\n","epoch_35: loss=9180.193145751953\n","epoch_36: loss=9336.213897705078\n","epoch_37: loss=9159.22543334961\n","epoch_38: loss=9305.651489257812\n","epoch_39: loss=9302.890197753906\n","epoch_40: loss=9213.26058959961\n","epoch_41: loss=9284.324493408203\n","epoch_42: loss=9388.560913085938\n","epoch_43: loss=9230.17855834961\n","epoch_44: loss=9386.049102783203\n","epoch_45: loss=9325.175659179688\n","epoch_46: loss=9343.497131347656\n","epoch_47: loss=9220.862823486328\n","epoch_48: loss=9395.390197753906\n","epoch_49: loss=9170.428771972656\n"]}]},{"cell_type":"code","metadata":{"id":"_LKhJLxWVzKl"},"source":["dict_user_id2vec = {}\n","dict_item_id2vec = {}\n","candidate_ngtv_item = torch.randn(100)\n","prob_ngtv_item = torch.randn(100)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"aqOWX1MIWBlK"},"source":["for i in candidate_ngtv_item:\n"," dict_item_id2vec[i.item()] = torch.randn(1,20)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"LHalxg4hWcGb"},"source":["## Training procedure"]},{"cell_type":"markdown","metadata":{"id":"uZhUQDQzZ457"},"source":["![](https://github.com/recohut/incremental-learning/raw/a6fdcde2e8af7ebfd9f5efd278c487e0e9560cb3/docs/_images/T063788_2.png)"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Q20Gq0PqV_Jc","executionInfo":{"status":"ok","timestamp":1635231474888,"user_tz":-330,"elapsed":750287,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"75abddd1-c09d-45be-be92-51399bd8f7ec"},"source":["#pre-training of generator\n","gener = generator(20)\n","total_reward = []\n","total_loss = []\n","optimizer_G = Adam(gener.parameters(), lr= 0.01)\n","\n","for epoch in range(50):\n"," epoch_reward = 0.\n"," epoch_loss = 0.\n"," for batch_id, (x, y) in enumerate(train_loader):\n"," reward = torch.empty(x.shape[0])\n"," mini_batch_loss = torch.empty(x.shape[0])\n"," \n"," for data_point in range(x.shape[0]):\n"," #state = x[i]\n"," #action_pool = []\n"," #reward_pool = []\n"," candidate_ngtv_item = torch.randn(100)\n"," prob_ngtv_item = torch.randn(100)\n"," for i in candidate_ngtv_item:\n"," dict_item_id2vec[i.item()] = torch.randn(1,20)\n"," \n"," for i in range(candidate_ngtv_item.shape[0]):\n"," prob_ngtv_item[i] = gener(x[data_point].unsqueeze(0), dict_item_id2vec[candidate_ngtv_item[i].item()])\n"," prob_ngtv_item = torch.softmax(prob_ngtv_item,0)\n"," reward[data_point], ngtv_item_id = G_pretrain_per_datapoint(x[data_point], ngtv_spl_by_generator, model)\n"," mini_batch_loss[data_point] = (-reward[data_point]) * torch.log(prob_ngtv_item[ngtv_item_id])\n"," rewards = torch.sum(reward)\n"," mini_batch_losses = torch.sum(mini_batch_loss)\n"," epoch_loss = epoch + float(mini_batch_losses.data)\n"," epoch_reward = epoch + float(rewards.data)\n"," \n"," gener.zero_grad()\n"," mini_batch_losses.backward()\n"," optimizer_G.step()\n"," \n"," print('epoch_' + str(epoch) + ': loss=' + str(epoch_loss) + ' reward= ' + str(epoch_reward))\n"," total_loss.append(epoch_loss) "],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["epoch_0: loss=-434.15869140625 reward= -94.21874237060547\n","epoch_1: loss=-424.1404113769531 reward= -91.69937896728516\n","epoch_2: loss=-436.8798522949219 reward= -93.18343353271484\n","epoch_3: loss=-454.4273681640625 reward= -96.57652282714844\n","epoch_4: loss=-436.4529113769531 reward= -91.76925659179688\n","epoch_5: loss=-421.24530029296875 reward= -87.33723449707031\n","epoch_6: loss=-413.3312683105469 reward= -84.89115142822266\n","epoch_7: loss=-394.2191162109375 reward= -79.93122863769531\n","epoch_8: loss=-415.0520324707031 reward= -83.29130554199219\n","epoch_9: loss=-380.2655029296875 reward= -75.64791870117188\n","epoch_10: loss=-404.0414733886719 reward= -78.71047973632812\n","epoch_11: loss=-370.0225830078125 reward= -74.467041015625\n","epoch_12: loss=-403.89019775390625 reward= -78.97952270507812\n","epoch_13: loss=-400.32525634765625 reward= -77.73628234863281\n","epoch_14: loss=-380.1445617675781 reward= -73.97552490234375\n","epoch_15: loss=-391.7283020019531 reward= -74.8330078125\n","epoch_16: loss=-409.9129333496094 reward= -76.5369644165039\n","epoch_17: loss=-398.1524658203125 reward= -73.23690032958984\n","epoch_18: loss=-412.8946533203125 reward= -75.6321029663086\n","epoch_19: loss=-392.4568176269531 reward= -70.35201263427734\n","epoch_20: loss=-402.51251220703125 reward= -71.76426696777344\n","epoch_21: loss=-393.3529968261719 reward= -68.99085235595703\n","epoch_22: loss=-424.10772705078125 reward= -74.88196563720703\n","epoch_23: loss=-371.7626647949219 reward= -62.756309509277344\n","epoch_24: loss=-406.7032470703125 reward= -69.5239028930664\n","epoch_25: loss=-413.46826171875 reward= -70.21016693115234\n","epoch_26: loss=-381.4178771972656 reward= -62.48395538330078\n","epoch_27: loss=-393.3104553222656 reward= -64.2697982788086\n","epoch_28: loss=-411.4168395996094 reward= -67.4195556640625\n","epoch_29: loss=-402.0660705566406 reward= -64.6007080078125\n","epoch_30: loss=-411.00799560546875 reward= -65.77066802978516\n","epoch_31: loss=-404.57196044921875 reward= -63.586524963378906\n","epoch_32: loss=-396.587890625 reward= -61.06828308105469\n","epoch_33: loss=-375.3912658691406 reward= -55.68218231201172\n","epoch_34: loss=-379.1978759765625 reward= -55.68901824951172\n","epoch_35: loss=-385.91357421875 reward= -56.388343811035156\n","epoch_36: loss=-393.93780517578125 reward= -57.38147735595703\n","epoch_37: loss=-416.4791259765625 reward= -61.48567199707031\n","epoch_38: loss=-370.73846435546875 reward= -50.73662567138672\n","epoch_39: loss=-400.337890625 reward= -56.4139404296875\n","epoch_40: loss=-388.9174499511719 reward= -53.15800476074219\n","epoch_41: loss=-359.5585632324219 reward= -45.974037170410156\n","epoch_42: loss=-363.9505615234375 reward= -46.14964294433594\n","epoch_43: loss=-378.8984375 reward= -48.61058807373047\n","epoch_44: loss=-359.16448974609375 reward= -43.544586181640625\n","epoch_45: loss=-372.9488830566406 reward= -45.768882751464844\n","epoch_46: loss=-374.2965087890625 reward= -45.270103454589844\n","epoch_47: loss=-384.1467590332031 reward= -46.62992858886719\n","epoch_48: loss=-373.5057067871094 reward= -43.53697204589844\n","epoch_49: loss=-368.18255615234375 reward= -41.583251953125\n"]}]},{"cell_type":"markdown","metadata":{"id":"mjV33gv8WVJC"},"source":["## Links\n","\n","[https://github.com/lzzscl/KBGAN-KV_MemNN/tree/master](https://github.com/lzzscl/KBGAN-KV_MemNN/tree/master/)\n","\n","[Neural Memory Streaming Recommender Networks with Adversarial Training](https://dl.acm.org/doi/epdf/10.1145/3219819.3220004)"]}]}