{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "hX4n9TsbGw-f" }, "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "0nbI5DtDGw-i" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "9TnJztDZGw-n" }, "source": [ "# Text classification with an RNN" ] }, { "cell_type": "markdown", "metadata": { "id": "AfN3bMR5Gw-o" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "lUWearf0Gw-p" }, "source": [ "This text classification tutorial trains a [recurrent neural network](https://developers.google.com/machine-learning/glossary/#recurrent_neural_network) on the [IMDB large movie review dataset](http://ai.stanford.edu/~amaas/data/sentiment/) for sentiment analysis." ] }, { "cell_type": "markdown", "metadata": { "id": "_2VQo4bajwUU" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vH_FAfIz5dEw" }, "outputs": [], "source": [ "!pip install -q tfds-nightly" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z682XYsrjkY9" }, "outputs": [], "source": [ "import tensorflow_datasets as tfds\n", "import tensorflow as tf" ] }, { "cell_type": "markdown", "metadata": { "id": "1rXHa-w9JZhb" }, "source": [ "Import `matplotlib` and create a helper function to plot graphs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mp1Z7P9pYRSK" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "def plot_graphs(history, metric):\n", " plt.plot(history.history[metric])\n", " plt.plot(history.history['val_'+metric], '')\n", " plt.xlabel(\"Epochs\")\n", " plt.ylabel(metric)\n", " plt.legend([metric, 'val_'+metric])\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "pRmMubr0jrE2" }, "source": [ "## Setup input pipeline\n", "\n", "\n", "The IMDB large movie review dataset is a *binary classification* dataset—all the reviews have either a *positive* or *negative* sentiment.\n", "\n", "Download the dataset using [TFDS](https://www.tensorflow.org/datasets).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SHRwRoP2nVHX" }, "outputs": [], "source": [ "dataset, info = tfds.load('imdb_reviews/subwords8k', with_info=True,\n", " as_supervised=True)\n", "train_dataset, test_dataset = dataset['train'], dataset['test']" ] }, { "cell_type": "markdown", "metadata": { "id": "MCorLciXSDJE" }, "source": [ " The dataset `info` includes the encoder (a `tfds.deprecated.text.SubwordTextEncoder`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EplYp5pNnW1S" }, "outputs": [], "source": [ "encoder = info.features['text'].encoder" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e7ACuHM5hFp3" }, "outputs": [], "source": [ "print('Vocabulary size: {}'.format(encoder.vocab_size))" ] }, { "cell_type": "markdown", "metadata": { "id": "tAfGg8YRe6fu" }, "source": [ "This text encoder will reversibly encode any string, falling back to byte-encoding if necessary." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bq6xDmf2SAs-" }, "outputs": [], "source": [ "sample_string = 'Hello TensorFlow.'\n", "\n", "encoded_string = encoder.encode(sample_string)\n", "print('Encoded string is {}'.format(encoded_string))\n", "\n", "original_string = encoder.decode(encoded_string)\n", "print('The original string: \"{}\"'.format(original_string))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TN7QbKaM4-5H" }, "outputs": [], "source": [ "assert original_string == sample_string" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MDVc6UGO5Dh6" }, "outputs": [], "source": [ "for index in encoded_string:\n", " print('{} ----> {}'.format(index, encoder.decode([index])))" ] }, { "cell_type": "markdown", "metadata": { "id": "GlYWqhTVlUyQ" }, "source": [ "## Prepare the data for training" ] }, { "cell_type": "markdown", "metadata": { "id": "z2qVJzcEluH_" }, "source": [ "Next create batches of these encoded strings. Use the `padded_batch` method to zero-pad the sequences to the length of the longest string in the batch:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dDsCaZCDYZgm" }, "outputs": [], "source": [ "BUFFER_SIZE = 10000\n", "BATCH_SIZE = 64" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VznrltNOnUc5" }, "outputs": [], "source": [ "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n", "train_dataset = train_dataset.padded_batch(BATCH_SIZE)\n", "\n", "test_dataset = test_dataset.padded_batch(BATCH_SIZE)" ] }, { "cell_type": "markdown", "metadata": { "id": "bjUqGVBxGw-t" }, "source": [ "## Create the model" ] }, { "cell_type": "markdown", "metadata": { "id": "bgs6nnSTGw-t" }, "source": [ "Build a `tf.keras.Sequential` model and start with an embedding layer. An embedding layer stores one vector per word. When called, it converts the sequences of word indices to sequences of vectors. These vectors are trainable. After training (on enough data), words with similar meanings often have similar vectors.\n", "\n", "This index-lookup is much more efficient than the equivalent operation of passing a one-hot encoded vector through a `tf.keras.layers.Dense` layer.\n", "\n", "A recurrent neural network (RNN) processes sequence input by iterating through the elements. RNNs pass the outputs from one timestep to their input—and then to the next.\n", "\n", "The `tf.keras.layers.Bidirectional` wrapper can also be used with an RNN layer. This propagates the input forward and backwards through the RNN layer and then concatenates the output. This helps the RNN to learn long range dependencies." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LwfoBkmRYcP3" }, "outputs": [], "source": [ "model = tf.keras.Sequential([\n", " tf.keras.layers.Embedding(encoder.vocab_size, 64),\n", " tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),\n", " tf.keras.layers.Dense(64, activation='relu'),\n", " tf.keras.layers.Dense(1)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "QIGmIGkkouUb" }, "source": [ "Please note that we choose to Keras sequential model here since all the layers in the model only have single input and produce single output. In case you want to use stateful RNN layer, you might want to build your model with Keras functional API or model subclassing so that you can retrieve and reuse the RNN layer states. Please check [Keras RNN guide](https://www.tensorflow.org/guide/keras/rnn#rnn_state_reuse) for more details." ] }, { "cell_type": "markdown", "metadata": { "id": "sRI776ZcH3Tf" }, "source": [ "Compile the Keras model to configure the training process:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kj2xei41YZjC" }, "outputs": [], "source": [ "model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", " optimizer=tf.keras.optimizers.Adam(1e-4),\n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "zIwH3nto596k" }, "source": [ "## Train the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hw86wWS4YgR2" }, "outputs": [], "source": [ "history = model.fit(train_dataset, epochs=10,\n", " validation_data=test_dataset, \n", " validation_steps=30)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BaNbXi43YgUT" }, "outputs": [], "source": [ "test_loss, test_acc = model.evaluate(test_dataset)\n", "\n", "print('Test Loss: {}'.format(test_loss))\n", "print('Test Accuracy: {}'.format(test_acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "DwSE_386uhxD" }, "source": [ "The above model does not mask the padding applied to the sequences. This can lead to skew if trained on padded sequences and test on un-padded sequences. Ideally you would [use masking](../../guide/keras/masking_and_padding) to avoid this, but as you can see below it only have a small effect on the output.\n", "\n", "If the prediction is >= 0.5, it is positive else it is negative." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8w0dseJMiEUh" }, "outputs": [], "source": [ "def pad_to_size(vec, size):\n", " zeros = [0] * (size - len(vec))\n", " vec.extend(zeros)\n", " return vec" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-E4cgkIvmVu" }, "outputs": [], "source": [ "def sample_predict(sample_pred_text, pad):\n", " encoded_sample_pred_text = encoder.encode(sample_pred_text)\n", "\n", " if pad:\n", " encoded_sample_pred_text = pad_to_size(encoded_sample_pred_text, 64)\n", " encoded_sample_pred_text = tf.cast(encoded_sample_pred_text, tf.float32)\n", " predictions = model.predict(tf.expand_dims(encoded_sample_pred_text, 0))\n", "\n", " return (predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O41gw3KfWHus" }, "outputs": [], "source": [ "# predict on a sample text without padding.\n", "\n", "sample_pred_text = ('The movie was cool. The animation and the graphics '\n", " 'were out of this world. I would recommend this movie.')\n", "predictions = sample_predict(sample_pred_text, pad=False)\n", "print(predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kFh4xLARucTy" }, "outputs": [], "source": [ "# predict on a sample text with padding\n", "\n", "sample_pred_text = ('The movie was cool. The animation and the graphics '\n", " 'were out of this world. I would recommend this movie.')\n", "predictions = sample_predict(sample_pred_text, pad=True)\n", "print(predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZfIVoxiNmKBF" }, "outputs": [], "source": [ "plot_graphs(history, 'accuracy')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IUzgkqnhmKD2" }, "outputs": [], "source": [ "plot_graphs(history, 'loss')" ] }, { "cell_type": "markdown", "metadata": { "id": "7g1evcaRpTKm" }, "source": [ "## Stack two or more LSTM layers\n", "\n", "Keras recurrent layers have two available modes that are controlled by the `return_sequences` constructor argument:\n", "\n", "* Return either the full sequences of successive outputs for each timestep (a 3D tensor of shape `(batch_size, timesteps, output_features)`).\n", "* Return only the last output for each input sequence (a 2D tensor of shape (batch_size, output_features))." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jo1jjO3vn0jo" }, "outputs": [], "source": [ "model = tf.keras.Sequential([\n", " tf.keras.layers.Embedding(encoder.vocab_size, 64),\n", " tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)),\n", " tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),\n", " tf.keras.layers.Dense(64, activation='relu'),\n", " tf.keras.layers.Dropout(0.5),\n", " tf.keras.layers.Dense(1)\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hEPV5jVGp-is" }, "outputs": [], "source": [ "model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", " optimizer=tf.keras.optimizers.Adam(1e-4),\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LeSE-YjdqAeN" }, "outputs": [], "source": [ "history = model.fit(train_dataset, epochs=10,\n", " validation_data=test_dataset,\n", " validation_steps=30)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_LdwilM1qPM3" }, "outputs": [], "source": [ "test_loss, test_acc = model.evaluate(test_dataset)\n", "\n", "print('Test Loss: {}'.format(test_loss))\n", "print('Test Accuracy: {}'.format(test_acc))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ykUKnAoqbycW" }, "outputs": [], "source": [ "# predict on a sample text without padding.\n", "\n", "sample_pred_text = ('The movie was not good. The animation and the graphics '\n", " 'were terrible. I would not recommend this movie.')\n", "predictions = sample_predict(sample_pred_text, pad=False)\n", "print(predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2RiC-94zvdZO" }, "outputs": [], "source": [ "# predict on a sample text with padding\n", "\n", "sample_pred_text = ('The movie was not good. The animation and the graphics '\n", " 'were terrible. I would not recommend this movie.')\n", "predictions = sample_predict(sample_pred_text, pad=True)\n", "print(predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_YYub0EDtwCu" }, "outputs": [], "source": [ "plot_graphs(history, 'accuracy')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DPV3Nn9xtwFM" }, "outputs": [], "source": [ "plot_graphs(history, 'loss')" ] }, { "cell_type": "markdown", "metadata": { "id": "9xvpE3BaGw_V" }, "source": [ "Check out other existing recurrent layers such as [GRU layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers/GRU).\n", "\n", "If you're interestied in building custom RNNs, see the [Keras RNN Guide](../../guide/keras/rnn.ipynb).\n" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "text_classification_rnn.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }