{ "cells": [ { "cell_type": "markdown", "id": "13165aa1-fe26-481f-bb08-3d1fac826041", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": "73ec4c5e-654b-4690-8e78-4741d28639c3", "diskcache": false, "headerColor": "inherit", "id": "b85b379c-83f2-4c6c-996c-ded984bab2c8", "isComponent": false, "name": "", "parents": [] }, "tags": [] }, "source": [ "# Sequential Data Classification\n", "Adapted from : https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/2_lstm.ipynb\n", "- Link Component Color Annotations\n", " - Yellow : data load / preprocessing\n", " - Violet : model train / predict" ] }, { "cell_type": "markdown", "id": "ace54a3c-e4ad-4b0b-9d9a-697742de73b7", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "inherit", "id": "b57118ae-bfa2-46fe-8b9d-fee2d726c371", "isComponent": false, "name": "", "parents": [] } }, "source": [ "### Required Python Packages\n", "- `tqdm`\n", "- `torch`\n", "- `torchtext`\n", "- `datasets`\n", "- `matplotlib`\n", "\n", "Run the following cell to install the packages." ] }, { "cell_type": "code", "execution_count": null, "id": "bd51f254-866d-444a-abb9-9de45249c3c0", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "inherit", "id": "c83a6d5a-1d50-441e-8dbd-e69bb82b6bde", "isComponent": false, "name": "", "parents": [] } }, "outputs": [], "source": [ "#\n", "# Required Packages\n", "# Run this cell to install required packages.\n", "#\n", "%pip install \"datasets>=2.2\" \"matplotlib>=2.0\" \"torch>=1.9\" \"torchtext>=0.12\" \"tqdm>=4.64\" " ] }, { "cell_type": "markdown", "id": "9f57ecf9-db3e-4ec7-8c83-89a07b760fdb", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": "bf338eaf-cf47-4680-bc99-9cf3bd0fa5fd", "diskcache": false, "headerColor": "inherit", "id": "a5334a1c-ddda-4f4b-8f5b-6f40dc035d3d", "isComponent": false, "name": "", "parents": [] } }, "source": [ "### 0. Global Parameters\n", "- global paprameter of link pipeline\n", " - seed : torch seed\n", " - max_length : max sequence length\n", " - test_size : validayion ratio\n", " - min_freq : minimum frequency of token in vocabulary\n", " - embedding_dim : embedding dimension size of embedding layer \n", " - hidden_dim : hidden size of RNN layer\n", " - lr : train learning rate\n", " - batch_size : train batch size\n", " - n_epochs : train epoch" ] }, { "cell_type": "markdown", "id": "f0d86820-9565-4f47-90a0-da2e0b982cc2", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": "a16d8c78-8820-4f89-87c5-dfbdd31e7189", "diskcache": false, "headerColor": "inherit", "id": "f0b3d714-4564-405a-8034-6ce10eda9622", "isComponent": false, "name": "", "parents": [] } }, "source": ["### 1. Load package,data"] }, { "cell_type": "code", "execution_count": null, "id": "b672f09d-fcc2-4331-bb83-3e4483673e87", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#FAFF00", "id": "8d792aa6-cac6-4ba6-a4a4-1a5566cfbb81", "isComponent": true, "name": "Import packages", "parents": [] }, "tags": [] }, "outputs": [], "source": [ "import functools\n", "import sys\n", "\n", "import datasets\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torchtext\n", "import tqdm\n", "from datasets import Dataset, DatasetDict\n", "\n", "_ = torch.manual_seed(seed)" ] }, { "cell_type": "code", "execution_count": null, "id": "2685f1a6-8e78-4479-aef7-b5bf36b5b2d0", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#FAFF00", "id": "e6449370-f87c-470a-813c-6686f1988e13", "isComponent": true, "name": "Load data", "parents": [ { "id": "8d792aa6-cac6-4ba6-a4a4-1a5566cfbb81", "name": "Import packages" } ] }, "tags": [] }, "outputs": [], "source": [ "train_data, test_data = datasets.load_dataset(\"imdb\", split=[\"train\", \"test\"])" ] }, { "cell_type": "markdown", "id": "b329be11-5e4d-4083-9eaa-3a170aceb753", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "inherit", "id": "14b952e3-b3ad-429b-be52-9b1954883f27", "isComponent": false, "name": "", "parents": [] }, "tags": [] }, "source": ["### 2. Prepare Data"] }, { "cell_type": "code", "execution_count": null, "id": "6c4e6dfb-f6ae-4341-9e84-8f9885bed7f3", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#FAFF00", "id": "c7543bc0-670a-49bd-a070-25c610f2c36f", "isComponent": true, "name": "Load tokenizer", "parents": [ { "id": "8d792aa6-cac6-4ba6-a4a4-1a5566cfbb81", "name": "Import packages" } ] }, "tags": [] }, "outputs": [], "source": [ "tokenizer = torchtext.data.utils.get_tokenizer(\"basic_english\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d53a51eb-8730-41cd-8e85-d084ffa3f070", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#FAFF00", "id": "3568006d-7cb0-47e4-b56a-59c1d5ed5f9d", "isComponent": true, "name": "Tokenize data", "parents": [ { "id": "e6449370-f87c-470a-813c-6686f1988e13", "name": "Load data" }, { "id": "c7543bc0-670a-49bd-a070-25c610f2c36f", "name": "Load tokenizer" } ] }, "tags": [] }, "outputs": [], "source": [ "def tokenize_data(example, tokenizer, max_length):\n", " tokens = tokenizer(example[\"text\"])[:max_length]\n", " length = len(tokens)\n", " return {\"tokens\": tokens, \"length\": length}\n", "\n", "\n", "train_data = train_data.map(tokenize_data, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length})\n", "test_data = test_data.map(tokenize_data, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length})" ] }, { "cell_type": "code", "execution_count": null, "id": "5f520507-a21c-4701-a8ed-47054ed3ad49", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#FAFF00", "id": "2f143cce-7b1e-4d14-b3ad-93ba9afc436b", "isComponent": true, "name": "Sampling and split data", "parents": [ { "id": "3568006d-7cb0-47e4-b56a-59c1d5ed5f9d", "name": "Tokenize data" } ] }, "tags": [] }, "outputs": [], "source": [ "train_data_df = Dataset.to_pandas(train_data).sample(n=3000)\n", "train_data = Dataset.from_pandas(train_data_df)\n", "\n", "test_data_df = Dataset.to_pandas(test_data).sample(n=2000)\n", "test_data = Dataset.from_pandas(test_data_df)\n", "\n", "train_valid_data = train_data.train_test_split(test_size=test_size)\n", "train_data = train_valid_data[\"train\"]\n", "valid_data = train_valid_data[\"test\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "57d7f30e-c940-4c9c-b780-d567331a773a", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#FAFF00", "id": "87489719-b806-424b-90b4-55bbfadd1102", "isComponent": true, "name": "Set vocab", "parents": [ { "id": "2f143cce-7b1e-4d14-b3ad-93ba9afc436b", "name": "Sampling and split data" } ] }, "tags": [] }, "outputs": [], "source": [ "special_tokens = [\"\", \"\"]\n", "\n", "vocab = torchtext.vocab.build_vocab_from_iterator(\n", " train_data[\"tokens\"],\n", " min_freq=min_freq,\n", " specials=special_tokens,\n", ")\n", "\n", "unk_index = vocab[\"\"]\n", "pad_index = vocab[\"\"]\n", "\n", "vocab.set_default_index(unk_index)" ] }, { "cell_type": "code", "execution_count": null, "id": "1694f0f3-b040-49d9-8948-ab400f94aadc", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#FAFF00", "id": "77e482c4-435f-40e3-832f-7b8487dec28e", "isComponent": true, "name": "Preprocessing data", "parents": [ { "id": "87489719-b806-424b-90b4-55bbfadd1102", "name": "Set vocab" } ] }, "tags": [] }, "outputs": [], "source": [ "def numericalize_data(example, vocab):\n", " ids = [vocab[token] for token in example[\"tokens\"]]\n", " return {\"ids\": ids}\n", "\n", "\n", "train_data = train_data.map(numericalize_data, fn_kwargs={\"vocab\": vocab})\n", "valid_data = valid_data.map(numericalize_data, fn_kwargs={\"vocab\": vocab})\n", "test_data = test_data.map(numericalize_data, fn_kwargs={\"vocab\": vocab})\n", "\n", "train_data = train_data.with_format(type=\"torch\", columns=[\"ids\", \"label\", \"length\"])\n", "valid_data = valid_data.with_format(type=\"torch\", columns=[\"ids\", \"label\", \"length\"])\n", "test_data = test_data.with_format(type=\"torch\", columns=[\"ids\", \"label\", \"length\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "3a3881e6-3d16-4995-b977-5e6900454070", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#FAFF00", "id": "24f8e6d3-56bf-4f50-8561-2408cc91ce4d", "isComponent": true, "name": "Set dataloader", "parents": [ { "id": "77e482c4-435f-40e3-832f-7b8487dec28e", "name": "Preprocessing data" } ] }, "tags": [] }, "outputs": [], "source": [ "def collate(batch, pad_index):\n", " batch_ids = [i[\"ids\"] for i in batch]\n", " batch_ids = nn.utils.rnn.pad_sequence(batch_ids, padding_value=pad_index, batch_first=True)\n", " batch_length = [i[\"length\"] for i in batch]\n", " batch_length = torch.stack(batch_length)\n", " batch_label = [i[\"label\"] for i in batch]\n", " batch_label = torch.stack(batch_label)\n", " batch = {\"ids\": batch_ids, \"length\": batch_length, \"label\": batch_label}\n", " return batch\n", "\n", "\n", "collate = functools.partial(collate, pad_index=pad_index)\n", "\n", "train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, collate_fn=collate, shuffle=True)\n", "\n", "valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, collate_fn=collate)\n", "test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, collate_fn=collate)" ] }, { "cell_type": "markdown", "id": "279ad833-6a00-442a-b8a8-2b48cdb93cee", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "inherit", "id": "3c3814f9-c014-4d53-86ce-d73f3a8cff5c", "isComponent": false, "name": "", "parents": [] } }, "source": ["### 3. Modeling"] }, { "cell_type": "code", "execution_count": null, "id": "840f8f40-e433-40c5-af35-67747d50fe0b", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#6C00FF", "id": "c9662ed2-0038-442e-86b9-42a8df99ae79", "isComponent": true, "name": "Define model", "parents": [ { "id": "24f8e6d3-56bf-4f50-8561-2408cc91ce4d", "name": "Set dataloader" } ] }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": ["The model has 1,809,398 trainable parameters\n"] } ], "source": [ "class RNN(nn.Module):\n", " def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim, pad_index):\n", "\n", " super().__init__()\n", "\n", " self.embedding = nn.Embedding(input_dim, embedding_dim)\n", "\n", " self.rnn = nn.RNN(embedding_dim, hidden_dim)\n", "\n", " self.fc = nn.Linear(hidden_dim, output_dim)\n", "\n", " def forward(self, ids, length):\n", "\n", " # text = [sent len, batch size]\n", "\n", " embedded = self.embedding(ids)\n", "\n", " # embedded = [sent len, batch size, emb dim]\n", "\n", " packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, length, batch_first=True, enforce_sorted=False)\n", "\n", " packed_output, hidden = self.rnn(packed_embedded)\n", "\n", " output, output_length = nn.utils.rnn.pad_packed_sequence(packed_output)\n", "\n", " # output = [sent len, batch size, hid dim]\n", " # hidden = [1, batch size, hid dim]\n", "\n", " # assert torch.equal(output[-1,:,:], hidden.squeeze(0))\n", "\n", " return self.fc(hidden.squeeze(0))\n", "\n", "\n", "vocab_size = len(vocab)\n", "output_dim = 2\n", "\n", "model = RNN(vocab_size, embedding_dim, hidden_dim, output_dim, pad_index)\n", "\n", "\n", "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "\n", "print(f\"The model has {count_parameters(model):,} trainable parameters\")" ] }, { "cell_type": "code", "execution_count": null, "id": "5cdb0b6c-48f9-4b67-8d9e-90823628614f", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#6C00FF", "id": "d4b1ae9a-a2cb-4ff1-a56b-61070b61e7b5", "isComponent": true, "name": "Initialize weights", "parents": [ { "id": "c9662ed2-0038-442e-86b9-42a8df99ae79", "name": "Define model" } ] }, "tags": [] }, "outputs": [], "source": [ "def initialize_weights(m):\n", " if isinstance(m, nn.Linear):\n", " nn.init.xavier_normal_(m.weight)\n", " nn.init.zeros_(m.bias)\n", " elif isinstance(m, nn.LSTM):\n", " for name, param in m.named_parameters():\n", " if \"bias\" in name:\n", " nn.init.zeros_(param)\n", " elif \"weight\" in name:\n", " nn.init.orthogonal_(param)\n", "\n", "\n", "model.apply(initialize_weights)" ] }, { "cell_type": "code", "execution_count": null, "id": "e95f1ea6-3eef-443a-a606-ed31fd58d92f", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#6C00FF", "id": "7ac2a4c4-5258-40b0-b942-9ed914507e23", "isComponent": true, "name": "Define training functions", "parents": [ { "id": "8d792aa6-cac6-4ba6-a4a4-1a5566cfbb81", "name": "Import packages" } ] }, "tags": [] }, "outputs": [], "source": [ "def train(dataloader, model, criterion, optimizer, device):\n", "\n", " model.train()\n", " epoch_losses = []\n", " epoch_accs = []\n", "\n", " for batch in tqdm.tqdm(dataloader, desc=\"training...\", file=sys.stdout):\n", " ids = batch[\"ids\"].to(device)\n", " length = batch[\"length\"]\n", " label = batch[\"label\"].to(device)\n", " prediction = model(ids, length)\n", " loss = criterion(prediction, label)\n", " accuracy = get_accuracy(prediction, label)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " epoch_losses.append(loss.item())\n", " epoch_accs.append(accuracy.item())\n", "\n", " return epoch_losses, epoch_accs\n", "\n", "\n", "def evaluate(dataloader, model, criterion, device):\n", "\n", " model.eval()\n", " epoch_losses = []\n", " epoch_accs = []\n", "\n", " with torch.no_grad():\n", " for batch in tqdm.tqdm(dataloader, desc=\"evaluating...\", file=sys.stdout):\n", " ids = batch[\"ids\"].to(device)\n", " length = batch[\"length\"]\n", " label = batch[\"label\"].to(device)\n", " prediction = model(ids, length)\n", " loss = criterion(prediction, label)\n", " accuracy = get_accuracy(prediction, label)\n", " epoch_losses.append(loss.item())\n", " epoch_accs.append(accuracy.item())\n", "\n", " return epoch_losses, epoch_accs\n", "\n", "\n", "def get_accuracy(prediction, label):\n", " batch_size, _ = prediction.shape\n", " predicted_classes = prediction.argmax(dim=-1)\n", " correct_predictions = predicted_classes.eq(label).sum()\n", " accuracy = correct_predictions / batch_size\n", " return accuracy" ] }, { "cell_type": "code", "execution_count": null, "id": "d888856d-f8d1-4c0f-9817-de4a81f38c48", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#6C00FF", "id": "49b7e440-bc55-4a07-8e56-14ac7a5c3f9a", "isComponent": true, "name": "Train", "parents": [ { "id": "d4b1ae9a-a2cb-4ff1-a56b-61070b61e7b5", "name": "Initialize weights" }, { "id": "7ac2a4c4-5258-40b0-b942-9ed914507e23", "name": "Define training functions" } ] }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:58<00:00, 23.71s/it]\n", "evaluating...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.60it/s]\n", "epoch: 1\n", "train_loss: 0.788, train_acc: 0.521\n", "valid_loss: 0.749, valid_acc: 0.545\n", "training...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:55<00:00, 23.05s/it]\n", "evaluating...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.60it/s]\n", "epoch: 2\n", "train_loss: 0.682, train_acc: 0.592\n", "valid_loss: 0.736, valid_acc: 0.534\n", "training...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:55<00:00, 23.04s/it]\n", "evaluating...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.59it/s]\n", "epoch: 3\n", "train_loss: 0.641, train_acc: 0.637\n", "valid_loss: 0.708, valid_acc: 0.553\n", "training...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:55<00:00, 23.19s/it]\n", "evaluating...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.48it/s]\n", "epoch: 4\n", "train_loss: 0.604, train_acc: 0.671\n", "valid_loss: 0.707, valid_acc: 0.586\n", "training...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:55<00:00, 23.16s/it]\n", "evaluating...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.59it/s]\n", "epoch: 5\n", "train_loss: 0.567, train_acc: 0.706\n", "valid_loss: 0.717, valid_acc: 0.576\n" ] } ], "source": [ "optimizer = optim.Adam(model.parameters(), lr=lr)\n", "criterion = nn.CrossEntropyLoss()\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model = model.to(device)\n", "\n", "best_valid_loss = float(\"inf\")\n", "\n", "train_losses = []\n", "train_accs = []\n", "valid_losses = []\n", "valid_accs = []\n", "\n", "for epoch in range(n_epochs):\n", "\n", " train_loss, train_acc = train(train_dataloader, model, criterion, optimizer, device)\n", " valid_loss, valid_acc = evaluate(valid_dataloader, model, criterion, device)\n", "\n", " train_losses.extend(train_loss)\n", " train_accs.extend(train_acc)\n", " valid_losses.extend(valid_loss)\n", " valid_accs.extend(valid_acc)\n", "\n", " epoch_train_loss = np.mean(train_loss)\n", " epoch_train_acc = np.mean(train_acc)\n", " epoch_valid_loss = np.mean(valid_loss)\n", " epoch_valid_acc = np.mean(valid_acc)\n", "\n", " if epoch_valid_loss < best_valid_loss:\n", " best_valid_loss = epoch_valid_loss\n", " torch.save(model.state_dict(), \"rnn.pt\")\n", "\n", " print(f\"epoch: {epoch+1}\")\n", " print(f\"train_loss: {epoch_train_loss:.3f}, train_acc: {epoch_train_acc:.3f}\")\n", " print(f\"valid_loss: {epoch_valid_loss:.3f}, valid_acc: {epoch_valid_acc:.3f}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "5dd7cecb-ed99-46eb-8c08-dcc8a9e92f0e", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "#6C00FF", "id": "6459e1c5-6660-4663-9d00-d39212d4670a", "isComponent": true, "name": "Plot loss and accuracy", "parents": [ { "id": "49b7e440-bc55-4a07-8e56-14ac7a5c3f9a", "name": "Train" } ] }, "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": ["
"] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": ["
"] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": ["Text(17.200000000000003, 0.5, 'accuracy')"] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fig = plt.figure(figsize=(10, 6))\n", "ax = fig.add_subplot(1, 1, 1)\n", "ax.plot(train_losses, label=\"train loss\")\n", "ax.plot(valid_losses, label=\"valid loss\")\n", "plt.legend()\n", "ax.set_xlabel(\"updates\")\n", "ax.set_ylabel(\"loss\")\n", "\n", "fig = plt.figure(figsize=(10, 6))\n", "ax = fig.add_subplot(1, 1, 1)\n", "ax.plot(train_accs, label=\"train accuracy\")\n", "ax.plot(valid_accs, label=\"valid accuracy\")\n", "plt.legend()\n", "ax.set_xlabel(\"updates\")\n", "ax.set_ylabel(\"accuracy\");" ] }, { "cell_type": "code", "execution_count": null, "id": "7d0e4458-64ea-47c4-b082-2f0fe5ce8d19", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "inherit", "id": "51f1e556-f037-46b1-a60b-477bd3188e2b", "isComponent": true, "name": "Test model", "parents": [ { "id": "49b7e440-bc55-4a07-8e56-14ac7a5c3f9a", "name": "Train" } ] }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "evaluating...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00, 1.17it/s]\n", "test_loss: 0.748, test_acc: 0.532\n" ] } ], "source": [ "# model.load_state_dict(torch.load('rnn.pt'))\n", "\n", "test_loss, test_acc = evaluate(test_dataloader, model, criterion, device)\n", "\n", "epoch_test_loss = np.mean(test_loss)\n", "epoch_test_acc = np.mean(test_acc)\n", "\n", "print(f\"test_loss: {epoch_test_loss:.3f}, test_acc: {epoch_test_acc:.3f}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "57f8be70-5aaa-46a0-979b-e25ad9a985c7", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "inherit", "id": "8e5a5c08-e0f6-4581-874f-6c60fa12cc8a", "isComponent": true, "name": "Define predict sentiment func", "parents": [ { "id": "8d792aa6-cac6-4ba6-a4a4-1a5566cfbb81", "name": "Import packages" } ] }, "tags": [] }, "outputs": [], "source": [ "def predict_sentiment(text, model, tokenizer, vocab, device):\n", " tokens = tokenizer(text)\n", " ids = [vocab[t] for t in tokens]\n", " length = torch.LongTensor([len(ids)])\n", " tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n", " prediction = model(tensor, length).squeeze(dim=0)\n", " probability = torch.softmax(prediction, dim=-1)\n", " predicted_class = prediction.argmax(dim=-1).item()\n", " predicted_probability = probability[predicted_class].item()\n", " return predicted_class, predicted_probability" ] }, { "cell_type": "code", "execution_count": null, "id": "5daf5a5c-8efb-4c61-a676-1e40c3017542", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "inherit", "id": "30454386-fc46-4851-813b-110a3a23a334", "isComponent": true, "name": "Test example1", "parents": [ { "id": "8e5a5c08-e0f6-4581-874f-6c60fa12cc8a", "name": "Define predict sentiment func" }, { "id": "51f1e556-f037-46b1-a60b-477bd3188e2b", "name": "Test model" } ] }, "tags": [] }, "outputs": [ { "data": { "text/plain": ["(0, 0.523607611656189)"] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text = \"This film is terrible!\"\n", "\n", "predict_sentiment(text, model, tokenizer, vocab, device)" ] }, { "cell_type": "code", "execution_count": null, "id": "e4eeded2-c2c3-42d9-8b3a-50fd2458b64d", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "inherit", "id": "a3978ff9-806c-4387-9552-3d9e572ea1c6", "isComponent": true, "name": "Test example2", "parents": [ { "id": "8e5a5c08-e0f6-4581-874f-6c60fa12cc8a", "name": "Define predict sentiment func" }, { "id": "51f1e556-f037-46b1-a60b-477bd3188e2b", "name": "Test model" } ] }, "tags": [] }, "outputs": [ { "data": { "text/plain": ["(1, 0.7643947005271912)"] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text = \"This film is great!\"\n", "\n", "predict_sentiment(text, model, tokenizer, vocab, device)" ] }, { "cell_type": "code", "execution_count": null, "id": "fbe9b425-7cae-45f3-80ca-0fa31da24c98", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "inherit", "id": "be312074-9a73-42cc-a24d-b5f533a0dc8a", "isComponent": true, "name": "Test example3", "parents": [ { "id": "51f1e556-f037-46b1-a60b-477bd3188e2b", "name": "Test model" }, { "id": "8e5a5c08-e0f6-4581-874f-6c60fa12cc8a", "name": "Define predict sentiment func" } ] }, "tags": [] }, "outputs": [ { "data": { "text/plain": ["(1, 0.8956218957901001)"] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text = \"This film is not terrible, it's great!\"\n", "\n", "predict_sentiment(text, model, tokenizer, vocab, device)" ] }, { "cell_type": "code", "execution_count": null, "id": "2fb6b36b-723b-49d2-bd91-c3b69daf8a98", "metadata": { "canvas": { "comments": [], "componentType": "CodeCell", "copiedOriginId": null, "diskcache": false, "headerColor": "inherit", "id": "631f2528-8252-4339-9224-35ae145047d4", "isComponent": true, "name": "Test example4", "parents": [ { "id": "51f1e556-f037-46b1-a60b-477bd3188e2b", "name": "Test model" }, { "id": "8e5a5c08-e0f6-4581-874f-6c60fa12cc8a", "name": "Define predict sentiment func" } ] }, "tags": [] }, "outputs": [ { "data": { "text/plain": ["(1, 0.6588938236236572)"] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text = \"This film is not great, it's terrible!\"\n", "\n", "predict_sentiment(text, model, tokenizer, vocab, device)" ] } ], "metadata": { "canvas": { "colorPalette": [ "inherit", "inherit", "inherit", "inherit", "inherit", "inherit", "inherit", "inherit", "inherit", "inherit" ], "parameters": [ { "name": "seed", "type": "int", "value": "0" }, { "name": "max_length", "type": "int", "value": "256" }, { "name": "test_size", "type": "float", "value": "0.25" }, { "name": "min_freq", "type": "int", "value": "5" }, { "name": "embedding_dim", "type": "int", "value": "300" }, { "name": "hidden_dim", "type": "int", "value": "256" }, { "name": "lr", "type": "float", "value": "5e-4" }, { "name": "batch_size", "type": "int", "value": "512" }, { "name": "n_epochs", "type": "int", "value": "5" } ], "version": "1.0" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" } }, "nbformat": 4, "nbformat_minor": 5 }