{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-14-nfm-criteo.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/P917494%20%7C%20NFM%20on%20Criteo%20DAC%20sample%20dataset%20in%20PyTorch.ipynb","timestamp":1644614102319},{"file_id":"1PcrzoopQcJ6T5CwS38RIyYqoayb0ytc7","timestamp":1641536947688},{"file_id":"1FEZmnoLGIsTsGiK2gi1TsIHLAaWCXF_a","timestamp":1640329037065}],"collapsed_sections":[],"toc_visible":true,"mount_file_id":"1FEZmnoLGIsTsGiK2gi1TsIHLAaWCXF_a","authorship_tag":"ABX9TyM0maX1ftvDR7GCOcKO5FYK"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# NFM on Criteo DAC sample dataset in PyTorch"],"metadata":{"id":"2VZ8fr0qMt5t"}},{"cell_type":"code","source":["import numpy as np\n","import pandas as pd\n","import torch\n","from sklearn.model_selection import train_test_split\n","from sklearn.metrics import log_loss, roc_auc_score\n","from sklearn.preprocessing import LabelEncoder, MinMaxScaler\n","from collections import OrderedDict, namedtuple, defaultdict\n","from torch import nn as nn\n","from torch.utils.data import DataLoader, Dataset, TensorDataset"],"metadata":{"id":"EToD4LnRLgyY"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["!wget -q --show-progress https://github.com/RecoHut-Datasets/criteo/raw/v1/dac_sample.txt"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"L60wRz3KLutF","executionInfo":{"status":"ok","timestamp":1641536229646,"user_tz":-330,"elapsed":1342,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"1a380f02-5cf1-43d2-8a40-974119bbecdc"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["\rdac_sample.txt 0%[ ] 0 --.-KB/s \rdac_sample.txt 100%[===================>] 23.20M --.-KB/s in 0.09s \n"]}]},{"cell_type":"code","source":["class SparseFeat(namedtuple('SparseFeat', ['name', 'vocabulary_size', 'embedding_dim', 'use_hash',\n"," 'dtype', 'embedding_name', 'group_name'])):\n"," __slots__ = ()\n"," def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, dtype=\"int32\", embedding_name=None,\n"," group_name='default_group'):\n"," if embedding_name is None:\n"," embedding_name = name\n"," return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, dtype, embedding_name, group_name)\n","\n"," def __hash__(self):\n"," return self.name.__hash__()"],"metadata":{"id":"9Rs8-3zMLSCa"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class DenseFeat(namedtuple('DenseFeat', ['name', 'dimension', 'dtype'])):\n"," __slots__ = ()\n"," def __new__(cls, name, dimension=1, dtype='float32'):\n"," return super(DenseFeat, cls).__new__(cls, name, dimension, dtype)\n"," def __hash__(self):\n"," return self.name.__hash__()"],"metadata":{"id":"EVO8EyF6N8ru"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def activation_layer(act_name, hidden_size=None, dice_dim=2):\n"," if isinstance(act_name, str):\n"," if act_name.lower() == 'sigmoid':\n"," act_layer = nn.Sigmoid()\n"," elif act_name.lower() == 'relu':\n"," act_layer = nn.ReLU(inplace=True)\n"," elif act_name.lower() == 'prelu':\n"," act_layer = nn.PReLU()\n"," return act_layer"],"metadata":{"id":"2Ab6UDMIN9ii"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def get_auc(loader, model):\n"," pred, target = [], []\n"," model.eval()\n"," with torch.no_grad():\n"," for x, y in loader:\n"," x, y = x.to(device).float(), y.to(device).float()\n"," y_hat = model(x)\n"," pred += list(y_hat.cpu().numpy())\n"," target += list(y.cpu().numpy())\n"," auc = roc_auc_score(target, pred)\n"," return auc"],"metadata":{"id":"dBqDa29AN-jX"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class DNN(nn.Module):\n"," def __init__(self, inputs_dim, hidden_units, activation='relu', l2_reg=0, dropout_rate=0, use_bn=False,\n"," init_std=0.0001, dice_dim=3, seed=1024, device='cpu'):\n"," super(DNN, self).__init__()\n"," self.dropout = nn.Dropout(dropout_rate)\n"," self.seed = seed\n"," self.l2_reg = l2_reg\n"," self.use_bn = use_bn\n"," hidden_units = [inputs_dim] + list(hidden_units)\n"," self.linears = nn.ModuleList([\n"," nn.Linear(hidden_units[i], hidden_units[i+1]) for i in range(len(hidden_units)-1)\n"," ])\n"," if use_bn:\n"," self.bn = nn.ModuleList([\n"," nn.BatchNorm1d(hidden_units[i], hidden_units[i+1]) for i in range(len(hidden_units)-1)\n"," ])\n"," self.activation_layer = nn.ModuleList([\n"," activation_layer(activation, hidden_units[i+1], dice_dim) for i in range(len(hidden_units)-1)\n"," ])\n"," for name, tensor in self.linears.named_parameters():\n"," if 'weight' in name:\n"," nn.init.normal_(tensor, mean=0, std=init_std)\n"," self.to(device)\n"," def forward(self, inputs):\n"," deep_input = inputs\n"," for i in range(len(self.linears)):\n"," fc = self.linears[i](deep_input)\n"," if self.use_bn:\n"," fc = self.bn[i](fc)\n"," fc = self.activation_layer[i](fc)\n"," fc = self.dropout(fc)\n"," deep_input = fc\n"," return deep_input"],"metadata":{"id":"-JJ1mTc0N7Fv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class BiInteractionPooling(nn.Module):\n"," def __init__(self):\n"," super(BiInteractionPooling, self).__init__()\n"," def forward(self, inputs):\n"," concated_embeds_value = inputs\n"," square_of_sum = torch.pow(torch.sum(concated_embeds_value, dim=1, keepdim=True), 2)\n"," sum_of_square = torch.sum(concated_embeds_value * concated_embeds_value, dim=1, keepdim=True)\n"," cross_term = 0.5 * (square_of_sum - sum_of_square)\n"," return cross_term"],"metadata":{"id":"QCh31R87N5_5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class NFM(nn.Module):\n"," def __init__(self, feat_sizes, embedding_size, linear_feature_columns, dnn_feature_columns, dnn_hidden_units=(128, 128),\n"," l2_reg_embedding=1e-5, l2_reg_linear=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, bi_dropout=1,\n"," dnn_dropout=0, dnn_activation='relu', task='binary', device='cpu', gpus=None):\n"," super(NFM, self).__init__()\n"," self.dense_features_columns = list(\n"," filter(lambda x: isinstance(x, DenseFeat), dnn_feature_columns)) if len(dnn_feature_columns) else []\n"," dense_input_dim = sum(map(lambda x: x.dimension, self.dense_features_columns))\n","\n"," self.sparse_features_columns = list(\n"," filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns)) if len(dnn_feature_columns) else []\n","\n"," self.feat_sizes = feat_sizes\n"," self.embedding_size = embedding_size\n"," self.embedding_dic = nn.ModuleDict({feat.name:nn.Embedding(self.feat_sizes[feat.name], self.embedding_size, sparse=False)\n"," for feat in self.sparse_features_columns})\n"," for tensor in self.embedding_dic.values():\n"," nn.init.normal_(tensor.weight, mean=0, std=init_std)\n","\n"," self.feature_index = defaultdict(int)\n"," start = 0\n"," for feat in self.feat_sizes:\n"," if feat in self.feature_index:\n"," continue\n"," self.feature_index[feat] = start\n"," start += 1\n","\n"," self.dnn = DNN(dense_input_dim+self.embedding_size, dnn_hidden_units, activation=dnn_activation,\n"," l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=False,\n"," init_std=init_std, device=device)\n"," self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False).to(device)\n","\n"," dnn_hidden_units = [len(self.feature_index)] + list(dnn_hidden_units) + [1]\n"," self.Linears = nn.ModuleList(\n"," [nn.Linear(dnn_hidden_units[i], dnn_hidden_units[i + 1]) for i in range(len(dnn_hidden_units) - 1)])\n"," self.relu = nn.ReLU()\n"," self.bi_pooling = BiInteractionPooling()\n"," self.bi_dropout = bi_dropout\n"," if self.bi_dropout > 0:\n"," self.dropout = nn.Dropout(bi_dropout)\n"," self.to(device)\n","\n"," def forward(self, X):\n"," sparse_embedding = [self.embedding_dic[feat.name](X[:, self.feature_index[feat.name]].long()).reshape(X.shape[0], 1, -1)\n"," for feat in self.sparse_features_columns]\n"," dense_values = [X[:, self.feature_index[feat.name]].reshape(-1, 1) for feat in self.dense_features_columns]\n"," # print('sparse_embedding shape', sparse_embedding[0].shape)\n"," dense_input = torch.cat(dense_values, dim=1)\n"," # print('densn_input shape', dense_input.shape)\n"," fm_input = torch.cat(sparse_embedding, dim=1)\n"," # print('fm_input_shape', fm_input.shape)\n"," bi_out = self.bi_pooling(fm_input)\n"," # print('bi_out shape', bi_out.shape)\n"," if self.bi_dropout:\n"," bi_out = self.dropout(bi_out)\n","\n"," bi_out = torch.flatten(torch.cat([bi_out], dim=-1), start_dim=1)\n","\n"," dnn_input = torch.cat((dense_input, bi_out), dim=1)\n"," dnn_output = self.dnn(dnn_input)\n"," dnn_output = self.dnn_linear(dnn_output)\n","\n"," # print('X shape', X.shape)\n"," for i in range(len(self.Linears)):\n"," fc = self.Linears[i](X)\n"," fc = self.relu(fc)\n"," fc = self.dropout(fc)\n"," X = fc\n","\n"," logit = X + dnn_output\n"," y_pred = torch.sigmoid(logit)\n"," return y_pred"],"metadata":{"id":"D_slY7KGN44C"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["batch_size = 1024\n","lr = 1e-3\n","wd = 1e-5\n","epochs = 10\n","seed = 1024\n","embedding_size = 4\n","\n","sparse_features = ['C' + str(i) for i in range(1, 27)]\n","dense_features = ['I' + str(i) for i in range(1, 14)]\n","col_names = ['label'] + dense_features + sparse_features\n","data = pd.read_csv('dac_sample.txt', names=col_names, sep='\\t')\n","\n","data[sparse_features] = data[sparse_features].fillna('-1',)\n","data[dense_features] = data[dense_features].fillna('0', )\n","target = ['label']\n","\n","feat_sizes = {}\n","feat_sizes_dense = {feat: 1 for feat in dense_features}\n","feat_sizes_sparse = {feat: len(data[feat].unique()) for feat in sparse_features}\n","feat_sizes.update(feat_sizes_dense)\n","feat_sizes.update(feat_sizes_sparse)\n","\n","for feat in sparse_features:\n"," lbe = LabelEncoder()\n"," data[feat] = lbe.fit_transform(data[feat])\n","\n","nms = MinMaxScaler(feature_range=(0, 1))\n","data[dense_features] = nms.fit_transform(data[dense_features])\n","\n","fixlen_feature_columns = [SparseFeat(feat, data[feat].nunique()) for feat in sparse_features] + [DenseFeat(feat, 1,)\n"," for feat in dense_features]\n","dnn_feature_columns = fixlen_feature_columns\n","linear_feature_columns = fixlen_feature_columns\n","\n","train, test = train_test_split(data, test_size=0.2, random_state=2022)\n","feature_names = sparse_features + dense_features\n","# train_model_input = {name: train[name] for name in feature_names}\n","# test_model_input = {name: test[name] for name in feature_names}\n","\n","device = 'cpu'\n","model = NFM(feat_sizes, embedding_size, linear_feature_columns, dnn_feature_columns).to(device)\n","\n","train_label = pd.DataFrame(train['label'])\n","train = train.drop(columns=['label'])\n","train_tensor_data = TensorDataset(torch.from_numpy(np.array(train)), torch.from_numpy(np.array(train_label)))\n","train_loader = DataLoader(train_tensor_data, shuffle=True, batch_size=batch_size)\n","\n","test_label = pd.DataFrame(test['label'])\n","test = test.drop(columns=['label'])\n","test_tensor_data = TensorDataset(torch.from_numpy(np.array(test)), torch.from_numpy(np.array(test_label)))\n","test_loader = DataLoader(test_tensor_data, batch_size=batch_size)\n","\n","loss_func = nn.BCELoss(reduction='mean')\n","optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n","\n","for epoch in range(epochs):\n"," total_loss_epoch = 0.0\n"," total_tmp = 0\n"," model.train()\n"," for index, (x, y) in enumerate(train_loader):\n"," x, y = x.to(device).float(), y.to(device).float()\n"," y_hat = model(x)\n","\n"," optimizer.zero_grad()\n"," loss = loss_func(y_hat, y)\n"," loss.backward()\n"," optimizer.step()\n"," total_loss_epoch += loss.item()\n"," total_tmp += 1\n"," auc = get_auc(test_loader, model)\n"," print('epoch/epoches: {}/{}, train loss: {:.3f}, test auc: {:.3f}'.format(epoch, epochs,\n"," total_loss_epoch / total_tmp, auc))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ZH69aNy6LmvA","executionInfo":{"status":"ok","timestamp":1641536912733,"user_tz":-330,"elapsed":38318,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"0ba95cc5-745e-41c5-e3c3-88ad3a7e3223"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["epoch/epoches: 0/10, train loss: 0.567, test auc: 0.483\n","epoch/epoches: 1/10, train loss: 0.524, test auc: 0.674\n","epoch/epoches: 2/10, train loss: 0.506, test auc: 0.677\n","epoch/epoches: 3/10, train loss: 0.502, test auc: 0.682\n","epoch/epoches: 4/10, train loss: 0.498, test auc: 0.684\n","epoch/epoches: 5/10, train loss: 0.497, test auc: 0.686\n","epoch/epoches: 6/10, train loss: 0.496, test auc: 0.688\n","epoch/epoches: 7/10, train loss: 0.495, test auc: 0.689\n","epoch/epoches: 8/10, train loss: 0.495, test auc: 0.690\n","epoch/epoches: 9/10, train loss: 0.494, test auc: 0.691\n"]}]},{"cell_type":"code","source":["!pip install -q watermark\n","%reload_ext watermark\n","%watermark -a \"Sparsh A.\" -m -iv -u -t -d"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"tmhWlpwiL0bb","executionInfo":{"status":"ok","timestamp":1641536460223,"user_tz":-330,"elapsed":3599,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"fac5bee6-aa05-4846-ac1d-c13dd1f05222"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Author: Sparsh A.\n","\n","Last updated: 2022-01-07 06:21:01\n","\n","Compiler : GCC 7.5.0\n","OS : Linux\n","Release : 5.4.144+\n","Machine : x86_64\n","Processor : x86_64\n","CPU cores : 2\n","Architecture: 64bit\n","\n","IPython: 5.5.0\n","torch : 1.10.0+cu111\n","pandas : 1.1.5\n","numpy : 1.19.5\n","\n"]}]}]}