{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "from fastai.rnn_reg import *\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embedded_dropout(embed, words, dropout=0.1, scale=None):\n",
    "    \"\"\" Applies dropout in the embedding layer by zeroing out some elements of the embedding vector.\n",
    "    Uses the dropout_mask custom layer to achieve this.\n",
    "\n",
    "    Args:\n",
    "        embed (torch.nn.Embedding): An embedding torch layer\n",
    "        words (torch.nn.Variable): A torch variable\n",
    "        dropout (float): dropout fraction to apply to the embedding weights\n",
    "        scale (float): additional scaling to apply to the modified embedding weights\n",
    "\n",
    "    Returns:\n",
    "        tensor of size: (batch_size x seq_length x embedding_size)\n",
    "\n",
    "    Example:\n",
    "\n",
    "    >> embed = torch.nn.Embedding(10,3)\n",
    "    >> words = Variable(torch.LongTensor([[1,2,4,5] ,[4,3,2,9]]))\n",
    "    >> words.size()\n",
    "        (2,4)\n",
    "\n",
    "    >> dropout_out_ = embedded_dropout(embed, words, dropout=0.40)\n",
    "    >> dropout_out_\n",
    "        Variable containing:\n",
    "        (0 ,.,.) =\n",
    "          1.2549  1.8230  1.9367\n",
    "          0.0000 -0.0000  0.0000\n",
    "          2.2540 -0.1299  1.5448\n",
    "          0.0000 -0.0000 -0.0000\n",
    "\n",
    "        (1 ,.,.) =\n",
    "          2.2540 -0.1299  1.5448\n",
    "         -4.0457  2.4815 -0.2897\n",
    "          0.0000 -0.0000  0.0000\n",
    "          1.8796 -0.4022  3.8773\n",
    "        [torch.FloatTensor of size 2x4x3]\n",
    "    \"\"\"\n",
    "\n",
    "    if dropout:\n",
    "        mask = Variable(dropout_mask(embed.weight.data, (embed.weight.size(0), 1), dropout))\n",
    "        masked_embed_weight = mask * embed.weight\n",
    "    else: masked_embed_weight = embed.weight\n",
    "    if scale: masked_embed_weight = scale * masked_embed_weight\n",
    "\n",
    "    padding_idx = embed.padding_idx\n",
    "    if padding_idx is None: padding_idx = -1\n",
    "    X = embed._backend.Embedding.apply(words, masked_embed_weight,\n",
    "        padding_idx, embed.max_norm, embed.norm_type,\n",
    "        embed.scale_grad_by_freq, embed.sparse\n",
    "    )\n",
    "    return X\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Initialize embedding matrix and input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embed = torch.nn.Embedding(10,3)\n",
    "words = torch.autograd.Variable(torch.LongTensor([[1,2,4,5] ,[4,3,2,9]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### propagate the input via the old method (embedded_dropout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Variable containing:\n",
       "(0 ,.,.) = \n",
       " -0.0000 -0.0000  0.0000\n",
       " -0.0000 -0.0000  0.0000\n",
       " -0.8727 -1.3131 -1.2833\n",
       " -0.0000 -0.0000 -0.0000\n",
       "\n",
       "(1 ,.,.) = \n",
       " -0.8727 -1.3131 -1.2833\n",
       " -0.7425 -4.4689  0.9931\n",
       " -0.0000 -0.0000  0.0000\n",
       " -1.3907 -0.6359 -3.4528\n",
       "[torch.FloatTensor of size 2x4x3]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(88123)\n",
    "dropout_out_old = embedded_dropout(embed, words, dropout=0.40)\n",
    "dropout_out_old"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### propagate the input via the forward method in the new layer (EmbeddingDropout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Variable containing:\n",
       "(0 ,.,.) = \n",
       " -0.0000 -0.0000  0.0000\n",
       " -0.0000 -0.0000  0.0000\n",
       " -0.8727 -1.3131 -1.2833\n",
       " -0.0000 -0.0000 -0.0000\n",
       "\n",
       "(1 ,.,.) = \n",
       " -0.8727 -1.3131 -1.2833\n",
       " -0.7425 -4.4689  0.9931\n",
       " -0.0000 -0.0000  0.0000\n",
       " -1.3907 -0.6359 -3.4528\n",
       "[torch.FloatTensor of size 2x4x3]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(88123)\n",
    "embed_dropout_layer = EmbeddingDropout(embed)\n",
    "dropout_out_new = embed_dropout_layer(words, dropout=0.4)\n",
    "dropout_out_new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n"
     ]
    }
   ],
   "source": [
    "print(np.testing.assert_array_equal(to_np(dropout_out_old),to_np(dropout_out_new)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Initialize embedding and matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embed = torch.nn.Embedding(10,7)\n",
    "words = torch.autograd.Variable(torch.LongTensor([[1,2,4,5,2,8] ,[4,3,2,9,7,6]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### get the input by propagating via the old method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Variable containing:\n",
       "(0 ,.,.) = \n",
       " -0.1285  0.7315  0.1410 -1.6623 -3.7160 -1.1500  1.9748\n",
       " -0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000  0.0000\n",
       " -1.4062 -0.2985 -1.2622 -1.2019 -3.9978  2.7126  4.9408\n",
       "  0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000  0.0000\n",
       " -0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000  0.0000\n",
       " -0.0000  0.0000  0.0000 -0.0000 -0.0000  0.0000  0.0000\n",
       "\n",
       "(1 ,.,.) = \n",
       " -1.4062 -0.2985 -1.2622 -1.2019 -3.9978  2.7126  4.9408\n",
       "  0.0000 -0.0000  0.0000  0.0000 -0.0000  0.0000 -0.0000\n",
       " -0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000  0.0000\n",
       " -6.4958  2.6005 -2.2331 -7.8277 -0.9653  3.3815  1.9252\n",
       " -0.0000 -0.0000  0.0000  0.0000 -0.0000  0.0000  0.0000\n",
       "  0.0000  0.0000  0.0000  0.0000 -0.0000 -0.0000  0.0000\n",
       "[torch.FloatTensor of size 2x6x7]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(7123)\n",
    "dropout_out_old = embedded_dropout(embed, words, dropout=0.64)\n",
    "dropout_out_old"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### get the input by propagating input via the embedding layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Variable containing:\n",
       "(0 ,.,.) = \n",
       " -0.1285  0.7315  0.1410 -1.6623 -3.7160 -1.1500  1.9748\n",
       " -0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000  0.0000\n",
       " -1.4062 -0.2985 -1.2622 -1.2019 -3.9978  2.7126  4.9408\n",
       "  0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000  0.0000\n",
       " -0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000  0.0000\n",
       " -0.0000  0.0000  0.0000 -0.0000 -0.0000  0.0000  0.0000\n",
       "\n",
       "(1 ,.,.) = \n",
       " -1.4062 -0.2985 -1.2622 -1.2019 -3.9978  2.7126  4.9408\n",
       "  0.0000 -0.0000  0.0000  0.0000 -0.0000  0.0000 -0.0000\n",
       " -0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000  0.0000\n",
       " -6.4958  2.6005 -2.2331 -7.8277 -0.9653  3.3815  1.9252\n",
       " -0.0000 -0.0000  0.0000  0.0000 -0.0000  0.0000  0.0000\n",
       "  0.0000  0.0000  0.0000  0.0000 -0.0000 -0.0000  0.0000\n",
       "[torch.FloatTensor of size 2x6x7]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(7123)\n",
    "embed_dropout_layer = EmbeddingDropout(embed)\n",
    "dropout_out_new = embed_dropout_layer(words, dropout=0.64)\n",
    "dropout_out_new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n"
     ]
    }
   ],
   "source": [
    "print(np.testing.assert_array_equal(to_np(dropout_out_old),to_np(dropout_out_new)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}