{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Get Started with Keras 3.0 + MLflow\n", "\n", "This tutorial is an end-to-end tutorial on training a MINST classifier with **Keras 3.0** and logging results with **MLflow**. It will demonstrate the use of `mlflow.keras.MlflowCallback`, and how to subclass it to implement custom logging logic.\n", "\n", "**Keras** is a high-level api that is designed to be simple, flexible, and powerful - allowing everyone from beginners to advanced users to quickly build, train, and evaluate models. **Keras 3.0**, or Keras Core, is a full rewrite of the Keras codebase that rebases it on top of a modular backend architecture. It makes it possible to run Keras workflows on top of arbitrary frameworks — starting with TensorFlow, JAX, and PyTorch." ] }, { "cell_type": "raw", "metadata": {}, "source": [ " Download this Notebook
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install Packages\n", "\n", "`pip install -q keras mlflow jax jaxlib torch tensorflow`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import Packages / Configure Backend\n", "Keras 3.0 is inherently multi-backend, so you will need to set the backend environment variable **before** importing the package." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "# You can use 'tensorflow', 'torch' or 'jax' as backend. Make sure to set the environment variable before importing.\n", "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using TensorFlow backend\n" ] } ], "source": [ "import keras\n", "import numpy as np\n", "\n", "import mlflow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Dataset\n", "We will use the MNIST dataset. This is a dataset of handwritten digits and will be used for an image classification task. There are 10 classes corresponding to the 10 digits." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(28, 28, 1)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", "x_train = np.expand_dims(x_train, axis=3)\n", "x_test = np.expand_dims(x_test, axis=3)\n", "x_train[0].shape" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Visualize Dataset\n", "import matplotlib.pyplot as plt\n", "\n", "grid = 3\n", "fig, axes = plt.subplots(grid, grid, figsize=(6, 6))\n", "for i in range(grid):\n", " for j in range(grid):\n", " axes[i][j].imshow(x_train[i * grid + j])\n", " axes[i][j].set_title(f\"label={y_train[i * grid + j]}\")\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build Model\n", "We will use the Keras 3.0 sequential API to build a simple CNN." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"sequential\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n",
       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n",
       "│ conv2d (Conv2D)                 │ (None, 26, 26, 32)        │        320 │\n",
       "├─────────────────────────────────┼───────────────────────────┼────────────┤\n",
       "│ conv2d_1 (Conv2D)               │ (None, 24, 24, 32)        │      9,248 │\n",
       "├─────────────────────────────────┼───────────────────────────┼────────────┤\n",
       "│ conv2d_2 (Conv2D)               │ (None, 22, 22, 32)        │      9,248 │\n",
       "├─────────────────────────────────┼───────────────────────────┼────────────┤\n",
       "│ global_average_pooling2d        │ (None, 32)                │          0 │\n",
       "│ (GlobalAveragePooling2D)        │                           │            │\n",
       "├─────────────────────────────────┼───────────────────────────┼────────────┤\n",
       "│ dense (Dense)                   │ (None, 10)                │        330 │\n",
       "└─────────────────────────────────┴───────────────────────────┴────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", "│ conv2d (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m320\u001b[0m │\n", "├─────────────────────────────────┼───────────────────────────┼────────────┤\n", "│ conv2d_1 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m24\u001b[0m, \u001b[38;5;34m24\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m9,248\u001b[0m │\n", "├─────────────────────────────────┼───────────────────────────┼────────────┤\n", "│ conv2d_2 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m22\u001b[0m, \u001b[38;5;34m22\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m9,248\u001b[0m │\n", "├─────────────────────────────────┼───────────────────────────┼────────────┤\n", "│ global_average_pooling2d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "│ (\u001b[38;5;33mGlobalAveragePooling2D\u001b[0m) │ │ │\n", "├─────────────────────────────────┼───────────────────────────┼────────────┤\n", "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m330\u001b[0m │\n", "└─────────────────────────────────┴───────────────────────────┴────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 19,146 (74.79 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m19,146\u001b[0m (74.79 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 19,146 (74.79 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m19,146\u001b[0m (74.79 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "NUM_CLASSES = 10\n", "INPUT_SHAPE = (28, 28, 1)\n", "\n", "\n", "def initialize_model():\n", " return keras.Sequential(\n", " [\n", " keras.Input(shape=INPUT_SHAPE),\n", " keras.layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n", " keras.layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n", " keras.layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n", " keras.layers.GlobalAveragePooling2D(),\n", " keras.layers.Dense(NUM_CLASSES, activation=\"softmax\"),\n", " ]\n", " )\n", "\n", "\n", "model = initialize_model()\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Model (Default Callback)\n", "We will fit the model on the dataset, using MLflow's `mlflow.keras.MlflowCallback` to log metrics during training." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 64 # adjust this based on the memory of your machine\n", "EPOCHS = 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Log Per Epoch\n", "An epoch defined as one pass through the entire training dataset." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n", "\u001b[1m844/844\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m30s\u001b[0m 34ms/step - accuracy: 0.5922 - loss: 1.2862 - val_accuracy: 0.9427 - val_loss: 0.2075\n", "Epoch 2/3\n", "\u001b[1m844/844\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m28s\u001b[0m 33ms/step - accuracy: 0.9330 - loss: 0.2286 - val_accuracy: 0.9348 - val_loss: 0.2020\n", "Epoch 3/3\n", "\u001b[1m844/844\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m28s\u001b[0m 33ms/step - accuracy: 0.9499 - loss: 0.1671 - val_accuracy: 0.9558 - val_loss: 0.1491\n" ] } ], "source": [ "model = initialize_model()\n", "\n", "model.compile(\n", " loss=keras.losses.SparseCategoricalCrossentropy(),\n", " optimizer=keras.optimizers.Adam(),\n", " metrics=[\"accuracy\"],\n", ")\n", "\n", "run = mlflow.start_run()\n", "model.fit(\n", " x_train,\n", " y_train,\n", " batch_size=BATCH_SIZE,\n", " epochs=EPOCHS,\n", " validation_split=0.1,\n", " callbacks=[mlflow.keras.MlflowCallback(run)],\n", ")\n", "mlflow.end_run()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Log Results\n", "The callback for the run would log **parameters**, **metrics** and **artifacts** to MLflow dashboard.\n", "\n", "![run page](https://i.imgur.com/YLGFDJEl.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Log Per Batch\n", "Within each epoch, the training dataset is broken down to batches based on the defined `BATCH_SIZE`. If we set the callback to not log based on epochs with `log_every_epoch=False`, and to log every 5 batches with `log_every_n_steps=5`, we can adjust the logging to be based on the batches." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n", "\u001b[1m844/844\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m30s\u001b[0m 34ms/step - accuracy: 0.6151 - loss: 1.2100 - val_accuracy: 0.9373 - val_loss: 0.2144\n", "Epoch 2/3\n", "\u001b[1m844/844\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m29s\u001b[0m 34ms/step - accuracy: 0.9274 - loss: 0.2459 - val_accuracy: 0.9608 - val_loss: 0.1338\n", "Epoch 3/3\n", "\u001b[1m844/844\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m28s\u001b[0m 34ms/step - accuracy: 0.9477 - loss: 0.1738 - val_accuracy: 0.9577 - val_loss: 0.1454\n" ] } ], "source": [ "model = initialize_model()\n", "\n", "model.compile(\n", " loss=keras.losses.SparseCategoricalCrossentropy(),\n", " optimizer=keras.optimizers.Adam(),\n", " metrics=[\"accuracy\"],\n", ")\n", "\n", "with mlflow.start_run() as run:\n", " model.fit(\n", " x_train,\n", " y_train,\n", " batch_size=BATCH_SIZE,\n", " epochs=EPOCHS,\n", " validation_split=0.1,\n", " callbacks=[mlflow.keras.MlflowCallback(run, log_every_epoch=False, log_every_n_steps=5)],\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Log Results\n", "\n", "If we **log per epoch**, we will only have three datapoints, since there are only 3 epochs:\n", "\n", "![log per epoch](https://i.imgur.com/rFDj8SHl.png)\n", "\n", "By **logging per batch**, we can get more datapoints, but they can be noisier:\n", "\n", "![log per batch](https://i.imgur.com/ZCYXLqll.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class MlflowCallbackLogPerBatch(mlflow.keras.MlflowCallback):\n", " def on_batch_end(self, batch, logs=None):\n", " if self.log_every_n_steps is None or logs is None:\n", " return\n", " if (batch + 1) % self.log_every_n_steps == 0:\n", " self.metrics_logger.record_metrics(logs, self._log_step)\n", " self._log_step += self.log_every_n_steps" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n", "\u001b[1m844/844\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m29s\u001b[0m 34ms/step - accuracy: 0.5645 - loss: 1.4105 - val_accuracy: 0.9187 - val_loss: 0.2826\n", "Epoch 2/3\n", "\u001b[1m844/844\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m29s\u001b[0m 34ms/step - accuracy: 0.9257 - loss: 0.2615 - val_accuracy: 0.9602 - val_loss: 0.1368\n", "Epoch 3/3\n", "\u001b[1m844/844\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m29s\u001b[0m 34ms/step - accuracy: 0.9456 - loss: 0.1800 - val_accuracy: 0.9678 - val_loss: 0.1037\n" ] } ], "source": [ "model = initialize_model()\n", "\n", "model.compile(\n", " loss=keras.losses.SparseCategoricalCrossentropy(),\n", " optimizer=keras.optimizers.Adam(),\n", " metrics=[\"accuracy\"],\n", ")\n", "\n", "with mlflow.start_run() as run:\n", " model.fit(\n", " x_train,\n", " y_train,\n", " batch_size=BATCH_SIZE,\n", " epochs=EPOCHS,\n", " validation_split=0.1,\n", " callbacks=[MlflowCallbackLogPerBatch(run, log_every_epoch=False, log_every_n_steps=5)],\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation\n", "Similar to training, you can use the callback to log the evaluation result." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m313/313\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9541 - loss: 0.1487\n" ] } ], "source": [ "with mlflow.start_run() as run:\n", " model.evaluate(x_test, y_test, callbacks=[mlflow.keras_core.MlflowCallback(run)])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 2 }