{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualize JAX model metrics with TensorBoard\n",
"\n",
"[](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_visualizing_models_metrics.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html) - if you haven't read that yet and want the primer, start there before returning.\n",
"\n",
"All of the modeling and training code is the same here. What we have added are the tensorboard connections and the discussion around them."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import io\n",
"from datetime import datetime"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "hKhPLnNxfOHU",
"outputId": "ac3508f0-ccc6-409b-c719-99a4b8f94bd6"
},
"outputs": [],
"source": [
"from sklearn.datasets import load_digits\n",
"digits = load_digits()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we set the location of the tensorflow writer - the organization is somewhat arbitrary, though keeping a folder for each training run can make later navigation more straightforward."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"file_path = \"runs/test/\" + datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
"test_summary_writer = tf.summary.create_file_writer(file_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pulled from the official tensorboard examples, this convert function makes it simple to drop matplotlib figures directly into tensorboard"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def plot_to_image(figure):\n",
" \"\"\"Sourced from https://www.tensorflow.org/tensorboard/image_summaries\n",
" Converts the matplotlib plot specified by 'figure' to a PNG image and\n",
" returns it. The supplied figure is closed and inaccessible after this call.\"\"\"\n",
" # Save the plot to a PNG in memory.\n",
" buf = io.BytesIO()\n",
" plt.savefig(buf, format='png')\n",
" # Closing the figure prevents it from being displayed directly inside\n",
" # the notebook.\n",
" plt.close(figure)\n",
" buf.seek(0)\n",
" # Convert PNG buffer to TF image\n",
" image = tf.image.decode_png(buf.getvalue(), channels=4)\n",
" # Add the batch dimension\n",
" image = tf.expand_dims(image, 0)\n",
" return image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Whereas previously the example displays the training data snapshot in the notebook, here we stash it in the tensorboard images. If a given training is to be repeated many, many times it can save space to stash the training data information as its own run and skip this step for each subsequent training, provided the input is static. Note that this pattern uses the writer in a `with` context manager. We are able to step into and out of this type of context through the run without losing the same file/folder experiment."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "Y8cMntSdfyyT",
"outputId": "9343a558-cd8c-473c-c109-aa8015c7ae7e"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"fig, axes = plt.subplots(10, 10, figsize=(6, 6),\n",
" subplot_kw={'xticks':[], 'yticks':[]},\n",
" gridspec_kw=dict(hspace=0.1, wspace=0.1))\n",
"\n",
"for i, ax in enumerate(axes.flat):\n",
" ax.imshow(digits.images[i], cmap='binary', interpolation='gaussian')\n",
" ax.text(0.05, 0.05, str(digits.target[i]), transform=ax.transAxes, color='green')\n",
"with test_summary_writer.as_default():\n",
" tf.summary.image(\"Training Data\", plot_to_image(fig), step=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After running all above and launching `tensorboard --logdir runs/test` from the same folder, you should see the following in the supplied URL:\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "6jrYisoPh6TL"
},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"splits = train_test_split(digits.images, digits.target, random_state=0)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "oMRcwKd4hqOo",
"outputId": "0ad36290-397b-431d-eba2-ef114daf5ea6"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"images_train.shape=(1347, 8, 8) label_train.shape=(1347,)\n",
"images_test.shape=(450, 8, 8) label_test.shape=(450,)\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"images_train, images_test, label_train, label_test = map(jnp.asarray, splits)\n",
"print(f\"{images_train.shape=} {label_train.shape=}\")\n",
"print(f\"{images_test.shape=} {label_test.shape=}\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "U77VMQwRjTfH",
"outputId": "345fed7a-4455-4036-85ed-57e673a4de01"
},
"outputs": [
{
"data": {
"text/html": [
"