{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multi-Class Single-Label classification\n", "\n", "The natural extension of binary classification is a multi-class classification task.\n", "We first approach multi-class single-label classification, which makes the assumption that each example is assigned to one and only one label.\n", "\n", "We use the *Iris flower* data set, which consists of a classification into three mutually-exclusive classes; call these $A$, $B$ and $C$.\n", "\n", "While one could train three unary predicates $A(x)$, $B(x)$ and $C(x)$, it turns out to be more effective if this problem is modelled by a single binary predicate $P(x,l)$, where $l$ is a variable denoting a multi-class label, in this case classes $A$, $B$ or $C$.\n", "- This syntax allows one to write statements quantifying over the classes, e.g. $\\forall x ( \\exists l ( P(x,l)))$.\n", "- Since the classes are mutually-exclusive in this case, the output layer of the $\\mathtt{MLP}$ representing $P(x,l)$ will be a $\\mathtt{softmax}$ layer, instead of a $\\mathtt{sigmoid}$ function, to learn the probability of $A$, $B$ and $C$. This avoids writing additional constraints $\\lnot (A(x) \\land B(x))$, $\\lnot (A(x) \\land C(x))$, ..." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Init Plugin\n", "Init Graph Optimizer\n", "Init Kernel\n" ] } ], "source": [ "import logging; logging.basicConfig(level=logging.INFO)\n", "import tensorflow as tf\n", "import pandas as pd\n", "import ltn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data\n", "\n", "Load the iris dataset: 50 samples from each of three species of iris flowers (setosa, virginica, versicolor), measured with four features." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " sepal_length sepal_width petal_length petal_width species\n", "0 6.4 2.8 5.6 2.2 2\n", "1 5.0 2.3 3.3 1.0 1\n", "2 4.9 2.5 4.5 1.7 2\n", "3 4.9 3.1 1.5 0.1 0\n", "4 5.7 3.8 1.7 0.3 0\n" ] } ], "source": [ "df_train = pd.read_csv(\"iris_training.csv\")\n", "df_test = pd.read_csv(\"iris_test.csv\")\n", "print(df_train.head(5))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Metal device set to: Apple M1\n", "\n", "systemMemory: 16.00 GB\n", "maxCacheSize: 5.33 GB\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2021-08-30 14:38:15.642262: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.\n", "2021-08-30 14:38:15.642359: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: )\n" ] } ], "source": [ "labels_train = df_train.pop(\"species\")\n", "labels_test = df_test.pop(\"species\")\n", "batch_size = 64\n", "ds_train = tf.data.Dataset.from_tensor_slices((df_train,labels_train)).batch(batch_size)\n", "ds_test = tf.data.Dataset.from_tensor_slices((df_test,labels_test)).batch(batch_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# LTN\n", "\n", "Predicate with softmax `P(x,class)`" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class MLP(tf.keras.Model):\n", " \"\"\"Model that returns logits.\"\"\"\n", " def __init__(self, n_classes, hidden_layer_sizes=(16,16,8)):\n", " super(MLP, self).__init__()\n", " self.denses = [tf.keras.layers.Dense(s, activation=\"elu\") for s in hidden_layer_sizes]\n", " self.dense_class = tf.keras.layers.Dense(n_classes)\n", " self.dropout = tf.keras.layers.Dropout(0.2)\n", " \n", " def call(self, inputs, training=False):\n", " x = inputs[0]\n", " for dense in self.denses:\n", " x = dense(x)\n", " x = self.dropout(x, training=training)\n", " return self.dense_class(x)\n", "\n", "logits_model = MLP(3)\n", "p = ltn.Predicate.FromLogits(logits_model, activation_function=\"softmax\", with_class_indexing=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Constants to index/iterate on the classes" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class_A = ltn.Constant(0, trainable=False)\n", "class_B = ltn.Constant(1, trainable=False)\n", "class_C = ltn.Constant(2, trainable=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Operators and axioms" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "Not = ltn.Wrapper_Connective(ltn.fuzzy_ops.Not_Std())\n", "And = ltn.Wrapper_Connective(ltn.fuzzy_ops.And_Prod())\n", "Or = ltn.Wrapper_Connective(ltn.fuzzy_ops.Or_ProbSum())\n", "Implies = ltn.Wrapper_Connective(ltn.fuzzy_ops.Implies_Reichenbach())\n", "Forall = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMeanError(p=2),semantics=\"forall\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "formula_aggregator = ltn.Wrapper_Formula_Aggregator(ltn.fuzzy_ops.Aggreg_pMeanError(p=2))\n", "\n", "@tf.function\n", "def axioms(features, labels, training=False):\n", " x_A = ltn.Variable(\"x_A\",features[labels==0])\n", " x_B = ltn.Variable(\"x_B\",features[labels==1])\n", " x_C = ltn.Variable(\"x_C\",features[labels==2])\n", " axioms = [\n", " Forall(x_A,p([x_A,class_A],training=training)),\n", " Forall(x_B,p([x_B,class_B],training=training)),\n", " Forall(x_C,p([x_C,class_C],training=training))\n", " ]\n", " sat_level = formula_aggregator(axioms).tensor\n", " return sat_level" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initialize all layers and the static graph" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial sat level 0.25581\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2021-08-30 14:38:20.990753: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)\n", "2021-08-30 14:38:20.992807: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n", "2021-08-30 14:38:20.992905: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.\n" ] } ], "source": [ "for features, labels in ds_test:\n", " print(\"Initial sat level %.5f\"%axioms(features,labels))\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training\n", "\n", "Define the metrics. While training, we measure:\n", "1. The level of satisfiability of the Knowledge Base of the training data.\n", "1. The level of satisfiability of the Knowledge Base of the test data.\n", "3. The training accuracy.\n", "4. The test accuracy." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "metrics_dict = {\n", " 'train_sat_kb': tf.keras.metrics.Mean(name='train_sat_kb'),\n", " 'test_sat_kb': tf.keras.metrics.Mean(name='test_sat_kb'),\n", " 'train_accuracy': tf.keras.metrics.CategoricalAccuracy(name=\"train_accuracy\"),\n", " 'test_accuracy': tf.keras.metrics.CategoricalAccuracy(name=\"test_accuracy\")\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the training and test step" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)\n", "@tf.function\n", "def train_step(features, labels):\n", " # sat and update\n", " with tf.GradientTape() as tape:\n", " sat = axioms(features, labels, training=True)\n", " loss = 1.-sat\n", " gradients = tape.gradient(loss, p.trainable_variables)\n", " optimizer.apply_gradients(zip(gradients, p.trainable_variables))\n", " sat = axioms(features, labels) # compute sat without dropout\n", " metrics_dict['train_sat_kb'](sat)\n", " # accuracy\n", " predictions = logits_model([features])\n", " metrics_dict['train_accuracy'](tf.one_hot(labels,3),predictions)\n", " \n", "@tf.function\n", "def test_step(features, labels):\n", " # sat\n", " sat = axioms(features, labels)\n", " metrics_dict['test_sat_kb'](sat)\n", " # accuracy\n", " predictions = logits_model([features])\n", " metrics_dict['test_accuracy'](tf.one_hot(labels,3),predictions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2021-08-30 14:39:58.964336: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.\n", "2021-08-30 14:39:59.951405: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.\n", "2021-08-30 14:40:00.487437: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0, train_sat_kb: 0.2620, test_sat_kb: 0.2640, train_accuracy: 0.3000, test_accuracy: 0.4667\n", "Epoch 20, train_sat_kb: 0.4088, test_sat_kb: 0.4085, train_accuracy: 0.7333, test_accuracy: 0.5667\n", "Epoch 40, train_sat_kb: 0.5422, test_sat_kb: 0.5404, train_accuracy: 0.9417, test_accuracy: 0.9000\n", "Epoch 60, train_sat_kb: 0.6432, test_sat_kb: 0.6381, train_accuracy: 0.9417, test_accuracy: 0.9000\n", "Epoch 80, train_sat_kb: 0.7105, test_sat_kb: 0.7041, train_accuracy: 0.9583, test_accuracy: 0.9000\n", "Epoch 100, train_sat_kb: 0.7486, test_sat_kb: 0.7443, train_accuracy: 0.9667, test_accuracy: 0.9333\n", "Epoch 120, train_sat_kb: 0.7888, test_sat_kb: 0.7884, train_accuracy: 0.9667, test_accuracy: 0.9667\n", "Epoch 140, train_sat_kb: 0.8182, test_sat_kb: 0.8197, train_accuracy: 0.9750, test_accuracy: 0.9667\n", "Epoch 160, train_sat_kb: 0.8356, test_sat_kb: 0.8374, train_accuracy: 0.9750, test_accuracy: 1.0000\n", "Epoch 180, train_sat_kb: 0.8525, test_sat_kb: 0.8457, train_accuracy: 0.9750, test_accuracy: 0.9667\n", "Epoch 200, train_sat_kb: 0.8561, test_sat_kb: 0.8563, train_accuracy: 0.9833, test_accuracy: 0.9667\n", "Epoch 220, train_sat_kb: 0.8706, test_sat_kb: 0.8541, train_accuracy: 0.9833, test_accuracy: 0.9667\n", "Epoch 240, train_sat_kb: 0.8739, test_sat_kb: 0.8587, train_accuracy: 0.9833, test_accuracy: 0.9667\n", "Epoch 260, train_sat_kb: 0.8694, test_sat_kb: 0.8635, train_accuracy: 0.9750, test_accuracy: 0.9667\n", "Epoch 280, train_sat_kb: 0.8709, test_sat_kb: 0.8625, train_accuracy: 0.9750, test_accuracy: 0.9667\n", "Epoch 300, train_sat_kb: 0.8782, test_sat_kb: 0.8429, train_accuracy: 0.9833, test_accuracy: 0.9667\n", "Epoch 320, train_sat_kb: 0.8780, test_sat_kb: 0.8387, train_accuracy: 0.9833, test_accuracy: 0.9667\n", "Epoch 340, train_sat_kb: 0.8791, test_sat_kb: 0.8614, train_accuracy: 0.9750, test_accuracy: 0.9667\n", "Epoch 360, train_sat_kb: 0.8880, test_sat_kb: 0.8497, train_accuracy: 0.9833, test_accuracy: 0.9333\n", "Epoch 380, train_sat_kb: 0.8894, test_sat_kb: 0.8541, train_accuracy: 0.9750, test_accuracy: 0.9333\n", "Epoch 400, train_sat_kb: 0.8870, test_sat_kb: 0.8401, train_accuracy: 0.9917, test_accuracy: 0.9667\n", "Epoch 420, train_sat_kb: 0.8894, test_sat_kb: 0.8402, train_accuracy: 0.9917, test_accuracy: 0.9667\n", "Epoch 440, train_sat_kb: 0.8912, test_sat_kb: 0.8557, train_accuracy: 0.9750, test_accuracy: 0.9667\n", "Epoch 460, train_sat_kb: 0.8953, test_sat_kb: 0.8519, train_accuracy: 0.9750, test_accuracy: 0.9333\n", "Epoch 480, train_sat_kb: 0.8810, test_sat_kb: 0.8593, train_accuracy: 0.9750, test_accuracy: 0.9667\n" ] } ], "source": [ "import commons\n", "\n", "EPOCHS = 500\n", "\n", "commons.train(\n", " EPOCHS,\n", " metrics_dict,\n", " ds_train,\n", " ds_test,\n", " train_step,\n", " test_step,\n", " csv_path=\"iris_results.csv\",\n", " track_metrics=20\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "889985fd10eb245a43f2ae5f5aa0c555254f5b898fe16071f1c89d06fa8d76a2" }, "kernelspec": { "display_name": "Python 3.9.6 64-bit ('tf-py39': conda)", "language": "python", "name": "python396jvsc74a57bd0889985fd10eb245a43f2ae5f5aa0c555254f5b898fe16071f1c89d06fa8d76a2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 4 }