{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-24-agnn.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T290734%20%7C%20AGGN%20Cold-start%20Recommendation%20on%20ML-100k.ipynb","timestamp":1644665912540}],"collapsed_sections":[],"authorship_tag":"ABX9TyNdlExN8FUpf3MHz/hjVhXz"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"gcKmx_vMzTTs"},"source":["# AGGN Cold-start Recommendation on ML-100k"]},{"cell_type":"markdown","metadata":{"id":"aqBXKpEN1VZh"},"source":["![](https://github.com/recohut/coldstart-recsys/raw/da72950ca514faee94f010a2cb6e99a373044ec1/docs/_images/T290734_1.png)"]},{"cell_type":"markdown","metadata":{"id":"DIWL_ukaz3-A"},"source":["## Introduction"]},{"cell_type":"markdown","metadata":{"id":"k6EyiwszzpyN"},"source":["### Input Layer\n","\n","We first design an input layer to build the user (item) attribute graph $\\mathcal{A}_u$ ($\\mathcal{A}_i$). We calculate two kinds of proximity scores between the nodes - preference proximity and attribute proximity (can be calculated with cosine similarity).\n","\n","- The preference proximity measures the historical preference similarity between two nodes. If two users have similar rating record list (or two items have similar rated record list), they will have a high preference proximity. Note we cannot calculate preference proximity for the cold start nodes as they do not have the historical ratings.\n","- The attribute proximity measures the similarity between the attributes of two nodes. If two users have similar user profiles, e.g., gender, occupation (or two items have similar properties, e.g., category), they will have a high attribute proximity.\n","\n","After calculating the overall proximity between two nodes, it becomes a natural choice to build a k-NN graph as adopted in (Monti, Bronstein, and Bresson 2017). Such a method will keep a fixed number of neighbors once the graph is constructed.\n","\n","### Attribute Interaction Layer\n","\n","In the constructed attribute graph $\\mathcal{A}_u$ and $\\mathcal{A}_i$, each nodes has an attached multi-hot attribute encoding and a unique one-hot representation denoting its identity. Due to the huge number of users and items in the web-scale recommender systems, the dimensionality of nodes’ one-hot representation is extremely high. Moreover, the multi-hot attribute representation simply combines multiple types of attributes into one long vector without considering their interactive relations. The goal of interaction layer is to reduce the dimensionality for one-hot identity representation and learn the high-order attribute interactions for multi-hot attribute representation. To this end, we first set up a lookup table to transform a node’s one-hot representation into the low-dimensional dense vector. The lookup layers correspond to two parameter matrices $M \\in \\mathbb{R}^{M×D}$ and $N \\in \\mathbb{R}^{ N×D}$. Each entry $m_u \\in \\mathbb{R}^D$ and $n_i \\in \\mathbb{R}^D$ encodes the user $u$’s preference and the item $i$’s property, respectively. Note that $m_u$ and $n_i$ for cold start nodes are meaningless, since no interaction is observed to train their preference embedding. Inspired by (He and Chua 2017), we capture the high-order attribute interactions with a ***Bi-Interactive pooling operation***, in addition to the linear combination operation.\n","\n","### Gated GNN Layer\n","\n","Intuitively, different neighbors have different relations to a node. Furthermore, one neighbor usually has multiple attributes. For example, in a social network, a user’s neighborhood may consist of classmates, family members, colleagues, and so on, and each neighbor may have several attributes such as age, gender, and occupation. Since all these attributes (along with the preferences) are now encoded in the node’s embedding, it is necessary to pay different attentions to different dimensions of the neighbor node’s embedding. However, existing GCN (Kipf and Welling 2017) or GAT (Veliˇckovi´c et al. 2018) structures cannot do this because they are at the coarse granularity. GCN treats all neighbors equally and GAT differentiates the importance of neighbors at the node level. To solve this problem, we design a gated-GNN structure to aggregate the fine-grained neighbor information.\n","\n","### Prediction Layer\n","\n","Given a user $u$’s final representation $\\tilde{p}_u$ and an item $i$’s final representation $\\tilde{q}_i$ after the gated-GNN layer, we model the predicted rating of the user $u$ to the item $i$ as:\n","\n","$$\\hat{R}_{u,i} = MLP([\\tilde{p}_u; \\tilde{q}_i]) + \\tilde{p}_u\\tilde{q}_i^T + b_u + b_i + \\mu,$$\n","\n","where the MLP function is the multilayer perceptron implemented with one hidden layer, and $b_u$, $b_i$ , and $\\mu$ denotes user bias, item bias, and global bias, respectively. The second term is inner product interaction function (Koren, Bell, and Volinsky 2009), and we add the first term to capture the complicated nonlinear interaction between the user and the item."]},{"cell_type":"markdown","metadata":{"id":"lq1voif5zuYJ"},"source":["### Solution to Cold Start Problem\n","\n","The cold start problem is caused by the lack of historical interactions for cold start nodes. We view this as a missing preference problem, and solve it by employing the variational autoencoder structure to reconstruct the preference from the attribute distribution."]},{"cell_type":"markdown","metadata":{"id":"XmTZBwpH1YRW"},"source":["![](https://github.com/recohut/coldstart-recsys/raw/da72950ca514faee94f010a2cb6e99a373044ec1/docs/_images/T290734_2.png)"]},{"cell_type":"markdown","metadata":{"id":"HYSIQmIozybK"},"source":["### Loss Functions\n","\n","For the rating prediction loss, we employ the square loss as the objective function:\n","\n","$$L_{pred} = \\sum_{u,i \\in \\mathcal{T}} (\\hat{R}_{u,i} - R_{u,i})^2,$$\n","\n","where $\\mathcal{T}$ denotes the set of instances for training, i.e., $\\mathcal{T} = {(u, i, r_{u,i}, a_u, a_i)}$, $R_{u,i}$ is ground truth rating in the training set $\\mathcal{T}$ , and $\\hat{R}_{u,i}$ is the predicted rating.\n","\n","The reconstruction loss function in eVAE is defined as follows:\n","\n","$$L_{recon} = − KL(q_\\phi(z_u|x_u)||p(z_u)) + \\mathbb{E}_{q_\\phi}(z_u|x_u)[\\log p_θ(x'_u|z_u)] + ||x'_u − m_u||_2,$$\n","\n","where the first two terms are same as those in standard VAE, and the last one is our extension for the approximation part.\n","\n","The overall loss function then becomes:\n","\n","$$L = L_{pred} + L_{recon},$$\n","\n","where $L_{pred}$ is the task-specific rating prediction loss, and $L_{recon}$ is the reconstruction loss."]},{"cell_type":"markdown","metadata":{"id":"7NIAMKRYuhPq"},"source":["## Setup"]},{"cell_type":"markdown","metadata":{"id":"eNDV6EZpvmOe"},"source":["### Imports"]},{"cell_type":"code","metadata":{"id":"iwoXoXHwvnrH"},"source":["import torch\n","import torch.nn as nn\n","import torch.nn.init as init\n","import torch.nn.functional as F\n","import torch.optim as optim\n","from torch.autograd import Variable\n","from torch.utils.tensorboard import SummaryWriter\n","\n","import numpy as np\n","import pandas as pd\n","from torch.utils.data import Dataset, DataLoader\n","\n","import os\n","import json\n","import time\n","import pickle\n","import argparse\n","from collections import OrderedDict"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"09-uWbsAyRKD"},"source":["import warnings\n","warnings.filterwarnings('ignore')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"zLMkK0q7wvmG"},"source":["### Params"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_aDgk0kGwwjl","executionInfo":{"status":"ok","timestamp":1635858986458,"user_tz":-330,"elapsed":18,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"099ab43a-2cb2-43a0-bcb5-60ec79e4c9f7"},"source":["parser = argparse.ArgumentParser()\n","parser.add_argument(\"--lr\", default=0.0005, type=float,\n","\t\t\t\t\thelp=\"learning rate.\")\n","parser.add_argument(\"--dropout\", default=0.5, type=float,\n","\t\t\t\t\thelp=\"dropout rate.\")\n","parser.add_argument(\"--batch_size\", default=128, type=int,\n","\t\t\t\t\thelp=\"batch size when training.\")\n","parser.add_argument(\"--gpu\", default=\"0\", type=str,\n","\t\t\t\t\thelp=\"gpu card ID.\")\n","parser.add_argument(\"--epochs\", default=20, type=str,\n","\t\t\t\t\thelp=\"training epoches.\")\n","parser.add_argument(\"--clip_norm\", default=5.0, type=float,\n","\t\t\t\t\thelp=\"clip norm for preventing gradient exploding.\")\n","parser.add_argument(\"--embed_size\", default=30, type=int, help=\"embedding size for users and items.\")\n","parser.add_argument(\"--attention_size\", default=50, type=int, help=\"embedding size for users and items.\")\n","parser.add_argument(\"--item_layer1_nei_num\", default=10, type=int)\n","parser.add_argument(\"--user_layer1_nei_num\", default=10, type=int)\n","parser.add_argument(\"--vae_lambda\", default=1, type=int)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["_StoreAction(option_strings=['--vae_lambda'], dest='vae_lambda', nargs=None, const=None, default=1, type=, choices=None, help=None, metavar=None)"]},"metadata":{},"execution_count":3}]},{"cell_type":"markdown","metadata":{"id":"MSfLDlCQw9VD"},"source":["## Dataset"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"utY5GhUavnjt","executionInfo":{"status":"ok","timestamp":1635858988645,"user_tz":-330,"elapsed":2191,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"2459cbd5-5d6d-48a3-aacd-3c1eb30ed840"},"source":["!wget -q --show-progress https://github.com/sparsh-ai/coldstart-recsys/raw/main/data/AGNN/ml100k.zip\n","!unzip ml100k.zip"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["\rml100k.zip 0%[ ] 0 --.-KB/s \rml100k.zip 100%[===================>] 15.31M 96.6MB/s in 0.2s \n","Archive: ml100k.zip\n"," creating: ml100k/\n"," inflating: ml100k/neighbor_aspect_extension_2_zscore_warm_uuii.pkl \n"," inflating: ml100k/uiinfo.pkl \n"," inflating: ml100k/ics_train.dat \n"," inflating: ml100k/warm_val.dat \n"," inflating: ml100k/neighbor_aspect_extension_2_zscore_ics_uuii_0.20.pkl \n"," inflating: ml100k/ucs_train.dat \n"," inflating: ml100k/ucs_val.dat \n"," inflating: ml100k/neighbor_aspect_extension_2_zscore_ucs_uuii.pkl \n"," inflating: ml100k/ics_val.dat \n"," inflating: ml100k/warm_train.dat \n"," creating: ml100k/source_data/\n"," inflating: ml100k/source_data/item_content.dat \n"," extracting: ml100k/source_data/ml-100k.zip \n"," inflating: ml100k/source_data/item_url_all.dat \n"," extracting: ml100k/source_data/__ \n"]}]},{"cell_type":"code","metadata":{"id":"-f34Bwvow54T"},"source":["def get_data_list(ftrain, batch_size):\n"," f = open(ftrain, 'r')\n"," train_list = []\n"," for eachline in f:\n"," eachline = eachline.strip().split('\\t')\n"," u, i, l = int(eachline[0]), int(eachline[1]), float(eachline[2])\n"," train_list.append([u, i, l])\n"," num_batches_per_epoch = int((len(train_list) - 1) / batch_size) + 1\n"," return num_batches_per_epoch, train_list"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"rRWBAr5IxClg"},"source":["def get_batch_instances(train_list, user_feature_dict, item_feature_dict, item_director_dict, item_writer_dict, item_star_dict, item_country_dict, batch_size, user_nei_dict, item_nei_dict, shuffle=True):\n"," num_batches_per_epoch = int((len(train_list) - 1) / batch_size) + 1\n"," def data_generator(train_list):\n"," data_size = len(train_list)\n"," user_feature_arr = np.array(list(user_feature_dict.values()))\n"," max_user_cate_size = user_feature_arr.shape[1]\n","\n"," item_genre_arr = np.array(list(item_feature_dict.values())) #len=6 ,0\n"," item_director_arr = np.array(list(item_director_dict.values())) #len=3 ,6\n"," item_writer_arr = np.array(list(item_writer_dict.values())) #len=3, 9\n"," item_star_arr = np.array(list(item_star_dict.values())) #len=3, 12\n"," item_country_arr = np.array(list(item_country_dict.values())) #len=8, 15\n","\n"," item_feature_arr = np.concatenate([item_genre_arr, item_director_arr, item_writer_arr, item_star_arr, item_country_arr], axis=1)\n"," max_item_cate_size = item_feature_arr.shape[1]\n","\n"," item_layer1_nei_num = FLAGS.item_layer1_nei_num\n"," user_layer1_nei_num = FLAGS.user_layer1_nei_num\n","\n"," if shuffle == True:\n"," np.random.shuffle(train_list)\n"," train_list = np.array(train_list)\n","\n"," for batch_num in range(num_batches_per_epoch):\n"," start_index = batch_num * batch_size\n"," end_index = min((batch_num + 1) * batch_size, data_size)\n"," current_batch_size = end_index - start_index\n","\n"," u = train_list[start_index: end_index][:, 0].astype(np.int)\n"," i = train_list[start_index: end_index][:, 1].astype(np.int)\n"," l = train_list[start_index: end_index][:, 2]\n","\n"," i_self_cate = np.zeros([current_batch_size, max_item_cate_size], dtype=np.int)\n"," i_onehop_id = np.zeros([current_batch_size, item_layer1_nei_num], dtype=np.int)\n"," i_onehop_cate = np.zeros([current_batch_size, item_layer1_nei_num, max_item_cate_size], dtype=np.int)\n","\n"," u_self_cate = np.zeros([current_batch_size, max_user_cate_size], dtype=np.int)\n"," u_onehop_id = np.zeros([current_batch_size, user_layer1_nei_num], dtype=np.int)\n"," u_onehop_cate = np.zeros([current_batch_size, user_layer1_nei_num, max_user_cate_size], dtype=np.int)\n","\n"," for index, each_i in enumerate(i):\n"," i_self_cate[index] = item_feature_arr[each_i] #item_self_cate\n","\n"," tmp_one_nei = item_nei_dict[each_i][0]\n"," tmp_prob = item_nei_dict[each_i][1]\n"," if len(tmp_one_nei) > item_layer1_nei_num: #re-sampling\n"," tmp_one_nei = np.random.choice(tmp_one_nei, item_layer1_nei_num, replace=False, p=tmp_prob)\n"," elif len(tmp_one_nei) < item_layer1_nei_num:\n"," tmp_one_nei = np.random.choice(tmp_one_nei, item_layer1_nei_num, replace=True, p=tmp_prob)\n"," tmp_one_nei[-1] = each_i\n","\n"," i_onehop_id[index] = tmp_one_nei #item_1_neigh\n"," i_onehop_cate[index] = item_feature_arr[tmp_one_nei] #item_1_neigh_cate\n","\n"," for index, each_u in enumerate(u):\n"," u_self_cate[index] = user_feature_dict[each_u] # item_self_cate\n","\n"," tmp_one_nei = user_nei_dict[each_u][0]\n"," tmp_prob = user_nei_dict[each_u][1]\n"," if len(tmp_one_nei) > user_layer1_nei_num: # re-sampling\n"," tmp_one_nei = np.random.choice(tmp_one_nei, user_layer1_nei_num, replace=False, p=tmp_prob)\n"," elif len(tmp_one_nei) < user_layer1_nei_num:\n"," tmp_one_nei = np.random.choice(tmp_one_nei, user_layer1_nei_num, replace=True, p=tmp_prob)\n"," tmp_one_nei[-1] = each_u\n","\n"," u_onehop_id[index] = tmp_one_nei # user_1_neigh\n"," u_onehop_cate[index] = user_feature_arr[tmp_one_nei] # user_1_neigh_cate\n","\n"," yield ([u, i, l, u_self_cate, u_onehop_id, u_onehop_cate, i_self_cate, i_onehop_id, i_onehop_cate])\n"," return data_generator(train_list)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EGEcbwIxvng5"},"source":["## Model"]},{"cell_type":"code","metadata":{"id":"PAev2rRAvndh"},"source":["class VAE(nn.Module):\n","\n"," def __init__(self, embed_size):\n"," super(VAE, self).__init__()\n","\n"," Z_dim = X_dim = h_dim = embed_size\n"," self.Z_dim = Z_dim\n"," self.X_dim= X_dim\n"," self.h_dim = h_dim\n"," self.embed_size= embed_size\n","\n"," def init_weights(m):\n"," if isinstance(m, nn.Linear):\n"," nn.init.xavier_uniform(m.weight)\n"," if m.bias is not None:\n"," nn.init.constant(m.bias, 0)\n","\n"," # =============================== Q(z|X) ======================================\n"," self.dense_xh = nn.Linear(X_dim, h_dim)\n"," init_weights(self.dense_xh)\n","\n"," self.dense_hz_mu = nn.Linear(h_dim, Z_dim)\n"," init_weights(self.dense_hz_mu)\n","\n"," self.dense_hz_var = nn.Linear(h_dim, Z_dim)\n"," init_weights(self.dense_hz_var)\n","\n"," # =============================== P(X|z) ======================================\n"," self.dense_zh = nn.Linear(Z_dim, h_dim)\n"," init_weights(self.dense_zh)\n","\n"," self.dense_hx = nn.Linear(h_dim, X_dim)\n"," init_weights(self.dense_hx)\n","\n"," def Q(self, X):\n"," h = nn.ReLU()(self.dense_xh(X))\n"," z_mu = self.dense_hz_mu(h)\n"," z_var = self.dense_hz_var(h)\n"," return z_mu, z_var\n","\n"," def sample_z(self, mu, log_var):\n"," mb_size = mu.shape[0]\n"," eps = Variable(torch.randn(mb_size, self.Z_dim)).cuda()\n"," return mu + torch.exp(log_var / 2) * eps\n","\n"," def P(self, z):\n"," h = nn.ReLU()(self.dense_zh(z))\n"," X = self.dense_hx(h)\n"," return X"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"MA3ueUySwMW1"},"source":["class AGNN(torch.nn.Module):\n"," def __init__(self, user_size, item_size, gender_size, age_size, occupation_size, genre_size, director_size, writer_size, star_size, country_size, embed_size, attention_size, dropout):\n"," super(AGNN, self).__init__()\n"," self.user_size = user_size\n"," self.item_size = item_size\n"," self.gender_size = gender_size\n"," self.age_size = age_size\n"," self.occupation_size = occupation_size\n"," self.genre_size = genre_size\n"," self.director_size = director_size\n"," self.writer_size = writer_size\n"," self.star_size = star_size\n"," self.country_size = country_size\n"," self.embed_size = embed_size\n"," self.dropout = dropout\n"," self.attention_size = attention_size\n","\n"," def init_weights(m):\n"," if isinstance(m, nn.Linear):\n"," nn.init.xavier_uniform(m.weight)\n"," if m.bias is not None:\n"," nn.init.constant(m.bias, 0)\n","\n"," self.user_embed = torch.nn.Embedding(self.user_size, self.embed_size)\n"," self.item_embed = torch.nn.Embedding(self.item_size, self.embed_size)\n"," nn.init.xavier_uniform(self.user_embed.weight)\n"," nn.init.xavier_uniform(self.item_embed.weight)\n","\n"," self.user_bias = torch.nn.Embedding(self.user_size, 1)\n"," self.item_bias = torch.nn.Embedding(self.item_size, 1)\n"," nn.init.constant(self.user_bias.weight, 0)\n"," nn.init.constant(self.item_bias.weight, 0)\n","\n"," self.miu = torch.nn.Parameter(torch.zeros(1), requires_grad=True)\n","\n"," self.gender_embed = torch.nn.Embedding(self.gender_size, self.embed_size)\n"," self.gender_embed.weight.data.normal_(0, 0.05)\n"," self.age_embed = torch.nn.Embedding(self.age_size, self.embed_size)\n"," self.age_embed.weight.data.normal_(0, 0.05)\n"," self.occupation_embed = torch.nn.Embedding(self.occupation_size, self.embed_size)\n"," self.occupation_embed.weight.data.normal_(0, 0.05)\n","\n"," self.genre_embed = torch.nn.Embedding(self.genre_size, self.embed_size)\n"," self.genre_embed.weight.data.normal_(0, 0.05)\n"," self.director_embed = torch.nn.Embedding(self.director_size, self.embed_size)\n"," self.director_embed.weight.data.normal_(0, 0.05)\n"," self.writer_embed = torch.nn.Embedding(self.writer_size, self.embed_size)\n"," self.writer_embed.weight.data.normal_(0, 0.05)\n"," self.star_embed = torch.nn.Embedding(self.star_size, self.embed_size)\n"," self.star_embed.weight.data.normal_(0, 0.05)\n"," self.country_embed = torch.nn.Embedding(self.country_size, self.embed_size)\n"," self.country_embed.weight.data.normal_(0, 0.05)\n","\n","\n"," #--------------------------------------------------\n"," self.dense_item_self_biinter = nn.Linear(self.embed_size, self.embed_size)\n"," self.dense_item_self_siinter = nn.Linear(self.embed_size, self.embed_size)\n"," self.dense_item_onehop_biinter = nn.Linear(self.embed_size, self.embed_size)\n"," self.dense_item_onehop_siinter = nn.Linear(self.embed_size, self.embed_size)\n"," self.dense_user_self_biinter = nn.Linear(self.embed_size, self.embed_size)\n"," self.dense_user_self_siinter = nn.Linear(self.embed_size, self.embed_size)\n"," self.dense_user_onehop_biinter = nn.Linear(self.embed_size, self.embed_size)\n"," self.dense_user_onehop_siinter = nn.Linear(self.embed_size, self.embed_size)\n"," init_weights(self.dense_item_self_biinter)\n"," init_weights(self.dense_item_self_siinter)\n"," init_weights(self.dense_item_onehop_biinter)\n"," init_weights(self.dense_item_onehop_siinter)\n"," init_weights(self.dense_user_self_biinter)\n"," init_weights(self.dense_user_self_siinter)\n"," init_weights(self.dense_user_onehop_biinter)\n"," init_weights(self.dense_user_onehop_siinter)\n","\n"," self.dense_item_cate_self = nn.Linear(2 * self.embed_size, self.embed_size)\n"," self.dense_item_cate_hop1 = nn.Linear(2 * self.embed_size, self.embed_size)\n"," self.dense_user_cate_self = nn.Linear(2 * self.embed_size, self.embed_size)\n"," self.dense_user_cate_hop1 = nn.Linear(2 * self.embed_size, self.embed_size)\n"," init_weights(self.dense_item_cate_self)\n"," init_weights(self.dense_item_cate_hop1)\n"," init_weights(self.dense_user_cate_self)\n"," init_weights(self.dense_user_cate_hop1)\n","\n"," self.dense_item_addgate = nn.Linear(self.embed_size * 2, self.embed_size)\n"," init_weights(self.dense_item_addgate)\n"," self.dense_item_erasegate = nn.Linear(self.embed_size * 2, self.embed_size)\n"," init_weights(self.dense_item_erasegate)\n"," self.dense_user_addgate = nn.Linear(self.embed_size * 2, self.embed_size)\n"," init_weights(self.dense_user_addgate)\n"," self.dense_user_erasegate = nn.Linear(self.embed_size * 2, self.embed_size)\n","\n"," self.user_vae = VAE(embed_size)\n"," self.item_vae = VAE(embed_size)\n","\n"," #----------------------------------------------------\n"," #concat, mlp\n"," self.FC_pre = nn.Linear(2 * embed_size, 1)\n"," init_weights(self.FC_pre)\n","\n"," \"\"\"# dot\n"," self.user_bias = nn.Embedding(self.user_size, 1)\n"," self.item_bias = nn.Embedding(self.item_size, 1)\n"," self.user_bias.weight.data.normal_(0, 0.01)\n"," self.item_bias.weight.data.normal_(0, 0.01)\n"," self.bias = torch.nn.Parameter(torch.rand(1), requires_grad=True)\n"," self.bias.data.uniform_(0, 0.1)\"\"\"\n","\n"," self.sigmoid = nn.Sigmoid()\n"," self.tanh = nn.Tanh()\n"," self.relu = nn.ReLU()\n"," self.leakyrelu = nn.LeakyReLU()\n"," self.dropout = nn.Dropout(p=0.2)\n","\n"," def feat_interaction(self, feature_embedding, fun_bi, fun_si, dimension):\n"," summed_features_emb_square = (torch.sum(feature_embedding, dim=dimension)).pow(2)\n"," squared_sum_features_emb = torch.sum(feature_embedding.pow(2), dim=dimension)\n"," deep_fm = 0.5 * (summed_features_emb_square - squared_sum_features_emb)\n"," deep_fm = self.leakyrelu(fun_bi(deep_fm))\n"," bias_fm = self.leakyrelu(fun_si(feature_embedding.sum(dim=dimension)))\n"," nfm = deep_fm + bias_fm\n"," return nfm\n","\n"," def forward(self, user, item, user_self_cate, user_onehop_id, user_onehop_cate, item_self_cate, item_self_director, item_self_writer, item_self_star, item_self_country, item_onehop_id, item_onehop_cate, item_onehop_director, item_onehop_writer, item_onehop_star, item_onehop_country, mode='train'):\n","\n"," uids_list = user.cuda()\n"," sids_list = item.cuda()\n"," if mode == 'train' or mode == 'warm':\n"," user_embedding = self.user_embed(torch.autograd.Variable(uids_list))\n"," item_embedding = self.item_embed(torch.autograd.Variable(sids_list))\n"," if mode == 'ics':\n"," user_embedding = self.user_embed(torch.autograd.Variable(uids_list))\n"," if mode == 'ucs':\n"," item_embedding = self.item_embed(torch.autograd.Variable(sids_list))\n","\n"," batch_size = item_self_cate.shape[0]\n"," cate_size = item_self_cate.shape[1]\n"," director_size = item_self_director.shape[1]\n"," writer_size = item_self_writer.shape[1]\n"," star_size = item_self_star.shape[1]\n"," country_size = item_self_country.shape[1]\n"," user_onehop_size = user_onehop_id.shape[1]\n"," item_onehop_size = item_onehop_id.shape[1]\n","\n"," #------------------------------------------------------GCN-item\n"," # K=2\n"," item_onehop_id = self.item_embed(Variable(item_onehop_id))\n","\n"," item_onehop_cate = self.genre_embed(Variable(item_onehop_cate).view(-1, cate_size)).view(batch_size,item_onehop_size,cate_size, -1)\n"," item_onehop_director = self.director_embed(Variable(item_onehop_director).view(-1, director_size)).view(batch_size, item_onehop_size, director_size, -1)\n"," item_onehop_writer = self.writer_embed(Variable(item_onehop_writer).view(-1, writer_size)).view(batch_size, item_onehop_size, writer_size, -1)\n"," item_onehop_star = self.star_embed(Variable(item_onehop_star).view(-1, star_size)).view(batch_size, item_onehop_size, star_size, -1)\n"," item_onehop_country = self.country_embed(Variable(item_onehop_country).view(-1, country_size)).view(batch_size, item_onehop_size, country_size, -1)\n","\n"," item_onehop_feature = torch.cat([item_onehop_cate, item_onehop_director, item_onehop_writer, item_onehop_star, item_onehop_country], dim=2)\n"," item_onehop_embed = self.dense_item_cate_hop1(torch.cat([self.feat_interaction(item_onehop_feature, self.dense_item_onehop_biinter, self.dense_item_onehop_siinter, dimension=2), item_onehop_id], dim=-1))\n","\n"," # K=1\n"," item_self_cate = self.genre_embed(Variable(item_self_cate))\n"," item_self_director = self.director_embed(Variable(item_self_director))\n"," item_self_writer = self.writer_embed(Variable(item_self_writer))\n"," item_self_star = self.star_embed(Variable(item_self_star))\n"," item_self_country = self.country_embed(Variable(item_self_country))\n","\n"," item_self_feature = torch.cat([item_self_cate, item_self_director, item_self_writer, item_self_star, item_self_country], dim=1)\n"," item_self_feature = self.feat_interaction(item_self_feature, self.dense_item_self_biinter, self.dense_item_self_siinter, dimension=1)\n","\n"," if mode == 'ics':\n"," item_mu, item_var = self.item_vae.Q(item_self_feature)\n"," item_z = self.item_vae.sample_z(item_mu, item_var)\n"," item_embedding = self.item_vae.P(item_z)\n"," item_self_embed = self.dense_item_cate_self(torch.cat([item_self_feature, item_embedding], dim=-1))\n","\n"," item_addgate = self.sigmoid(self.dense_item_addgate(torch.cat([item_self_embed.unsqueeze(1).repeat(1, item_onehop_size, 1), item_onehop_embed], dim=-1))) # 商品的邻居门,控制邻居信息多少作为输入\n"," item_erasegate = self.sigmoid(self.dense_item_erasegate(torch.cat([item_self_embed, item_onehop_embed.mean(dim=1)], dim=-1)))\n"," item_onehop_embed_final = (item_onehop_embed * item_addgate).mean(1)\n"," item_self_embed = (1 - item_erasegate) * item_self_embed\n","\n"," item_gcn_embed = self.leakyrelu(item_self_embed + item_onehop_embed_final) # [batch, embed]\n","\n"," #----------------------------------------------------------GCN-user\n"," # K=2\n"," user_onehop_id = self.user_embed(Variable(user_onehop_id))\n","\n"," user_onehop_gender_emb = self.gender_embed(Variable(user_onehop_cate[:, :, 0]))\n"," user_onehop_age_emb = self.age_embed(Variable(user_onehop_cate[:, :, 1]))\n"," user_onehop_occupation_emb = self.occupation_embed(Variable(user_onehop_cate[:, :, 2]))\n","\n"," user_onehop_feat = torch.cat([user_onehop_gender_emb.unsqueeze(2), user_onehop_age_emb.unsqueeze(2), user_onehop_occupation_emb.unsqueeze(2)], dim=2)\n"," user_onehop_embed = self.dense_user_cate_hop1(torch.cat([self.feat_interaction(user_onehop_feat, self.dense_user_onehop_biinter, self.dense_user_onehop_siinter, dimension=2), user_onehop_id], dim=-1))\n","\n"," # K=1\n"," user_gender_emb = self.gender_embed(Variable(user_self_cate[:, 0]))\n"," user_age_emb = self.age_embed(Variable(user_self_cate[:, 1]))\n"," user_occupation_emb = self.occupation_embed(Variable(user_self_cate[:, 2]))\n","\n"," user_self_feature = torch.cat([user_gender_emb.unsqueeze(1), user_age_emb.unsqueeze(1), user_occupation_emb.unsqueeze(1)], dim=1)\n"," user_self_feature = self.feat_interaction(user_self_feature, self.dense_user_self_biinter, self.dense_user_onehop_siinter, dimension=1)\n","\n"," if mode == 'ucs':\n"," user_mu, user_var = self.user_vae.Q(user_self_feature)\n"," user_z = self.user_vae.sample_z(user_mu, user_var)\n"," user_embedding = self.user_vae.P(user_z)\n"," user_self_embed = self.dense_user_cate_self(torch.cat([user_self_feature, user_embedding], dim=-1))\n","\n"," user_addgate = self.sigmoid(self.dense_user_addgate(torch.cat([user_self_embed.unsqueeze(1).repeat(1, user_onehop_size, 1), user_onehop_embed],dim=-1)))\n"," user_erasegate = self.sigmoid(self.dense_user_erasegate(torch.cat([user_self_embed, user_onehop_embed.mean(dim=1)], dim=-1)))\n"," user_onehop_embed_final = (user_onehop_embed * user_addgate).mean(dim=1)\n"," user_self_embed = (1 - user_erasegate) * user_self_embed\n","\n"," user_gcn_embed = self.leakyrelu(user_self_embed + user_onehop_embed_final)\n","\n"," #--------------------------------------------------norm\n"," item_mu, item_var = self.item_vae.Q(item_self_feature)\n"," item_z = self.item_vae.sample_z(item_mu, item_var)\n"," item_preference_sample = self.item_vae.P(item_z)\n","\n"," user_mu, user_var = self.user_vae.Q(user_self_feature)\n"," user_z = self.user_vae.sample_z(user_mu, user_var)\n"," user_preference_sample = self.user_vae.P(user_z)\n","\n"," recon_loss = torch.norm(item_preference_sample - item_embedding) + torch.norm(user_preference_sample - user_embedding)\n"," kl_loss = torch.mean(0.5 * torch.sum(torch.exp(item_z) + item_mu ** 2 - 1. - item_var, 1)) + \\\n"," torch.mean(0.5 * torch.sum(torch.exp(user_z) + user_mu ** 2 - 1. - user_var, 1))\n","\n"," ####################################prediction#####################################################\n","\n"," #concat -> mlp\n"," bu = self.user_bias(Variable(uids_list))\n"," bi = self.item_bias(Variable(sids_list))\n"," #pred = (user_gcn_embed * item_gcn_embed).sum(1, keepdim=True) + bu + bi + (self.miu).repeat(batch_size, 1)\n"," tmp = torch.cat([user_gcn_embed, item_gcn_embed], dim=1)\n"," pred = self.FC_pre(tmp) + (user_gcn_embed * item_gcn_embed).sum(1, keepdim=True) + bu + bi + (self.miu).repeat(batch_size, 1)\n","\n"," return pred.squeeze(), recon_loss, kl_loss"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SdtWr805wMR0"},"source":["## Metrics and Evaluation"]},{"cell_type":"code","metadata":{"id":"dmvZXQfdwMLT"},"source":["def metrics(model, test_dataloader):\n"," label_lst, pred_lst = [], []\n"," rmse, mse, mae = 0,0,0\n"," count = 0\n"," for batch_data in test_dataloader:\n"," user = torch.LongTensor(batch_data[0]).cuda()\n"," item = torch.LongTensor(batch_data[1]).cuda()\n"," label = torch.FloatTensor(batch_data[2]).cuda()\n"," user_self_cate = torch.LongTensor(batch_data[3]).cuda()\n"," user_onehop_id = torch.LongTensor(batch_data[4]).cuda()\n"," user_onehop_cate = torch.LongTensor(batch_data[5]).cuda()\n"," item_self_cate, item_self_director, item_self_writer, item_self_star, item_self_country = torch.LongTensor(\n"," batch_data[6])[:, 0:6].cuda(), torch.LongTensor(batch_data[6])[:, 6:9].cuda(), torch.LongTensor(\n"," batch_data[6])[:, 9:12].cuda(), torch.LongTensor(batch_data[6])[:, 12:15].cuda(), torch.LongTensor(\n"," batch_data[6])[:, 15:].cuda()\n"," item_onehop_id = torch.LongTensor(batch_data[7]).cuda()\n"," item_onehop_cate, item_onehop_director, item_onehop_writer, item_onehop_star, item_onehop_country = torch.LongTensor(\n"," batch_data[8])[:, :, 0:6].cuda(), torch.LongTensor(batch_data[8])[:, :, 6:9].cuda(), torch.LongTensor(\n"," batch_data[8])[:, :, 9:12].cuda(), torch.LongTensor(batch_data[8])[:, :, 12:15].cuda(), torch.LongTensor(\n"," batch_data[8])[:, :, 15:].cuda()\n","\n"," prediction, recon_loss, kl_loss = model(user, item, user_self_cate, user_onehop_id, user_onehop_cate, item_self_cate,\n"," item_self_director, item_self_writer, item_self_star, item_self_country, item_onehop_id,\n"," item_onehop_cate, item_onehop_director, item_onehop_writer, item_onehop_star,\n"," item_onehop_country, mode = mode)\n"," prediction = prediction.cpu().data.numpy()\n"," prediction = prediction.reshape(prediction.shape[0])\n"," label = label.cpu().numpy()\n"," my_rmse = np.sum((prediction - label) ** 2)\n"," my_mse = np.sum((prediction - label) ** 2)\n"," my_mae = np.sum(np.abs(prediction - label))\n"," # my_rmse = torch.sqrt(torch.sum((prediction - label) ** 2) / FLAGS.batch_size)\n"," rmse+=my_rmse\n"," mse+=my_mse\n"," mae+=my_mae\n"," count += len(user)\n"," label_lst.extend(list([float(l) for l in label]))\n"," pred_lst.extend(list([float(l) for l in prediction]))\n","\n"," my_mse = mse/count\n"," my_rmse = np.sqrt(rmse/count)\n"," my_mae = mae/count\n"," return my_rmse, my_mse, my_mae, label_lst, pred_lst"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"w4sXP-MRxRVM","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1635860051781,"user_tz":-330,"elapsed":993538,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"e3f51476-ec76-404f-a86e-5eacbc1e5291"},"source":["if __name__ == '__main__':\n"," #item cold start\n"," f_info = 'ml100k/uiinfo.pkl'\n"," f_neighbor = 'ml100k/neighbor_aspect_extension_2_zscore_ics_uuii_0.20.pkl'\n"," f_train = 'ml100k/ics_train.dat'\n"," f_test = 'ml100k/ics_val.dat'\n"," f_model = 'ml100k/agnn_ics_'\n"," mode = 'ics'\n","\n"," \"\"\"# user cold start\n"," f_info = 'ml100k/uiinfo.pkl'\n"," f_neighbor = 'ml100k/neighbor_aspect_extension_2_zscore_ucs_uuii.pkl'\n"," f_train = 'ml100k/ucs_train.dat'\n"," f_test = 'ml100k/ucs_val.dat'\n"," f_model = 'ml100k/agnn_ucs_'\n"," mode = 'ucs'\"\"\"\n","\n"," \"\"\"# warm start\n"," f_info = 'ml100k/uiinfo.pkl'\n"," f_neighbor = 'ml100k/neighbor_aspect_extension_2_zscore_warm_uuii.pkl'\n"," f_train = 'ml100k/warm_train.dat'\n"," f_test = 'ml100k/warm_val.dat'\n"," f_model = 'ml100k/agnn_warm_'\n"," mode = 'warm'\"\"\"\n","\n","\n"," FLAGS = parser.parse_args(args={})\n"," print(\"\\nParameters:\")\n"," print(FLAGS.__dict__)\n","\n"," with open(f_neighbor, 'rb') as f:\n"," neighbor_dict = pickle.load(f)\n"," user_nei_dict = neighbor_dict['user_nei_dict']\n"," item_nei_dict = neighbor_dict['item_nei_dict']\n"," director_num = neighbor_dict['director_num']\n"," writer_num = neighbor_dict['writer_num']\n"," star_num = neighbor_dict['star_num']\n"," country_num = neighbor_dict['country_num']\n","\n"," item_director_dict = neighbor_dict['item_director_dict'] #dict[i]=[x,x,x]\n"," item_writer_dict = neighbor_dict['item_writer_dict'] #dict[i]=[x,x,x]\n"," item_star_dict = neighbor_dict['item_star_dict'] #dict[i]=[x,x,x]\n"," item_country_dict = neighbor_dict['item_country_dict'] #dict[i]=[x,x,x,x,x,x,x,x]\n","\n"," with open(f_info, 'rb') as f:\n"," item_info = pickle.load(f)\n"," user_num = item_info['user_num']\n"," item_num = item_info['item_num']\n"," gender_num = item_info['gender_num']\n"," age_num = item_info['age_num']\n"," occupation_num = item_info['occupation_num']\n"," genre_num = item_info['genre_num']\n"," user_feature_dict = item_info['user_feature_dict'] #gender, age, occupation dict[u]=[x,x,x]\n"," item_feature_dict = item_info['item_feature_dict'] #genre dict[i]=[x,x,x,x,x,x]\n","\n"," print(\"user_num {}, item_num {}, gender_num {}, age_num {}, occupation_num {}, genre_num {}, director_num {}, writer_num {}, star_num {}, country_num {}, mode {} \".format(user_num, item_num, gender_num, age_num, occupation_num, genre_num, director_num, writer_num, star_num, country_num, mode))\n","\n"," train_steps, train_list = get_data_list(f_train, batch_size=FLAGS.batch_size)\n"," test_steps, test_list = get_data_list(f_test, batch_size=FLAGS.batch_size)\n","\n"," model = AGNN(user_num, item_num, gender_num, age_num, occupation_num, genre_num, director_num, writer_num, star_num, country_num, FLAGS.embed_size, FLAGS.attention_size, FLAGS.dropout)\n"," model.cuda()\n","\n"," loss_function = torch.nn.MSELoss(size_average=False)\n"," optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=FLAGS.lr, weight_decay=0.001)\n","\n"," writer = SummaryWriter() # For visualization\n"," #f_loss_curve = open('tmp_loss_curve.txt', 'w')\n"," best_rmse = 5\n","\n"," count = 0\n"," for epoch in range(FLAGS.epochs):\n"," #tmp_main_loss, tmp_vae_loss = [], []\n"," model.train() # Enable dropout (if have).\n"," start_time = time.time()\n"," train_dataloader = get_batch_instances(train_list, user_feature_dict, item_feature_dict, item_director_dict, item_writer_dict, item_star_dict, item_country_dict, batch_size=FLAGS.batch_size, user_nei_dict=user_nei_dict, item_nei_dict=item_nei_dict, shuffle=True)\n","\n"," for idx, batch_data in enumerate(train_dataloader): #u, i, l, u_self_cate, u_onehop_id, u_onehop_rating, u_onehop_cate, i_self_cate, i_onehop_id, i_onehop_cate\n"," user = torch.LongTensor(batch_data[0]).cuda()\n"," item = torch.LongTensor(batch_data[1]).cuda()\n"," label = torch.FloatTensor(batch_data[2]).cuda()\n"," user_self_cate = torch.LongTensor(batch_data[3]).cuda()\n"," user_onehop_id = torch.LongTensor(batch_data[4]).cuda()\n"," user_onehop_cate = torch.LongTensor(batch_data[5]).cuda()\n"," item_self_cate, item_self_director, item_self_writer, item_self_star, item_self_country = torch.LongTensor(batch_data[6])[:, 0:6].cuda(), torch.LongTensor(batch_data[6])[:, 6:9].cuda(), torch.LongTensor(batch_data[6])[:, 9:12].cuda(), torch.LongTensor(batch_data[6])[:, 12:15].cuda(), torch.LongTensor(batch_data[6])[:, 15:].cuda()\n"," item_onehop_id = torch.LongTensor(batch_data[7]).cuda()\n"," item_onehop_cate, item_onehop_director, item_onehop_writer, item_onehop_star, item_onehop_country = torch.LongTensor(batch_data[8])[:, :, 0:6].cuda(), torch.LongTensor(batch_data[8])[:, :, 6:9].cuda(), torch.LongTensor(batch_data[8])[:, :, 9:12].cuda(), torch.LongTensor(batch_data[8])[:, :, 12:15].cuda(), torch.LongTensor(batch_data[8])[:, :, 15:].cuda()\n","\n"," model.zero_grad()\n"," prediction, recon_loss, kl_loss = model(user, item, user_self_cate, user_onehop_id, user_onehop_cate, item_self_cate, item_self_director, item_self_writer, item_self_star, item_self_country, item_onehop_id, item_onehop_cate, item_onehop_director, item_onehop_writer, item_onehop_star, item_onehop_country, mode='train')\n","\n"," label = Variable(label)\n","\n"," main_loss = loss_function(prediction, label)\n"," loss = main_loss + FLAGS.vae_lambda * (recon_loss + kl_loss)\n","\n"," loss.backward()\n"," # nn.utils.clip_grad_norm(model.parameters(), FLAGS.clip_norm)\n"," optimizer.step()\n"," writer.add_scalar('data/loss', loss.data, count)\n"," count += 1\n","\n"," tmploss = torch.sqrt(loss / FLAGS.batch_size)\n"," print(50 * '#')\n"," print('epoch: ', epoch, ' ', tmploss.detach())\n","\n"," model.eval()\n"," print('time = ', time.time() - start_time)\n"," test_dataloader = get_batch_instances(test_list, user_feature_dict, item_feature_dict, item_director_dict, item_writer_dict, item_star_dict, item_country_dict, batch_size=FLAGS.batch_size, user_nei_dict=user_nei_dict, item_nei_dict=item_nei_dict, shuffle=False)\n"," rmse, mse, mae, label_lst, pred_lst = metrics(model, test_dataloader)\n"," print('test rmse,mse,mae: ', rmse,mse,mae)\n","\n"," \"\"\"if (rmse < best_rmse):\n"," best_rmse = rmse\n"," f_name = f_model + str(best_rmse)[:7] + '.dat' #f_model + str(best_rmse)[:7] + '.dat'\n"," #torch.save(model, f_name)\n"," f = open(f_name, 'w')\n"," res_dict = {}\n"," res_dict['label'] = label_lst\n"," res_dict['pred'] = pred_lst\n"," json.dump(res_dict, f)\n"," f.close()\n"," print('save result ok')\"\"\""],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["\n","Parameters:\n","{'lr': 0.0005, 'dropout': 0.5, 'batch_size': 128, 'gpu': '0', 'epochs': 20, 'clip_norm': 5.0, 'embed_size': 30, 'attention_size': 50, 'item_layer1_nei_num': 10, 'user_layer1_nei_num': 10, 'vae_lambda': 1}\n","user_num 944, item_num 1683, gender_num 2, age_num 7, occupation_num 21, genre_num 19, director_num 1112, writer_num 2016, star_num 2568, country_num 128, mode ics \n","##################################################\n","epoch: 0 tensor(0.5626, device='cuda:0')\n","time = 47.153045654296875\n","test rmse,mse,mae: 1.0325273649289848 1.066112759327193 0.8309432697873491\n","##################################################\n","epoch: 1 tensor(0.5689, device='cuda:0')\n","time = 43.06915545463562\n","test rmse,mse,mae: 1.0266010076226462 1.0539096288518324 0.824870026105042\n","##################################################\n","epoch: 2 tensor(0.6943, device='cuda:0')\n","time = 42.44387221336365\n","test rmse,mse,mae: 1.0571549817251507 1.1175766553863038 0.8638132363016552\n","##################################################\n","epoch: 3 tensor(0.5527, device='cuda:0')\n","time = 42.13542699813843\n","test rmse,mse,mae: 1.0517702752306248 1.1062207118587042 0.8609723428611776\n","##################################################\n","epoch: 4 tensor(0.5443, device='cuda:0')\n","time = 43.24214553833008\n","test rmse,mse,mae: 1.0358337215744027 1.0729514987506772 0.8453118500108566\n","##################################################\n","epoch: 5 tensor(0.4815, device='cuda:0')\n","time = 42.43221974372864\n","test rmse,mse,mae: 1.0231939566804946 1.0469258729874857 0.8253297993686741\n","##################################################\n","epoch: 6 tensor(0.4968, device='cuda:0')\n","time = 42.93929886817932\n","test rmse,mse,mae: 1.0359929845205296 1.0732814639757542 0.8411421427834342\n","##################################################\n","epoch: 7 tensor(0.4895, device='cuda:0')\n","time = 42.229517459869385\n","test rmse,mse,mae: 1.0532774718633056 1.1093934327347565 0.8607855100322045\n","##################################################\n","epoch: 8 tensor(0.5074, device='cuda:0')\n","time = 42.55834674835205\n","test rmse,mse,mae: 1.0247800932391622 1.050174239499266 0.824709580819818\n","##################################################\n","epoch: 9 tensor(0.4578, device='cuda:0')\n","time = 41.558250427246094\n","test rmse,mse,mae: 1.0922955741421183 1.19310962129046 0.8990770569855391\n","##################################################\n","epoch: 10 tensor(0.5202, device='cuda:0')\n","time = 42.752477407455444\n","test rmse,mse,mae: 1.0405772957327133 1.0828011083944065 0.8449710432355982\n","##################################################\n","epoch: 11 tensor(0.5196, device='cuda:0')\n","time = 41.88001823425293\n","test rmse,mse,mae: 1.0647690947606663 1.1337332251574486 0.8694145861863434\n","##################################################\n","epoch: 12 tensor(0.5833, device='cuda:0')\n","time = 41.569355964660645\n","test rmse,mse,mae: 1.0307747052978313 1.0624964930818313 0.832099199058856\n","##################################################\n","epoch: 13 tensor(0.5642, device='cuda:0')\n","time = 41.10500383377075\n","test rmse,mse,mae: 1.0283506332487717 1.05750502490315 0.8281871831344915\n","##################################################\n","epoch: 14 tensor(0.5756, device='cuda:0')\n","time = 41.294716119766235\n","test rmse,mse,mae: 1.05585361314044 1.1148268523817215 0.8628918800500965\n","##################################################\n","epoch: 15 tensor(0.5247, device='cuda:0')\n","time = 41.94802975654602\n","test rmse,mse,mae: 1.0287578659232348 1.0583427466989286 0.831378873784134\n","##################################################\n","epoch: 16 tensor(0.5186, device='cuda:0')\n","time = 41.139464139938354\n","test rmse,mse,mae: 1.0345978113820935 1.070392631316618 0.8337778758005447\n","##################################################\n","epoch: 17 tensor(0.4468, device='cuda:0')\n","time = 42.24388003349304\n","test rmse,mse,mae: 1.027209256353785 1.0551588563388956 0.8190214309314482\n","##################################################\n","epoch: 18 tensor(0.4668, device='cuda:0')\n","time = 41.94647932052612\n","test rmse,mse,mae: 1.0637774120911032 1.1316223824752447 0.8678730945240749\n","##################################################\n","epoch: 19 tensor(0.4666, device='cuda:0')\n","time = 41.875648021698\n","test rmse,mse,mae: 1.063158955390531 1.1303069644270851 0.8644698513033106\n"]}]},{"cell_type":"markdown","metadata":{"id":"Uo_PSXGbxxil"},"source":["**END**"]}]}