{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "collapsed_sections": [], "mount_file_id": "1DXQ2nyL3PBZqaxXXHuOji_Ff49Tcu7Ws", "authorship_tag": "ABX9TyP7AbGxg6Hn7Mz56WhP0tfb", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "gpuClass": "standard", "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "# Natural Language Processing Demystified | Transformers, Pre-training, and Transfer Learning\n", "https://nlpdemystified.org
\n", "https://github.com/nitinpunjabi/nlp-demystified

\n", "\n", "Course module for this demo: https://www.nlpdemystified.org/course/transformers" ], "metadata": { "id": "KmEyadzTtGxY" } }, { "cell_type": "markdown", "source": [ "**IMPORTANT**
\n", "Enable **GPU acceleration** by going to *Runtime > Change Runtime Type*. Keep in mind that, on certain tiers, you're not guaranteed GPU access depending on usage history and current load.\n", "

\n", "Also, if you're running this in the cloud rather than a local Jupyter server on your machine, then the notebook will *timeout* after a period of inactivity.\n", "

\n", "Refer to this link on how to run Colab notebooks locally on your machine to avoid this issue:
\n", "https://research.google.com/colaboratory/local-runtimes.html" ], "metadata": { "id": "uOVYaAveQJia" } }, { "cell_type": "code", "source": [ "!pip install BPEmb\n", "\n", "import math\n", "import numpy as np\n", "import tensorflow as tf\n", "\n", "from bpemb import BPEmb" ], "metadata": { "id": "vWyjB-YNwTG_" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Transformers From Scratch" ], "metadata": { "id": "MenE2varZEXc" } }, { "cell_type": "markdown", "source": [ "We'll build a transformer from scratch, layer-by-layer. We'll start with the **Multi-Head Self-Attention** layer since that's the most involved bit. Once we have that working, the rest of the model will look familiar if you've been following the course so far." ], "metadata": { "id": "mDkTVv3KMJX_" } }, { "cell_type": "markdown", "source": [ "## Multi-Head Self-Attention" ], "metadata": { "id": "LqX04fFXBdxy" } }, { "cell_type": "markdown", "source": [ "#### Scaled Dot Product Self-Attention" ], "metadata": { "id": "-XnKHnlYyijq" } }, { "cell_type": "markdown", "source": [ "\n", "Inside each attention head is a **Scaled Dot Product Self-Attention** operation as we covered in the slides. Given *queries*, *keys*, and *values*, the operation returns a new \"mix\" of the values.\n", "\n", "$$Attention(Q, K, V) = softmax(\\frac{QK^T)}{\\sqrt{d_k}})V$$\n", "\n", "The following function implements this and also takes a mask to account for padding and for masking future tokens for decoding (i.e. **look-ahead mask**). We'll cover masking later in the notebook." ], "metadata": { "id": "3NAf9HP7RsQu" } }, { "cell_type": "code", "source": [ "def scaled_dot_product_attention(query, key, value, mask=None):\n", " key_dim = tf.cast(tf.shape(key)[-1], tf.float32)\n", " scaled_scores = tf.matmul(query, key, transpose_b=True) / np.sqrt(key_dim)\n", "\n", " if mask is not None:\n", " scaled_scores = tf.where(mask==0, -np.inf, scaled_scores)\n", "\n", " softmax = tf.keras.layers.Softmax()\n", " weights = softmax(scaled_scores) \n", " return tf.matmul(weights, value), weights" ], "metadata": { "id": "7hpO6cGEN7HK" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Suppose our *queries*, *keys*, and *values* are each a length of 3 with a dimension of 4." ], "metadata": { "id": "lC_HhsreXh3H" } }, { "cell_type": "code", "source": [ "seq_len = 3\n", "embed_dim = 4\n", "\n", "queries = np.random.rand(seq_len, embed_dim)\n", "keys = np.random.rand(seq_len, embed_dim)\n", "values = np.random.rand(seq_len, embed_dim)\n", "\n", "print(\"Queries:\\n\", queries)" ], "metadata": { "id": "WB2cDybgX5LZ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "This would be the self-attention output and weights." ], "metadata": { "id": "QuNdMuz5vb1c" } }, { "cell_type": "code", "source": [ "output, attn_weights = scaled_dot_product_attention(queries, keys, values)\n", "\n", "print(\"Output\\n\", output, \"\\n\")\n", "print(\"Weights\\n\", attn_weights)" ], "metadata": { "id": "pxKj56hNX5UO" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "#### Generating queries, keys, and values for multiple heads." ], "metadata": { "id": "O8NLm6qaN7DE" } }, { "cell_type": "markdown", "source": [ "Now that we have a way to calculate self-attention, let's actually generate the input *queries*, *keys*, and *values* for multiple heads.\n", "

\n", "In the slides (and in most references), each attention head had its own separate set of *query*, *key*, and *value* weights. Each weight matrix was of dimension $d\\ x \\ d/h$ where h was the number of heads. " ], "metadata": { "id": "wBm9jbpSN6-L" } }, { "cell_type": "markdown", "source": [ "![](https://drive.google.com/uc?export=view&id=1SLWkHQgy4nQPFvvjG5_V8UTtpSAJ2zrr)" ], "metadata": { "id": "YLiJy9OzfMu5" } }, { "cell_type": "markdown", "source": [ "It's easier to understand things this way and we can certainly code it this way as well. But we can also \"simulate\" different heads with a single query matrix, single key matrix, and single value matrix.\n", "

\n", "We'll do both. First we'll create *query*, *key*, and *value* vectors using separate weights per head.\n", "

\n", "In the slides, we used an example of 12 dimensional embeddings processed by three attentions heads, and we'll do the same here." ], "metadata": { "id": "3tKPwmi3fbys" } }, { "cell_type": "code", "source": [ "batch_size = 1\n", "seq_len = 3\n", "embed_dim = 12\n", "num_heads = 3\n", "head_dim = embed_dim // num_heads\n", "\n", "print(f\"Dimension of each head: {head_dim}\")" ], "metadata": { "id": "rJLyGtqbX3uW" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "**Using separate weight matrices per head**" ], "metadata": { "id": "JDl37YzAf7bh" } }, { "cell_type": "markdown", "source": [ "Suppose these are our input embeddings. Here we have a batch of 1 containing a sequence of length 3, with each element being a 12-dimensional embedding." ], "metadata": { "id": "xQ_KoJq3fv-A" } }, { "cell_type": "code", "source": [ "x = np.random.rand(batch_size, seq_len, embed_dim).round(1)\n", "print(\"Input shape: \", x.shape, \"\\n\")\n", "print(\"Input:\\n\", x)" ], "metadata": { "id": "7NcX3KBrX3uW" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We'll declare three sets of *query* weights (one for each head), three sets of *key* weights, and three sets of *value* weights. Remember each weight matrix should have a dimension of $\\text{d}\\ \\text{x}\\ \\text{d/h}$." ], "metadata": { "id": "uvJicbp6f7pI" } }, { "cell_type": "code", "source": [ "# The query weights for each head.\n", "wq0 = np.random.rand(embed_dim, head_dim).round(1)\n", "wq1 = np.random.rand(embed_dim, head_dim).round(1)\n", "wq2 = np.random.rand(embed_dim, head_dim).round(1)\n", "\n", "# The key weights for each head. \n", "wk0 = np.random.rand(embed_dim, head_dim).round(1)\n", "wk1 = np.random.rand(embed_dim, head_dim).round(1)\n", "wk2 = np.random.rand(embed_dim, head_dim).round(1)\n", "\n", "# The value weights for each head.\n", "wv0 = np.random.rand(embed_dim, head_dim).round(1)\n", "wv1 = np.random.rand(embed_dim, head_dim).round(1)\n", "wv2 = np.random.rand(embed_dim, head_dim).round(1)" ], "metadata": { "id": "8zdg7rqrX3uX" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(\"The three sets of query weights (one for each head):\")\n", "print(\"wq0:\\n\", wq0)\n", "print(\"wq1:\\n\", wq1)\n", "print(\"wq2:\\n\", wq1)" ], "metadata": { "id": "QzMRHZooX3uX" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We'll generate our *queries*, *keys*, and *values* for each head by multiplying our input by the weights." ], "metadata": { "id": "HmwGKV9qgch-" } }, { "cell_type": "code", "source": [ "# Geneated queries, keys, and values for the first head.\n", "q0 = np.dot(x, wq0)\n", "k0 = np.dot(x, wk0)\n", "v0 = np.dot(x, wv0)\n", "\n", "# Geneated queries, keys, and values for the second head.\n", "q1 = np.dot(x, wq1)\n", "k1 = np.dot(x, wk1)\n", "v1 = np.dot(x, wv1)\n", "\n", "# Geneated queries, keys, and values for the third head.\n", "q2 = np.dot(x, wq2)\n", "k2 = np.dot(x, wk2)\n", "v2 = np.dot(x, wv2)" ], "metadata": { "id": "NucbYNNSX3uX" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "These are the resulting *query*, *key*, and *value* vectors for the first head." ], "metadata": { "id": "AIDiwWZ0gqhm" } }, { "cell_type": "code", "source": [ "print(\"Q, K, and V for first head:\\n\")\n", "\n", "print(f\"q0 {q0.shape}:\\n\", q0, \"\\n\")\n", "print(f\"k0 {k0.shape}:\\n\", k0, \"\\n\")\n", "print(f\"v0 {v0.shape}:\\n\", v0)" ], "metadata": { "id": "NMcMmbkqX3uX" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now that we have our Q, K, V vectors, we can just pass them to our self-attention operation. Here we're calculating the output and attention weights for the first head." ], "metadata": { "id": "iw5CQ9i6qZDv" } }, { "cell_type": "code", "source": [ "out0, attn_weights0 = scaled_dot_product_attention(q0, k0, v0)\n", "\n", "print(\"Output from first attention head: \", out0, \"\\n\")\n", "print(\"Attention weights from first head: \", attn_weights0)" ], "metadata": { "id": "i7tHIvXKX3uX" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Here are the other two (attention weights are ignored)." ], "metadata": { "id": "DoYEXSm7qr_A" } }, { "cell_type": "code", "source": [ "out1, _ = scaled_dot_product_attention(q1, k1, v1)\n", "out2, _ = scaled_dot_product_attention(q2, k2, v2)\n", "\n", "print(\"Output from second attention head: \", out1, \"\\n\")\n", "print(\"Output from third attention head: \", out2,)" ], "metadata": { "id": "otnqbaDSqpJ7" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "As we covered in the slides, once we have each head's output, we concatenate them and then put them through a linear layer for further processing." ], "metadata": { "id": "lOV717bqX3uX" } }, { "cell_type": "code", "source": [ "combined_out_a = np.concatenate((out0, out1, out2), axis=-1)\n", "print(f\"Combined output from all heads {combined_out_a.shape}:\")\n", "print(combined_out_a)\n", "\n", "# The final step would be to run combined_out_a through a linear/dense layer \n", "# for further processing." ], "metadata": { "id": "gmSv5trtt2v9" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "So that's a complete run of **multi-head self-attention** using separate sets of weights per head.
\n", "\n", "Let's now get the same thing done using a single query weight matrix, single key weight matrix, and single value weight matrix.

\n", "These were our separate per-head query weights:" ], "metadata": { "id": "RRZpFR0Wt8h9" } }, { "cell_type": "code", "source": [ "print(\"Query weights for first head: \\n\", wq0, \"\\n\")\n", "print(\"Query weights for second head: \\n\", wq1, \"\\n\")\n", "print(\"Query weights for third head: \\n\", wq2)" ], "metadata": { "id": "XoJmLAsUX3uX" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Suppose instead of declaring three separate query weight matrices, we had declared one. i.e. a single $d\\ x\\ d$ matrix. We're concatenating our per-head query weights here instead of declaring a new set of weights so that we get the same results." ], "metadata": { "id": "oa_p3bk8mO9D" } }, { "cell_type": "code", "source": [ "wq = np.concatenate((wq0, wq1, wq2), axis=1)\n", "print(f\"Single query weight matrix {wq.shape}: \\n\", wq)" ], "metadata": { "id": "7jh6zeg1X3uX" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "In the same vein, pretend we declared a single key weight matrix, and single value weight matrix." ], "metadata": { "id": "-9MzE5Okmdbl" } }, { "cell_type": "code", "source": [ "wk = np.concatenate((wk0, wk1, wk2), axis=1)\n", "wv = np.concatenate((wv0, wv1, wv2), axis=1)\n", "\n", "print(f\"Single key weight matrix {wk.shape}:\\n\", wk, \"\\n\")\n", "print(f\"Single value weight matrix {wv.shape}:\\n\", wv)" ], "metadata": { "id": "xq2guuobX3uX" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now we can calculate all our *queries*, *keys*, and *values* with three dot products." ], "metadata": { "id": "WA7dl1VRnXHz" } }, { "cell_type": "code", "source": [ "q_s = np.dot(x, wq)\n", "k_s = np.dot(x, wk)\n", "v_s = np.dot(x, wv)" ], "metadata": { "id": "UQ5i98bLX3uX" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "These are our resulting query vectors (we'll call them \"combined queries\"). How do we simulate different heads with this?" ], "metadata": { "id": "xkAzG-bgnx1U" } }, { "cell_type": "code", "source": [ "print(f\"Query vectors using a single weight matrix {q_s.shape}:\\n\", q_s)" ], "metadata": { "id": "H-qKM3jZr242" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Somehow, we need to separate these vectors such they're treated like three separate sets by the self-attention operation." ], "metadata": { "id": "qsUULAgRsB2n" } }, { "cell_type": "code", "source": [ "print(q0, \"\\n\")\n", "print(q1, \"\\n\")\n", "print(q2)" ], "metadata": { "id": "FKXYVHbJvnGp" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Notice how each set of per-head queries looks like we took the combined queries, and chopped them vertically every four dimensions.\n", "

\n", "We can split our combined queries into $\\text{d}\\ \\text{x}\\ \\text{d/h}$ heads using **reshape** and **transpose**.

\n", "The first step is to *reshape* our combined queries from a shape of:
\n", "(batch_size, seq_len, embed_dim)
\n", "\n", "into a shape of
\n", " (batch_size, seq_len, num_heads, head_dim).\n", "
\n", "\n", " https://www.tensorflow.org/api_docs/python/tf/reshape" ], "metadata": { "id": "twXi0Sx-sTut" } }, { "cell_type": "code", "source": [ "# Note: we can achieve the same thing by passing -1 instead of seq_len.\n", "q_s_reshaped = tf.reshape(q_s, (batch_size, seq_len, num_heads, head_dim))\n", "print(f\"Combined queries: {q_s.shape}\\n\", q_s, \"\\n\")\n", "print(f\"Reshaped into separate heads: {q_s_reshaped.shape}\\n\", q_s_reshaped)" ], "metadata": { "id": "d3iHh7XxX3uY" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "At this point, we have our desired shape. The next step is to *transpose* it such that simulates vertically chopping our combined queries. By transposing, our matrix dimensions become:
\n", "(batch_size, num_heads, seq_len, head_dim)
\n", "\n", "https://www.tensorflow.org/api_docs/python/tf/transpose" ], "metadata": { "id": "6fIWohaZvVs9" } }, { "cell_type": "code", "source": [ "q_s_transposed = tf.transpose(q_s_reshaped, perm=[0, 2, 1, 3]).numpy()\n", "print(f\"Queries transposed into \\\"separate\\\" heads {q_s_transposed.shape}:\\n\", \n", " q_s_transposed)" ], "metadata": { "id": "6Vv3kV3jX3uY" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "If we compare this against the separate per-head queries we calculated previously, we see the same result except we now have all our queries in a single matrix." ], "metadata": { "id": "J2DOWEPewUns" } }, { "cell_type": "code", "source": [ "print(\"The separate per-head query matrices from before: \")\n", "print(q0, \"\\n\")\n", "print(q1, \"\\n\")\n", "print(q2)" ], "metadata": { "id": "ZMLEBmtowQ02" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Let's do the exact same thing with our combined keys and values." ], "metadata": { "id": "kmVPAaE3wmGj" } }, { "cell_type": "code", "source": [ "k_s_transposed = tf.transpose(tf.reshape(k_s, (batch_size, -1, num_heads, head_dim)), perm=[0, 2, 1, 3]).numpy()\n", "v_s_transposed = tf.transpose(tf.reshape(v_s, (batch_size, -1, num_heads, head_dim)), perm=[0, 2, 1, 3]).numpy()\n", "\n", "print(f\"Keys for all heads in a single matrix {k_s.shape}: \\n\", k_s_transposed, \"\\n\")\n", "print(f\"Values for all heads in a single matrix {v_s.shape}: \\n\", v_s_transposed)" ], "metadata": { "id": "vauGkBv3X3uY" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Set up this way, we can now calculate the outputs from all attention heads with a single call to our self-attention operation." ], "metadata": { "id": "ebGFAKGrxCoe" } }, { "cell_type": "code", "source": [ "all_heads_output, all_attn_weights = scaled_dot_product_attention(q_s_transposed, \n", " k_s_transposed, \n", " v_s_transposed)\n", "print(\"Self attention output:\\n\", all_heads_output)" ], "metadata": { "id": "hIElo1ObX3uY" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "As a sanity check, we can compare this against the outputs from individual heads we calculated earlier:" ], "metadata": { "id": "PCPtOI_awd-Z" } }, { "cell_type": "code", "source": [ "print(\"Per head outputs from using separate sets of weights per head:\")\n", "print(out0, \"\\n\")\n", "print(out1, \"\\n\")\n", "print(out2)" ], "metadata": { "id": "bXIB_z11xsh7" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "To get the final concatenated result, we need to reverse our **reshape** and **transpose** operation, starting with the **transpose** this time." ], "metadata": { "id": "hPlpXbZI74mX" } }, { "cell_type": "code", "source": [ "combined_out_b = tf.reshape(tf.transpose(all_heads_output, perm=[0, 2, 1, 3]), \n", " shape=(batch_size, seq_len, embed_dim))\n", "print(\"Final output from using single query, key, value matrices:\\n\", \n", " combined_out_b, \"\\n\")\n", "print(\"Final output from using separate query, key, value matrices per head:\\n\", \n", " combined_out_a)" ], "metadata": { "id": "9lWtCPk1wuod" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We can encapsulate everything we just covered in a class." ], "metadata": { "id": "Wi8WnhwL9UIa" } }, { "cell_type": "code", "source": [ "class MultiHeadSelfAttention(tf.keras.layers.Layer):\n", " def __init__(self, d_model, num_heads):\n", " super(MultiHeadSelfAttention, self).__init__()\n", " self.d_model = d_model\n", " self.num_heads = num_heads\n", "\n", " self.d_head = self.d_model // self.num_heads\n", "\n", " self.wq = tf.keras.layers.Dense(self.d_model)\n", " self.wk = tf.keras.layers.Dense(self.d_model)\n", " self.wv = tf.keras.layers.Dense(self.d_model)\n", "\n", " # Linear layer to generate the final output.\n", " self.dense = tf.keras.layers.Dense(self.d_model)\n", " \n", " def split_heads(self, x):\n", " batch_size = x.shape[0]\n", "\n", " split_inputs = tf.reshape(x, (batch_size, -1, self.num_heads, self.d_head))\n", " return tf.transpose(split_inputs, perm=[0, 2, 1, 3])\n", " \n", " def merge_heads(self, x):\n", " batch_size = x.shape[0]\n", "\n", " merged_inputs = tf.transpose(x, perm=[0, 2, 1, 3])\n", " return tf.reshape(merged_inputs, (batch_size, -1, self.d_model))\n", "\n", " def call(self, q, k, v, mask):\n", " qs = self.wq(q)\n", " ks = self.wk(k)\n", " vs = self.wv(v)\n", "\n", " qs = self.split_heads(qs)\n", " ks = self.split_heads(ks)\n", " vs = self.split_heads(vs)\n", "\n", " output, attn_weights = scaled_dot_product_attention(qs, ks, vs, mask)\n", " output = self.merge_heads(output)\n", "\n", " return self.dense(output), attn_weights\n" ], "metadata": { "id": "Sd_IgJI34vP4" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "mhsa = MultiHeadSelfAttention(12, 3)\n", "\n", "output, attn_weights = mhsa(x, x, x, None)\n", "print(f\"MHSA output{output.shape}:\")\n", "print(output)" ], "metadata": { "id": "nuvv-8cg6owq" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Encoder Block" ], "metadata": { "id": "uAk-GG2yMM59" } }, { "cell_type": "markdown", "source": [ "We can now build our **Encoder Block**. In addition to the **Multi-Head Self Attention** layer, the **Encoder Block** also has **skip connections**, **layer normalization steps**, and a **two-layer feed-forward neural network**. The original **Attention Is All You Need** paper also included some **dropout** applied to the self-attention output which isn't shown in the illustration below (see references for a link to the paper).\n", "\n", "
\n", "\n", "
" ], "metadata": { "id": "BHrQaN_B_rLh" } }, { "cell_type": "markdown", "source": [ "Since a two-layer feed forward neural network is used in multiple places in the transformer, here's a function which creates and returns one." ], "metadata": { "id": "S7Yc_FnvDNx4" } }, { "cell_type": "code", "source": [ "def feed_forward_network(d_model, hidden_dim):\n", " return tf.keras.Sequential([\n", " tf.keras.layers.Dense(hidden_dim, activation='relu'),\n", " tf.keras.layers.Dense(d_model)\n", " ])" ], "metadata": { "id": "mN5B0vduMM9a" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "This is our encoder block containing all the layers and steps from the preceding illustration (plus dropout)." ], "metadata": { "id": "4FrRAMJFDnVQ" } }, { "cell_type": "code", "source": [ "class EncoderBlock(tf.keras.layers.Layer):\n", " def __init__(self, d_model, num_heads, hidden_dim, dropout_rate=0.1):\n", " super(EncoderBlock, self).__init__()\n", "\n", " self.mhsa = MultiHeadSelfAttention(d_model, num_heads)\n", " self.ffn = feed_forward_network(d_model, hidden_dim)\n", "\n", " self.dropout1 = tf.keras.layers.Dropout(dropout_rate)\n", " self.dropout2 = tf.keras.layers.Dropout(dropout_rate)\n", "\n", " self.layernorm1 = tf.keras.layers.LayerNormalization()\n", " self.layernorm2 = tf.keras.layers.LayerNormalization()\n", " \n", " def call(self, x, training, mask):\n", " mhsa_output, attn_weights = self.mhsa(x, x, x, mask)\n", " mhsa_output = self.dropout1(mhsa_output, training=training)\n", " mhsa_output = self.layernorm1(x + mhsa_output)\n", "\n", " ffn_output = self.ffn(mhsa_output)\n", " ffn_output = self.dropout2(ffn_output, training=training)\n", " output = self.layernorm2(mhsa_output + ffn_output)\n", "\n", " return output, attn_weights\n" ], "metadata": { "id": "q8uu0mISAb0n" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Suppose we have an embedding dimension of 12, and we want 3 attention heads and a feed forward network with a hidden dimension of 48 (4x the embedding dimension). We would declare and use a single encoder block like so:" ], "metadata": { "id": "q3_2uXRBFBEY" } }, { "cell_type": "code", "source": [ "encoder_block = EncoderBlock(12, 3, 48)\n", "\n", "block_output, _ = encoder_block(x, True, None)\n", "print(f\"Output from single encoder block {block_output.shape}:\")\n", "print(block_output)" ], "metadata": { "id": "vBnumPJ7C7Jj" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Word and Positional Embeddings" ], "metadata": { "id": "I5z32v2QKYdy" } }, { "cell_type": "markdown", "source": [ "Let's now deal with the actual input to the **initial** encoder block. The inputs are going to be *positional word embeddings*. That is, word embeddings with some positional information added to them.\n", "
\n", "\n", "Let's start with **subword** tokenization. For demonstration, we'll use a subword tokenizer called **BPEmb**. It uses **Byte-Pair Encoding** and supports over two hundred languages. \n", "\n", "https://bpemb.h-its.org/\n" ], "metadata": { "id": "S4NuyQpYGBUo" } }, { "cell_type": "code", "source": [ "# Load the English tokenizer.\n", "bpemb_en = BPEmb(lang=\"en\")" ], "metadata": { "id": "nmMOHYDEKdvQ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "The library comes with embeddings for a number of words." ], "metadata": { "id": "uAjjB6ykHHyQ" } }, { "cell_type": "code", "source": [ "bpemb_vocab_size, bpemb_embed_size = bpemb_en.vectors.shape\n", "print(\"Vocabulary size:\", bpemb_vocab_size)\n", "print(\"Embedding size:\", bpemb_embed_size)" ], "metadata": { "id": "FhtnbTmdH6jU" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Embedding for the word \"car\".\n", "bpemb_en.vectors[bpemb_en.words.index('car')]" ], "metadata": { "id": "vKvODSJDIdt0" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We don't need the embeddings since we're going to use our own embedding layer. What we're interested in are the subword tokens and their respective ids. The ids will be used as indexes into our embedding layer.
\n", "\n", "If this doesn't sound familiar, refer to the module on word vectors:
\n", "https://www.nlpdemystified.org/course/word-vectors" ], "metadata": { "id": "YZ7wTWoUI4Zz" } }, { "cell_type": "markdown", "source": [ "These are the subword tokens for our example sentence from the slides. **BPEmb** places underscores in front of any tokens which are whole words or intended to begin words.
\n", "\n", "Remember that subword tokenizers are trained using count frequencies over a corpus. So these subword tokens are specific to **BPEmb**. Another subword tokenizer may output something different. This is why it's important that when we use a pretrained model, we make sure to use the pretrained model's tokenizer. We'll see this when we use pretrained transformers later in this module." ], "metadata": { "id": "JnW_aHliJdRD" } }, { "cell_type": "code", "source": [ "sample_sentence = \"Where can I find a pizzeria?\"\n", "tokens = bpemb_en.encode(sample_sentence)\n", "print(tokens)" ], "metadata": { "id": "AIgpfG3hKjbZ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We can retrieve each subword token's respective id using the *encode_ids* method." ], "metadata": { "id": "-WIjAEwLKwwh" } }, { "cell_type": "code", "source": [ "token_seq = np.array(bpemb_en.encode_ids(\"Where can I find a pizzeria?\"))\n", "print(token_seq)" ], "metadata": { "id": "grMR-DHEKjWx" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now that we have a way to tokenize and vectorize sentences, we can declare and use an embedding layer with the same vocabulary size as **BPEmb** and a desired embedding size." ], "metadata": { "id": "Mqz7PY5nSGiW" } }, { "cell_type": "code", "source": [ "token_embed = tf.keras.layers.Embedding(bpemb_vocab_size, embed_dim)\n", "token_embeddings = token_embed(token_seq)\n", "\n", "# The untrained embeddings for our sample sentence.\n", "print(\"Embeddings for: \", sample_sentence)\n", "print(token_embeddings)" ], "metadata": { "id": "UO7eOOrWKjSc" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Next, we need to add *positional* information to each token embedding. As we covered in the slides, the original paper used sinusoidals but it's more common these days to just use another set of embeddings. We'll do the latter here.
\n", "\n", "Here, we're declaring an embedding layer with rows equalling a maximum sequence length and columns equalling our token embedding size. We then generate a vector of position ids." ], "metadata": { "id": "20Bg_sB5TzEE" } }, { "cell_type": "code", "source": [ "max_seq_len = 256\n", "pos_embed = tf.keras.layers.Embedding(max_seq_len, embed_dim)\n", "\n", "# Generate ids for each position of the token sequence.\n", "pos_idx = tf.range(len(token_seq))\n", "print(pos_idx)" ], "metadata": { "id": "pcurqcv3KjNY" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We'll use these position ids to index into the positional embedding layer." ], "metadata": { "id": "z4jK-iJP4Fve" } }, { "cell_type": "code", "source": [ "# These are our positon embeddings.\n", "position_embeddings = pos_embed(pos_idx)\n", "print(\"Position embeddings for the input sequence\\n\", position_embeddings)" ], "metadata": { "id": "6vIgau8YMTgi" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "The final step is to add our token and position embeddings. The result will be the input to the first encoder block." ], "metadata": { "id": "UC6V2IodUhbH" } }, { "cell_type": "code", "source": [ "input = token_embeddings + position_embeddings\n", "print(\"Input to the initial encoder block:\\n\", input)" ], "metadata": { "id": "K6x9JVlTKjIi" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Encoder" ], "metadata": { "id": "LDctrWODMNG4" } }, { "cell_type": "markdown", "source": [ "Now that we have an encoder block and a way to embed our tokens with position information, we can create the **encoder** itself.
\n", "\n", "Given a batch of vectorized sequences, the encoder creates positional embeddings, runs them through its encoder blocks, and returns contextualized tokens." ], "metadata": { "id": "LmV5KuIXWSUr" } }, { "cell_type": "code", "source": [ "class Encoder(tf.keras.layers.Layer):\n", " def __init__(self, num_blocks, d_model, num_heads, hidden_dim, src_vocab_size,\n", " max_seq_len, dropout_rate=0.1):\n", " super(Encoder, self).__init__()\n", "\n", " self.d_model = d_model\n", " self.max_seq_len = max_seq_len\n", "\n", " self.token_embed = tf.keras.layers.Embedding(src_vocab_size, self.d_model)\n", " self.pos_embed = tf.keras.layers.Embedding(max_seq_len, self.d_model)\n", "\n", " # The original Attention Is All You Need paper applied dropout to the\n", " # input before feeding it to the first encoder block.\n", " self.dropout = tf.keras.layers.Dropout(dropout_rate)\n", "\n", " # Create encoder blocks.\n", " self.blocks = [EncoderBlock(self.d_model, num_heads, hidden_dim, dropout_rate) \n", " for _ in range(num_blocks)]\n", " \n", " def call(self, input, training, mask):\n", " token_embeds = self.token_embed(input)\n", "\n", " # Generate position indices for a batch of input sequences.\n", " num_pos = input.shape[0] * self.max_seq_len\n", " pos_idx = np.resize(np.arange(self.max_seq_len), num_pos)\n", " pos_idx = np.reshape(pos_idx, input.shape)\n", " pos_embeds = self.pos_embed(pos_idx)\n", "\n", " x = self.dropout(token_embeds + pos_embeds, training=training)\n", "\n", " # Run input through successive encoder blocks.\n", " for block in self.blocks:\n", " x, weights = block(x, training, mask)\n", "\n", " return x, weights" ], "metadata": { "id": "NinUihSpC6K-" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "If you're wondering about this code block here:\n", "\n", "\n", "```\n", "num_pos = input.shape[0] * self.max_seq_len\n", "pos_idx = np.resize(np.arange(self.max_seq_len), num_pos)\n", "pos_idx = np.reshape(pos_idx, input.shape)\n", "pos_embeds = self.pos_embed(pos_idx)\n", "```\n", "\n", "\n", "This generates positional embeddings for a *batch* of input sequences. Suppose this was our batch of input sequences to the encoder." ], "metadata": { "id": "xb7v8lKuYTT6" } }, { "cell_type": "code", "source": [ "# Batch of 3 sequences, each of length 10 (10 is also the \n", "# maximum sequence length in this case).\n", "seqs = np.random.randint(0, 10000, size=(3, 10))\n", "print(seqs.shape)\n", "print(seqs)" ], "metadata": { "id": "Cllud1-mJhNi" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We need to retrieve a positional embedding for every element in this batch. The first step is to create the respective positional ids..." ], "metadata": { "id": "DUjolKY8ZC-6" } }, { "cell_type": "code", "source": [ "pos_ids = np.resize(np.arange(seqs.shape[1]), seqs.shape[0] * seqs.shape[1])\n", "print(pos_ids)" ], "metadata": { "id": "WgfMkY6fk4I4" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "...and then reshape them to match the input batch dimensions." ], "metadata": { "id": "5OMssAJLZbAg" } }, { "cell_type": "code", "source": [ "pos_ids = np.reshape(pos_ids, (3, 10))\n", "print(pos_ids.shape)\n", "print(pos_ids)" ], "metadata": { "id": "ah0t-pZznGWt" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We can now retrieve position embeddings for every token embedding." ], "metadata": { "id": "TphnVF8_ZxzL" } }, { "cell_type": "code", "source": [ "pos_embed(pos_ids)" ], "metadata": { "id": "cAODAGYAwpAr" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Let's try our encoder on a batch of sentences." ], "metadata": { "id": "e-4hBnztXfN5" } }, { "cell_type": "code", "source": [ "input_batch = [\n", " \"Where can I find a pizzeria?\",\n", " \"Mass hysteria over listeria.\",\n", " \"I ain't no circle back girl.\"\n", "]\n", "\n", "bpemb_en.encode(input_batch)" ], "metadata": { "id": "jbX82NUpwyGL" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "input_seqs = bpemb_en.encode_ids(input_batch)\n", "print(\"Vectorized inputs:\")\n", "input_seqs" ], "metadata": { "id": "wOXHqq2Kxh5r" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Note how the input sequences aren't the same length in this batch. In this case, we need to pad them out so that they are. If you're unfamiliar with why, refer to the notebook on Recurrent Neural Networks:
\n", "https://colab.research.google.com/github/nitinpunjabi/nlp-demystified/blob/main/notebooks/nlpdemystified_recurrent_neural_networks.ipynb
\n", "\n", "We'll do this using *pad_sequences*.
\n", "https://www.tensorflow.org/api_docs/python/tf/keras/utils/pad_sequences" ], "metadata": { "id": "EOgoulJTb7Q6" } }, { "cell_type": "code", "source": [ "padded_input_seqs = tf.keras.preprocessing.sequence.pad_sequences(input_seqs, padding=\"post\")\n", "print(\"Input to the encoder:\")\n", "print(padded_input_seqs.shape)\n", "print(padded_input_seqs)" ], "metadata": { "id": "np2vsXpwxMS8" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Since our input now has padding, now's a good time to cover **masking**.\n", "
\n", "\n", "So given a mask, wherever there's a mask position set to 0, the corresponding position in the attention scores will be set to *-inf*. The resulting attention weight for the position will then be zero and no attending will occur for that position.\n", "
\n", "\n", "In the slides, we covered *look-ahead* masks for the decoder to prevent it from attending to future tokens, but we also need masks for padding.\n", "
\n", "\n", "In total, there are three masks involved:\n", "1. The *encoder mask* to mask out any padding in the encoder sequences.\n", "\n", "2. The *decoder mask* which is used in the decoder's **first** multi-head self-attention layer. It's a combination of two masks: one to account for the padding in target sequences, and the look-ahead mask.\n", "\n", "3. The *memory mask* which is used in the decoder's **second** multi-head self-attention layer. The keys and values for this layer are going to be the encoder's output, and this mask will ensure the decoder doesn't attend to any encoder output which corresponds to padding. In practice, 1 and 3 are often the same.\n", "\n", "The *scaled_dot_product_attention* function has this line:\n", "```\n", " if mask is not None:\n", " scaled_scores = tf.where(mask==0, -np.inf, scaled_scores)\n", "```" ], "metadata": { "id": "PqkDdMKJVSa6" } }, { "cell_type": "markdown", "source": [ "Let's create an encoder mask for our batch of input sequences.
\n", "\n", "Wherever there's padding, we want the mask position set to zero." ], "metadata": { "id": "41HyT3jSVq0B" } }, { "cell_type": "code", "source": [ "enc_mask = tf.cast(tf.math.not_equal(padded_input_seqs, 0), tf.float32)\n", "print(\"Input:\")\n", "print(padded_input_seqs, '\\n')\n", "print(\"Encoder mask:\")\n", "print(enc_mask)" ], "metadata": { "id": "AHvAAVhnZouZ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Keep in mind that the dimension of the attention matrix (for this example) is going to be:
\n", "*(batch size, number of heads, query size, key size)*
\n", "(3, 3, 10, 10)" ], "metadata": { "id": "idqcJwFhZ7zD" } }, { "cell_type": "markdown", "source": [ "So we need to expand the mask dimensions like so:" ], "metadata": { "id": "vgVXwdwra84q" } }, { "cell_type": "code", "source": [ "enc_mask = enc_mask[:, tf.newaxis, tf.newaxis, :]\n", "enc_mask" ], "metadata": { "id": "aYPlbsrvZu8_" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "This way, the encoder mask will now be *broadcasted*.
\n", "https://www.tensorflow.org/xla/broadcasting" ], "metadata": { "id": "nsJEDxNPckz5" } }, { "cell_type": "markdown", "source": [ "Now we can declare an encoder and pass it batches of vectorized sequences." ], "metadata": { "id": "X87_VQmiVbSj" } }, { "cell_type": "code", "source": [ "num_encoder_blocks = 6\n", "\n", "# d_model is the embedding dimension used throughout.\n", "d_model = 12\n", "\n", "num_heads = 3\n", "\n", "# Feed-forward network hidden dimension width.\n", "ffn_hidden_dim = 48\n", "\n", "src_vocab_size = bpemb_vocab_size\n", "max_input_seq_len = padded_input_seqs.shape[1]\n", "\n", "encoder = Encoder(\n", " num_encoder_blocks,\n", " d_model,\n", " num_heads,\n", " ffn_hidden_dim,\n", " src_vocab_size,\n", " max_input_seq_len)" ], "metadata": { "id": "Ns8G5ujRVQMv" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We can now pass our input sequences and mask to the encoder." ], "metadata": { "id": "hGQ6lg3fJhIg" } }, { "cell_type": "code", "source": [ "encoder_output, attn_weights = encoder(padded_input_seqs, training=True, \n", " mask=enc_mask)\n", "print(f\"Encoder output {encoder_output.shape}:\")\n", "print(encoder_output)" ], "metadata": { "id": "rf6q86hBj8eV" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Decoder Block" ], "metadata": { "id": "24TYaX3zMNAh" } }, { "cell_type": "markdown", "source": [ "Let's build the **Decoder Block**. Everything we did to create the **encoder** block applies here. The major differences are that the **Decoder Block** has:\n", "1. a **Multi-Head Cross-Attention** layer which uses the encoder's outputs as the keys and values.\n", "\n", "2. an extra skip/residual connection along with an extra layer normalization step.\n", "\n", "
\n", "\n", "
" ], "metadata": { "id": "uH-5iDDXeU_j" } }, { "cell_type": "code", "source": [ "class DecoderBlock(tf.keras.layers.Layer):\n", " def __init__(self, d_model, num_heads, hidden_dim, dropout_rate=0.1):\n", " super(DecoderBlock, self).__init__()\n", "\n", " self.mhsa1 = MultiHeadSelfAttention(d_model, num_heads)\n", " self.mhsa2 = MultiHeadSelfAttention(d_model, num_heads)\n", "\n", " self.ffn = feed_forward_network(d_model, hidden_dim)\n", "\n", " self.dropout1 = tf.keras.layers.Dropout(dropout_rate)\n", " self.dropout2 = tf.keras.layers.Dropout(dropout_rate)\n", " self.dropout3 = tf.keras.layers.Dropout(dropout_rate)\n", "\n", " self.layernorm1 = tf.keras.layers.LayerNormalization()\n", " self.layernorm2 = tf.keras.layers.LayerNormalization()\n", " self.layernorm3 = tf.keras.layers.LayerNormalization()\n", " \n", " # Note the decoder block takes two masks. One for the first MHSA, another\n", " # for the second MHSA.\n", " def call(self, encoder_output, target, training, decoder_mask, memory_mask):\n", " mhsa_output1, attn_weights = self.mhsa1(target, target, target, decoder_mask)\n", " mhsa_output1 = self.dropout1(mhsa_output1, training=training)\n", " mhsa_output1 = self.layernorm1(mhsa_output1 + target)\n", "\n", " mhsa_output2, attn_weights = self.mhsa2(mhsa_output1, encoder_output, \n", " encoder_output, \n", " memory_mask)\n", " mhsa_output2 = self.dropout2(mhsa_output2, training=training)\n", " mhsa_output2 = self.layernorm2(mhsa_output2 + mhsa_output1)\n", "\n", " ffn_output = self.ffn(mhsa_output2)\n", " ffn_output = self.dropout3(ffn_output, training=training)\n", " output = self.layernorm3(ffn_output + mhsa_output2)\n", "\n", " return output, attn_weights\n" ], "metadata": { "id": "Hco1IwfutNqD" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Decoder" ], "metadata": { "id": "YVstTioxMNDq" } }, { "cell_type": "markdown", "source": [ "The decoder is almost the same as the encoder except it takes the encoder's output as part of its input, and it takes two masks: the decoder mask and memory mask." ], "metadata": { "id": "M3iT7wyOi_bv" } }, { "cell_type": "code", "source": [ "class Decoder(tf.keras.layers.Layer):\n", " def __init__(self, num_blocks, d_model, num_heads, hidden_dim, target_vocab_size,\n", " max_seq_len, dropout_rate=0.1):\n", " super(Decoder, self).__init__()\n", "\n", " self.d_model = d_model\n", " self.max_seq_len = max_seq_len\n", "\n", " self.token_embed = tf.keras.layers.Embedding(target_vocab_size, self.d_model)\n", " self.pos_embed = tf.keras.layers.Embedding(max_seq_len, self.d_model)\n", "\n", " self.dropout = tf.keras.layers.Dropout(dropout_rate)\n", "\n", " self.blocks = [DecoderBlock(self.d_model, num_heads, hidden_dim, dropout_rate) for _ in range(num_blocks)]\n", "\n", " def call(self, encoder_output, target, training, decoder_mask, memory_mask):\n", " token_embeds = self.token_embed(target)\n", "\n", " # Generate position indices.\n", " num_pos = target.shape[0] * self.max_seq_len\n", " pos_idx = np.resize(np.arange(self.max_seq_len), num_pos)\n", " pos_idx = np.reshape(pos_idx, target.shape)\n", "\n", " pos_embeds = self.pos_embed(pos_idx)\n", "\n", " x = self.dropout(token_embeds + pos_embeds, training=training)\n", "\n", " for block in self.blocks:\n", " x, weights = block(encoder_output, x, training, decoder_mask, memory_mask)\n", "\n", " return x, weights" ], "metadata": { "id": "27zG_wV3MNJ_" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Before we try the decoder, let's cover the masks involved. The decoder takes two masks:\n", "\n", "The *decoder mask* which is a combination of two masks: one to account for the padding in target sequences, and the look-ahead mask. This mask is used in the decoder's **first** multi-head self-attention layer.\n", "\n", "The *memory mask* which is used in the decoder's **second** multi-head self-attention. The keys and values for this layer are going to be the encoder's output, and this mask will ensure the decoder doesn't attend to any encoder output which corresponds to padding." ], "metadata": { "id": "gkZ1T-hSscOw" } }, { "cell_type": "markdown", "source": [ "Suppose this is our batch of vectorized target *input* sequences for the decoder. These values are just made up.
\n", "\n", "**Note**: If you need a refresher on how to prepare target input and output sequences for the decoder, refer to the [seq2seq notebook](https://colab.research.google.com/github/nitinpunjabi/nlp-demystified/blob/main/notebooks/nlpdemystified_seq2seq_and_attention.ipynb).\n", "\n" ], "metadata": { "id": "EjiEOx5WoOb8" } }, { "cell_type": "code", "source": [ "# Made up values.\n", "target_input_seqs = [\n", " [1, 652, 723, 123, 62],\n", " [1, 25, 98, 129, 248, 215, 359, 249],\n", " [1, 2369, 1259, 125, 486],\n", "]" ], "metadata": { "id": "0X6gKNzgv0gP" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "As we did with the encoder input sequences, we need to pad out this batch so that all sequences within it are the same length." ], "metadata": { "id": "SgriJUKgyxNN" } }, { "cell_type": "code", "source": [ "padded_target_input_seqs = tf.keras.preprocessing.sequence.pad_sequences(target_input_seqs, padding=\"post\")\n", "print(\"Padded target inputs to the decoder:\")\n", "print(padded_target_input_seqs.shape)\n", "print(padded_target_input_seqs)" ], "metadata": { "id": "4hFp1nkSypnz" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We can create the padding mask the same way we did for the encoder." ], "metadata": { "id": "qZysfgvUzNBI" } }, { "cell_type": "code", "source": [ "dec_padding_mask = tf.cast(tf.math.not_equal(padded_target_input_seqs, 0), tf.float32)\n", "dec_padding_mask = dec_padding_mask[:, tf.newaxis, tf.newaxis, :]\n", "print(dec_padding_mask)" ], "metadata": { "id": "PLKeI4R20axA" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "As we covered in the slides, the look-ahead mask is a diagonal where the lower half are 1s and the upper half are zeros. This is easy to create using the *band_part* method:
\n", "https://www.tensorflow.org/api_docs/python/tf/linalg/band_part" ], "metadata": { "id": "S7EwYtJa0uvH" } }, { "cell_type": "code", "source": [ "target_input_seq_len = padded_target_input_seqs.shape[1]\n", "look_ahead_mask = tf.linalg.band_part(tf.ones((target_input_seq_len, \n", " target_input_seq_len)), -1, 0)\n", "print(look_ahead_mask)" ], "metadata": { "id": "yZFnGgJa04a-" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "To create the decoder mask, we just need to combine the padding and look-ahead masks. Note how the columns of the resulting decoder mask are all zero for padding positions." ], "metadata": { "id": "WPzxVG2S87T2" } }, { "cell_type": "code", "source": [ "dec_mask = tf.minimum(dec_padding_mask, look_ahead_mask)\n", "print(\"The decoder mask:\")\n", "print(dec_mask)" ], "metadata": { "id": "vArTOY1x2bzn" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We can now declare a decoder and pass it everything it needs. In our case, the *memory* mask is the same as the *encoder* mask." ], "metadata": { "id": "iLHbt7nJ9xUX" } }, { "cell_type": "code", "source": [ "decoder = Decoder(6, 12, 3, 48, 10000, 8)\n", "decoder_output, _ = decoder(encoder_output, padded_target_input_seqs, \n", " True, dec_mask, enc_mask)\n", "print(f\"Decoder output {decoder_output.shape}:\")\n", "print(decoder_output)" ], "metadata": { "id": "bFE-VaCrmLKu" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Transformer" ], "metadata": { "id": "UgFtxMQxMNNJ" } }, { "cell_type": "markdown", "source": [ "We now have all the pieces to build the **Transformer** itself, and it's pretty simple. " ], "metadata": { "id": "bYFJuqbl-Jt7" } }, { "cell_type": "code", "source": [ "class Transformer(tf.keras.Model):\n", " def __init__(self, num_blocks, d_model, num_heads, hidden_dim, source_vocab_size,\n", " target_vocab_size, max_input_len, max_target_len, dropout_rate=0.1):\n", " super(Transformer, self).__init__()\n", "\n", " self.encoder = Encoder(num_blocks, d_model, num_heads, hidden_dim, source_vocab_size, \n", " max_input_len, dropout_rate)\n", " \n", " self.decoder = Decoder(num_blocks, d_model, num_heads, hidden_dim, target_vocab_size,\n", " max_target_len, dropout_rate)\n", " \n", " # The final dense layer to generate logits from the decoder output.\n", " self.output_layer = tf.keras.layers.Dense(target_vocab_size)\n", "\n", " def call(self, input_seqs, target_input_seqs, training, encoder_mask,\n", " decoder_mask, memory_mask):\n", " encoder_output, encoder_attn_weights = self.encoder(input_seqs, \n", " training, encoder_mask)\n", "\n", " decoder_output, decoder_attn_weights = self.decoder(encoder_output, \n", " target_input_seqs, training,\n", " decoder_mask, memory_mask)\n", "\n", " return self.output_layer(decoder_output), encoder_attn_weights, decoder_attn_weights\n" ], "metadata": { "id": "DfNkAsv8MNQ8" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "transformer = Transformer(\n", " num_blocks = 6,\n", " d_model = 12,\n", " num_heads = 3,\n", " hidden_dim = 48,\n", " source_vocab_size = bpemb_vocab_size,\n", " target_vocab_size = 7000, # made-up target vocab size.\n", " max_input_len = padded_input_seqs.shape[1],\n", " max_target_len = padded_target_input_seqs.shape[1])\n", "\n", "transformer_output, _, _ = transformer(padded_input_seqs, \n", " padded_target_input_seqs, True, \n", " enc_mask, dec_mask, memory_mask=enc_mask)\n", "print(f\"Transformer output {transformer_output.shape}:\")\n", "print(transformer_output) # If training, we would use this output to calculate losses." ], "metadata": { "id": "1VOou7zjQ7el" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "That's the whole original transformer from scratch. From here, if you want to train this transformer, you can use the same approach we used when we built the translation model with attention in the [seq2seq notebook](https://colab.research.google.com/github/nitinpunjabi/nlp-demystified/blob/main/notebooks/nlpdemystified_seq2seq_and_attention.ipynb#scrollTo=x8Ef_eWXjWMn&line=3&uniqifier=1). Remember to use a learning rate warmup (Refer to the paper for more information on this)." ], "metadata": { "id": "BV_fyVfIPzjH" } }, { "cell_type": "markdown", "source": [ "It's useful to know how these models work under the hood, but to train our own transformer to get impressive results is expensive. Both in terms of compute and data.
\n", "\n", "Fortunately, there's a zoo of **pretrained** transformer models we can use. We'll explore that next." ], "metadata": { "id": "UReJEI3rFKN2" } }, { "cell_type": "markdown", "source": [ "# Pre-Training and Transfer Learning with Hugging Face and OpenAI" ], "metadata": { "id": "Biy-OojYMNdg" } }, { "cell_type": "markdown", "source": [ "**IMPORTANT**
\n", "Enable **GPU acceleration** by going to *Runtime > Change Runtime Type*. Keep in mind that, on certain tiers, you're not guaranteed GPU access depending on usage history and current load.\n", "

\n", "Also, if you're running this in the cloud rather than a local Jupyter server on your machine, then the notebook will *timeout* after a period of inactivity.\n", "

\n", "Refer to this link on how to run Colab notebooks locally on your machine to avoid this issue:
\n", "https://research.google.com/colaboratory/local-runtimes.html" ], "metadata": { "id": "sdbQkTP3MNgw" } }, { "cell_type": "markdown", "source": [ "We'll explore pre-training and transfer learning using the **Transformers** library from [Hugging Face](https://huggingface.co/). **Transformers** is an API and toolkit to download pre-trained models and further train them as needed.
\n", "\n", "We'll start with the **pipelines** module which abstracts a lot of operations such as tokenization, vectorization, inference, etc.
\n", "\n", "With **Transformers pipelines**, we can just feed text input and get text output. And there are **pipelines** for common tasks including classification, NER, summarization, etc.
\n", "https://huggingface.co/docs/transformers/index
\n", "https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#pipelines" ], "metadata": { "id": "4d5zQqs_LU75" } }, { "cell_type": "markdown", "source": [ "To get started, we'll need to install **Transformers**." ], "metadata": { "id": "NxKmRric-yUR" } }, { "cell_type": "code", "source": [ "!pip install transformers\n", "!pip install datasets" ], "metadata": { "id": "NBnq4tryF5Iv" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import operator\n", "import pandas as pd\n", "import tensorflow as tf\n", "import transformers\n", "\n", "from datasets import load_dataset\n", "from tensorflow import keras\n", "from transformers import AutoTokenizer\n", "from transformers import pipeline\n", "from transformers import TFAutoModelForQuestionAnswering" ], "metadata": { "id": "AShuLLbDx-KA" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Getting up and running quickly with Hugging Face Pipelines" ], "metadata": { "id": "xvKIhFMJMNj4" } }, { "cell_type": "markdown", "source": [ "We'll use the **pipeline** (note the singular) abstraction which wraps all the other pipelines. Put simply, it'll be our interface to doing a bunch of NLP tasks." ], "metadata": { "id": "YqDi3x3e7-mw" } }, { "cell_type": "markdown", "source": [ "Using the **pipeline** abstraction is easy. We can instantiate a pipeline with a particular task, and it'll automatically download a suitable tokenizer and model behind the scenes for us and take care of the input and output operations.
\n", "https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#transformers.pipeline
\n", "\n" ], "metadata": { "id": "FswxoIXjDQec" } }, { "cell_type": "markdown", "source": [ "Here, we're retrieving a pipeline for text-classification." ], "metadata": { "id": "wfOgySwaH5nm" } }, { "cell_type": "code", "source": [ "classifier = pipeline(\"text-classification\")" ], "metadata": { "id": "0H74TebNGEqy" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Note the warning message about how no model was supplied. When we instantiate a pipeline for a task without specifying a particular model to perform the task, **Transformers** uses a default model. This is good enough for prototyping but for production, we'll want to specify which model to use for the task since the default can change. We'll see how to do this further below." ], "metadata": { "id": "C4zaqdrVx-hw" } }, { "cell_type": "markdown", "source": [ "We can use the pipeline immediately to classify some text. Tokenization, vectorization, etc is taken care of behind the scenes." ], "metadata": { "id": "-DaXJ_6jNTBe" } }, { "cell_type": "code", "source": [ "classifier(\"Alice was excited to go the island but it didn't live up to the hype.\")" ], "metadata": { "id": "dIpPUNWdGTJy" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "classifier(\"Bob doesn't do well in group situations but he said it wasn't bad.\")" ], "metadata": { "id": "fVkpFKK3Nhso" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "There's support for summarization..." ], "metadata": { "id": "rypbsJ3pNhn-" } }, { "cell_type": "code", "source": [ "summarizer = pipeline(\"summarization\")" ], "metadata": { "id": "asGq_aIdNhhW" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "text = \"\"\"\n", "Hans Niemann is launching a counterattack in his dispute with chess world \n", "champion Magnus Carlsen, filing a federal lawsuit that accuses Carlsen of \n", "maliciously colluding with others to defame the 19-year-old grandmaster and \n", "ruin his career.\n", "\n", "It's the latest move in a scandal that has injected unprecedented levels of \n", "drama into the world of elite chess since early September, when Carlsen \n", "suggested Niemann's upset victory over him at the Sinquefield Cup tournament \n", "in St. Louis was the result of cheating.\n", "\n", "Niemann wants a federal court in Missouri's eastern district to award him at \n", "least $100 million in damages. Defendants in the lawsuit include Carlsen, his \n", "company Play Magnus Group, the online platform Chess.com and its leader, Danny \n", "Rensch, along with grandmaster Hikaru Nakamura.\n", "\"\"\"" ], "metadata": { "id": "6uErT9XbIHjn" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "summarizer(text)" ], "metadata": { "id": "1OVj0SPWRU2O" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "...and question answering (extractive in this example)." ], "metadata": { "id": "L6zPiJBzRinR" } }, { "cell_type": "code", "source": [ "qa = pipeline(\"question-answering\")" ], "metadata": { "id": "JF9aDUUUSOUt" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "context=\"\"\"\n", "Hugging Face was founded in 2016 by Clément Delangue, Julien Chaumond, and \n", "Thomas Wolf originally as a company that developed a chatbot app targeted at \n", "teenagers.[2] After open-sourcing the model behind the chatbot, the company \n", "pivoted to focus on being a platform for democratizing machine learning. In March \n", "2021, Hugging Face raised $40 million in a Series B funding round.\n", "\"\"\"\n", "\n", "question = \"Who are the Hugging Face founders?\"\n", "\n", "qa(question=question, context=context)" ], "metadata": { "id": "7Iv7feIrRp5F" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Extractive question-answering models work fine for certain domains, document structures, and questions. But situations that require reasoning, more complex parsing, or contain ambiguity can trip it up." ], "metadata": { "id": "eWuoNhDfzz5q" } }, { "cell_type": "code", "source": [ "question = \"What does Hugging Face do?\"\n", "qa(question=question, context=context)" ], "metadata": { "id": "F3Zx1oG0zpCa" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "There are ready-made pipelines for a number of tasks:
\n", "https://huggingface.co/docs/transformers/main/en/quicktour#pipeline" ], "metadata": { "id": "pJxzUhrqRp06" } }, { "cell_type": "markdown", "source": [ "Let's say we want a pipeline that uses a particular model. On the Hugging Face model hub, you'll find both pre-trained models (e.g. BERT) *and* pre-trained models that have been fine-tuned for all sorts of tasks (e.g. BERT for text classification). These models are contributed by Hugging Face, other companies, institutions, and individuals. You can (and are encouraged) to train or fine-tune a model and upload it for others to use.
\n", "https://huggingface.co/models\n", "

\n", "For example, here's a collection of pre-trained models that have been tuned for text classification.
\n", "https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads\n", "

\n", "This particular one is a pre-trained *Roberta-base* model that's been fine-tuned on Twitter data for sentiment analysis:\n", "https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment\n", "

\n", "Note how you can try the model directly on the model page." ], "metadata": { "id": "Iuyd8fzyRpwM" } }, { "cell_type": "markdown", "source": [ "Let's say we want to download and use a particular model. For example, this *BERT-base* model fine-tuned for NER:\n", "https://huggingface.co/dslim/bert-base-NER\n", "
\n", "\n", "We just need to pass the model path during pipeline instantiation." ], "metadata": { "id": "Mj4dqCcTRpq1" } }, { "cell_type": "code", "source": [ "ner = pipeline(model=\"dslim/bert-base-NER\")" ], "metadata": { "id": "FwVmLNnqL16A" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "text = \"Panic ensues in Redmond as love child of Microsoft and OpenAI declares humanity obsolete.\" \n", "ner(text)" ], "metadata": { "id": "jWr8Fzl3DLKb" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "The **Transformers** library provides a bunch of helper classes to help with training models. And beyond the model hub, Hugging Face also hosts datasets, provides *spaces* where you can host your app, and offers a bunch of services such as cloud hardware and inference endpoints to help deploy your model.
\n", "Datasets: https://huggingface.co/datasets
\n", "Spaces: https://huggingface.co/spaces
\n", "\n", "With Hugging Face, you can build an ML app prototype within minutes and iterate quickly from there.
\n", "https://huggingface.co/docs
\n", "\n", "Learn more about how to build with Hugging Face through their free course and fantastic book:
\n", "https://huggingface.co/course
\n", "https://www.oreilly.com/library/view/natural-language-processing/9781098136789/\n" ], "metadata": { "id": "RdEpdaUeGlGg" } }, { "cell_type": "markdown", "source": [ "## Fine-Tuning a Pre-Trained Model." ], "metadata": { "id": "_simw4wbGlEi" } }, { "cell_type": "markdown", "source": [ "Let's say the model hub doesn't have a model that exactly suits your purpose. Perhaps you work in a particular domain and need to fine-tune a model using your own dataset.
\n", "\n", "In this section, we'll walk through how to download a pre-trained model and fine-tune it. Our example covers extractive question answering but it's the same idea with other tasks." ], "metadata": { "id": "j8bv2QOPGk-m" } }, { "cell_type": "markdown", "source": [ "We'll fine-tune using a dataset from the **Datasets** hub.
\n", "https://huggingface.co/datasets

\n", "Hugging Face provides a **datasets** library to download and interact with the datasets. It's similar to the Tensorflow Dataset library we used in that it can hold data and provides a bunch of methods to preprocess that data.
\n", "https://huggingface.co/docs/datasets/ndex" ], "metadata": { "id": "XSceJ-5GNsPY" } }, { "cell_type": "markdown", "source": [ "The **Datasets** hub holds a bunch of question answering datasets.
\n", "https://huggingface.co/datasets?task_categories=task_categories:question-answering&sort=downloads
\n", "\n", "\n", "They differ based on data source, domain, and level of challenge. Since we're in a constrained environment (Colab free tier) and just learning how to fine-tune, we'll use SQuAD, a famous dataset comprised of crowd-sourced questions on a set of Wikipedia articles, and where the answer is a span of text in the article.
\n", "https://huggingface.co/datasets/squad\n" ], "metadata": { "id": "Gr_RfVRwQXME" } }, { "cell_type": "code", "source": [ "data = load_dataset(\"squad\")" ], "metadata": { "id": "n-44MBYSyDIY" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "The **datasets** library downloads and automatically splits the data into train and validation sets. It returns a dictionary of **Dataset** objects:
\n", "https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict
\n", "https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset
\n", "\n", "\n", "A **Dataset** object wraps an Apache Arrow table and provides a bunch of helper functions on top of it.
\n", "https://arrow.apache.org/" ], "metadata": { "id": "cOEwPv9pSpjy" } }, { "cell_type": "code", "source": [ "data" ], "metadata": { "id": "JKrr8vvLyDDX" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Glaning at the data, we see every context (Wikipedia passage) is used multiple times. i.e., there are multiple questions and answers for each context.
\n", "\n", "Every answer is a span of text from the context and the character position where the answer starts in the context is given." ], "metadata": { "id": "jhO8e_RqbCNL" } }, { "cell_type": "code", "source": [ "pd.DataFrame(data['train'][0, 1, 2, 100, 101, 102], \n", " columns=[\"context\", \"question\", \"answers\"])" ], "metadata": { "id": "SRrBEsC2VhSF" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Here's what we need to do:\n", "1. Choose a pre-trained model based on what we want to accomplish and our constraints.\n", "2. Download the appropriate tokenizer for the pre-trained model.\n", "3. Tokenize and vectorize our dataset.\n", "4. Mark where each answer starts and ends in our vectorized dataset.\n", "5. Download the pre-trained model.\n", "6. Fine-tune the pre-trained model with the vectorized dataset." ], "metadata": { "id": "qK6WR9y6WGxf" } }, { "cell_type": "markdown", "source": [ "Given the free tier of Colab doesn't have a lot of GPU memory and that we're just trying to fine-tune a simple, extractive question answering model, we'll use *distilroberta-base*.
\n", "https://huggingface.co/distilroberta-base\n", "

\n", "Recall from the slides that *DistilBERT* was created using a technique called *knowledge distillation*. The result is a model that performs almost as well as BERT but is 40% smaller and 60% faster.
\n", "DistilBert Paper: https://arxiv.org/abs/1910.01108
\n", "https://en.wikipedia.org/wiki/Knowledge_distillation\n", "

\n", "*distilroberta-base* was created by applying knowledge distillation to *Roberta-Base*, a more powerful model than BERT.
\n", "Roberta paper: https://arxiv.org/abs/1907.11692

\n", "The **Transformers** library provides a set of Auto Classes that can automatically retrieve configurations, tokenizers, and models based on a path or a name. We'll use the **AutoTokenizer** class to get the right tokenizer for *distilroberta-base*.
\n", "https://huggingface.co/docs/transformers/main/en/model_doc/auto
\n", "https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoTokenizer\n", "\n" ], "metadata": { "id": "9-v69DgXcQgn" } }, { "cell_type": "code", "source": [ "model_name = 'distilroberta-base'\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)" ], "metadata": { "id": "HaCAmq1GjmgT" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Calling *encode* converts a string to a sequence of integer token ids.
\n", "https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.encode" ], "metadata": { "id": "BPRj0m_8LkDS" } }, { "cell_type": "code", "source": [ "t = \"Where can I find a pizzeria?\"\n", "print(tokenizer.encode(t))" ], "metadata": { "id": "FEFBpuLvoCxX" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "But to tokenize, we call the tokenizer object directly (i.e. using *\\_\\_call\\_\\_*).
\n", "https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.__call__
\n", "\n", "This returns a sequence of ids and an attention mask in a **BatchEncoding** object:
\n", "https://huggingface.co/docs/transformers/main/en/glossary#input-ids
\n", "https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.BatchEncoding
\n", "\n", "Since there's no padding on this sample string, the mask is all 1s." ], "metadata": { "id": "tRYgNIrfM0RH" } }, { "cell_type": "code", "source": [ "encoded_t = tokenizer(t)\n", "print(encoded_t)" ], "metadata": { "id": "LvOuVBSTSgHe" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We can convert the ids back to tokens using *convert_ids_to_tokens*.
\n", "https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.convert_ids_to_tokens
\n", "\n", "Note how the tokenizer added a start of sequence token (\\), end of sequence token (\\), and how it uses Ġ to signal a word has preceding whitespace. Keep in mind that what you're seeing here is the output from the *distilroberta-base* tokenizer. Other tokenizers may work differently." ], "metadata": { "id": "3qVZ7qsXNfdc" } }, { "cell_type": "code", "source": [ "print(tokenizer.convert_ids_to_tokens(encoded_t['input_ids']))" ], "metadata": { "id": "M1HDLRBDTNij" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "As we covered in the slides, for question answering, we need to encode the question and context as a pair. In our case, we can do that by passing in both strings separated by a comma." ], "metadata": { "id": "uwTdWW7YOb-l" } }, { "cell_type": "code", "source": [ "encoded_pair = tokenizer(\"this is a question\", \"this is the context\")\n", "print(encoded_pair)" ], "metadata": { "id": "u3EzHVvCyCf3" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "The *distilroberta-base* tokenizer uses a double \\\\ as a separator." ], "metadata": { "id": "7mXGDtJ9OuPQ" } }, { "cell_type": "code", "source": [ "print(tokenizer.convert_ids_to_tokens(encoded_pair['input_ids']))" ], "metadata": { "id": "a8Qen04EyCR5" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "**Side note**:
\n", "Most of the tokenizers in the **Transformers** library come in two versions: a Python implementation and a faster Rust implementation. When available, **Autotokenizer** will download the fast version.
\n", "https://huggingface.co/docs/transformers/main_classes/tokenizer
\n", "\n", "We can check whether we have a fast tokenizer.\n" ], "metadata": { "id": "_9E0BcYjPSQk" } }, { "cell_type": "code", "source": [ "assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)" ], "metadata": { "id": "wBtq5L5cnKeH" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Suppose we tokenize this question/context pair..." ], "metadata": { "id": "ymWFVX5xQMXe" } }, { "cell_type": "code", "source": [ "context = \"Sarah went to The Mirthless Cafe last night to meet her friend.\"\n", "question = \"Where did Sarah go?\"\n", "\n", "# The answer span and the answer's starting character position in the context.\n", "answer = \"The Mirthless Cafe\"\n", "answer_start = 14" ], "metadata": { "id": "N843CjJGjmjK" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "x = tokenizer(question, context)\n", "x" ], "metadata": { "id": "0uukBv7QuRCh" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Note how the word *Mirthless* gets tokenized into subwords. For legibility, we're using *batch_decode* to convert the input_ids to strings.
\n", "https://huggingface.co/docs/transformers/v4.23.1/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.batch_decode" ], "metadata": { "id": "MuoUE5QnQW9J" } }, { "cell_type": "code", "source": [ "tokenizer.batch_decode(x['input_ids'])" ], "metadata": { "id": "DOAB5UtyuRGZ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "When we tokenize our dataset, there will probably be question/context pairs which exceed our model's maximum sequence length. In *Roberta*'s case, that's 512. Available GPU memory may make us further reduce the maximum sequence length of our input.
\n", "\n", "Let's say the maximum sequence length we can handle is 15, so we truncate the context." ], "metadata": { "id": "ZJU_5RBSQ1fA" } }, { "cell_type": "code", "source": [ "example_max_length = 15\n", "x = tokenizer(question, context, max_length=example_max_length, \n", " truncation=\"only_second\")\n", "x" ], "metadata": { "id": "wdFH9GY4I5TT" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "The problem here is that the answer span gets chopped off by truncation. In other situations, the answer may not be included at all." ], "metadata": { "id": "zteMsetFSI8M" } }, { "cell_type": "code", "source": [ "tokenizer.batch_decode(x['input_ids'])" ], "metadata": { "id": "oPC7ZJ_FSGVE" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "To ensure we tokenize all context tokens while respecting a maximum length, we can set *return_overflowing_tokens* to **True**. The end effect is to split the input into multiple question/context sequences, with each context sequence being a continuation of the previous one. Since the last one may be shorter than the max length, we set the right padding length as well.
\n", "\n", "What we get back are multiple *input_id* sequences." ], "metadata": { "id": "FrgW73TpTnJC" } }, { "cell_type": "code", "source": [ "x = tokenizer(question, context, max_length=example_max_length, \n", " truncation=\"only_second\", return_overflowing_tokens=True, \n", " padding=\"max_length\")\n", "x" ], "metadata": { "id": "oh3jDv9tuRJ9" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "len(x['input_ids'])" ], "metadata": { "id": "LgMpEkAhvSKy" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Looking at the decoded sequences, we see the entire context is included across three sequences (along with padding on the last one).
" ], "metadata": { "id": "nekyxb-WU3bR" } }, { "cell_type": "code", "source": [ "tokenizer.batch_decode(x['input_ids'])" ], "metadata": { "id": "yx6a1rVWuRQn" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Note a few things from the encoded object *x*:\n", "- The last *attention_mask* sequence has 0s to signify padding.\n", "- The *overflow_to_sample_mapping* array tells us which question/context pair each *input_ids* sequence comes from. In our example, we tokenized a single question/context pair which resulted in three *input_ids* sequences, so *overflow_to_sample_mapping* is 3 0s.
\n", "\n", "If we tokenize two question/context pairs, we'll see the *overflow_to_sample_mapping* reflect that." ], "metadata": { "id": "nWXHeyOMY_Id" } }, { "cell_type": "code", "source": [ "tokenizer(['question 1', 'question 2'], \n", " ['context 1', 'context 2'], \n", " return_overflowing_tokens=True)" ], "metadata": { "id": "C3qOppteYZDe" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "But there's still a problem here in that none of the sequences contain the full answer (\"The Mirthless Cafe\"). Right now, the correct full answer is split across sequences.
\n", "\n", "To counter this, we can tokenize our question/context pair into overlapping sequences by setting a *stride* length. We did something similar when we prepared the dataset for our [character-level language model](https://colab.research.google.com/github/nitinpunjabi/nlp-demystified/blob/main/notebooks/nlpdemystified_recurrent_neural_networks.ipynb#scrollTo=X1c-ihOByy88)." ], "metadata": { "id": "YaBX1vSFVWV6" } }, { "cell_type": "code", "source": [ "stride = 5\n", "x = tokenizer(question, context, max_length=example_max_length, \n", " truncation=\"only_second\", return_overflowing_tokens=True,\n", " stride=stride, padding=\"max_length\")" ], "metadata": { "id": "DvUsKuFVuRT3" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "By setting a stride of 5, each context sequence starts 5 subwords back from the previous sequence.
\n", "\n", "This way, two of our tokenized sequences now contain the full answer.\n" ], "metadata": { "id": "0J-wk8eoV8wr" } }, { "cell_type": "code", "source": [ "tokenizer.batch_decode(x['input_ids'])" ], "metadata": { "id": "e4nQa25GxVbE" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We now have a way to tokenize our question/context pairs.
\n", "\n", "Our tokenizer returned this **BatchEncoding** object:" ], "metadata": { "id": "08cJKo07d35u" } }, { "cell_type": "code", "source": [ "print(x.keys(), '\\n')\n", "x" ], "metadata": { "id": "WcMMKk7odl8j" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "To fine-tune a model for question answering, our pre-trained *distilroberta-base* model expects this object to contain two more pieces of information:\n", "- *start_positions*: the token positions where answers begin.\n", "- *end_positions*: the token positions where answers end.
\n", "\n", "https://huggingface.co/docs/transformers/main/en/model_doc/roberta#transformers.RobertaForQuestionAnswering.forward" ], "metadata": { "id": "lBOIZvU_e49e" } }, { "cell_type": "markdown", "source": [ "All we have in our example (and the SQuAD dataset) is the position of the starting character of the answer." ], "metadata": { "id": "n1dj61DBe46s" } }, { "cell_type": "code", "source": [ "print(answer_start)\n", "print(context[answer_start:answer_start+len(answer)])" ], "metadata": { "id": "x4Ut-VWqe43K" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We need to use this to locate the token positions where each answer starts and ends in every *input_ids* sequence. In some cases, the complete answer may not be in a particular sequence. We need to handle those cases as well.
\n", "\n", "To do this, we'll get more information by setting *return_offsets_mapping* to **True** in the tokenizer." ], "metadata": { "id": "yOGFqWeme4yX" } }, { "cell_type": "code", "source": [ "x = tokenizer(question, context, max_length=example_max_length, \n", " truncation=\"only_second\", return_overflowing_tokens=True,\n", " stride=stride, return_offsets_mapping=True,\n", " padding=\"max_length\")\n", "x" ], "metadata": { "id": "9n8nn9xRjplX" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "This results in *offset_mapping* sequences, one for each *input_ids* sequence. " ], "metadata": { "id": "77HlwNtVpHnE" } }, { "cell_type": "code", "source": [ "print(len(x['input_ids']))\n", "print(len(x['offset_mapping']))" ], "metadata": { "id": "7b24VEPvo6_2" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Each entry in an *offset_mapping* tells us the starting and ending character position of each token in the original string. An offset mapping of (0,0) represents a special token (e.g. \\).
\n", "\n", "For example, here's the first *input_ids* sequence along with its respective *offset_mapping*." ], "metadata": { "id": "sy4c2PBko681" } }, { "cell_type": "code", "source": [ "print(x['input_ids'][0])\n", "print(x['offset_mapping'][0])" ], "metadata": { "id": "HHP7O9lgo65N" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "If we convert the first non-special input id to a token, and use the first non-special offset_mapping to extract a span from the question string, we get a match." ], "metadata": { "id": "YKQbrYGLw6Hq" } }, { "cell_type": "code", "source": [ "print(\"First non-special input_id converted to token:\")\n", "print(tokenizer.convert_ids_to_tokens(x['input_ids'][0][1]), \"\\n\")\n", "\n", "offset = x['offset_mapping'][0][1]\n", "print(f\"Span extracted from context using corresponding offset_mapping {offset}:\")\n", "print(question[offset[0]:offset[1]])" ], "metadata": { "id": "XNCI2va9sgPm" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Since we know the character position of where the answer starts, we can use that and *offset_mapping* to get the start and ending token positions of the answer span.
\n", "\n", "The only remaining issue is identifying whether an offset is for a question or a context. Looking at the first two *offset_mappings*, note that:
\n", "1. In the first sequence, both the question and context *offset_mappings* start from zero.\n", "2. In the second sequence, the context *offset_mapping* values carry on from the previous sequence (after accounting in the stride)." ], "metadata": { "id": "JuS8cxMQxVZn" } }, { "cell_type": "code", "source": [ "print(x['offset_mapping'][0])\n", "print(x['offset_mapping'][1])" ], "metadata": { "id": "5iGMftD8o611" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "This means we need to identify\n", "1. which *offset_mappings* belong to a context.\n", "2. whether a particular sequence contains the answer at all.
\n", "\n", "The first can be done using the *sequence_ids* method on the encoding object. Each *input_ids* sequence has a corresponding *sequence_ids* list which tells us whether a token is part of a question, part of a context, or a special token.
\n", "https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.BatchEncoding.sequence_ids" ], "metadata": { "id": "gFIDFQlTo6yN" } }, { "cell_type": "code", "source": [ "print(x['input_ids'][0])\n", "print(x.sequence_ids(0))" ], "metadata": { "id": "a5at6BbBIZ43" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "So to identify whether a token is part a context, we can use *sequence_ids* to check whether a token position maps to 1." ], "metadata": { "id": "L6aM4eiwIjCU" } }, { "cell_type": "markdown", "source": [ "For the second issue, we can check whether the answer start and end character positions are within the lowest and highest offset mapping values respectively." ], "metadata": { "id": "zGr2Ecv5ICQm" } }, { "cell_type": "code", "source": [ "# We can calculate the answer end character position using the answer length.\n", "answer_end = answer_start + len(answer)\n", "\n", "print(\"Answer start character position:\", answer_start)\n", "print(\"Answer end character position:\", answer_end)\n", "print(\"Answer pulled from context:\", context[answer_start:answer_end])" ], "metadata": { "id": "3-nvMAABG3sS" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Let's find the start and end token positions from our collection of sequences. The full answer is not in the first sequence, but is in the third sequence. So let's experiment with those." ], "metadata": { "id": "mAA_vmugJ1U3" } }, { "cell_type": "code", "source": [ "tokenizer.batch_decode(x['input_ids'])" ], "metadata": { "id": "ybaE6cdrbX2d" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "First get all the information we need for the first sequence." ], "metadata": { "id": "O2IcdQvvMOVN" } }, { "cell_type": "code", "source": [ "input_ids = x['input_ids'][0]\n", "offset_mapping = x['offset_mapping'][0]\n", "seq_ids = x.sequence_ids(0)" ], "metadata": { "id": "ImKRjKc4MSPB" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Determine where the context tokens start and end in the sequence." ], "metadata": { "id": "nUGpuj2gMlHP" } }, { "cell_type": "code", "source": [ "# These are the sequence ids\n", "print(\"Sequence IDs: \", seq_ids)" ], "metadata": { "id": "rdwfloVp06ka" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Get the start index position (i.e. the first occurrence of 1).\n", "context_pos_start = seq_ids.index(1)" ], "metadata": { "id": "XdYF-gAEXTo1" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Utility function to find the *last* occurrence of a sequence.\n", "def rindex(lst, value):\n", " return len(lst) - operator.indexOf(reversed(lst), value) - 1\n", "\n", "# Get the end index position (i.e. the last occurrence of 1).\n", "context_pos_end = rindex(seq_ids, 1)" ], "metadata": { "id": "rEZixsDMk3Ma" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(\"Context tokens begin at position\", context_pos_start)\n", "print(\"Context tokens end at position\", context_pos_end)" ], "metadata": { "id": "5wRAB76PNHi8" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now that we know which tokens are part of the context, we can look at their corresponding offset mappings to check whether the start and end character positions are within the offsets." ], "metadata": { "id": "qWqIyuEFNqag" } }, { "cell_type": "code", "source": [ "# These are the corresponding offsets.\n", "context_offsets = offset_mapping[context_pos_start:context_pos_end+1]\n", "print(context_offsets)" ], "metadata": { "id": "dfjPnHaWNn77" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(\"Is the lowest offset value lower than or equal to the starting character position?\")\n", "print(\"Answer starting character position:\", answer_start)\n", "print(\"First offset:\", context_offsets[0])\n", "\n", "# Note how we're checking the first tuple value.\n", "print(context_offsets[0][0] <= answer_start)" ], "metadata": { "id": "qMhpN7h-l_kE" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(\"Is the highest offset value higher than or equal to the ending character position?\")\n", "print(\"Answer ending character position:\", answer_end)\n", "print(\"Last offset:\", context_offsets[-1])\n", "\n", "# Note how how we're checking the second tuple value.\n", "print(context_offsets[-1][1] >= answer_end)" ], "metadata": { "id": "rTekBBhlyA0y" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "So the first sequence contains a part of the answer but the full answer gets truncated. This matches a visual inspection:" ], "metadata": { "id": "tDmHTRKXPqPj" } }, { "cell_type": "code", "source": [ "print(tokenizer.batch_decode(input_ids))" ], "metadata": { "id": "ewZnKzr8P5QM" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Let's now do the same with the third sequence." ], "metadata": { "id": "nb2TT8CtQNPY" } }, { "cell_type": "code", "source": [ "input_ids = x['input_ids'][2]\n", "offset_mapping = x['offset_mapping'][2]\n", "seq_ids = x.sequence_ids(2)\n", "\n", "context_pos_start = seq_ids.index(1)\n", "context_pos_end = rindex(seq_ids, 1)\n", "\n", "context_offsets = offset_mapping[context_pos_start:context_pos_end+1]\n", "\n", "print(\"Is the lowest offset value lower than or equal to the starting character position?\")\n", "print(\"Answer starting character position:\", answer_start)\n", "print(\"First offset:\", context_offsets[0])\n", "\n", "# Note how we're checking the first tuple value.\n", "print(context_offsets[0][0] <= answer_start)\n", "\n", "print(\"Is the highest offset value higher than or equal to the ending character position?\")\n", "print(\"Answer ending character position:\", answer_end)\n", "print(\"Last offset:\", context_offsets[-1])\n", "\n", "# Note how how we're checking the second tuple value.\n", "print(context_offsets[-1][1] >= answer_end)\n" ], "metadata": { "id": "rC_xDU-piiN1" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now that we've confirmed the third sequence contains the full answer, we need to identify where the answer starts and ends in the *input_ids*. We can do this by scanning the offset_mapping from the left to find the start, and from the right to find the end." ], "metadata": { "id": "h2m1qy-AQ4kT" } }, { "cell_type": "code", "source": [ "s = e = 0\n", "\n", "# Start scanning the offset_mapping from the\n", "# left to find the token position where the answer starts.\n", "# It's not guaranteed a tokenizer will output a token where the\n", "# starting character matches the first answer character. When\n", "# this happens, we take the previous token's position as our start.\n", "i = context_pos_start\n", "while offset_mapping[i][0] < answer_start:\n", " i += 1\n", "if offset_mapping[i][0] == answer_start:\n", " s = i\n", "else:\n", " s = i - 1\n", "\n", "# Same idea when finding the ending token position.\n", "j = context_pos_end\n", "while offset_mapping[j][1] > answer_end:\n", " j -= 1 \n", "if offset_mapping[j][1] == answer_end:\n", " e = j\n", "else:\n", " e = j + 1" ], "metadata": { "id": "CcdK1rOlcRC6" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(\"Answer start token position in context:\", s)\n", "print(\"Answer end token position in context:\", e)" ], "metadata": { "id": "Axs-I9C5f4tH" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(\"Answer lifted from context:\")\n", "tokenizer.batch_decode(input_ids[s:e+1])" ], "metadata": { "id": "r8NDcNTQiruR" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "All the logic we stepped through so far is encapsulated in the following method. We'll use this to process our dataset." ], "metadata": { "id": "OzjLm4LwX2oh" } }, { "cell_type": "code", "source": [ "def prepare_dataset(examples):\n", " # Some tokenizers don't strip spaces. If there happens to be question text\n", " # with excessive spaces, the context may not get encoded at all.\n", " examples[\"question\"] = [q.lstrip() for q in examples[\"question\"]]\n", " examples[\"context\"] = [c.lstrip() for c in examples[\"context\"]]\n", "\n", " # Tokenize. \n", " tokenized_examples = tokenizer(\n", " examples['question'],\n", " examples['context'],\n", " truncation=\"only_second\",\n", " max_length = max_length,\n", " stride=stride,\n", " return_overflowing_tokens=True,\n", " return_offsets_mapping=True,\n", " padding=\"max_length\"\n", " )\n", "\n", " # We'll collect a list of starting positions and ending positions.\n", " tokenized_examples['start_positions'] = []\n", " tokenized_examples['end_positions'] = []\n", "\n", " # Work through every sequence.\n", " for seq_idx in range(len(tokenized_examples['input_ids'])):\n", " seq_ids = tokenized_examples.sequence_ids(seq_idx)\n", " offset_mappings = tokenized_examples['offset_mapping'][seq_idx]\n", "\n", " cur_example_idx = tokenized_examples['overflow_to_sample_mapping'][seq_idx]\n", " answer = examples['answers'][cur_example_idx]\n", " answer_text = answer['text'][0]\n", " answer_start = answer['answer_start'][0]\n", " answer_end = answer_start + len(answer_text)\n", "\n", " context_pos_start = seq_ids.index(1)\n", " context_pos_end = rindex(seq_ids, 1)\n", "\n", " s = e = 0\n", " if (offset_mappings[context_pos_start][0] <= answer_start and\n", " offset_mappings[context_pos_end][1] >= answer_end):\n", " i = context_pos_start\n", " while offset_mappings[i][0] < answer_start:\n", " i += 1\n", " if offset_mappings[i][0] == answer_start:\n", " s = i\n", " else:\n", " s = i - 1\n", "\n", " j = context_pos_end\n", " while offset_mappings[j][1] > answer_end:\n", " j -= 1 \n", " if offset_mappings[j][1] == answer_end:\n", " e = j\n", " else:\n", " e = j + 1\n", "\n", " tokenized_examples['start_positions'].append(s)\n", " tokenized_examples['end_positions'].append(e)\n", "\n", " return tokenized_examples" ], "metadata": { "id": "lz_PEjaZcROj" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Before we process, we'll set maximum sequence length, stride, and batch size values.
\n", "\n", "I arrived at these values through experimentation. Even though *distilroberta-base* has a maximum sequence length of 512, using the full capacity (or a large batch value) results in an out-of-memory error while the attention scores are being calculated. This is on Colab's free tier. On the premium tier, you can use larger sequence lengths or batch values.
\n", "\n", "The nature of the data will also influence the values." ], "metadata": { "id": "8eUOZJuCAJyV" } }, { "cell_type": "code", "source": [ "max_length = 400\n", "stride = 100\n", "batch_size = 32" ], "metadata": { "id": "rPtD9WBD_AhO" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We can map over the **Dataset** objects and apply our prepare method to the examples in batches.
\n", "https://huggingface.co/docs/datasets/main/en/nlp_process
\n", "https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.map
\n", "\n", "*remove_columns* removes the original data columns and leaves only the post-tokenization columns in place. We can also parallelize processing by using the *num_proc* parameter.
\n", "https://huggingface.co/docs/datasets/main/en/process#multiprocessing" ], "metadata": { "id": "DX_c5jEwlgku" } }, { "cell_type": "code", "source": [ "tokenized_datasets = data.map(\n", " prepare_dataset,\n", " batched=True,\n", " remove_columns=data[\"train\"].column_names,\n", " num_proc=2,\n", ")" ], "metadata": { "id": "yXPkihnn-47H" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Our tokenized dataset still contains two entries (*offset_mapping* and *overflow_to_sample_mapping*) our model won't expect, so we'll remove them.
\n", "https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.remove_columns" ], "metadata": { "id": "UKnxf6CnmfSP" } }, { "cell_type": "code", "source": [ "data = tokenized_datasets.remove_columns([\"offset_mapping\", \n", " \"overflow_to_sample_mapping\"])" ], "metadata": { "id": "XDs0Wv4kRCS8" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "The last preparation step is to convert the Hugging Face **Dataset** objects into a Tensorflow-compatible datasets.
\n", "https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.to_tf_dataset
\n", "https://huggingface.co/docs/datasets/main/en/use_with_tensorflow#when-to-use-totfdataset" ], "metadata": { "id": "rMgS2KhMmpUq" } }, { "cell_type": "code", "source": [ "train_set = data['train'].to_tf_dataset(batch_size=batch_size)\n", "validation_set = data['validation'].to_tf_dataset(batch_size=batch_size)" ], "metadata": { "id": "8mRpM6VRIlha" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We can now download a pre-trained model for fine-tuning. Just like we did with the tokenizer, we'll use an Auto Class to download the right model. In this case, we're using **TFAutoModelForQuestionAnswering**. This will download a Tensorflow implementation of the pre-trained model with a question answering head on it.
\n", "\n", "The head in this case is a dense layer that returns *start_logits* and *end_logits*. We can take the argmax of each to determine the start and end of the answer span (see model code for details).
\n", "https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.TFAutoModelForQuestionAnswering
\n", "https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/modeling_tf_roberta.py#L1629" ], "metadata": { "id": "pfVojvWUna1q" } }, { "cell_type": "code", "source": [ "model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)" ], "metadata": { "id": "tMibUfkknZBW" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "The following method attempts to answer a question given a context. It tokenizes the question and context, runs it through the model, takes the argmax of the start and end logits, and uses the result to extract an answer span from the context." ], "metadata": { "id": "VoBZvBiitBg0" } }, { "cell_type": "code", "source": [ "def get_answer(tokenizer, model, question, context):\n", " inputs = tokenizer([question], [context], return_tensors=\"np\")\n", " outputs = model(inputs)\n", " start_position = tf.argmax(outputs.start_logits, axis=1)\n", " end_position = tf.argmax(outputs.end_logits, axis=1)\n", " answer = inputs[\"input_ids\"][0, int(start_position) : int(end_position) + 1]\n", " return tokenizer.decode(answer).strip()" ], "metadata": { "id": "ACBHx7Lz4jDm" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "While the model body (the pre-trained *distilroberta-base* model) is trained, the head is not. So if we try to use our model to answer a question, it should fail or perform poorly (your output will differ because of different initial head weight values)." ], "metadata": { "id": "ig1h3CMlsxR8" } }, { "cell_type": "code", "source": [ "c = \"Sarah went to The Mirthless Cafe last night to meet her friend.\"\n", "q = \"Where did Sarah go?\"\n", "get_answer(tokenizer, model, q, c)" ], "metadata": { "id": "NpbqCksK3frD" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# https://www.tensorflow.org/guide/mixed_precision\n", "keras.mixed_precision.set_global_policy(\"mixed_float16\")\n", "\n", "# Use a learning rate recommended by the BERT authors.\n", "# https://github.com/google-research/bert\n", "model.compile(optimizer=keras.optimizers.Adam(learning_rate=3e-5))" ], "metadata": { "id": "Tp9LzfCHIlua" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We'll now fine-tune the model. Note that we didn't freeze the layers of the pre-trained body, so its weights will be tuned along with the head's weights.
\n", "\n", "Because the body is already pre-trained, we don't need a lot of epochs. 2-4 is typically enough (BERT authors recommend 4). Here, we're using 1 to demonstrate the power of pre-training.
\n", "\n", "**Note:** If you have GPU enabled and you're using Colab's free tier, the training time can be all over the place depending on which GPU you get assigned (anywhere from 20 minutes to an hour)." ], "metadata": { "id": "sgDFKpYQvh_J" } }, { "cell_type": "code", "source": [ "model.fit(train_set, validation_data=validation_set, epochs=1)" ], "metadata": { "id": "Wy0ZbVvoIlx_" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "After completing our fine-tuning, we should now have a decent extractive question answering model." ], "metadata": { "id": "dJpdL5UY1jaV" } }, { "cell_type": "code", "source": [ "c = \"Sarah went to The Mirthless Cafe last night to meet her friend.\"\n", "q = \"Where did Sarah go?\"\n", "get_answer(tokenizer, model, q, c)" ], "metadata": { "id": "k8HcIvIX4y0k" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "q = \"Who did Sarah meet?\"\n", "get_answer(tokenizer, model, q, c)" ], "metadata": { "id": "4tSpF3Xx1T-C" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "q = \"When did Sarah meet her friend?\"\n", "get_answer(tokenizer, model, q, c)" ], "metadata": { "id": "xl5A9bOp1YTT" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "q = \"Who went to the restaurant?\"\n", "get_answer(tokenizer, model, q, c)" ], "metadata": { "id": "t-Cohp9q1YLS" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "But as we saw earlier, extractive question answering has its limits." ], "metadata": { "id": "vquVx3s212l8" } }, { "cell_type": "code", "source": [ "# Asking a logic teaser question is difficult despite the\n", "# answer being available. To be fair, there is ambiguity here.\n", "q = \"Who did Sarah's friend meet?\"\n", "get_answer(tokenizer, model, q, c)" ], "metadata": { "id": "iil1PWaA1YCQ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# The model can't determine when a question can't be \n", "# answered. Some question answering datasets explicitly \n", "# train for this.\n", "q = \"How did Sarah get to the restaurant?\"\n", "get_answer(tokenizer, model, q, c)" ], "metadata": { "id": "FhhLzoGb1X1i" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# The model isn't generative, either.\n", "q = \"What is a possible reason for why Sarah met her friend?\"\n", "get_answer(tokenizer, model, q, c)" ], "metadata": { "id": "dG47Qq7P2dnw" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "But despite this model's limitations, I hope this shows the power of pre-training and how fast you can get something cool and useful up and running.
\n", "\n", "I encourage you to make an account on Hugging Face and push your model to the hub. Learn how here:\n", "https://huggingface.co/docs/transformers/model_sharing\n" ], "metadata": { "id": "NyXBhzQJ29eL" } }, { "cell_type": "markdown", "source": [ "We're just scratching the surface of question answering. Indeed you could dedicate an entire career to it. Areas to explore:\n", "- Right now, we have to supply the context along with the question. A more sophisticated system would load all relevant documents into some database and search over it for an appropriate context/passage and then extract the answer from it. If multiple answers are extracted, then maybe some ranking system can be included as well.\n", "- Another enhancement is extracting answers from different kinds of data. Beyond text, there are images, audio, graphs, tables, charts, etc.\n", "- Abstractive answering involves composing answers (possibly multiple lines) rather than extracting them. Open book means having the system search for the answer first in a database, then composing an answer based on what it's found. Closed book means the model relies on its internal knowledge only. This is what a large language model like GPT-3 would do.\n", "
\n", "\n", "Speaking of GPT-3, let's play with a few more prompts. At this point, we'll switch to the [OpenAI Playground](https://beta.openai.com/playground) and try out the prompts below. Check out the module video for commentary.
\n", "\n", "Before you can try out these prompts yourself, you'll need to open an account an [OpenAI](https://openai.com/api/). At the time of this recording, OpenAI was providing a few dollars of credit to get started. It's more than enough to run the prompts below." ], "metadata": { "id": "vJrXeeql29aS" } }, { "cell_type": "markdown", "source": [ "## GPT-3 prompts to try out:\n", "\n", "---\n", "He said hello
\n", "She said bonjour
\n", "He said goodbye
\n", "She said
\n", "\n", "---\n", "My name is Kilgore. I run a revenge-for-hire business called Delightful Reprisals. I billed my customer the wrong amount. Write an apology email to the customer and sign it with my name and company. Include a reason blaming Elon Musk.\n", "\n", "---\n", "He brought three pies to the office. He gave one to his co-workers, and threw one at his boss' face. How many pies did he have left?\n", "\n", "---\n", "Answer in the style of Jeopardy. He was the 26th President of the United States.\n", "\n", "---\n", "Sarah went to The Mirthless Cafe last night to meet her friend.
\n", "\n", "Print each answer to the following questions on separate lines.
\n", "\n", "Where did Sarah go?
\n", "Who did Sarah meet?
\n", "When did Sarah meet her friend?
\n", "How did Sarah get to the restaurant?
\n", "What's a possible reason why they met?
\n", "\n", "---\n", "What did Marie Antoinette say about avocado toast?\n", "\n", "---\n", "What's a word that rhymes with money?\n", "\n", "---\n", "What's a word that rhymes with Kafkaesque?\n", "\n", "---\n", "Write a haiku about The Terminator.\n", "\n", "---\n", "Write a limerick about The Terminator.\n", "\n", "---" ], "metadata": { "id": "BHicu3dCLzc5" } }, { "cell_type": "markdown", "source": [ "There are open-source generative decoder models available on Hugging Face as well. Among them are:\n", "- [GPT-2](https://huggingface.co/gpt2)\n", "- [EleutherAI/gpt-neo-1.3B](https://huggingface.co/EleutherAI/gpt-neo-1.3B)\n", "- [bigscience/bloom-560m](https://huggingface.co/bigscience/bloom-560m)\n" ], "metadata": { "id": "_-QEobMwcKeA" } }, { "cell_type": "markdown", "source": [ "**Be on the lookout for GPT-4!**" ], "metadata": { "id": "RqdpsjcIM4ik" } }, { "cell_type": "markdown", "source": [ "# Further Exploration" ], "metadata": { "id": "XKC7lzU_0-gM" } }, { "cell_type": "markdown", "source": [ "OpenAI API docs to learn how to build products using their models:
\n", "https://openai.com/api/
\n", "https://beta.openai.com/docs/introduction" ], "metadata": { "id": "ifJdRjQuKlBF" } }, { "cell_type": "markdown", "source": [ "A catalog of transformer models:
\n", "https://amatriain.net/blog/transformer-models-an-introduction-and-catalog-2d1e9039f376/\n" ], "metadata": { "id": "PGmU69mr14LC" } }, { "cell_type": "markdown", "source": [ "Wordpiece and Sentencepiece:
\n", "https://huggingface.co/course/chapter6/6?fw=pt
\n", "https://github.com/google/sentencepiece\n" ], "metadata": { "id": "6GyYhdkjIEx8" } }, { "cell_type": "markdown", "source": [ "**Papers**
\n", "Attention Is All You Need (original Transformer paper): https://arxiv.org/abs/1706.03762
\n", "\n", "The Annotated Transformer: http://nlp.seas.harvard.edu/annotated-transformer/
\n", "\n", "GPT-3: https://arxiv.org/abs/2005.14165
\n", "\n", "BERT: https://arxiv.org/abs/1810.04805
\n", "\n", "RoBERTa paper: https://arxiv.org/abs/1907.11692
\n", "\n", "ALBERT paper: https://arxiv.org/abs/1909.11942
\n", "\n", "DistilBert paper: https://arxiv.org/abs/1910.01108
\n", "\n", "Electra paper: https://arxiv.org/abs/2003.10555
\n", "\n", "XLM: https://arxiv.org/abs/1901.07291
" ], "metadata": { "id": "X4thcEcxIE4B" } } ] }