{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multi-Class Multi-Label classification\n", "\n", "We now turn to multi-label classification, whereby multiple labels can be assigned to each example. \n", "As a first example of the reach of LTNs, we shall see how the previous example can be extended naturally using LTN to account for multiple labels, not always a trivial extension for most ML algorithms. \n", "\n", "The standard approach to the multi-label problem is to provide explicit negative examples for each class. By contrast, LTN can use background knowledge to relate classes directly to each other, thus becoming a powerful tool in the case of the multi-label problem when typically the labelled data is scarce.\n", "\n", "We use the *Leptograpsus crabs* data set consisting of 200 examples of 5 morphological measurements of 50 crabs. The task is to classify the crabs according to their colour and sex. There are four labels: blue, orange, male and female. \n", "\n", "The colour labels are mutually-exclusive, and so are the labels for sex. LTN will be used to specify such information logically.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import logging; logging.basicConfig(level=logging.INFO)\n", "import tensorflow as tf\n", "import ltn\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data\n", "\n", "Crabs dataset from: http://www.stats.ox.ac.uk/pub/PRNN/\n", "\n", "The crabs data frame has 200 rows and 8 columns, describing 5 morphological measurements on 50 crabs each of two colour forms and both sexes, of the species Leptograpsus variegatus collected at Fremantle, W. Australia.\n", "\n", "- Multi-class: Male, Female, Blue, Orange.\n", "- Multi-label: Only Male-Female and Blue-Orange are mutually exclusive.\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " sp sex index FL RW CL CW BD\n", "28 B M 29 15.9 12.7 34.0 38.9 14.2\n", "11 B M 12 12.3 11.0 26.8 31.5 11.4\n", "72 B F 23 12.8 12.2 27.9 31.9 11.5\n", "190 O F 41 20.3 16.0 39.4 44.1 18.0\n", "70 B F 21 12.8 11.7 27.1 31.2 11.9\n" ] } ], "source": [ "df = pd.read_csv(\"crabs.dat\",sep=\" \", skipinitialspace=True)\n", "df = df.sample(frac=1) #shuffle\n", "print(df.head(5))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use 160 samples for training and 40 samples for testing." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Metal device set to: Apple M1\n", "\n", "systemMemory: 16.00 GB\n", "maxCacheSize: 5.33 GB\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/2v/rg9pvdc95_l3q1v2gdz6w_xh0000gn/T/ipykernel_58863/2152502988.py:6: FutureWarning: The behavior of `series[i:j]` with an integer-dtype index is deprecated. In a future version, this will be treated as *label-based* indexing, consistent with e.g. `series[i]` lookups. To retain the old behavior, use `series.iloc[i:j]`. To get the future behavior, use `series.loc[i:j]`.\n", " ds_train = tf.data.Dataset.from_tensor_slices((features[:160],labels_sex[:160],labels_color[:160])).batch(batch_size)\n", "2023-03-30 16:35:09.545751: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.\n", "2023-03-30 16:35:09.546820: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: )\n", "/var/folders/2v/rg9pvdc95_l3q1v2gdz6w_xh0000gn/T/ipykernel_58863/2152502988.py:7: FutureWarning: The behavior of `series[i:j]` with an integer-dtype index is deprecated. In a future version, this will be treated as *label-based* indexing, consistent with e.g. `series[i]` lookups. To retain the old behavior, use `series.iloc[i:j]`. To get the future behavior, use `series.loc[i:j]`.\n", " ds_test = tf.data.Dataset.from_tensor_slices((features[160:],labels_sex[160:],labels_color[160:])).batch(batch_size)\n" ] } ], "source": [ "features = df[['FL','RW','CL','CW','BD']]\n", "labels_sex = df['sex']\n", "labels_color = df['sp']\n", "\n", "batch_size=64\n", "ds_train = tf.data.Dataset.from_tensor_slices((features[:160],labels_sex[:160],labels_color[:160])).batch(batch_size)\n", "ds_test = tf.data.Dataset.from_tensor_slices((features[160:],labels_sex[160:],labels_color[160:])).batch(batch_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# LTN\n", "\n", "### Predicate\n", "\n", "| index | class | \n", "| --- | --- |\n", "| 0 | Male |\n", "| 1 | Female |\n", "| 2 | Blue |\n", "| 3 | Orange |\n", "\n", "Let's note that, since the classes are not mutually exclusive, the last layer of the model will be a `sigmoid` and not a `softmax`." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class MLP(tf.keras.Model):\n", " \"\"\"Model that returns logits.\"\"\"\n", " def __init__(self, n_classes, hidden_layer_sizes=(16,16,8)):\n", " super(MLP, self).__init__()\n", " self.denses = [tf.keras.layers.Dense(s, activation=\"elu\") for s in hidden_layer_sizes]\n", " self.dense_class = tf.keras.layers.Dense(n_classes)\n", " \n", " def call(self, inputs):\n", " x = inputs[0]\n", " for dense in self.denses:\n", " x = dense(x)\n", " return self.dense_class(x)\n", "\n", "logits_model = MLP(4)\n", "p = ltn.Predicate.FromLogits(logits_model, activation_function=\"sigmoid\", with_class_indexing=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Constants to index the classes" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class_male = ltn.Constant(0, trainable=False)\n", "class_female = ltn.Constant(1, trainable=False)\n", "class_blue = ltn.Constant(2, trainable=False)\n", "class_orange = ltn.Constant(3, trainable=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Axioms\n", "\n", "```\n", "forall x_blue: C(x_blue,blue)\n", "forall x_orange: C(x_orange,orange)\n", "forall x_male: C(x_male,male)\n", "forall x_female: C(x_female,female)\n", "forall x: ~(C(x,male) & C(x,female))\n", "forall x: ~(C(x,blue) & C(x,orange))\n", "```" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "Not = ltn.Wrapper_Connective(ltn.fuzzy_ops.Not_Std())\n", "And = ltn.Wrapper_Connective(ltn.fuzzy_ops.And_Prod())\n", "Or = ltn.Wrapper_Connective(ltn.fuzzy_ops.Or_ProbSum())\n", "Implies = ltn.Wrapper_Connective(ltn.fuzzy_ops.Implies_Reichenbach())\n", "Forall = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMeanError(p=2),semantics=\"forall\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "formula_aggregator = ltn.Wrapper_Formula_Aggregator(ltn.fuzzy_ops.Aggreg_pMeanError(p=2))\n", "\n", "@tf.function\n", "def axioms(features,labels_sex,labels_color):\n", " x = ltn.Variable(\"x\",features)\n", " x_blue = ltn.Variable(\"x_blue\",features[labels_color==\"B\"])\n", " x_orange = ltn.Variable(\"x_orange\",features[labels_color==\"O\"])\n", " x_male = ltn.Variable(\"x_blue\",features[labels_sex==\"M\"])\n", " x_female = ltn.Variable(\"x_blue\",features[labels_sex==\"F\"])\n", " axioms = [\n", " Forall(x_blue, p([x_blue,class_blue])),\n", " Forall(x_orange, p([x_orange,class_orange])),\n", " Forall(x_male, p([x_male,class_male])),\n", " Forall(x_female, p([x_female,class_female])),\n", " Forall(x,Not(And(p([x,class_blue]),p([x,class_orange])))),\n", " Forall(x,Not(And(p([x,class_male]),p([x,class_female]))))\n", " ]\n", " sat_level = formula_aggregator(axioms).tensor\n", " return sat_level" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initialize all layers and the static graph." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-03-30 16:35:18.622522: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n", "2023-03-30 16:35:18.624912: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Initial sat level 0.52923\n" ] } ], "source": [ "for features, labels_sex, labels_color in ds_train:\n", " print(\"Initial sat level %.5f\"%axioms(features,labels_sex,labels_color))\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training\n", "\n", "Define the metrics. While training, we measure:\n", " 1. The level of satisfiability of the Knowledge Base of the training data.\n", " 1. The level of satisfiability of the Knowledge Base of the test data.\n", " 3. The training accuracy.\n", " 4. The test accuracy.\n", " 5. The level of satisfiability of a formula $\\phi_1$ we expect to be true. \n", " `forall x (p(x,blue)->~p(x,orange))` (every blue crab cannot be orange and vice-versa)\n", " 6. The level of satisfiability of a formula $\\phi_2$ we expect to be false.\n", " `forall x (p(x,blue)->p(x,orange))` (every blue crab is also orange)\n", " 7. The level of satisfiability of a formula $\\phi_3$ we expect to be false. \n", " `forall x (p(x,blue)->p(x,male))` (every blue crab is male)\n", " \n", " \n", "For the last 3 queries, we use $p=5$ when approximating the universal quantifier. A higher $p$ denotes a stricter universal quantification with a stronger focus on outliers (see turorial on operators for more details).\n", "Training should usually not focus on outliers, as optimizers would struggle to generalize and tend to get stuck in local minima. However, when querying $\\phi_1$,$\\phi_2$,$\\phi_3$, we wish to be more careful about the interpretation of our statement." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "metrics_dict = {\n", " 'train_sat_kb': tf.keras.metrics.Mean(name='train_sat_kb'),\n", " 'test_sat_kb': tf.keras.metrics.Mean(name='test_sat_kb'),\n", " 'train_accuracy': tf.keras.metrics.Mean(name=\"train_accuracy\"),\n", " 'test_accuracy': tf.keras.metrics.Mean(name=\"test_accuracy\"),\n", " 'test_sat_phi1': tf.keras.metrics.Mean(name='test_sat_phi1'),\n", " 'test_sat_phi2': tf.keras.metrics.Mean(name='test_sat_phi2'),\n", " 'test_sat_phi3': tf.keras.metrics.Mean(name='test_sat_phi3')\n", "}\n", "\n", "@tf.function()\n", "def sat_phi1(features):\n", " x = ltn.Variable(\"x\",features)\n", " phi1 = Forall(x, Implies(p([x,class_blue]),Not(p([x,class_orange]))),p=5)\n", " return phi1.tensor\n", "\n", "@tf.function()\n", "def sat_phi2(features):\n", " x = ltn.Variable(\"x\",features)\n", " phi2 = Forall(x, Implies(p([x,class_blue]),p([x,class_orange])),p=5)\n", " return phi2.tensor\n", "\n", "@tf.function()\n", "def sat_phi3(features):\n", " x = ltn.Variable(\"x\",features)\n", " phi3 = Forall(x, Implies(p([x,class_blue]),p([x,class_male])),p=5)\n", " return phi3.tensor\n", "\n", "def multilabel_hamming_loss(y_true, y_pred, threshold=0.5,from_logits=False):\n", " if from_logits:\n", " y_pred = tf.math.sigmoid(y_pred)\n", " y_pred = y_pred > threshold\n", " y_true = tf.cast(y_true, tf.int32)\n", " y_pred = tf.cast(y_pred, tf.int32)\n", " nonzero = tf.cast(tf.math.count_nonzero(y_true-y_pred,axis=-1),tf.float32)\n", " return nonzero/y_true.get_shape()[-1]\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)\n", "@tf.function\n", "def train_step(features, labels_sex, labels_color):\n", " # sat and update\n", " with tf.GradientTape() as tape:\n", " sat = axioms(features, labels_sex, labels_color)\n", " loss = 1.-sat\n", " gradients = tape.gradient(loss, p.trainable_variables)\n", " optimizer.apply_gradients(zip(gradients, p.trainable_variables))\n", " metrics_dict['train_sat_kb'](sat)\n", " # accuracy\n", " predictions = logits_model([features])\n", " labels_male = (labels_sex == \"M\")\n", " labels_female = (labels_sex == \"F\")\n", " labels_blue = (labels_color == \"B\")\n", " labels_orange = (labels_color == \"O\")\n", " onehot = tf.stack([labels_male,labels_female,labels_blue,labels_orange],axis=-1)\n", " metrics_dict['train_accuracy'](1-multilabel_hamming_loss(onehot,predictions,from_logits=True))\n", " \n", "@tf.function\n", "def test_step(features, labels_sex, labels_color):\n", " # sat\n", " sat_kb = axioms(features, labels_sex, labels_color)\n", " metrics_dict['test_sat_kb'](sat_kb)\n", " metrics_dict['test_sat_phi1'](sat_phi1(features))\n", " metrics_dict['test_sat_phi2'](sat_phi2(features))\n", " metrics_dict['test_sat_phi3'](sat_phi3(features))\n", " # accuracy\n", " predictions = logits_model([features])\n", " labels_male = (labels_sex == \"M\")\n", " labels_female = (labels_sex == \"F\")\n", " labels_blue = (labels_color == \"B\")\n", " labels_orange = (labels_color == \"O\")\n", " onehot = tf.stack([labels_male,labels_female,labels_blue,labels_orange],axis=-1)\n", " metrics_dict['test_accuracy'](1-multilabel_hamming_loss(onehot,predictions,from_logits=True))\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-03-30 16:35:37.988464: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", "2023-03-30 16:35:39.606891: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", "2023-03-30 16:35:41.016587: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0, train_sat_kb: 0.5375, test_sat_kb: 0.5510, train_accuracy: 0.4875, test_accuracy: 0.5000, test_sat_phi1: 0.7549, test_sat_phi2: 0.2181, test_sat_phi3: 0.5221\n", "Epoch 20, train_sat_kb: 0.6129, test_sat_kb: 0.6166, train_accuracy: 0.5641, test_accuracy: 0.5750, test_sat_phi1: 0.5040, test_sat_phi2: 0.5107, test_sat_phi3: 0.6251\n", "Epoch 40, train_sat_kb: 0.6338, test_sat_kb: 0.6346, train_accuracy: 0.5063, test_accuracy: 0.5063, test_sat_phi1: 0.5404, test_sat_phi2: 0.6947, test_sat_phi3: 0.7039\n", "Epoch 60, train_sat_kb: 0.6866, test_sat_kb: 0.6897, train_accuracy: 0.7312, test_accuracy: 0.7375, test_sat_phi1: 0.5891, test_sat_phi2: 0.5913, test_sat_phi3: 0.5883\n", "Epoch 80, train_sat_kb: 0.7761, test_sat_kb: 0.7689, train_accuracy: 0.9250, test_accuracy: 0.9187, test_sat_phi1: 0.6952, test_sat_phi2: 0.4343, test_sat_phi3: 0.4886\n", "Epoch 100, train_sat_kb: 0.8385, test_sat_kb: 0.8230, train_accuracy: 0.9656, test_accuracy: 0.9500, test_sat_phi1: 0.7713, test_sat_phi2: 0.3052, test_sat_phi3: 0.4203\n", "Epoch 120, train_sat_kb: 0.8728, test_sat_kb: 0.8543, train_accuracy: 0.9750, test_accuracy: 0.9500, test_sat_phi1: 0.8096, test_sat_phi2: 0.2400, test_sat_phi3: 0.3618\n", "Epoch 140, train_sat_kb: 0.8915, test_sat_kb: 0.8675, train_accuracy: 0.9797, test_accuracy: 0.9563, test_sat_phi1: 0.8240, test_sat_phi2: 0.2111, test_sat_phi3: 0.3272\n", "Epoch 160, train_sat_kb: 0.9037, test_sat_kb: 0.8731, train_accuracy: 0.9812, test_accuracy: 0.9625, test_sat_phi1: 0.8311, test_sat_phi2: 0.1953, test_sat_phi3: 0.3067\n", "Epoch 180, train_sat_kb: 0.9130, test_sat_kb: 0.8758, train_accuracy: 0.9875, test_accuracy: 0.9688, test_sat_phi1: 0.8361, test_sat_phi2: 0.1851, test_sat_phi3: 0.2934\n" ] } ], "source": [ "import commons\n", "\n", "EPOCHS = 200\n", "\n", "commons.train(\n", " EPOCHS,\n", " metrics_dict,\n", " ds_train,\n", " ds_test,\n", " train_step,\n", " test_step,\n", " csv_path=\"crabs_results.csv\",\n", " track_metrics=20\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "889985fd10eb245a43f2ae5f5aa0c555254f5b898fe16071f1c89d06fa8d76a2" }, "kernelspec": { "display_name": "ltn", "language": "python", "name": "ltn" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 }