{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3aef95bd-19b9-406e-a860-b1720c26d307",
   "metadata": {},
   "source": [
    "Sascha Spors,\n",
    "Professorship Signal Theory and Digital Signal Processing,\n",
    "Institute of Communications Engineering (INT),\n",
    "Faculty of Computer Science and Electrical Engineering (IEF),\n",
    "University of Rostock,\n",
    "Germany\n",
    "\n",
    "# Data Driven Audio Signal Processing - A Tutorial with Computational Examples\n",
    "\n",
    "Winter Semester 2023/24 (Master Course #24512)\n",
    "\n",
    "- lecture: https://github.com/spatialaudio/data-driven-audio-signal-processing-lecture\n",
    "- tutorial: https://github.com/spatialaudio/data-driven-audio-signal-processing-exercise\n",
    "\n",
    "Feel free to contact lecturer frank.schultz@uni-rostock.de"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46f6f98a-12a4-4bcd-975e-18a7d404cb9c",
   "metadata": {},
   "source": [
    "# Multiclass Classification with Hyper Parameter Tuning\n",
    "\n",
    "- **Softmax** activation Function at Output\n",
    "- Categorical crossentropy loss\n",
    "- Split data into training, validating, testing data sets\n",
    "    - training, validating used for hyper parameter tuning (validate serves as the unseen test data here)\n",
    "    - training, testing used for train/test the best model (test data was never used before! and is here only and once used to finally check model performance) \n",
    "- Avoid over-/underfit by\n",
    "    - deploying early stopping\n",
    "    - deploying hyper parameter tuning using Keras tuner\n",
    "- we use convenient stuff from scikit-learn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fdebeca-af44-46ef-89d3-f2c229abf073",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcdf6b91-8051-4962-b5ee-c67cd6916135",
   "metadata": {},
   "outputs": [],
   "source": [
    "import keras_tuner as kt\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "from sklearn.datasets import make_classification\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder, LabelBinarizer\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import time\n",
    "\n",
    "print(\n",
    "    \"TF version\",\n",
    "    tf.__version__,\n",
    "    \"\\nKeras Tuner version\",\n",
    "    kt.__version__,\n",
    ")\n",
    "verbose = 1  # plot training status\n",
    "\n",
    "CI_flag = True  # use toy parameters to check if this notebooks runs in CI"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5249aa49-d2da-4aec-be0a-769001eb010d",
   "metadata": {},
   "source": [
    "## Folder Structure For Log Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de910024-77ed-4f72-8bfe-d12a2c929c39",
   "metadata": {},
   "outputs": [],
   "source": [
    "ex_str = \"ex12_\"\n",
    "time_str = \"%Y_%m_%d_%H_%M_\"\n",
    "\n",
    "\n",
    "def get_kt_logdir():\n",
    "    run_id = time.strftime(time_str + ex_str + \"kt\")\n",
    "    return os.path.join(root_logdir, run_id)\n",
    "\n",
    "\n",
    "def get_tf_kt_logdir():\n",
    "    run_id = time.strftime(time_str + ex_str + \"tf_kt\")\n",
    "    return os.path.join(root_logdir, run_id)\n",
    "\n",
    "\n",
    "def get_tf_logdir():\n",
    "    run_id = time.strftime(time_str + ex_str + \"tf\")\n",
    "    return os.path.join(root_logdir, run_id)\n",
    "\n",
    "\n",
    "root_logdir = os.path.join(os.curdir, \"tf_keras_logs\")\n",
    "kt_logdir = get_kt_logdir()\n",
    "tf_kt_logdir = get_tf_kt_logdir()\n",
    "tf_logdir = get_tf_logdir()\n",
    "print(root_logdir)\n",
    "print(kt_logdir)  # folder for keras tuner results\n",
    "print(tf_kt_logdir)  # folder for TF checkpoints while keras tuning\n",
    "print(tf_logdir)  # folder for TF checkpoint for best model training\n",
    "\n",
    "os.makedirs(tf_logdir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0872515f-abb8-4222-9b01-a49a76245047",
   "metadata": {},
   "source": [
    "## Data Synthesis / One Hot Encoding / Splitting "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3696b9b2-a2e5-467c-a7b4-3c85fad5cd78",
   "metadata": {},
   "outputs": [],
   "source": [
    "nlabels = 3  # number of classes\n",
    "labels = np.arange(nlabels)  # we encode as integers\n",
    "nx = 2 * nlabels  # number of features, here we use 6\n",
    "m = 100000  # data examples\n",
    "\n",
    "train_size = 7 / 10  # 7/10 of the whole data set\n",
    "validate_size = 3 / 10 * 2 / 3  # 1/5 of the whole data set\n",
    "test_size = 1 - train_size - validate_size  # remaining data, must be > 0\n",
    "\n",
    "X, Y = make_classification(\n",
    "    n_samples=m,\n",
    "    n_features=nx,\n",
    "    n_informative=nx,\n",
    "    n_redundant=0,\n",
    "    n_classes=nlabels,\n",
    "    n_clusters_per_class=1,\n",
    "    class_sep=1,\n",
    "    flip_y=1e-2,\n",
    "    random_state=None,\n",
    ")\n",
    "encoder = OneHotEncoder(sparse_output=False)\n",
    "# we encode as one-hot for TF model\n",
    "Y = encoder.fit_transform(Y.reshape(-1, 1))\n",
    "\n",
    "# split into train, val, test data:\n",
    "X_train, X_tmp, Y_train, Y_tmp = train_test_split(\n",
    "    X, Y, train_size=train_size, random_state=None\n",
    ")\n",
    "val_size = (validate_size * m) / ((1 - train_size) * m)\n",
    "X_val, X_test, Y_val, Y_test = train_test_split(\n",
    "    X_tmp, Y_tmp, train_size=val_size, random_state=None\n",
    ")\n",
    "\n",
    "m_train, m_val, m_test = X_train.shape[0], X_val.shape[0], X_test.shape[0]\n",
    "\n",
    "print(train_size, validate_size, test_size)\n",
    "print(m_train, m_val, m_test, m_train + m_val + m_test == m)\n",
    "print(X_train.shape, X_val.shape, X_test.shape)\n",
    "print(Y_train.shape, Y_val.shape, Y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad248f4d-cebe-4eda-a0e3-d2887948a4d5",
   "metadata": {},
   "source": [
    "## Model Preparation / Hyper Parameter Range"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8b7a07f0-7151-4914-93ca-6efb3c22779c",
   "metadata": {},
   "outputs": [],
   "source": [
    "earlystopping_cb = keras.callbacks.EarlyStopping(\n",
    "    monitor=\"val_loss\", patience=2, restore_best_weights=True  # on val data!\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e1eae32c-83b4-4b69-837f-31e3d87a470a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# as homework we might also consider dropout and regularization in the model\n",
    "def build_model(hp):  # with hyper parameter ranges\n",
    "    model = keras.Sequential()\n",
    "    # input layer\n",
    "    model.add(keras.Input(shape=(nx, )))\n",
    "    # hidden layers\n",
    "    for layer in range(hp.Int(\"no_layers\", 1, 4)):\n",
    "        model.add(\n",
    "            keras.layers.Dense(\n",
    "                units=hp.Int(\n",
    "                    f\"no_perceptrons_{layer}\", min_value=2, max_value=16, step=2\n",
    "                ),\n",
    "                activation=hp.Choice(\"activation\", [\"tanh\", \"relu\", \"sigmoid\", \"softmax\"]),\n",
    "                # sigmoid and softmax could be choice that we want to check as well\n",
    "                # they are not restricted to be used in a classifiction problem\n",
    "                # output layer\n",
    "            )\n",
    "        )\n",
    "    # softmax output layer\n",
    "    model.add(keras.layers.Dense(nlabels, activation=\"softmax\"))\n",
    "    # learning_rate = hp.Float('learning_rate', min_value=1e-5, max_value=1e-1,\n",
    "    #                          sampling='log')\n",
    "    model.compile(\n",
    "        optimizer=keras.optimizers.Adam(),  # learning_rate=learning_rate\n",
    "        loss=keras.losses.CategoricalCrossentropy(\n",
    "            from_logits=False, label_smoothing=0\n",
    "        ),\n",
    "        metrics=[\"CategoricalCrossentropy\", \"CategoricalAccuracy\"],\n",
    "    )\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b2f345f-8393-4b77-a8bd-aa66f7b8d478",
   "metadata": {},
   "source": [
    "## Hyper Parameter Tuner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf18d595-cfcc-41d4-8f91-550ea25fbb92",
   "metadata": {},
   "outputs": [],
   "source": [
    "if CI_flag:\n",
    "    max_trials = 5  # number of models to build and try\n",
    "else:\n",
    "    max_trials = 20  # number of models to build and try\n",
    "executions_per_trial = 2\n",
    "model = build_model(kt.HyperParameters())\n",
    "hptuner = kt.RandomSearch(\n",
    "    hypermodel=build_model,\n",
    "    objective='val_loss',  # check performance on val data!\n",
    "    max_trials=max_trials,\n",
    "    executions_per_trial=executions_per_trial,\n",
    "    overwrite=True,\n",
    "    directory=kt_logdir,\n",
    "    project_name=None,\n",
    ")\n",
    "print(hptuner.search_space_summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c41a8d7b-03a3-41f5-b7a0-849d3bd2af3e",
   "metadata": {},
   "source": [
    "## Training of Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a42299d-ee62-4b8f-97ec-c97a246ee4fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "if CI_flag:\n",
    "    epochs = 3\n",
    "else:\n",
    "    epochs = 50\n",
    "tensorboard_cb = keras.callbacks.TensorBoard(tf_kt_logdir)\n",
    "hptuner.search(\n",
    "    X_train,\n",
    "    Y_train,\n",
    "    validation_data=(X_val, Y_val),\n",
    "    epochs=epochs,\n",
    "    callbacks=[earlystopping_cb, tensorboard_cb],\n",
    "    verbose=verbose,\n",
    ")\n",
    "print(hptuner.results_summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d897d67-7b85-4f14-aa25-007b961c0788",
   "metadata": {},
   "source": [
    "## Best Model Selection / Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3027d14-6dcd-4d4d-80a3-e888f26aeaf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# we might check the best XX models in detail\n",
    "# for didactical purpose we choose only the very best one, located in [0]:\n",
    "model = hptuner.get_best_models(num_models=1)[0]\n",
    "model.save(tf_logdir + \"/best_model.keras\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e76a4f45-6f5b-4755-9f98-7ad228efdf97",
   "metadata": {},
   "outputs": [],
   "source": [
    "# taken from https://github.com/keras-team/keras/issues/341\n",
    "# 183amir commented on 7 Oct 2019:\n",
    "# \"If you are using tensorflow 2, you can use this:\"\n",
    "def reset_weights(model):\n",
    "    for layer in model.layers:\n",
    "        if isinstance(layer, tf.keras.Model):\n",
    "            reset_weights(layer)\n",
    "            continue\n",
    "        for k, initializer in layer.__dict__.items():\n",
    "            if \"initializer\" not in k:\n",
    "                continue\n",
    "            # find the corresponding variable\n",
    "            var = getattr(layer, k.replace(\"_initializer\", \"\"))\n",
    "            var.assign(initializer(var.shape, var.dtype))\n",
    "\n",
    "\n",
    "# 183amir: \"I am not sure if it works in all cases, I have only tested the Dense and Conv2D layers.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "184bbba1-e1d4-4dfc-a239-10699d24af2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load best model and reset weights\n",
    "model = keras.models.load_model(tf_logdir + \"/best_model.keras\")\n",
    "reset_weights(model)  # start training from scratch\n",
    "print(model.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18850dd2-6ead-4f4b-8354-2648d3246ea6",
   "metadata": {},
   "source": [
    "## Training of Best Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10d1ef6b-90a1-47be-b8ac-949a930bc8f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "if CI_flag:\n",
    "    epochs = 3\n",
    "else:\n",
    "    epochs = 50\n",
    "tensorboard_cb = keras.callbacks.TensorBoard(tf_logdir)\n",
    "history = model.fit(\n",
    "    X_train,\n",
    "    Y_train,\n",
    "    epochs=epochs,\n",
    "    batch_size=batch_size,\n",
    "    validation_data=(X_val, Y_val),\n",
    "    callbacks=[earlystopping_cb, tensorboard_cb],\n",
    "    verbose=verbose,\n",
    ")\n",
    "model.save(tf_logdir + \"/trained_best_model.keras\")\n",
    "print(model.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bfe4c18-44fc-42e0-aef2-e513db4db646",
   "metadata": {},
   "source": [
    "## Evaluation of Best Model on Unseen Test Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce5b8a8b-fc24-43b8-adcf-3b27d731ec7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_results(X, Y):\n",
    "    # https://stackoverflow.com/questions/48908641/how-to-get-a-single-value-from-softmax-instead-of-probability-get-confusion-ma:\n",
    "    lb = LabelBinarizer()\n",
    "    lb.fit(labels)\n",
    "\n",
    "    m = X.shape[0]\n",
    "    results = model.evaluate(X, Y, batch_size=m, verbose=verbose)\n",
    "    Y_pred = model.predict(X)\n",
    "    cm = tf.math.confusion_matrix(\n",
    "        labels=lb.inverse_transform(Y),\n",
    "        predictions=lb.inverse_transform(Y_pred),\n",
    "        num_classes=nlabels,\n",
    "    )\n",
    "    print(\"data entries\", m)\n",
    "    print(\n",
    "        \"Cost\",\n",
    "        results[0],\n",
    "        \"\\nCategoricalCrossentropy\",\n",
    "        results[1],\n",
    "        \"\\nCategoricalAccuracy\",\n",
    "        results[2],\n",
    "    )\n",
    "    print(\n",
    "        \"nCategoricalAccuracy from Confusion Matrix = \",\n",
    "        np.sum(np.diag(cm.numpy())) / m,\n",
    "    )\n",
    "    print(\"Confusion Matrix in %\\n\", cm / m * 100)\n",
    "\n",
    "\n",
    "print(\"\\n\\nmetrics on train data:\")\n",
    "print_results(X_train, Y_train)\n",
    "\n",
    "print(\"\\n\\nmetrics on val data:\")\n",
    "print_results(X_val, Y_val)\n",
    "\n",
    "print(\"\\n\\nmetrics on never seen test data:\")\n",
    "print_results(X_test, Y_test)\n",
    "# recall: the model should generalize well on never before seen data\n",
    "# so after hyper parameter tuning finding the best model, re-train this best\n",
    "# model to optimized state we can check with test data (X_test, Y_test), which\n",
    "# we never used in above training steps!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34a66cab-62fd-4554-a88b-56f77fcd0cf1",
   "metadata": {},
   "source": [
    "## Copyright\n",
    "\n",
    "- the notebooks are provided as [Open Educational Resources](https://en.wikipedia.org/wiki/Open_educational_resources)\n",
    "- the text is licensed under [Creative Commons Attribution 4.0](https://creativecommons.org/licenses/by/4.0/)\n",
    "- the code of the IPython examples is licensed under the [MIT license](https://opensource.org/licenses/MIT)\n",
    "- feel free to use the notebooks for your own purposes\n",
    "- please attribute the work as follows: *Frank Schultz, Data Driven Audio Signal Processing - A Tutorial Featuring Computational Examples, University of Rostock* ideally with relevant file(s), github URL https://github.com/spatialaudio/data-driven-audio-signal-processing-exercise, commit number and/or version tag, year."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myddasp",
   "language": "python",
   "name": "myddasp"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}