{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-20-bst-mxnet.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T088416%20%7C%20BST%20in%20MXNet.ipynb","timestamp":1644653410057}],"collapsed_sections":[],"toc_visible":true,"mount_file_id":"1fqT2IkRGsZlWZSWiyU6Mo2KWFJbsPNL1","authorship_tag":"ABX9TyPLzDrHML17eRgnG7zvquu0"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"NBj3rQU4ydEV"},"source":["# BST in MXNet\n","\n","> Implementing BST transformer model in MXNet library framework. After implementation, running on a sample dataset."]},{"cell_type":"markdown","metadata":{"id":"PwZFDzR2-_w6"},"source":["## Setup"]},{"cell_type":"markdown","metadata":{"id":"0eZJphSm_BC5"},"source":["### Installations"]},{"cell_type":"code","metadata":{"id":"Ir5ol9MsUPuc"},"source":["!pip install -q mxnet\n","!pip install -q gluonnlp"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mxpuBk6h_ClV"},"source":["### Imports"]},{"cell_type":"code","metadata":{"id":"mj9PWmgJT_ce"},"source":["import numpy as np\n","import random\n","\n","import mxnet as mx\n","from mxnet import gluon\n","from mxnet.gluon import nn\n","from mxnet import autograd as ag\n","from mxnet.gluon.nn import HybridBlock, HybridSequential, LeakyReLU\n","from mxnet.gluon.block import HybridBlock\n","from mxnet.ndarray import L2Normalization\n","from gluonnlp.model import AttentionCell"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WXtMWjYL_bfv"},"source":["### Params"]},{"cell_type":"code","metadata":{"id":"TcLEXP0E_c--"},"source":["np.random.seed(100)\n","ctx = mx.cpu()\n","mx.random.seed(100)\n","random.seed(100)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"MiY7zOFk_ej-"},"source":["_BATCH = 1\n","_SEQ_LEN = 32\n","_OTHER_LEN = 32\n","_EMB_DIM = 32\n","_NUM_HEADS = 8\n","_DROP = 0.2\n","_UNITS = 32"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sZG6z7KB_Ubu"},"source":["## Model"]},{"cell_type":"code","metadata":{"id":"01vZAyZoT_yI"},"source":["def _masked_softmax(F, att_score, mask, dtype):\n"," \"\"\"Ignore the masked elements when calculating the softmax\n"," Parameters\n"," ----------\n"," F : symbol or ndarray\n"," att_score : Symborl or NDArray\n"," Shape (batch_size, query_length, memory_length)\n"," mask : Symbol or NDArray or None\n"," Shape (batch_size, query_length, memory_length)\n"," Returns\n"," -------\n"," att_weights : Symborl or NDArray\n"," Shape (batch_size, query_length, memory_length)\n"," \"\"\"\n"," if mask is not None:\n"," # Fill in the masked scores with a very small value\n"," neg = -1e18\n"," if np.dtype(dtype) == np.float16:\n"," neg = -1e4\n"," else:\n"," try:\n"," # if AMP (automatic mixed precision) is enabled, -1e18 will cause NaN.\n"," from mxnet.contrib import amp\n"," if amp.amp._amp_initialized:\n"," neg = -1e4\n"," except ImportError:\n"," pass\n"," att_score = F.where(mask, att_score, neg * F.ones_like(att_score))\n"," att_weights = F.softmax(att_score, axis=-1) * mask\n"," else:\n"," att_weights = F.softmax(att_score, axis=-1)\n"," return att_weights\n","\n","def _get_attention_cell(attention_cell, units=None,\n"," scaled=True, num_heads=None,\n"," use_bias=False, dropout=0.0, activation='relu'):\n"," \"\"\"\n"," Parameters\n"," ----------\n"," attention_cell : AttentionCell or str\n"," units : int or None\n"," Returns\n"," -------\n"," attention_cell : AttentionCell\n"," \"\"\"\n"," if isinstance(attention_cell, str):\n"," if attention_cell == 'scaled_luong':\n"," return DotProductAttentionCell(units=units, scaled=True, normalized=False,\n"," use_bias=use_bias, dropout=dropout, luong_style=True)\n"," elif attention_cell == 'scaled_dot':\n"," return DotProductAttentionCell(units=units, scaled=True, normalized=False,\n"," use_bias=use_bias, dropout=dropout, luong_style=False)\n"," elif attention_cell == 'dot':\n"," return DotProductAttentionCell(units=units, scaled=False, normalized=False,\n"," use_bias=use_bias, dropout=dropout, luong_style=False)\n"," elif attention_cell == 'cosine':\n"," return DotProductAttentionCell(units=units, scaled=False, use_bias=use_bias,\n"," dropout=dropout, normalized=True)\n"," # elif attention_cell == 'mlp':\n"," # return MLPAttentionCell(units=units, normalized=False)\n"," # elif attention_cell == 'normed_mlp':\n"," # return MLPAttentionCell(units=units, normalized=True)\n"," elif attention_cell == 'multi_head':\n"," base_cell = DotProductAttentionCell(scaled=scaled, dropout=dropout, activation=activation)\n"," return MultiHeadAttentionCell(base_cell=base_cell, query_units=units, use_bias=use_bias,\n"," key_units=units, value_units=units, num_heads=num_heads\n"," )\n"," else:\n"," raise NotImplementedError\n"," else:\n"," assert isinstance(attention_cell, AttentionCell),\\\n"," 'attention_cell must be either string or AttentionCell. Received attention_cell={}'\\\n"," .format(attention_cell)\n"," return attention_cell\n","\n","class DotProductAttentionCell(AttentionCell):\n"," r\"\"\"Dot product attention between the query and the key.\n"," Depending on parameters, defined as::\n"," units is None:\n"," score = \n"," units is not None and luong_style is False:\n"," score = \n"," units is not None and luong_style is True:\n"," score = \n"," Parameters\n"," ----------\n"," units: int or None, default None\n"," Project the query and key to vectors with `units` dimension\n"," before applying the attention. If set to None,\n"," the query vector and the key vector are directly used to compute the attention and\n"," should have the same dimension::\n"," If the units is None,\n"," score = \n"," Else if the units is not None and luong_style is False:\n"," score = \n"," Else if the units is not None and luong_style is True:\n"," score = \n"," luong_style: bool, default False\n"," If turned on, the score will be::\n"," score = \n"," `units` must be the same as the dimension of the key vector\n"," scaled: bool, default True\n"," Whether to divide the attention weights by the sqrt of the query dimension.\n"," This is first proposed in \"[NIPS2017] Attention is all you need.\"::\n"," score = / sqrt(dim_q)\n"," normalized: bool, default False\n"," If turned on, the cosine distance is used, i.e::\n"," score = \n"," use_bias : bool, default True\n"," Whether to use bias in the projection layers.\n"," dropout : float, default 0.0\n"," Attention dropout\n"," weight_initializer : str or `Initializer` or None, default None\n"," Initializer of the weights\n"," bias_initializer : str or `Initializer`, default 'zeros'\n"," Initializer of the bias\n"," prefix : str or None, default None\n"," See document of `Block`.\n"," params : str or None, default None\n"," See document of `Block`.\n"," \"\"\"\n"," def __init__(self, units=None, luong_style=False, scaled=True, normalized=False, use_bias=True,\n"," activation=None,\n"," dropout=0.0, weight_initializer=None, bias_initializer='zeros',\n"," prefix=None, params=None):\n"," super(DotProductAttentionCell, self).__init__(prefix=prefix, params=params)\n"," self._units = units\n"," self._scaled = scaled\n"," self._normalized = normalized\n"," self._use_bias = use_bias\n"," self._luong_style = luong_style\n"," self._dropout = dropout\n"," self._activation = activation\n","\n"," if self._luong_style:\n"," assert units is not None, 'Luong style attention is not available without explicitly ' \\\n"," 'setting the units'\n"," with self.name_scope():\n"," self._dropout_layer = nn.Dropout(dropout)\n","\n"," if self._activation is not None:\n"," with self.name_scope():\n"," self.act = gluon.nn.LeakyReLU(alpha=0.1)\n","\n"," if units is not None:\n"," with self.name_scope():\n"," self._proj_query = nn.Dense(units=self._units, use_bias=self._use_bias,\n"," flatten=False, weight_initializer=weight_initializer,\n"," bias_initializer=bias_initializer,\n"," prefix='query_')\n","\n"," if not self._luong_style:\n"," self._proj_key = nn.Dense(units=self._units, use_bias=self._use_bias,\n"," flatten=False, weight_initializer=weight_initializer,\n"," bias_initializer=bias_initializer, prefix='key_')\n"," if self._normalized:\n"," with self.name_scope():\n"," self._l2_norm = L2Normalization(axis=-1)\n","\n"," def _compute_weight(self, F, query, key, mask=None):\n"," if self._units is not None:\n"," query = self._proj_query(query)\n","\n"," # leakyrelu activation per alibaba rec article is used in self-attention and ffn\n"," if self._activation is not None:\n"," query = self.act(query)\n","\n"," if not self._luong_style:\n"," key = self._proj_key(key)\n","\n"," # leakyrelu activation per alibaba rec article is used in self-attention and ffn\n"," if self._activation is not None:\n"," key = self.act(key)\n","\n"," elif F == mx.nd:\n"," assert query.shape[-1] == key.shape[-1], 'Luong style attention requires key to ' \\\n"," 'have the same dim as the projected ' \\\n"," 'query. Received key {}, query {}.'.format(\n"," key.shape, query.shape)\n"," if self._normalized:\n"," query = self._l2_norm(query)\n"," key = self._l2_norm(key)\n"," if self._scaled:\n"," query = F.contrib.div_sqrt_dim(query)\n","\n"," att_score = F.batch_dot(query, key, transpose_b=True)\n","\n"," att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask, self._dtype))\n"," return att_weights\n","\n","\n","class MultiHeadAttentionCell(AttentionCell):\n"," r\"\"\"Multi-head Attention Cell.\n"," In the MultiHeadAttentionCell, the input query/key/value will be linearly projected\n"," for `num_heads` times with different projection matrices. Each projected key, value, query\n"," will be used to calculate the attention weights and values. The output of each head will be\n"," concatenated to form the final output.\n"," The idea is first proposed in \"[Arxiv2014] Neural Turing Machines\" and\n"," is later adopted in \"[NIPS2017] Attention is All You Need\" to solve the\n"," Neural Machine Translation problem.\n"," Parameters\n"," ----------\n"," base_cell : AttentionCell\n"," query_units : int\n"," Total number of projected units for query. Must be divided exactly by num_heads.\n"," key_units : int\n"," Total number of projected units for key. Must be divided exactly by num_heads.\n"," value_units : int\n"," Total number of projected units for value. Must be divided exactly by num_heads.\n"," num_heads : int\n"," Number of parallel attention heads\n"," use_bias : bool, default True\n"," Whether to use bias when projecting the query/key/values\n"," weight_initializer : str or `Initializer` or None, default None\n"," Initializer of the weights.\n"," bias_initializer : str or `Initializer`, default 'zeros'\n"," Initializer of the bias.\n"," prefix : str or None, default None\n"," See document of `Block`.\n"," params : str or None, default None\n"," See document of `Block`.\n"," \"\"\"\n"," def __init__(self, base_cell, query_units, key_units, value_units, num_heads, use_bias=True,\n"," weight_initializer=None, bias_initializer='zeros', prefix=None, params=None):\n"," super(MultiHeadAttentionCell, self).__init__(prefix=prefix, params=params)\n"," self._base_cell = base_cell\n"," self._num_heads = num_heads\n"," self._use_bias = use_bias\n"," units = {'query': query_units, 'key': key_units, 'value': value_units}\n"," for name, unit in units.items():\n"," if unit % self._num_heads != 0:\n"," raise ValueError(\n"," 'In MultiHeadAttetion, the {name}_units should be divided exactly'\n"," ' by the number of heads. Received {name}_units={unit}, num_heads={n}'.format(\n"," name=name, unit=unit, n=num_heads))\n"," setattr(self, '_{}_units'.format(name), unit)\n"," with self.name_scope():\n"," setattr(\n"," self, 'proj_{}'.format(name),\n"," nn.Dense(units=unit, use_bias=self._use_bias, flatten=False,\n"," weight_initializer=weight_initializer,\n"," bias_initializer=bias_initializer, prefix='{}_'.format(name)))\n","\n"," def __call__(self, query, key, value=None, mask=None):\n"," \"\"\"Compute the attention.\n"," Parameters\n"," ----------\n"," query : Symbol or NDArray\n"," Query vector. Shape (batch_size, query_length, query_dim)\n"," key : Symbol or NDArray\n"," Key of the memory. Shape (batch_size, memory_length, key_dim)\n"," value : Symbol or NDArray or None, default None\n"," Value of the memory. If set to None, the value will be set as the key.\n"," Shape (batch_size, memory_length, value_dim)\n"," mask : Symbol or NDArray or None, default None\n"," Mask of the memory slots. Shape (batch_size, query_length, memory_length)\n"," Only contains 0 or 1 where 0 means that the memory slot will not be used.\n"," If set to None. No mask will be used.\n"," Returns\n"," -------\n"," context_vec : Symbol or NDArray\n"," Shape (batch_size, query_length, context_vec_dim)\n"," att_weights : Symbol or NDArray\n"," Attention weights of multiple heads.\n"," Shape (batch_size, num_heads, query_length, memory_length)\n"," \"\"\"\n"," return super(MultiHeadAttentionCell, self).__call__(query, key, value, mask)\n","\n"," def _project(self, F, name, x):\n"," # Shape (batch_size, query_length, query_units)\n"," x = getattr(self, 'proj_{}'.format(name))(x)\n"," # Shape (batch_size * num_heads, query_length, ele_units)\n"," x = F.transpose(x.reshape(shape=(0, 0, self._num_heads, -1)),\n"," axes=(0, 2, 1, 3))\\\n"," .reshape(shape=(-1, 0, 0), reverse=True)\n"," return x\n","\n"," def _compute_weight(self, F, query, key, mask=None):\n"," query = self._project(F, 'query', query)\n"," key = self._project(F, 'key', key)\n"," if mask is not None:\n"," mask = F.broadcast_axis(F.expand_dims(mask, axis=1),\n"," axis=1, size=self._num_heads)\\\n"," .reshape(shape=(-1, 0, 0), reverse=True)\n"," att_weights = self._base_cell._compute_weight(F, query, key, mask)\n"," return att_weights.reshape(shape=(-1, self._num_heads, 0, 0), reverse=True)\n","\n"," def _read_by_weight(self, F, att_weights, value):\n"," att_weights = att_weights.reshape(shape=(-1, 0, 0), reverse=True)\n"," value = self._project(F, 'value', value)\n"," context_vec = self._base_cell._read_by_weight(F, att_weights, value)\n"," context_vec = F.transpose(context_vec.reshape(shape=(-1, self._num_heads, 0, 0),\n"," reverse=True),\n"," axes=(0, 2, 1, 3)).reshape(shape=(0, 0, -1))\n"," return context_vec\n","\n","\n","def _get_layer_norm(use_bert, units, layer_norm_eps=None):\n"," # from gluonnlp.model.bert import BERTLayerNorm\n"," layer_norm = nn.LayerNorm\n"," if layer_norm_eps:\n"," return layer_norm(in_channels=units, epsilon=layer_norm_eps)\n"," else:\n"," return layer_norm(in_channels=units)\n","\n","\n","class BasePositionwiseFFN(HybridBlock):\n"," \"\"\"Base Structure of the Positionwise Feed-Forward Neural Network.\n"," Parameters\n"," ----------\n"," units : int\n"," Number of units for the output\n"," hidden_size : int\n"," Number of units in the hidden layer of position-wise feed-forward networks\n"," dropout : float\n"," use_residual : bool\n"," weight_initializer : str or Initializer\n"," Initializer for the input weights matrix, used for the linear\n"," transformation of the inputs.\n"," bias_initializer : str or Initializer\n"," Initializer for the bias vector.\n"," activation : str, default 'relu'\n"," Activation function\n"," use_bert_layer_norm : bool, default False.\n"," Whether to use the BERT-stype layer norm implemented in Tensorflow, where\n"," epsilon is added inside the square root. Set to True for pre-trained BERT model.\n"," ffn1_dropout : bool, default False\n"," If True, apply dropout both after the first and second Positionwise\n"," Feed-Forward Neural Network layers. If False, only apply dropout after\n"," the second.\n"," prefix : str, default None\n"," Prefix for name of `Block`s\n"," (and name of weight if params is `None`).\n"," params : Parameter or None\n"," Container for weight sharing between cells.\n"," Created if `None`.\n"," layer_norm_eps : float, default None\n"," Epsilon for layer_norm\n"," Inputs:\n"," - **inputs** : input sequence of shape (batch_size, length, C_in).\n"," Outputs:\n"," - **outputs** : output encoding of shape (batch_size, length, C_out).\n"," \"\"\"\n","\n"," def __init__(self, units=512, hidden_size=2048, dropout=0.0, use_residual=True,\n"," weight_initializer=None, bias_initializer='zeros', activation='leakyrelu',\n"," use_bert_layer_norm=False, ffn1_dropout=False, prefix=None, params=None,\n"," layer_norm_eps=None):\n"," super(BasePositionwiseFFN, self).__init__(prefix=prefix, params=params)\n"," self._hidden_size = hidden_size\n"," self._units = units\n"," self._use_residual = use_residual\n"," self._dropout = dropout\n"," self._ffn1_dropout = ffn1_dropout\n"," with self.name_scope():\n"," self.ffn_1 = nn.Dense(units=hidden_size, flatten=False,\n"," weight_initializer=weight_initializer,\n"," bias_initializer=bias_initializer,\n"," prefix='ffn_1_')\n"," self.activation = self._get_activation(activation) if activation else None\n"," self.ffn_2 = nn.Dense(units=units, flatten=False,\n"," weight_initializer=weight_initializer,\n"," bias_initializer=bias_initializer,\n"," prefix='ffn_2_')\n"," if dropout:\n"," self.dropout_layer = nn.Dropout(rate=dropout)\n"," self.layer_norm = _get_layer_norm(use_bert_layer_norm, units,\n"," layer_norm_eps=layer_norm_eps)\n","\n"," def _get_activation(self, act):\n"," \"\"\"Get activation block based on the name. \"\"\"\n"," if isinstance(act, str):\n","\n"," # per alibaba rec article leakyRELU is used in self-attention and ffn\n"," if act.lower() == 'leakyrelu':\n"," return gluon.nn.LeakyReLU(alpha=0.1)\n"," else:\n"," return gluon.nn.Activation(act)\n"," assert isinstance(act, gluon.Block)\n"," return act\n","\n"," def hybrid_forward(self, F, inputs): # pylint: disable=arguments-differ\n"," # pylint: disable=unused-argument\n"," \"\"\"Position-wise encoding of the inputs.\n"," Parameters\n"," ----------\n"," inputs : Symbol or NDArray\n"," Input sequence. Shape (batch_size, length, C_in)\n"," Returns\n"," -------\n"," outputs : Symbol or NDArray\n"," Shape (batch_size, length, C_out)\n"," \"\"\"\n"," outputs = self.ffn_1(inputs)\n"," if self.activation:\n"," outputs = self.activation(outputs)\n"," if self._dropout and self._ffn1_dropout:\n"," outputs = self.dropout_layer(outputs)\n"," outputs = self.ffn_2(outputs)\n"," if self.activation:\n"," outputs = self.activation(outputs)\n"," if self._dropout:\n"," outputs = self.dropout_layer(outputs)\n"," if self._use_residual:\n"," outputs = outputs + inputs\n"," outputs = self.layer_norm(outputs)\n"," return outputs\n","\n","\n","class PositionwiseFFN(BasePositionwiseFFN):\n"," \"\"\"Structure of the Positionwise Feed-Forward Neural Network for\n"," Transformer.\n"," Computes the positionwise encoding of the inputs.\n"," Parameters\n"," ----------\n"," units : int\n"," Number of units for the output\n"," hidden_size : int\n"," Number of units in the hidden layer of position-wise feed-forward networks\n"," dropout : float\n"," Dropout probability for the output\n"," use_residual : bool\n"," Add residual connection between the input and the output\n"," ffn1_dropout : bool, default False\n"," If True, apply dropout both after the first and second Positionwise\n"," Feed-Forward Neural Network layers. If False, only apply dropout after\n"," the second.\n"," weight_initializer : str or Initializer\n"," Initializer for the input weights matrix, used for the linear\n"," transformation of the inputs.\n"," bias_initializer : str or Initializer\n"," Initializer for the bias vector.\n"," prefix : str, default None\n"," Prefix for name of `Block`s (and name of weight if params is `None`).\n"," params : Parameter or None\n"," Container for weight sharing between cells. Created if `None`.\n"," activation : str, default 'relu'\n"," Activation methods in PositionwiseFFN\n"," layer_norm_eps : float, default None\n"," Epsilon for layer_norm\n"," Inputs:\n"," - **inputs** : input sequence of shape (batch_size, length, C_in).\n"," Outputs:\n"," - **outputs** : output encoding of shape (batch_size, length, C_out).\n"," \"\"\"\n","\n"," def __init__(self, units=512, hidden_size=2048, dropout=0.0, use_residual=True,\n"," ffn1_dropout=False, weight_initializer=None, bias_initializer='zeros', prefix=None,\n"," params=None, activation='relu', layer_norm_eps=None):\n"," super(PositionwiseFFN, self).__init__(\n"," units=units,\n"," hidden_size=hidden_size,\n"," dropout=dropout,\n"," use_residual=use_residual,\n"," weight_initializer=weight_initializer,\n"," bias_initializer=bias_initializer,\n"," prefix=prefix,\n"," params=params,\n"," # extra configurations for transformer\n"," activation=activation,\n"," use_bert_layer_norm=False,\n"," layer_norm_eps=layer_norm_eps,\n"," ffn1_dropout=ffn1_dropout)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"D5-OEpEtUuP7"},"source":["def _position_encoding_init(max_length, dim):\n"," \"\"\"Init the sinusoid position encoding table \"\"\"\n"," position_enc = np.arange(max_length).reshape((-1, 1)) \\\n"," / (np.power(10000, (2. / dim) * np.arange(dim).reshape((1, -1))))\n"," position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])\n"," position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])\n"," return position_enc\n","\n","def _position_encoding_init_BST(max_length, dim):\n"," \"\"\"For the BST recommender, the positional embedding takes the time of item being clicked as\n"," input feature and calculates the position value of item vi as p(vt) - p(vi) where\n"," p(vt) is recommending time and p(vi) is time the user clicked on item vi\n"," \"\"\"\n"," # Assume position_enc is the p(vt) - p(vi) fed as input\n"," position_enc = np.arange(max_length).reshape((-1, 1)) \\\n"," / (np.power(10000, (2. / dim) * np.arange(dim).reshape((1, -1))))\n"," return position_enc\n","\n","class Rec(HybridBlock):\n"," \"\"\"Alibaba transformer based recommender\"\"\"\n"," def __init__(self, **kwargs):\n"," super(Rec, self).__init__(**kwargs)\n"," with self.name_scope():\n"," self.otherfeatures = nn.Embedding(input_dim=_OTHER_LEN,\n"," output_dim=_EMB_DIM)\n"," self.features = nn.Embedding(input_dim=_SEQ_LEN,\n"," output_dim=_EMB_DIM)\n"," # Transformer layers\n"," # Multi-head attention with base cell scaled dot-product attention\n"," # Use b=1 self-attention blocks per article recommendation\n"," self.cell = _get_attention_cell('multi_head',\n"," units=_UNITS,\n"," scaled=True,\n"," dropout=_DROP,\n"," num_heads=_NUM_HEADS,\n"," use_bias=False)\n"," self.proj = nn.Dense(units=_UNITS,\n"," use_bias=False,\n"," bias_initializer='zeros',\n"," weight_initializer=None,\n"," flatten=False\n"," )\n"," self.drop_out_layer = nn.Dropout(rate=_DROP)\n"," self.ffn = PositionwiseFFN(hidden_size=_UNITS,\n"," use_residual=True,\n"," dropout=_DROP,\n"," units=_UNITS,\n"," weight_initializer=None,\n"," bias_initializer='zeros',\n"," activation='leakyrelu'\n"," )\n"," self.layer_norm = nn.LayerNorm(in_channels=_UNITS)\n"," # Final MLP layers; BST dimensions in the article were 1024, 512, 256\n"," self.output = HybridSequential()\n"," self.output.add(nn.Dense(8))\n"," self.output.add(LeakyReLU(alpha=0.1))\n"," self.output.add(nn.Dense(4))\n"," self.output.add(LeakyReLU(alpha=0.1))\n"," self.output.add(nn.Dense(2))\n"," self.output.add(LeakyReLU(alpha=0.1))\n"," self.output.add(nn.Dense(1))\n","\n"," def _arange_like(self, F, inputs, axis):\n"," \"\"\"Helper function to generate indices of a range\"\"\"\n"," if F == mx.ndarray:\n"," seq_len = inputs.shape[axis]\n"," arange = F.arange(seq_len, dtype=inputs.dtype, ctx=inputs.context)\n"," else:\n"," input_axis = inputs.slice(begin=(0, 0, 0), end=(1, None, 1)).reshape((-1))\n"," zeros = F.zeros_like(input_axis)\n"," arange = F.arange(start=0, repeat=1, step=1,\n"," infer_range=True, dtype=inputs.dtype)\n"," arange = F.elemwise_add(arange, zeros)\n"," # print(arange)\n"," return arange\n","\n"," def _get_positional(self, weight_type, max_length, units):\n"," if weight_type == 'sinusoidal':\n"," encoding = _position_encoding_init(max_length, units)\n"," elif weight_type == 'BST':\n"," # BST position fed as input\n"," encoding = _position_encoding_init_BST(max_length, units)\n"," else:\n"," raise ValueError('Not known')\n"," return mx.nd.array(encoding)\n","\n"," def hybrid_forward(self, F, x, x_other, mask=None):\n"," # The manually engineered features\n"," x1 = self.otherfeatures(x_other)\n","\n"," # The transformer encoder\n"," steps = self._arange_like(F, x, axis=1)\n"," x = self.features(x)\n"," position_weight = self._get_positional('BST', _SEQ_LEN, _UNITS)\n"," # add positional embedding\n"," positional_embedding = F.Embedding(steps, position_weight, _SEQ_LEN, _UNITS)\n"," x = F.broadcast_add(x, F.expand_dims(positional_embedding, axis=0))\n"," # attention cell with dropout\n"," out_x, attn_w = self.cell(x, x, x, mask)\n"," out_x = self.proj(out_x)\n"," out_x = self.drop_out_layer(out_x)\n"," # add and norm\n"," out_x = x + out_x\n"," out_x = self.layer_norm(out_x)\n"," # ffn\n"," out_x = self.ffn(out_x)\n","\n"," # concat engineered features with transformer representations\n"," out_x = mx.ndarray.concat(out_x, x1)\n"," # leakyrelu final layers\n"," out_x = self.output(out_x)\n"," return out_x"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"fRjnFotyUvda"},"source":["## Train"]},{"cell_type":"code","metadata":{"id":"r64OhBu1U47O"},"source":["def generate_sample():\n"," \"\"\"Generate toy X and y. One target item.\n"," \"\"\"\n"," X = mx.random.uniform(shape=(100, 64), ctx=ctx)\n"," y = mx.random.uniform(shape=(100, 1), ctx=ctx)\n"," y = y > 0.5\n"," # Data loader\n"," d = gluon.data.ArrayDataset(X, y, )\n"," return gluon.data.DataLoader(d, _BATCH, last_batch='keep')\n","\n","\n","train_metric = mx.metric.Accuracy()\n","\n","\n","def train():\n"," train_data = generate_sample()\n"," optimizer = mx.optimizer.Adam()\n"," # Binary classification problem; predict if user clicks target item\n"," loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()\n"," net = Rec()\n"," net.initialize()\n"," trainer = gluon.Trainer(net.collect_params(),\n"," optimizer)\n"," # train_metric = mx.metric.Accuracy()\n"," epochs = 1\n","\n"," for epoch in range(epochs):\n"," train_metric.reset()\n"," for x, y in train_data:\n"," with ag.record():\n"," # assume x contains sequential inputs and manually engineered features\n"," output = net(x[:, :32], x[:, 32:])\n"," l = loss(output, y).sum()\n"," l.backward()\n"," trainer.step(_BATCH)\n"," train_metric.update(y, output)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"zS1Zwl1dU7un"},"source":["train()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"9vsk8mq_YXre","executionInfo":{"status":"ok","timestamp":1633117398581,"user_tz":-330,"elapsed":790,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"ab493efd-9d59-4561-cb65-10490c01ca8c"},"source":["train_metric.get()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["('accuracy', 0.0)"]},"metadata":{},"execution_count":26}]},{"cell_type":"markdown","metadata":{"id":"XBKrp1i9VCd-"},"source":["## Test"]},{"cell_type":"code","metadata":{"id":"XWZ_aIuwVFKG"},"source":["_SEQ_LEN = 32\n","_BATCH = 1\n","ctx = mx.cpu()\n","\n","def _tst_module(net, x):\n"," net.initialize()\n"," net.collect_params()\n"," net(x,x)\n"," mx.nd.waitall()\n","\n","def test():\n"," x = mx.random.uniform(shape=(_BATCH, _SEQ_LEN), ctx=ctx)\n"," net = Rec()\n"," _tst_module(net, x)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"oCqjWGuGVIEh"},"source":["test()"],"execution_count":null,"outputs":[]}]}