{ "cells": [ { "cell_type": "markdown", "id": "ee3e1116-f6cd-497e-b617-1d89d5d1f744", "metadata": {}, "source": [ "# Machine Translation with encoder-decoder transformer model\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_machine_translation.ipynb)" ] }, { "cell_type": "markdown", "id": "50f0bd58-dcc6-41f4-9dc4-3a08c8ef751b", "metadata": {}, "source": [ "This tutorial is adapted from [Keras' documentation on English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/), which is itself an adaptation from the book [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition)\n", "\n", "We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation." ] }, { "cell_type": "code", "execution_count": 1, "id": "dd506ffa-3b91-44f1-92d1-a08ed933e78e", "metadata": {}, "outputs": [], "source": [ "import pathlib\n", "import random\n", "import string\n", "import re\n", "import numpy as np\n", "\n", "import jax.numpy as jnp\n", "import optax\n", "\n", "from flax import nnx\n", "\n", "import tiktoken\n", "import grain.python as grain\n", "import tqdm" ] }, { "cell_type": "markdown", "id": "e1f324b0-140a-48fa-9fcb-d6308f098343", "metadata": {}, "source": [ "## Pull down data to temp and extract into memory\n", "\n", "There are lots of ways to get this done, but for simplicity and clear visibility into what's happening this is downloaded to a temporary directory, extracted there, and read into a python object with processing." ] }, { "cell_type": "code", "execution_count": 2, "id": "102943a5-8724-48e0-8d6a-f56069f03426", "metadata": {}, "outputs": [], "source": [ "import requests\n", "import zipfile\n", "import tempfile\n", "\n", "url = \"http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip\"\n", "\n", "with tempfile.TemporaryDirectory() as temp_dir:\n", " temp_path = pathlib.Path(temp_dir)\n", " zip_file_path = temp_path / \"spa-eng.zip\"\n", "\n", " response = requests.get(url)\n", " zip_file_path.write_bytes(response.content)\n", "\n", " with zipfile.ZipFile(zip_file_path, \"r\") as zip_ref:\n", " zip_ref.extractall(temp_path)\n", "\n", " text_file = temp_path / \"spa-eng\" / \"spa.txt\"\n", "\n", " with open(text_file) as f:\n", " lines = f.read().split(\"\\n\")[:-1]\n", " text_pairs = []\n", " for line in lines:\n", " eng, spa = line.split(\"\\t\")\n", " spa = \"[start] \" + spa + \" [end]\"\n", " text_pairs.append((eng, spa))" ] }, { "cell_type": "markdown", "id": "9524904b-fa17-493f-bcfa-335963cb7c45", "metadata": {}, "source": [ "## Build train/validate/test pair sets\n", "We'll stay close to the original tutorial so it's clear how to follow what's the same vs what's different; one early difference is the choice to go with an off-the-shelf encoder/tokenizer in tiktoken. Specifically \"cl100k_base\" - it has a wide range of language understanding and it's fast." ] }, { "cell_type": "code", "execution_count": 3, "id": "bee9f1b0-5f74-47dc-a7e1-a4ea3be1ef7f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "118964 total pairs\n", "83276 training pairs\n", "17844 validation pairs\n", "17844 test pairs\n" ] } ], "source": [ "random.shuffle(text_pairs)\n", "num_val_samples = int(0.15 * len(text_pairs))\n", "num_train_samples = len(text_pairs) - 2 * num_val_samples\n", "train_pairs = text_pairs[:num_train_samples]\n", "val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]\n", "test_pairs = text_pairs[num_train_samples + num_val_samples :]\n", "\n", "print(f\"{len(text_pairs)} total pairs\")\n", "print(f\"{len(train_pairs)} training pairs\")\n", "print(f\"{len(val_pairs)} validation pairs\")\n", "print(f\"{len(test_pairs)} test pairs\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "fd1f61fe-e4b7-479d-917e-f609ebe482e9", "metadata": {}, "outputs": [], "source": [ "tokenizer = tiktoken.get_encoding(\"cl100k_base\")" ] }, { "cell_type": "markdown", "id": "a714c4ea-9ff6-4dab-ae9c-1a884d4857e7", "metadata": {}, "source": [ "We strip out punctuation to keep things simple and in line with the original tutorial - the `[` `]` are kept in so that our `[start]` and `[end]` formatting is preserved." ] }, { "cell_type": "code", "execution_count": 5, "id": "07e054d3-a20c-4aed-8f8a-fb5158df8e5b", "metadata": {}, "outputs": [], "source": [ "strip_chars = string.punctuation + \"¿\"\n", "strip_chars = strip_chars.replace(\"[\", \"\")\n", "strip_chars = strip_chars.replace(\"]\", \"\")\n", "\n", "vocab_size = tokenizer.n_vocab\n", "sequence_length = 20" ] }, { "cell_type": "code", "execution_count": 6, "id": "e2b3e5b3-8466-4c81-99da-0559c88b25ef", "metadata": {}, "outputs": [], "source": [ "def custom_standardization(input_string):\n", " lowercase = input_string.lower()\n", " return re.sub(f\"[{re.escape(strip_chars)}]\", \"\", lowercase)" ] }, { "cell_type": "code", "execution_count": 7, "id": "5bdc0673-9723-45b5-8a42-2152295df69b", "metadata": {}, "outputs": [], "source": [ "def tokenize_and_pad(text, tokenizer, max_length):\n", " tokens = tokenizer.encode(text)[:max_length]\n", " padded = tokens + [0] * (max_length - len(tokens)) if len(tokens) < max_length else tokens ##assumes list-like - (https://github.com/openai/tiktoken/blob/main/tiktoken/core.py#L81 current tiktoken out)\n", " return padded" ] }, { "cell_type": "code", "execution_count": 8, "id": "235b1221-e72d-4793-addd-7bb870bd8e75", "metadata": {}, "outputs": [], "source": [ "def format_dataset(eng, spa, tokenizer, sequence_length):\n", " eng = custom_standardization(eng)\n", " spa = custom_standardization(spa)\n", " eng = tokenize_and_pad(eng, tokenizer, sequence_length)\n", " spa = tokenize_and_pad(spa, tokenizer, sequence_length)\n", " return {\n", " \"encoder_inputs\": eng,\n", " \"decoder_inputs\": spa[:-1],\n", " \"target_output\": spa[1:],\n", " }" ] }, { "cell_type": "code", "execution_count": 9, "id": "ca013d07-1504-42cc-906f-2fcacc757008", "metadata": {}, "outputs": [], "source": [ "train_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in train_pairs]\n", "val_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in val_pairs]\n", "test_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in test_pairs]" ] }, { "cell_type": "markdown", "id": "90bbae98-48dd-4ae4-99bb-92336d7c0a1c", "metadata": {}, "source": [ "At this point we've extracted the data, applied formatting, and tokenized the phrases with padding. The data is kept in train/validate/test sets that each have dictionary entries, which look like the following:" ] }, { "cell_type": "code", "execution_count": 10, "id": "dcbfa780-553f-41f6-8b3e-55955db78b2a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'encoder_inputs': [72, 1390, 311, 617, 264, 3137, 449, 1461, 922, 856, 3938, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_inputs': [29563, 60, 92820, 7669, 277, 390, 33013, 1645, 78993, 409, 9686, 65744, 510, 408, 60, 0, 0, 0, 0], 'target_output': [60, 92820, 7669, 277, 390, 33013, 1645, 78993, 409, 9686, 65744, 510, 408, 60, 0, 0, 0, 0, 0]}\n" ] } ], "source": [ "## data selection example\n", "print(train_data[135])" ] }, { "cell_type": "markdown", "id": "24c6271b-e359-4aba-a583-f18c40eddba9", "metadata": {}, "source": [ "The output should look something like\n", "\n", "{'encoder_inputs': [9514, 265, 3339, 264, 2466, 16930, 1618, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_inputs': [29563, 60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0], 'target_output': [60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0, 0]}" ] }, { "cell_type": "markdown", "id": "7a906a05-bd17-4a47-afe0-4422d2ea0f50", "metadata": {}, "source": [ "## Define Transformer components: Encoder, Decoder, Positional Embed\n", "\n", "In many ways this is very similar to the original source, with `ops` changing to `jnp` and `keras` or `layers` becoming `nnx`. Certain module-specific arguments come and go, like the rngs attached to most things in the updated version, and decode=False in the MultiHeadAttention call." ] }, { "cell_type": "code", "execution_count": 11, "id": "121bf138-34b3-4be9-a0fc-3bbac81f469a", "metadata": {}, "outputs": [], "source": [ "class TransformerEncoder(nnx.Module):\n", " def __init__(self, embed_dim: int, dense_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs):\n", " self.embed_dim = embed_dim\n", " self.dense_dim = dense_dim\n", " self.num_heads = num_heads\n", "\n", " self.attention = nnx.MultiHeadAttention(num_heads=num_heads,\n", " in_features=embed_dim,\n", " decode=False,\n", " rngs=rngs)\n", " self.dense_proj = nnx.Sequential(\n", " nnx.Linear(embed_dim, dense_dim, rngs=rngs),\n", " nnx.relu,\n", " nnx.Linear(dense_dim, embed_dim, rngs=rngs),\n", " )\n", "\n", " self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)\n", " self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)\n", "\n", " def __call__(self, inputs, mask=None):\n", " if mask is not None:\n", " padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)\n", " else:\n", " padding_mask = None\n", "\n", " attention_output = self.attention(\n", " inputs_q = inputs, inputs_k = inputs, inputs_v = inputs, mask=padding_mask, decode = False\n", " )\n", " proj_input = self.layernorm_1(inputs + attention_output)\n", " proj_output = self.dense_proj(proj_input)\n", " return self.layernorm_2(proj_input + proj_output)\n", "\n", "\n", "class PositionalEmbedding(nnx.Module):\n", " def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int, rngs: nnx.Rngs, **kwargs):\n", " self.token_embeddings = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)\n", " self.position_embeddings = nnx.Embed(num_embeddings=sequence_length, features=embed_dim, rngs=rngs)\n", " self.sequence_length = sequence_length\n", " self.vocab_size = vocab_size\n", " self.embed_dim = embed_dim\n", "\n", " def __call__(self, inputs):\n", " length = inputs.shape[1]\n", " positions = jnp.arange(0, length)[None, :]\n", " embedded_tokens = self.token_embeddings(inputs)\n", " embedded_positions = self.position_embeddings(positions)\n", " return embedded_tokens + embedded_positions\n", "\n", " def compute_mask(self, inputs, mask=None):\n", " if mask is None:\n", " return None\n", " else:\n", " return jnp.not_equal(inputs, 0)\n", "\n", "class TransformerDecoder(nnx.Module):\n", " def __init__(self, embed_dim: int, latent_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs):\n", " self.embed_dim = embed_dim\n", " self.latent_dim = latent_dim\n", " self.num_heads = num_heads\n", " self.attention_1 = nnx.MultiHeadAttention(num_heads=num_heads,\n", " in_features=embed_dim,\n", " decode=False,\n", " rngs=rngs)\n", " self.attention_2 = nnx.MultiHeadAttention(num_heads=num_heads,\n", " in_features=embed_dim,\n", " decode=False,\n", " rngs=rngs)\n", "\n", " self.dense_proj = nnx.Sequential(\n", " nnx.Linear(embed_dim, latent_dim, rngs=rngs),\n", " nnx.relu,\n", " nnx.Linear(latent_dim, embed_dim, rngs=rngs),\n", " )\n", " self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)\n", " self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)\n", " self.layernorm_3 = nnx.LayerNorm(embed_dim, rngs=rngs)\n", "\n", " def __call__(self, inputs, encoder_outputs, mask=None):\n", " causal_mask = self.get_causal_attention_mask(inputs.shape[1])\n", " if mask is not None:\n", " padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)\n", " padding_mask = jnp.minimum(padding_mask, causal_mask)\n", " else:\n", " padding_mask = None\n", " attention_output_1 = self.attention_1(\n", " inputs_q=inputs, inputs_v=inputs, inputs_k=inputs, mask=causal_mask\n", " )\n", " out_1 = self.layernorm_1(inputs + attention_output_1)\n", "\n", " attention_output_2 = self.attention_2( ## https://github.com/google/flax/blob/main/flax/nnx/nn/attention.py#L403-L405\n", " inputs_q=out_1,\n", " inputs_v=encoder_outputs,\n", " inputs_k=encoder_outputs,\n", " mask=padding_mask,\n", " )\n", " out_2 = self.layernorm_2(out_1 + attention_output_2)\n", "\n", " proj_output = self.dense_proj(out_2)\n", " return self.layernorm_3(out_2 + proj_output)\n", "\n", " def get_causal_attention_mask(self, sequence_length):\n", " i = jnp.arange(sequence_length)[:, None]\n", " j = jnp.arange(sequence_length)\n", " mask = (i >= j).astype(jnp.int32)\n", " mask = jnp.reshape(mask, (1, 1, sequence_length, sequence_length))\n", " return mask" ] }, { "cell_type": "markdown", "id": "d033ae31-cc43-4e61-8d7f-cdc6d55b8bf9", "metadata": {}, "source": [ "Here we finally use our earlier encoder, decoder, and positional embed classes to construct the Model that we'll train and later use for inference." ] }, { "cell_type": "code", "execution_count": 12, "id": "c5dcfaf6-f5cd-40f4-bbf0-2754c0193327", "metadata": {}, "outputs": [], "source": [ "class TransformerModel(nnx.Module):\n", " def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int, latent_dim: int, num_heads: int, dropout_rate: float, rngs: nnx.Rngs):\n", " self.sequence_length = sequence_length\n", " self.vocab_size = vocab_size\n", " self.embed_dim = embed_dim\n", " self.latent_dim = latent_dim\n", " self.num_heads = num_heads\n", " self.dropout_rate = dropout_rate\n", "\n", " self.encoder = TransformerEncoder(embed_dim, latent_dim, num_heads, rngs=rngs)\n", " self.positional_embedding = PositionalEmbedding(sequence_length, vocab_size, embed_dim, rngs=rngs)\n", " self.decoder = TransformerDecoder(embed_dim, latent_dim, num_heads, rngs=rngs)\n", " self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)\n", " self.dense = nnx.Linear(embed_dim, vocab_size, rngs=rngs)\n", "\n", " def __call__(self, encoder_inputs: jnp.array, decoder_inputs: jnp.array, mask: jnp.array = None, deterministic: bool = False):\n", " x = self.positional_embedding(encoder_inputs)\n", " encoder_outputs = self.encoder(x, mask=mask)\n", "\n", " x = self.positional_embedding(decoder_inputs)\n", " decoder_outputs = self.decoder(x, encoder_outputs, mask=mask)\n", " # per nnx.Dropout - disable (deterministic=True) for eval, keep (False) for training\n", " decoder_outputs = self.dropout(decoder_outputs, deterministic=deterministic)\n", "\n", " logits = self.dense(decoder_outputs)\n", " return logits" ] }, { "cell_type": "markdown", "id": "1744cd95-afcc-4a82-9a00-18fef4f6f7df", "metadata": {}, "source": [ "## Build out Data Loader and Training Definitions\n", "It can be more computationally efficient to use pygrain for the data load stage, but this way it's abundandtly clear what's happening: data pairs go in and sets of jnp arrays come out, in step with our original dictionaries. 'Encoder_inputs', 'decoder_inputs' and 'target_output'." ] }, { "cell_type": "code", "execution_count": 13, "id": "1fb8cb44-9012-4802-9286-1efc19dd2ba1", "metadata": {}, "outputs": [], "source": [ "batch_size = 512 #set here for the loader and model train later on\n", "\n", "class CustomPreprocessing(grain.MapTransform):\n", " def __init__(self):\n", " pass\n", "\n", " def map(self, data):\n", " return {\n", " \"encoder_inputs\": np.array(data[\"encoder_inputs\"]),\n", " \"decoder_inputs\": np.array(data[\"decoder_inputs\"]),\n", " \"target_output\": np.array(data[\"target_output\"]),\n", " }\n", "\n", "train_sampler = grain.IndexSampler(\n", " len(train_data),\n", " shuffle=True,\n", " seed=12, # Seed for reproducibility\n", " shard_options=grain.NoSharding(), # No sharding since it's a single-device setup\n", " num_epochs=1, # Iterate over the dataset for one epoch\n", ")\n", "\n", "val_sampler = grain.IndexSampler(\n", " len(val_data),\n", " shuffle=False,\n", " seed=12,\n", " shard_options=grain.NoSharding(),\n", " num_epochs=1,\n", ")\n", "\n", "train_loader = grain.DataLoader(\n", " data_source=train_data,\n", " sampler=train_sampler, # Sampler to determine how to access the data\n", " worker_count=4, # Number of child processes launched to parallelize the transformations\n", " worker_buffer_size=2, # Count of output batches to produce in advance per worker\n", " operations=[\n", " CustomPreprocessing(),\n", " grain.Batch(batch_size=batch_size, drop_remainder=True),\n", " ]\n", ")\n", "\n", "val_loader = grain.DataLoader(\n", " data_source=val_data,\n", " sampler=val_sampler,\n", " worker_count=4,\n", " worker_buffer_size=2,\n", " operations=[\n", " CustomPreprocessing(),\n", " grain.Batch(batch_size=batch_size),\n", " ]\n", ")" ] }, { "cell_type": "markdown", "id": "40d9707d-a73c-47f5-8c12-1f336e526e61", "metadata": {}, "source": [ "Optax doesn't have the identical loss function that the source tutorial uses, but this softmax cross entropy works well here - you can one_hot_encode if you don't use the `_with_integer_labels` version of the loss." ] }, { "cell_type": "code", "execution_count": 14, "id": "d2f8e06f-1126-41cc-b8d8-de6bd7a5255a", "metadata": {}, "outputs": [], "source": [ "def compute_loss(logits, labels):\n", " loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels)\n", " return jnp.mean(loss)" ] }, { "cell_type": "markdown", "id": "0a1b625a-d9e7-4028-bc98-521ce1632450", "metadata": {}, "source": [ "While in the original tutorial most of the model and training details happen inside keras, we make them explicit here in our step functions, which are later used in `train_one_epoch` and `eval_model`." ] }, { "cell_type": "code", "execution_count": 15, "id": "279d991f-f129-48b3-9b7e-d143019c18a8", "metadata": {}, "outputs": [], "source": [ "@nnx.jit\n", "def train_step(model, optimizer, batch):\n", " def loss_fn(model, train_encoder_input, train_decoder_input, train_target_input):\n", " logits = model(train_encoder_input, train_decoder_input)\n", " loss = compute_loss(logits, train_target_input)\n", " return loss\n", "\n", " grad_fn = nnx.value_and_grad(loss_fn)\n", " loss, grads = grad_fn(model, jnp.array(batch[\"encoder_inputs\"]), jnp.array(batch[\"decoder_inputs\"]), jnp.array(batch[\"target_output\"]))\n", " optimizer.update(grads)\n", " return loss\n", "\n", "@nnx.jit\n", "def eval_step(model, batch, eval_metrics):\n", " logits = model(jnp.array(batch[\"encoder_inputs\"]), jnp.array(batch[\"decoder_inputs\"]))\n", " loss = compute_loss(logits, jnp.array(batch[\"target_output\"]))\n", " labels = jnp.array(batch[\"target_output\"])\n", "\n", " eval_metrics.update(\n", " loss=loss,\n", " logits=logits,\n", " labels=labels,\n", " )" ] }, { "cell_type": "markdown", "id": "04e53ee9-6da1-431c-8b3f-f619d3fee68f", "metadata": {}, "source": [ "Here, `nnx.MultiMetric` helps us keep track of general training statistics, while we make our own dictionaries to hold historical values" ] }, { "cell_type": "code", "execution_count": 16, "id": "32a17edc-33d0-41bc-a516-8b8ce45c3ad7", "metadata": {}, "outputs": [], "source": [ "eval_metrics = nnx.MultiMetric(\n", " loss=nnx.metrics.Average('loss'),\n", " accuracy=nnx.metrics.Accuracy(),\n", ")\n", "\n", "train_metrics_history = {\n", " \"train_loss\": [],\n", "}\n", "\n", "eval_metrics_history = {\n", " \"test_loss\": [],\n", " \"test_accuracy\": [],\n", "}" ] }, { "cell_type": "code", "execution_count": 17, "id": "1189a6a6-2cc6-4c87-9f87-b4b800a1513d", "metadata": {}, "outputs": [], "source": [ "## Hyperparameters\n", "rng = nnx.Rngs(0)\n", "embed_dim = 256\n", "latent_dim = 2048\n", "num_heads = 8\n", "dropout_rate = 0.5\n", "vocab_size = tokenizer.n_vocab\n", "sequence_length = 20\n", "learning_rate = 1.5e-3\n", "num_epochs = 10" ] }, { "cell_type": "code", "execution_count": 18, "id": "fbeb6101-be11-4a33-9650-a3efd3656855", "metadata": {}, "outputs": [], "source": [ "bar_format = \"{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]\"\n", "train_total_steps = len(train_data) // batch_size\n", "\n", "def train_one_epoch(epoch):\n", " model.train() # Set model to the training mode: e.g. update batch statistics\n", " with tqdm.tqdm(\n", " desc=f\"[train] epoch: {epoch}/{num_epochs}, \",\n", " total=train_total_steps,\n", " bar_format=bar_format,\n", " leave=True,\n", " ) as pbar:\n", " for batch in train_loader:\n", " loss = train_step(model, optimizer, batch)\n", " train_metrics_history[\"train_loss\"].append(loss.item())\n", " pbar.set_postfix({\"loss\": loss.item()})\n", " pbar.update(1)\n", "\n", "\n", "def evaluate_model(epoch):\n", " # Compute the metrics on the train and val sets after each training epoch.\n", " model.eval() # Set model to evaluation model: e.g. use stored batch statistics\n", "\n", " eval_metrics.reset() # Reset the eval metrics\n", " for val_batch in val_loader:\n", " eval_step(model, val_batch, eval_metrics)\n", "\n", " for metric, value in eval_metrics.compute().items():\n", " eval_metrics_history[f'test_{metric}'].append(value)\n", "\n", " print(f\"[test] epoch: {epoch + 1}/{num_epochs}\")\n", " print(f\"- total loss: {eval_metrics_history['test_loss'][-1]:0.4f}\")\n", " print(f\"- Accuracy: {eval_metrics_history['test_accuracy'][-1]:0.4f}\")" ] }, { "cell_type": "code", "execution_count": 19, "id": "49a1d33a-c2e4-4d48-821b-519f5c0192c7", "metadata": {}, "outputs": [], "source": [ "model = TransformerModel(sequence_length, vocab_size, embed_dim, latent_dim, num_heads, dropout_rate, rngs=rng)\n", "optimizer = nnx.Optimizer(model, optax.adamw(learning_rate))" ] }, { "cell_type": "markdown", "id": "fa7d5601-60c1-4131-a40c-c670f055ce68", "metadata": {}, "source": [ "## Start the Training!\n", "With our data loaders set and the model, optimizer, and epoch train/eval functions set up - time to finally press go - on a 3090, this is roughly 19GB VRAM and each epoch is roughly 18 seconds with batch_size set to 512." ] }, { "cell_type": "code", "execution_count": 20, "id": "c764510c-4d98-46ad-b877-8cfc2fa5a9ea", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 0/10, [160/162], loss=1.98 [00:27<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 1/10\n", "- total loss: 1.9655\n", "- Accuracy: 0.6774\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 1/10, [160/162], loss=1.16 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 2/10\n", "- total loss: 1.1961\n", "- Accuracy: 0.7903\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 2/10, [160/162], loss=0.846 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 3/10\n", "- total loss: 1.0054\n", "- Accuracy: 0.8167\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 3/10, [160/162], loss=0.695 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 4/10\n", "- total loss: 0.9351\n", "- Accuracy: 0.8289\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 4/10, [160/162], loss=0.593 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 5/10\n", "- total loss: 0.8976\n", "- Accuracy: 0.8369\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 5/10, [160/162], loss=0.511 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 6/10\n", "- total loss: 0.8876\n", "- Accuracy: 0.8396\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 6/10, [160/162], loss=0.454 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 7/10\n", "- total loss: 0.8857\n", "- Accuracy: 0.8426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 7/10, [160/162], loss=0.421 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 8/10\n", "- total loss: 0.8959\n", "- Accuracy: 0.8427\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 8/10, [160/162], loss=0.371 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 9/10\n", "- total loss: 0.9128\n", "- Accuracy: 0.8434\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 9/10, [160/162], loss=0.341 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 10/10\n", "- total loss: 0.9227\n", "- Accuracy: 0.8452\n" ] } ], "source": [ "for epoch in range(num_epochs):\n", " train_one_epoch(epoch)\n", " evaluate_model(epoch)" ] }, { "cell_type": "markdown", "id": "f922eac4-8338-4a0d-bc6d-1f5880079bde", "metadata": {}, "source": [ "We can then plot the loss over training time. That log-plot comes in handy here, or it's hard to appreciate the progress after 1000 steps or so." ] }, { "cell_type": "code", "execution_count": 21, "id": "a79ecfa5-d74a-4956-9ee2-cbed86d5a82f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(train_metrics_history[\"train_loss\"], label=\"Loss value during the training\")\n", "plt.yscale('log')\n", "plt.legend()" ] }, { "cell_type": "markdown", "id": "66250bf2-3d88-40ad-87e3-7d2b906fd860", "metadata": {}, "source": [ "And eval set Loss and Accuracy - Accuracy does continue to rise, though it's hard-earned progress after about the 5th epoch. Based on the training statistics, it's fair to say the process starts overfitting after roughly that 5th epoch." ] }, { "cell_type": "code", "execution_count": 22, "id": "64d54051-358b-4de8-b5b3-04bebf18018f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axs = plt.subplots(1, 2, figsize=(10, 10))\n", "axs[0].set_title(\"Loss value on eval set\")\n", "axs[0].plot(eval_metrics_history[\"test_loss\"])\n", "axs[1].set_title(\"Accuracy on eval set\")\n", "axs[1].plot(eval_metrics_history[\"test_accuracy\"])" ] }, { "cell_type": "markdown", "id": "a3f7b0ad-ddfa-4ab3-b56f-6ea99385ff6a", "metadata": {}, "source": [ "## Use Model for Inference\n", "After all that, the product of what we were working for: a trained model we can save and load for inference. For people using LLMs recently, this pattern may look rather familiar: an input sentence tokenized into an array and computed 'next' token-by-token. While many recent LLMs are decoder-only, this was an encoder/decoder architecture with the very specific english-to-spanish pattern baked in.\n", "\n", "We've changed a couple things from the source 'use' function, here - because of the tokenizer used, things like `[start]` and `[end]` are no longer single tokens - instead `[start]` is `[29563, 60] = \"[start\" + \"]\"` and `[end]` is `[58308, 60] = \"[end\" + \"]\"` - thus we start with only a single token `[start` and can't only test on `last_token = \"[end\"]`. Otherwise, the main change here is that the input is assumed a single sentence, rather than batch inference." ] }, { "cell_type": "code", "execution_count": 23, "id": "e4589706-cfd6-4efb-9975-bfa0df75d4f0", "metadata": {}, "outputs": [], "source": [ "def decode_sequence(input_sentence):\n", "\n", " input_sentence = custom_standardization(input_sentence)\n", " tokenized_input_sentence = tokenize_and_pad(input_sentence, tokenizer, sequence_length)\n", "\n", " decoded_sentence = \"[start\"\n", " for i in range(sequence_length):\n", " tokenized_target_sentence = tokenize_and_pad(decoded_sentence, tokenizer, sequence_length)[:-1]\n", " predictions = model(jnp.array([tokenized_input_sentence]), jnp.array([tokenized_target_sentence]))\n", "\n", " sampled_token_index = np.argmax(predictions[0,i, :]).item(0)\n", " sampled_token = tokenizer.decode([sampled_token_index])\n", " decoded_sentence += \"\" + sampled_token\n", "\n", " if decoded_sentence[-5:] == \"[end]\":\n", " break\n", " return decoded_sentence" ] }, { "cell_type": "code", "execution_count": 24, "id": "554c2f72-0bd3-4ed1-804b-5f1a4cc13851", "metadata": {}, "outputs": [], "source": [ "test_eng_texts = [pair[0] for pair in test_pairs]" ] }, { "cell_type": "code", "execution_count": 25, "id": "c1d6edbb-af89-42c9-90c3-d61612b75da3", "metadata": {}, "outputs": [], "source": [ "test_result_pairs = []\n", "for _ in range(10):\n", " input_sentence = random.choice(test_eng_texts)\n", " translated = decode_sequence(input_sentence)\n", "\n", " test_result_pairs.append(f\"[Input]: {input_sentence} [Translation]: {translated}\")" ] }, { "cell_type": "markdown", "id": "258c2172-5a0f-4dee-9b21-f65433183c62", "metadata": {}, "source": [ "## Test Results\n", "For the model and the data, not too shabby - It's definitely spanish-ish. Though when 'making' friends, please don't confuse 'hacer' (to make) with 'comer' (to eat)." ] }, { "cell_type": "code", "execution_count": 26, "id": "4f0ae018-b7cd-4849-b245-c5c647ad1a95", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Input]: We're both way too busy to help you right now. [Translation]: [start] los dos estamos demasiado para ayudar esta mañana [end]\n", "[Input]: Have you eaten dinner? [Translation]: [start] has comido la cena [end]\n", "[Input]: That is the poet I met in Paris. [Translation]: [start] ese es el poeta que conocí en parís [end]\n", "[Input]: It doesn't make sense to me. [Translation]: [start] no me hace falta sentido [end]\n", "[Input]: We're happy. [Translation]: [start] estamos felices [end]\n", "[Input]: What about me? [Translation]: [start] de qué me [end]\n", "[Input]: Make a decision and make it with the confidence that you are right. [Translation]: [start] haz una decisión y tomará la confianza en el confian [end]\n", "[Input]: Put some salt on your meat. [Translation]: [start] ponte algo de sal [end]\n", "[Input]: Tom's deaf. [Translation]: [start] tom es sordo [end]\n", "[Input]: How old are your brothers and sisters? [Translation]: [start] qué edad son tus hermanos [end]\n" ] } ], "source": [ "for i in test_result_pairs:\n", " print(i)" ] }, { "cell_type": "markdown", "id": "5ca18d4c-b3c0-4abb-b5fc-fc96a2264b53", "metadata": {}, "source": [ "Example output from the above cell:\n", "\n", " [Input]: We're going to have a baby. [Translation]: [start] nosotros vamos a tener un bebé [end]\n", " [Input]: You drive too fast. [Translation]: [start] conducís demasiado rápido [end]\n", " [Input]: Let me know if there's anything I can do. [Translation]: [start] déjame saber si hay cualquier cosa que yo pueda hacer [end]\n", " [Input]: Let's go to the kitchen. [Translation]: [start] vayamos a la cocina [end]\n", " [Input]: Tom gasped. [Translation]: [start] tom se quedó sin aliento [end]\n", " [Input]: I was just hanging out with some of my friends. [Translation]: [start] estaba escquieto con algunos de mi amigos [end]\n", " [Input]: Tom is in the bathroom. [Translation]: [start] tom está en el cuarto de baño [end]\n", " [Input]: I feel safe here. [Translation]: [start] me siento segura [end]\n", " [Input]: I'm going to need you later. [Translation]: [start] me voy a necesitar después [end]\n", " [Input]: A party is a good place to make friends with other people. [Translation]: [start] una fiesta es un buen lugar de comer amigos con otras personas [end]" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }