{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "vY4SK0xKAJgm" }, "source": [ "# Model Zoo -- RNN with LSTM with Own Dataset" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "sc6xejhY-NzZ" }, "source": [ "Example notebook showing how to use an own CSV text dataset for training a simple RNN for sentiment classification (here: a binary classification problem with two labels, positive and negative) using LSTM (Long Short Term Memory) cells." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "moNmVfuvnImW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.6.8\n", "IPython 7.2.0\n", "\n", "torch 1.0.1.post2\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "from torchtext import data\n", "from torchtext import datasets\n", "import time\n", "import random\n", "import pandas as pd\n", "\n", "torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "GSRL42Qgy8I8" }, "source": [ "## General Settings" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "OvW1RgfepCBq" }, "outputs": [], "source": [ "RANDOM_SEED = 123\n", "torch.manual_seed(RANDOM_SEED)\n", "\n", "VOCABULARY_SIZE = 20000\n", "LEARNING_RATE = 1e-4\n", "BATCH_SIZE = 128\n", "NUM_EPOCHS = 15\n", "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "EMBEDDING_DIM = 128\n", "HIDDEN_DIM = 256\n", "OUTPUT_DIM = 1" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "mQMmKUEisW4W" }, "source": [ "## Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following cells will download the IMDB movie review dataset (http://ai.stanford.edu/~amaas/data/sentiment/) for positive-negative sentiment classification in as CSV-formatted file:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2019-11-28 19:47:46-- https://github.com/rasbt/python-machine-learning-book-2nd-edition/raw/master/code/ch08/movie_data.csv.gz\n", "Resolving github.com (github.com)... 140.82.113.3\n", "Connecting to github.com (github.com)|140.82.113.3|:443... connected.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://raw.githubusercontent.com/rasbt/python-machine-learning-book-2nd-edition/master/code/ch08/movie_data.csv.gz [following]\n", "--2019-11-28 19:47:46-- https://raw.githubusercontent.com/rasbt/python-machine-learning-book-2nd-edition/master/code/ch08/movie_data.csv.gz\n", "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.184.133\n", "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.184.133|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 26521894 (25M) [application/octet-stream]\n", "Saving to: ‘movie_data.csv.gz’\n", "\n", "movie_data.csv.gz 100%[===================>] 25.29M 10.5MB/s in 2.4s \n", "\n", "2019-11-28 19:47:49 (10.5 MB/s) - ‘movie_data.csv.gz’ saved [26521894/26521894]\n", "\n" ] } ], "source": [ "!wget https://github.com/rasbt/python-machine-learning-book-2nd-edition/raw/master/code/ch08/movie_data.csv.gz" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "!gunzip -f movie_data.csv.gz " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check that the dataset looks okay:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
reviewsentiment
0In 1974, the teenager Martha Moxley (Maggie Gr...1
1OK... so... I really like Kris Kristofferson a...0
2***SPOILER*** Do not read this, if you think a...0
3hi for all the people who have seen this wonde...1
4I recently bought the DVD, forgetting just how...0
\n", "
" ], "text/plain": [ " review sentiment\n", "0 In 1974, the teenager Martha Moxley (Maggie Gr... 1\n", "1 OK... so... I really like Kris Kristofferson a... 0\n", "2 ***SPOILER*** Do not read this, if you think a... 0\n", "3 hi for all the people who have seen this wonde... 1\n", "4 I recently bought the DVD, forgetting just how... 0" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('movie_data.csv')\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "del df" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "4GnH64XvsV8n" }, "source": [ "Define the Label and Text field formatters:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "TEXT = data.Field(sequential=True,\n", " tokenize='spacy',\n", " include_lengths=True) # necessary for packed_padded_sequence\n", "\n", "LABEL = data.LabelField(dtype=torch.float)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Process the dataset:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "fields = [('review', TEXT), ('sentiment', LABEL)]\n", "\n", "dataset = data.TabularDataset(\n", " path=\"movie_data.csv\", format='csv',\n", " skip_header=True, fields=fields)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Split the dataset into training, validation, and test partitions:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, "colab_type": "code", "id": "WZ_4jiHVnMxN", "outputId": "dfa51c04-4845-44c3-f50b-d36d41f132b8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Num Train: 37500\n", "Num Valid: 10000\n", "Num Test: 2500\n" ] } ], "source": [ "train_data, valid_data, test_data = dataset.split(\n", " split_ratio=[0.75, 0.05, 0.2],\n", " random_state=random.seed(RANDOM_SEED))\n", "\n", "print(f'Num Train: {len(train_data)}')\n", "print(f'Num Valid: {len(valid_data)}')\n", "print(f'Num Test: {len(test_data)}')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "L-TBwKWPslPa" }, "source": [ "Build the vocabulary based on the top \"VOCABULARY_SIZE\" words:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "e8uNrjdtn4A8", "outputId": "6cf499d7-7722-4da0-8576-ee0f218cc6e3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Vocabulary size: 20002\n", "Number of classes: 2\n" ] } ], "source": [ "TEXT.build_vocab(train_data, max_size=VOCABULARY_SIZE)\n", "LABEL.build_vocab(train_data)\n", "\n", "print(f'Vocabulary size: {len(TEXT.vocab)}')\n", "print(f'Number of classes: {len(LABEL.vocab)}')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Counter({'0': 18742, '1': 18758})" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "LABEL.vocab.freqs" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "JpEMNInXtZsb" }, "source": [ "The TEXT.vocab dictionary will contain the word counts and indices. The reason why the number of words is VOCABULARY_SIZE + 2 is that it contains to special tokens for padding and unknown words: `` and ``." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "eIQ_zfKLwjKm" }, "source": [ "Make dataset iterators:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": {}, "colab_type": "code", "id": "i7JiHR1stHNF" }, "outputs": [], "source": [ "train_loader, valid_loader, test_loader = data.BucketIterator.splits(\n", " (train_data, valid_data, test_data), \n", " batch_size=BATCH_SIZE,\n", " sort_within_batch=True, # necessary for packed_padded_sequence\n", " sort_key=lambda x: len(x.review),\n", " device=DEVICE)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "R0pT_dMRvicQ" }, "source": [ "Testing the iterators (note that the number of rows depends on the longest document in the respective batch):" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 204 }, "colab_type": "code", "id": "y8SP_FccutT0", "outputId": "fe33763a-4560-4dee-adee-31cc6c48b0b2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train\n", "Text matrix size: torch.Size([512, 128])\n", "Target vector size: torch.Size([128])\n", "\n", "Valid:\n", "Text matrix size: torch.Size([52, 128])\n", "Target vector size: torch.Size([128])\n", "\n", "Test:\n", "Text matrix size: torch.Size([75, 128])\n", "Target vector size: torch.Size([128])\n" ] } ], "source": [ "print('Train')\n", "for batch in train_loader:\n", " print(f'Text matrix size: {batch.review[0].size()}')\n", " print(f'Target vector size: {batch.sentiment.size()}')\n", " break\n", " \n", "print('\\nValid:')\n", "for batch in valid_loader:\n", " print(f'Text matrix size: {batch.review[0].size()}')\n", " print(f'Target vector size: {batch.sentiment.size()}')\n", " break\n", " \n", "print('\\nTest:')\n", "for batch in test_loader:\n", " print(f'Text matrix size: {batch.review[0].size()}')\n", " print(f'Target vector size: {batch.sentiment.size()}')\n", " break" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "G_grdW3pxCzz" }, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": {}, "colab_type": "code", "id": "nQIUm5EjxFNa" }, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "class RNN(nn.Module):\n", " def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):\n", " \n", " super().__init__()\n", " \n", " self.embedding = nn.Embedding(input_dim, embedding_dim)\n", " self.rnn = nn.LSTM(embedding_dim, hidden_dim)\n", " self.fc = nn.Linear(hidden_dim, output_dim)\n", " \n", " def forward(self, text, text_length):\n", "\n", " #[sentence len, batch size] => [sentence len, batch size, embedding size]\n", " embedded = self.embedding(text)\n", " \n", " packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, text_length)\n", " \n", " #[sentence len, batch size, embedding size] => \n", " # output: [sentence len, batch size, hidden size]\n", " # hidden: [1, batch size, hidden size]\n", " packed_output, (hidden, cell) = self.rnn(packed)\n", " \n", " return self.fc(hidden.squeeze(0)).view(-1)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": {}, "colab_type": "code", "id": "Ik3NF3faxFmZ" }, "outputs": [], "source": [ "INPUT_DIM = len(TEXT.vocab)\n", "\n", "torch.manual_seed(RANDOM_SEED)\n", "model = RNN(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)\n", "model = model.to(DEVICE)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Lv9Ny9di6VcI" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": {}, "colab_type": "code", "id": "T5t1Afn4xO11" }, "outputs": [], "source": [ "def compute_binary_accuracy(model, data_loader, device):\n", " model.eval()\n", " correct_pred, num_examples = 0, 0\n", " with torch.no_grad():\n", " for batch_idx, batch_data in enumerate(data_loader):\n", " text, text_lengths = batch_data.review\n", " logits = model(text, text_lengths)\n", " predicted_labels = (torch.sigmoid(logits) > 0.5).long()\n", " num_examples += batch_data.sentiment.size(0)\n", " correct_pred += (predicted_labels.long() == batch_data.sentiment.long()).sum()\n", " return correct_pred.float()/num_examples * 100" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1836 }, "colab_type": "code", "id": "EABZM8Vo0ilB", "outputId": "5d45e293-9909-4588-e793-8dfaf72e5c67" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/015 | Batch 000/293 | Cost: 0.6948\n", "Epoch: 001/015 | Batch 050/293 | Cost: 0.6868\n", "Epoch: 001/015 | Batch 100/293 | Cost: 0.6926\n", "Epoch: 001/015 | Batch 150/293 | Cost: 0.6788\n", "Epoch: 001/015 | Batch 200/293 | Cost: 0.6838\n", "Epoch: 001/015 | Batch 250/293 | Cost: 0.6563\n", "training accuracy: 64.83%\n", "valid accuracy: 64.88%\n", "Time elapsed: 0.29 min\n", "Epoch: 002/015 | Batch 000/293 | Cost: 0.5635\n", "Epoch: 002/015 | Batch 050/293 | Cost: 0.6154\n", "Epoch: 002/015 | Batch 100/293 | Cost: 0.5449\n", "Epoch: 002/015 | Batch 150/293 | Cost: 0.6161\n", "Epoch: 002/015 | Batch 200/293 | Cost: 0.5794\n", "Epoch: 002/015 | Batch 250/293 | Cost: 0.5190\n", "training accuracy: 75.81%\n", "valid accuracy: 75.02%\n", "Time elapsed: 0.57 min\n", "Epoch: 003/015 | Batch 000/293 | Cost: 0.5194\n", "Epoch: 003/015 | Batch 050/293 | Cost: 0.4679\n", "Epoch: 003/015 | Batch 100/293 | Cost: 0.5069\n", "Epoch: 003/015 | Batch 150/293 | Cost: 0.4728\n", "Epoch: 003/015 | Batch 200/293 | Cost: 0.4180\n", "Epoch: 003/015 | Batch 250/293 | Cost: 0.3722\n", "training accuracy: 77.14%\n", "valid accuracy: 76.48%\n", "Time elapsed: 0.85 min\n", "Epoch: 004/015 | Batch 000/293 | Cost: 0.4978\n", "Epoch: 004/015 | Batch 050/293 | Cost: 0.4959\n", "Epoch: 004/015 | Batch 100/293 | Cost: 0.4877\n", "Epoch: 004/015 | Batch 150/293 | Cost: 0.4808\n", "Epoch: 004/015 | Batch 200/293 | Cost: 0.4264\n", "Epoch: 004/015 | Batch 250/293 | Cost: 0.3528\n", "training accuracy: 82.63%\n", "valid accuracy: 80.89%\n", "Time elapsed: 1.14 min\n", "Epoch: 005/015 | Batch 000/293 | Cost: 0.3676\n", "Epoch: 005/015 | Batch 050/293 | Cost: 0.3325\n", "Epoch: 005/015 | Batch 100/293 | Cost: 0.4878\n", "Epoch: 005/015 | Batch 150/293 | Cost: 0.4481\n", "Epoch: 005/015 | Batch 200/293 | Cost: 0.4147\n", "Epoch: 005/015 | Batch 250/293 | Cost: 0.4270\n", "training accuracy: 84.73%\n", "valid accuracy: 82.78%\n", "Time elapsed: 1.42 min\n", "Epoch: 006/015 | Batch 000/293 | Cost: 0.4143\n", "Epoch: 006/015 | Batch 050/293 | Cost: 0.4586\n", "Epoch: 006/015 | Batch 100/293 | Cost: 0.3946\n", "Epoch: 006/015 | Batch 150/293 | Cost: 0.3729\n", "Epoch: 006/015 | Batch 200/293 | Cost: 0.3584\n", "Epoch: 006/015 | Batch 250/293 | Cost: 0.4089\n", "training accuracy: 86.25%\n", "valid accuracy: 84.17%\n", "Time elapsed: 1.71 min\n", "Epoch: 007/015 | Batch 000/293 | Cost: 0.3147\n", "Epoch: 007/015 | Batch 050/293 | Cost: 0.3494\n", "Epoch: 007/015 | Batch 100/293 | Cost: 0.2743\n", "Epoch: 007/015 | Batch 150/293 | Cost: 0.3913\n", "Epoch: 007/015 | Batch 200/293 | Cost: 0.2999\n", "Epoch: 007/015 | Batch 250/293 | Cost: 0.2530\n", "training accuracy: 86.61%\n", "valid accuracy: 84.16%\n", "Time elapsed: 1.99 min\n", "Epoch: 008/015 | Batch 000/293 | Cost: 0.3180\n", "Epoch: 008/015 | Batch 050/293 | Cost: 0.3589\n", "Epoch: 008/015 | Batch 100/293 | Cost: 0.3230\n", "Epoch: 008/015 | Batch 150/293 | Cost: 0.3192\n", "Epoch: 008/015 | Batch 200/293 | Cost: 0.3328\n", "Epoch: 008/015 | Batch 250/293 | Cost: 0.2283\n", "training accuracy: 87.09%\n", "valid accuracy: 84.59%\n", "Time elapsed: 2.29 min\n", "Epoch: 009/015 | Batch 000/293 | Cost: 0.3429\n", "Epoch: 009/015 | Batch 050/293 | Cost: 0.3042\n", "Epoch: 009/015 | Batch 100/293 | Cost: 0.2704\n", "Epoch: 009/015 | Batch 150/293 | Cost: 0.2430\n", "Epoch: 009/015 | Batch 200/293 | Cost: 0.4137\n", "Epoch: 009/015 | Batch 250/293 | Cost: 0.1736\n", "training accuracy: 74.11%\n", "valid accuracy: 72.36%\n", "Time elapsed: 2.59 min\n", "Epoch: 010/015 | Batch 000/293 | Cost: 0.5759\n", "Epoch: 010/015 | Batch 050/293 | Cost: 0.4807\n", "Epoch: 010/015 | Batch 100/293 | Cost: 0.2686\n", "Epoch: 010/015 | Batch 150/293 | Cost: 0.3420\n", "Epoch: 010/015 | Batch 200/293 | Cost: 0.2759\n", "Epoch: 010/015 | Batch 250/293 | Cost: 0.3928\n", "training accuracy: 89.58%\n", "valid accuracy: 86.27%\n", "Time elapsed: 2.88 min\n", "Epoch: 011/015 | Batch 000/293 | Cost: 0.2417\n", "Epoch: 011/015 | Batch 050/293 | Cost: 0.3175\n", "Epoch: 011/015 | Batch 100/293 | Cost: 0.2029\n", "Epoch: 011/015 | Batch 150/293 | Cost: 0.2389\n", "Epoch: 011/015 | Batch 200/293 | Cost: 0.3107\n", "Epoch: 011/015 | Batch 250/293 | Cost: 0.3486\n", "training accuracy: 90.21%\n", "valid accuracy: 86.52%\n", "Time elapsed: 3.17 min\n", "Epoch: 012/015 | Batch 000/293 | Cost: 0.2540\n", "Epoch: 012/015 | Batch 050/293 | Cost: 0.2851\n", "Epoch: 012/015 | Batch 100/293 | Cost: 0.1901\n", "Epoch: 012/015 | Batch 150/293 | Cost: 0.2286\n", "Epoch: 012/015 | Batch 200/293 | Cost: 0.3239\n", "Epoch: 012/015 | Batch 250/293 | Cost: 0.2856\n", "training accuracy: 90.72%\n", "valid accuracy: 86.78%\n", "Time elapsed: 3.47 min\n", "Epoch: 013/015 | Batch 000/293 | Cost: 0.1913\n", "Epoch: 013/015 | Batch 050/293 | Cost: 0.2547\n", "Epoch: 013/015 | Batch 100/293 | Cost: 0.3984\n", "Epoch: 013/015 | Batch 150/293 | Cost: 0.2294\n", "Epoch: 013/015 | Batch 200/293 | Cost: 0.2692\n", "Epoch: 013/015 | Batch 250/293 | Cost: 0.2132\n", "training accuracy: 91.51%\n", "valid accuracy: 87.13%\n", "Time elapsed: 3.76 min\n", "Epoch: 014/015 | Batch 000/293 | Cost: 0.1699\n", "Epoch: 014/015 | Batch 050/293 | Cost: 0.2611\n", "Epoch: 014/015 | Batch 100/293 | Cost: 0.2594\n", "Epoch: 014/015 | Batch 150/293 | Cost: 0.2062\n", "Epoch: 014/015 | Batch 200/293 | Cost: 0.2608\n", "Epoch: 014/015 | Batch 250/293 | Cost: 0.2881\n", "training accuracy: 91.43%\n", "valid accuracy: 86.93%\n", "Time elapsed: 4.05 min\n", "Epoch: 015/015 | Batch 000/293 | Cost: 0.2522\n", "Epoch: 015/015 | Batch 050/293 | Cost: 0.2753\n", "Epoch: 015/015 | Batch 100/293 | Cost: 0.2322\n", "Epoch: 015/015 | Batch 150/293 | Cost: 0.2361\n", "Epoch: 015/015 | Batch 200/293 | Cost: 0.3728\n", "Epoch: 015/015 | Batch 250/293 | Cost: 0.2895\n", "training accuracy: 89.71%\n", "valid accuracy: 85.54%\n", "Time elapsed: 4.34 min\n", "Total Training Time: 4.34 min\n", "Test accuracy: 86.88%\n" ] } ], "source": [ "start_time = time.time()\n", "\n", "for epoch in range(NUM_EPOCHS):\n", " model.train()\n", " for batch_idx, batch_data in enumerate(train_loader):\n", " \n", " text, text_lengths = batch_data.review\n", " \n", " ### FORWARD AND BACK PROP\n", " logits = model(text, text_lengths)\n", " cost = F.binary_cross_entropy_with_logits(logits, batch_data.sentiment)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 50:\n", " print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n", " f'Batch {batch_idx:03d}/{len(train_loader):03d} | '\n", " f'Cost: {cost:.4f}')\n", "\n", " with torch.set_grad_enabled(False):\n", " print(f'training accuracy: '\n", " f'{compute_binary_accuracy(model, train_loader, DEVICE):.2f}%'\n", " f'\\nvalid accuracy: '\n", " f'{compute_binary_accuracy(model, valid_loader, DEVICE):.2f}%')\n", " \n", " print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')\n", " \n", "print(f'Total Training Time: {(time.time() - start_time)/60:.2f} min')\n", "print(f'Test accuracy: {compute_binary_accuracy(model, test_loader, DEVICE):.2f}%')" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": {}, "colab_type": "code", "id": "jt55pscgFdKZ" }, "outputs": [], "source": [ "import spacy\n", "nlp = spacy.load('en')\n", "\n", "def predict_sentiment(model, sentence):\n", " # based on:\n", " # https://github.com/bentrevett/pytorch-sentiment-analysis/blob/\n", " # master/2%20-%20Upgraded%20Sentiment%20Analysis.ipynb\n", " model.eval()\n", " tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n", " indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n", " length = [len(indexed)]\n", " tensor = torch.LongTensor(indexed).to(DEVICE)\n", " tensor = tensor.unsqueeze(1)\n", " length_tensor = torch.LongTensor(length)\n", " prediction = torch.sigmoid(model(tensor, length_tensor))\n", " return prediction.item()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "O4__q0coFJyw", "outputId": "1a7f84ba-a977-455e-e248-3b7036d496d0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Probability positive:\n" ] }, { "data": { "text/plain": [ "0.8258040845394135" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print('Probability positive:')\n", "1-predict_sentiment(model, \"This is such an awesome movie, I really love it!\")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Probability negative:\n" ] }, { "data": { "text/plain": [ "0.8462136387825012" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print('Probability negative:')\n", "predict_sentiment(model, \"I really hate this movie. It is really bad and sucks!\")" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": {}, "colab_type": "code", "id": "7lRusB3dF80X" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch 1.0.1.post2\n", "pandas 0.23.4\n", "spacy 2.0.16\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "rnn_lstm_packed_imdb.ipynb", "provenance": [], "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }