{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Text classification with transformer model\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_transformer_text_classification.ipynb)\n", "\n", "In this short tutorial, we will perform sentiment analysis on movie reviews. We would like\n", "to know if a review is overall positive or negative, so we're facing a classification task.\n", "\n", "For that purpose, we're going to build a model comprised of a transformer and dense layers.\n", "What is special about transformers is that they utilize an attention mechanism that allows\n", "the model to focus on the most relevant parts of the input sequence when processing information.\n", "\n", "We will follow the standard ML model-building process - split data into training and test sets,\n", "construct the model, run training, and report metrics." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Required packages\n", "# !pip install -U jax flax optax\n", "# !pip install -U grain tqdm requests matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tools overview\n", "\n", "Here's a list of key packages we will use in this example that belong to JAX AI stack:\n", "\n", "- [JAX](https://github.com/jax-ml/jax) will be used for array computations.\n", "- [Flax](https://github.com/google/flax) for constructing neural networks.\n", "- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization.\n", "- [grain](https://github.com/google/grain/) will be be used to define data sources.\n", "- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import io\n", "import json\n", "import textwrap\n", "import typing\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from flax import nnx\n", "import optax\n", "import grain.python as grain\n", "import tqdm\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import requests" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We're going to use [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).\n", "Each review is encoded as a list of word indexes, where words are indexed by\n", "the frequency. The label is a positive or negative sentiment (0 or 1)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def prepare_imdb_dataset(num_words: int, index_from: int, oov_char: int = 2) -> tuple:\n", " response = requests.get(\"https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz\")\n", " response.raise_for_status()\n", " with np.load(io.BytesIO(response.content), allow_pickle=True) as f:\n", " x_train, y_train = f[\"x_train\"], f[\"y_train\"]\n", " x_test, y_test = f[\"x_test\"], f[\"y_test\"]\n", "\n", " rng = np.random.RandomState(113)\n", " indices = np.arange(len(x_train))\n", " rng.shuffle(indices)\n", " x_train = x_train[indices]\n", " y_train = y_train[indices]\n", "\n", " indices = np.arange(len(x_test))\n", " rng.shuffle(indices)\n", " x_test = x_test[indices]\n", " y_test = y_test[indices]\n", "\n", " x_train = [[w + index_from for w in x] for x in x_train]\n", " x_test = [[w + index_from for w in x] for x in x_test]\n", "\n", " xs = x_train + x_test\n", " labels = np.concatenate([y_train, y_test])\n", " xs = [\n", " [w if w < num_words else oov_char for w in x] for x in xs\n", " ]\n", "\n", " idx = len(x_train)\n", " x_train, y_train = np.array(xs[:idx], dtype=\"object\"), labels[:idx]\n", " x_test, y_test = np.array(xs[idx:], dtype=\"object\"), labels[idx:]\n", "\n", " return (x_train, y_train), (x_test, y_test)\n", "\n", "\n", "def pad_sequences(arrs: typing.Iterable, max_len: int) -> np.ndarray:\n", " # Ensure that each sample is the same length\n", " result = []\n", " for arr in arrs:\n", " arr_len = len(arr)\n", " if arr_len < max_len:\n", " padded_arr = np.pad(arr, (max_len - arr_len, 0), 'constant', constant_values=0)\n", " else:\n", " padded_arr = np.array(arr[arr_len - max_len:])\n", " result.append(padded_arr)\n", "\n", " return np.asarray(result)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "25000 Training sequences\n", "25000 Validation sequences\n" ] } ], "source": [ "index_from = 3 # make sure that 0 encodes pad token\n", "vocab_size = 20000 # Only consider the top 20k words\n", "maxlen = 200 # Only consider the first 200 words of each movie review\n", "(x_train, y_train), (x_test, y_test) = prepare_imdb_dataset(num_words=vocab_size, index_from=index_from)\n", "print(len(x_train), \"Training sequences\")\n", "print(len(x_test), \"Validation sequences\")\n", "x_train = pad_sequences(x_train, max_len=maxlen)\n", "x_test = pad_sequences(x_test, max_len=maxlen)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For handling input data we're going to use Grain, a pure Python package developed for JAX and\n", "Flax models. Grain supports custom setups where data sources might come in different forms, but\n", "they all need to implement the `grain.RandomAccessDataSource` interface. See\n", "[PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/source/data_sources.md)\n", "for more details.\n", "\n", "Our dataset is comprised of relatively small NumPy arrays so our `DataSource` is uncomplicated:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class DataSource(grain.RandomAccessDataSource):\n", " def __init__(self, x, y):\n", " self._x = x\n", " self._y = y\n", "\n", " def __getitem__(self, idx):\n", " return {\"encoded_indices\": self._x[idx], \"label\": self._y[idx]}\n", "\n", " def __len__(self):\n", " return len(self._x)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "train_source = DataSource(x_train, y_train)\n", "test_source = DataSource(x_test, y_test)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "seed = 12\n", "train_batch_size = 128\n", "test_batch_size = 2 * train_batch_size\n", "\n", "train_sampler = grain.IndexSampler(\n", " len(train_source),\n", " shuffle=True,\n", " seed=seed,\n", " shard_options=grain.NoSharding(), # No sharding since this is a single-device setup\n", " num_epochs=1, # Iterate over the dataset for one epoch\n", ")\n", "\n", "test_sampler = grain.IndexSampler(\n", " len(test_source),\n", " shuffle=False,\n", " seed=seed,\n", " shard_options=grain.NoSharding(), # No sharding since this is a single-device setup\n", " num_epochs=1, # Iterate over the dataset for one epoch\n", ")\n", "\n", "\n", "train_loader = grain.DataLoader(\n", " data_source=train_source,\n", " sampler=train_sampler, # Sampler to determine how to access the data\n", " worker_count=4, # Number of child processes launched to parallelize the transformations among\n", " worker_buffer_size=2, # Count of output batches to produce in advance per worker\n", " operations=[\n", " grain.Batch(train_batch_size, drop_remainder=True),\n", " ]\n", ")\n", "\n", "test_loader = grain.DataLoader(\n", " data_source=test_source,\n", " sampler=test_sampler, # Sampler to determine how to access the data\n", " worker_count=4, # Number of child processes launched to parallelize the transformations among\n", " worker_buffer_size=2, # Count of output batches to produce in advance per worker\n", " operations=[\n", " grain.Batch(test_batch_size),\n", " ]\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model\n", "\n", "Here we construct the model with the transformer and dense layers:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class TransformerBlock(nnx.Module):\n", " def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, rngs: nnx.Rngs, rate: float = 0.1):\n", " self.attention = nnx.MultiHeadAttention(\n", " num_heads=num_heads, in_features=embed_dim, qkv_features=embed_dim, decode=False, rngs=rngs\n", " )\n", "\n", " self.dense_1 = nnx.Linear(in_features=embed_dim, out_features=ff_dim, rngs=rngs)\n", " self.dense_2 = nnx.Linear(in_features=ff_dim, out_features=ff_dim, rngs=rngs)\n", "\n", " self.layer_norm_1 = nnx.LayerNorm(num_features=embed_dim, epsilon=1e-6, rngs=rngs)\n", " self.layer_norm_2 = nnx.LayerNorm(num_features=ff_dim, epsilon=1e-6, rngs=rngs)\n", "\n", " self.dropout_1 = nnx.Dropout(rate, rngs=rngs)\n", " self.dropout_2 = nnx.Dropout(rate, rngs=rngs)\n", "\n", " def __call__(self, inputs: jax.Array):\n", " x = self.attention(inputs, inputs)\n", " x = self.dropout_1(x)\n", " x_norm_1 = self.layer_norm_1(inputs + x)\n", " x = self.dense_1(x_norm_1)\n", " x = jax.nn.relu(x)\n", " x = self.dense_2(x)\n", " x = self.dropout_2(x)\n", " x = self.layer_norm_2(x_norm_1 + x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class TokenAndPositionEmbedding(nnx.Module):\n", " def __init__(self, max_length: int, vocab_size: int, embed_dim: int, rngs: nnx.Rngs):\n", " self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)\n", " self.pos_emb = nnx.Embed(num_embeddings=max_length, features=embed_dim, rngs=rngs)\n", "\n", " def __call__(self, x: jax.Array):\n", " maxlen = jnp.shape(x)[-1]\n", " positions = jnp.arange(start=0, stop=maxlen, step=1)\n", " positions = self.pos_emb(positions)\n", " x = self.token_emb(x)\n", " return x + positions" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "embed_dim = 32 # Embedding size for each token\n", "num_heads = 2 # Number of attention heads\n", "ff_dim = 32 # Hidden layer size in the feed forward network inside transformer\n", "\n", "class MyModel(nnx.Module):\n", " def __init__(self, rngs: nnx.Rngs):\n", " self.embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim, rngs=rngs)\n", " self.transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim, rngs=rngs)\n", " self.dropout1 = nnx.Dropout(0.1, rngs=rngs)\n", " self.dense1 = nnx.Linear(in_features=embed_dim, out_features=20, rngs=rngs)\n", " self.dropout2 = nnx.Dropout(0.1, rngs=rngs)\n", " self.dense2 = nnx.Linear(in_features=20, out_features=2, rngs=rngs)\n", "\n", " def __call__(self, x: jax.Array):\n", " x = self.embedding_layer(x)\n", " x = self.transformer_block(x)\n", " x = jnp.mean(x, axis=(1,)) # global average pooling\n", " x = self.dropout1(x)\n", " x = self.dense1(x)\n", " x = jax.nn.relu(x)\n", " x = self.dropout2(x)\n", " x = self.dense2(x)\n", " x = jax.nn.softmax(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MyModel(\n", " embedding_layer=TokenAndPositionEmbedding(\n", " token_emb=Embed(\n", " embedding=Param(\n", " value=Array(shape=(20000, 32), dtype=float32)\n", " ),\n", " num_embeddings=20000,\n", " features=32,\n", " dtype=dtype('float32'),\n", " param_dtype=,\n", " embedding_init=.init at 0x7f0cab141e10>\n", " ),\n", " pos_emb=Embed(\n", " embedding=Param(\n", " value=Array(shape=(200, 32), dtype=float32)\n", " ),\n", " num_embeddings=200,\n", " features=32,\n", " dtype=dtype('float32'),\n", " param_dtype=,\n", " embedding_init=.init at 0x7f0cab141e10>\n", " )\n", " ),\n", " transformer_block=TransformerBlock(\n", " attention=MultiHeadAttention(\n", " num_heads=2,\n", " in_features=32,\n", " qkv_features=32,\n", " out_features=32,\n", " dtype=None,\n", " param_dtype=,\n", " broadcast_dropout=True,\n", " dropout_rate=0.0,\n", " deterministic=None,\n", " precision=None,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " out_kernel_init=None,\n", " bias_init=,\n", " out_bias_init=None,\n", " use_bias=True,\n", " attention_fn=,\n", " decode=False,\n", " normalize_qk=False,\n", " qkv_dot_general=None,\n", " out_dot_general=None,\n", " qkv_dot_general_cls=None,\n", " out_dot_general_cls=None,\n", " head_dim=16,\n", " query=LinearGeneral(\n", " in_features=(32,),\n", " out_features=(2, 16),\n", " axis=(-1,),\n", " batch_axis=FrozenDict({}),\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " precision=None,\n", " dot_general=None,\n", " dot_general_cls=None,\n", " kernel=Param(\n", " value=Array(shape=(32, 2, 16), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(2, 16), dtype=float32)\n", " )\n", " ),\n", " key=LinearGeneral(\n", " in_features=(32,),\n", " out_features=(2, 16),\n", " axis=(-1,),\n", " batch_axis=FrozenDict({}),\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " precision=None,\n", " dot_general=None,\n", " dot_general_cls=None,\n", " kernel=Param(\n", " value=Array(shape=(32, 2, 16), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(2, 16), dtype=float32)\n", " )\n", " ),\n", " value=LinearGeneral(\n", " in_features=(32,),\n", " out_features=(2, 16),\n", " axis=(-1,),\n", " batch_axis=FrozenDict({}),\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " precision=None,\n", " dot_general=None,\n", " dot_general_cls=None,\n", " kernel=Param(\n", " value=Array(shape=(32, 2, 16), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(2, 16), dtype=float32)\n", " )\n", " ),\n", " query_ln=None,\n", " key_ln=None,\n", " out=LinearGeneral(\n", " in_features=(2, 16),\n", " out_features=(32,),\n", " axis=(-2, -1),\n", " batch_axis=FrozenDict({}),\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " precision=None,\n", " dot_general=None,\n", " dot_general_cls=None,\n", " kernel=Param(\n", " value=Array(shape=(2, 16, 32), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " )\n", " ),\n", " rngs=None,\n", " cached_key=None,\n", " cached_value=None,\n", " cache_index=None\n", " ),\n", " dense_1=Linear(\n", " kernel=Param(\n", " value=Array(shape=(32, 32), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " ),\n", " in_features=32,\n", " out_features=32,\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " dot_general=\n", " ),\n", " dense_2=Linear(\n", " kernel=Param(\n", " value=Array(shape=(32, 32), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " ),\n", " in_features=32,\n", " out_features=32,\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " dot_general=\n", " ),\n", " layer_norm_1=LayerNorm(\n", " scale=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " ),\n", " num_features=32,\n", " epsilon=1e-06,\n", " dtype=None,\n", " param_dtype=,\n", " use_bias=True,\n", " use_scale=True,\n", " bias_init=,\n", " scale_init=,\n", " reduction_axes=-1,\n", " feature_axes=-1,\n", " axis_name=None,\n", " axis_index_groups=None,\n", " use_fast_variance=True\n", " ),\n", " layer_norm_2=LayerNorm(\n", " scale=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " ),\n", " num_features=32,\n", " epsilon=1e-06,\n", " dtype=None,\n", " param_dtype=,\n", " use_bias=True,\n", " use_scale=True,\n", " bias_init=,\n", " scale_init=,\n", " reduction_axes=-1,\n", " feature_axes=-1,\n", " axis_name=None,\n", " axis_index_groups=None,\n", " use_fast_variance=True\n", " ),\n", " dropout_1=Dropout(rate=0.1, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(\n", " default=RngStream(\n", " key=RngKey(\n", " value=Array((), dtype=key) overlaying:\n", " [0 0],\n", " tag='default'\n", " ),\n", " count=RngCount(\n", " value=Array(22, dtype=uint32),\n", " tag='default'\n", " )\n", " )\n", " )),\n", " dropout_2=Dropout(rate=0.1, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(...))\n", " ),\n", " dropout1=Dropout(rate=0.1, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(...)),\n", " dense1=Linear(\n", " kernel=Param(\n", " value=Array(shape=(32, 20), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(20,), dtype=float32)\n", " ),\n", " in_features=32,\n", " out_features=20,\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " dot_general=\n", " ),\n", " dropout2=Dropout(rate=0.1, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(...)),\n", " dense2=Linear(\n", " kernel=Param(\n", " value=Array(shape=(20, 2), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(2,), dtype=float32)\n", " ),\n", " in_features=20,\n", " out_features=2,\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " dot_general=\n", " )\n", ")\n" ] } ], "source": [ "model = MyModel(rngs=nnx.Rngs(0))\n", "nnx.display(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training\n", "\n", "To train our model we construct an `nnx.Optimizer` object with our model and a selected\n", "optimization algorithm. We're going to use Adam optimizer, which is a popular choice for\n", "Deep Learning models. Adam automatically adjusts the learning rate for each parameter and\n", "uses momentum hyperparameter to accelerate convergence." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "num_epochs = 10\n", "learning_rate = 0.0001\n", "momentum = 0.9\n", "\n", "optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def compute_losses_and_logits(model: nnx.Module, batch_tokens: jax.Array, labels: jax.Array):\n", " logits = model(batch_tokens)\n", "\n", " loss = optax.softmax_cross_entropy_with_integer_labels(\n", " logits=logits, labels=labels\n", " ).mean()\n", " return loss, logits" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "@nnx.jit\n", "def train_step(\n", " model: nnx.Module, optimizer: nnx.Optimizer, batch: dict[str, jax.Array]\n", "):\n", " batch_tokens = jnp.array(batch[\"encoded_indices\"])\n", " labels = jnp.array(batch[\"label\"], dtype=jnp.int32)\n", "\n", " grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)\n", " (loss, logits), grads = grad_fn(model, batch_tokens, labels)\n", "\n", " optimizer.update(grads) # In-place updates.\n", "\n", " return loss\n", "\n", "@nnx.jit\n", "def eval_step(\n", " model: nnx.Module, batch: dict[str, jax.Array], eval_metrics: nnx.MultiMetric\n", "):\n", " batch_tokens = jnp.array(batch[\"encoded_indices\"])\n", " labels = jnp.array(batch[\"label\"], dtype=jnp.int32)\n", " loss, logits = compute_losses_and_logits(model, batch_tokens, labels)\n", "\n", " eval_metrics.update(\n", " loss=loss,\n", " logits=logits,\n", " labels=labels,\n", " )" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "eval_metrics = nnx.MultiMetric(\n", " loss=nnx.metrics.Average('loss'),\n", " accuracy=nnx.metrics.Accuracy(),\n", ")\n", "\n", "train_metrics_history = {\n", " \"train_loss\": [],\n", "}\n", "\n", "eval_metrics_history = {\n", " \"test_loss\": [],\n", " \"test_accuracy\": [],\n", "}" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "bar_format = \"{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]\"\n", "train_total_steps = len(x_train) // train_batch_size\n", "\n", "def train_one_epoch(epoch):\n", " model.train()\n", " with tqdm.tqdm(\n", " desc=f\"[train] epoch: {epoch}/{num_epochs}, \",\n", " total=train_total_steps,\n", " bar_format=bar_format,\n", " leave=True,\n", " ) as pbar:\n", " for batch in train_loader:\n", " loss = train_step(model, optimizer, batch)\n", " train_metrics_history[\"train_loss\"].append(loss.item())\n", " pbar.set_postfix({\"loss\": loss.item()})\n", " pbar.update(1)\n", "\n", "\n", "def evaluate_model(epoch):\n", " # Compute the metrics on the train and val sets after each training epoch.\n", " model.eval()\n", "\n", " eval_metrics.reset() # Reset the eval metrics\n", " for test_batch in test_loader:\n", " eval_step(model, test_batch, eval_metrics)\n", "\n", " for metric, value in eval_metrics.compute().items():\n", " eval_metrics_history[f'test_{metric}'].append(value)\n", "\n", " print(f\"[test] epoch: {epoch + 1}/{num_epochs}\")\n", " print(f\"- total loss: {eval_metrics_history['test_loss'][-1]:0.4f}\")\n", " print(f\"- Accuracy: {eval_metrics_history['test_accuracy'][-1]:0.4f}\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 0/10, [192/195], loss=0.693 [00:36<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 1/10\n", "- total loss: 0.6879\n", "- Accuracy: 0.5661\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 1/10, [192/195], loss=0.678 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 2/10\n", "- total loss: 0.6734\n", "- Accuracy: 0.6507\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 2/10, [192/195], loss=0.594 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 3/10\n", "- total loss: 0.6177\n", "- Accuracy: 0.7316\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 3/10, [192/195], loss=0.511 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 4/10\n", "- total loss: 0.5404\n", "- Accuracy: 0.7890\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 4/10, [192/195], loss=0.466 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 5/10\n", "- total loss: 0.4995\n", "- Accuracy: 0.8159\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 5/10, [192/195], loss=0.463 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 6/10\n", "- total loss: 0.4806\n", "- Accuracy: 0.8280\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 6/10, [192/195], loss=0.435 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 7/10\n", "- total loss: 0.4714\n", "- Accuracy: 0.8349\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 7/10, [192/195], loss=0.421 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 8/10\n", "- total loss: 0.4665\n", "- Accuracy: 0.8394\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 8/10, [192/195], loss=0.409 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 9/10\n", "- total loss: 0.4602\n", "- Accuracy: 0.8453\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 9/10, [192/195], loss=0.389 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 10/10\n", "- total loss: 0.4560\n", "- Accuracy: 0.8486\n", "CPU times: user 27min 44s, sys: 12min 14s, total: 39min 59s\n", "Wall time: 8min 19s\n" ] } ], "source": [ "%%time\n", "for epoch in range(num_epochs):\n", " train_one_epoch(epoch)\n", " evaluate_model(epoch)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(train_metrics_history[\"train_loss\"], label=\"Loss value during the training\")\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n", "axs[0].set_title(\"Loss value on test set\")\n", "axs[0].plot(eval_metrics_history[\"test_loss\"])\n", "axs[1].set_title(\"Accuracy on test set\")\n", "axs[1].plot(eval_metrics_history[\"test_accuracy\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, we acquired over 85% accuracy on a sentiment classification problem\n", "after 10 epochs. From the loss plot, we see that in the last few epochs the loss\n", "value didn't improve much which is an indication that our model has converged.\n", "\n", "JAX AI stack allowed us to build, train, and evaluate the model with JAX ecosystem\n", "libraries that interoperate smoothly in a multitude of computing environments.\n", "\n", "Let's now inspect a few predictions ourselves. First we're going to download\n", "\"indices to words\" dictionary to decode our samples. Then we will decode a few samples\n", "and print them together with a prediced and an actual label." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "response = requests.get(\"https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb_word_index.json\")\n", "response.raise_for_status()\n", "word_map = {v: k for k, v in json.loads(response.content).items()}" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def _label_to_str(label: int) -> str:\n", " return \"positive\" if label else \"negative\"\n", "\n", "def show_reviews(indices: list[int]) -> None:\n", " for idx in indices:\n", " x = x_test[idx][x_test[idx] != 0]\n", " y = y_test[idx]\n", " y_pred = model(x_test[idx][None, :]).argmax()\n", " review = \"\"\n", " for w_x in x:\n", " idx = w_x - index_from if w_x >= index_from else w_x\n", " review += f\"{word_map[idx]} \"\n", "\n", " print(\"Review:\")\n", " for line in textwrap.wrap(review):\n", " print(line)\n", " print(\"Predicted sentiment: \", _label_to_str(y_pred))\n", " print(\"Actual sentiment: \", _label_to_str(y), \"\\n\")" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Review:\n", "please give this one a miss br br kristy swanson and the rest of the\n", "cast rendered terrible performances the show is flat flat flat br br i\n", "don't know how michael madison could have allowed this one on his\n", "plate he almost seemed to know this wasn't going to work out and his\n", "performance was quite lacklustre so all you madison fans give this a\n", "miss\n", "Predicted sentiment: negative\n", "Actual sentiment: negative \n", "\n", "Review:\n", "this is a funny movie the bob eddie show feel of it could lead to a\n", "sequel but i doubt it will make enough money br br deniro proves he\n", "can be a great straight man again with some hilarious and spontaneous\n", "moments eddie was fun to watch working with people instead of cgi\n", "animals and and rene russo well she's just fun to watch anyway and\n", "she's played her part excellent br br some wild and unusual stunts\n", "especially the garbage truck scene this was worth seeing in the\n", "theater we needed a good laugh and got many from the movie and the\n", "great out takes at the end do not leave at the start of the credits br\n", "br at least a 7\n", "Predicted sentiment: positive\n", "Actual sentiment: positive \n", "\n", "Review:\n", "terrible absolutely terrible long confusing and and after about three\n", "hours of this painful mess the ending truly is the final nail in the\n", "coffin not even the magnificent sexy beautiful goddess and and can\n", "save this poor adaptation of agatha and work the plot drags and drags\n", "and time goes by slowly and suddenly you realize that you don't even\n", "have any idea of what's going on anymore by the end even with the\n", "usual explanation by the villain there's still a lot that's left\n", "unexplained and and it's over a complete waste of time and without a\n", "doubt one of the worst to bear the name of agatha christie\n", "Predicted sentiment: negative\n", "Actual sentiment: negative \n", "\n", "Review:\n", "whoa this is one of the worst movies i have ever seen the packaging\n", "for the film is better than the film itself my girlfriend and i\n", "watched it this past weekend and we only continued to watch it in the\n", "hopes that it would get better it didn't br br the picture quality is\n", "poor it looks like it was shot on video and transferred to film the\n", "lighting is not great which makes it harder to read the actors' facial\n", "expressions the acting itself was cheesy but i guess it's acceptable\n", "for yet another teenage horror flick the sound was a huge problem\n", "sometimes you have to rewind the video because the sound is unclear\n", "and or and br br it holds no real merit of it's own trying to ride on\n", "the and of sleepy hollow don't bother with this one\n", "Predicted sentiment: negative\n", "Actual sentiment: negative \n", "\n", "Review:\n", "i am a big gone with the wind nut but i was disappointed that both\n", "gone with the wind the movie and scarlett the mini series are so\n", "different from the books gone with the wind left so many things out in\n", "the movie that were in the book and they did the same with scarlett\n", "both were good movies but i really liked both books better there were\n", "so many characters left out of scarlett and the ages of some\n", "characters didn't seem to match up with the book the time lines don't\n", "match up either scarlett realizes she is pregnant on the ship to\n", "ireland in the book but she realizes it when she is throwing up while\n", "in also sally is made out to be an ugly monkey like woman in the book\n", "and the movie casted jean smart to play her who is obviously not an\n", "ugly woman over all scarlett is a good movie and it helps anyone who\n", "was disappointed in the way gone with the wind ended to see what might\n", "have happened if margaret mitchell had lived to write a sequel herself\n", "Predicted sentiment: negative\n", "Actual sentiment: positive \n", "\n", "Review:\n", "if you liked the richard chamberlain version of the bourne identity\n", "then you will like this too aiden quinn does this one brilliantly you\n", "can't help but wonder if he is really out there i reckon he and the\n", "other main cast members probably had nightmares for weeks after doing\n", "this movie as it's so intense when i first saw it i was just and\n", "channels on the remote late one evening i got hooked within minutes\n", "look up www answers com for and and who is the character that carlos\n", "the is based on for both i remember reading about and arrest in the\n", "paper in 1997 it was front page for weeks through the trial after his\n", "arrest\n", "Predicted sentiment: positive\n", "Actual sentiment: positive \n", "\n" ] } ], "source": [ "show_reviews([0, 500, 600, 1000, 1800, 2000])" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "jax-env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 }