{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MNIST Digit Addition Problem\n", "\n", "Consider a task where one needs to learn a classifier $\\mathtt{addition(X,Y,N)}$ where $\\mathtt{X}$ and $\\mathtt{Y}$ are images of digits (the MNIST data set will be used), and $\\mathtt{N}$ is a natural number corresponding to the sum of these digits. The classifier should return an estimate of the validity of the addition ($0$ is invalid, $1$ is valid). \n", "\n", "For instance, if $\\mathtt{X}$ is an image of a 0 and $\\mathtt{Y}$ is an image of a 9:\n", "- if $\\mathtt{N} = 9$, then the addition is valid; \n", "- if $\\mathtt{N} = 4$, then the addition is not valid. \n", "\n", "A natural approach is to seek to first 1) learn a single digit classifier, then 2) benefit from knowledge readily available about the properties of addition.\n", "For instance, suppose that a predicate $\\mathrm{digit}(x,d)$ gives the likelihood of an image $x$ being of digit $d$, one could query with LTN: \n", "$$\n", "\\exists d_1,d_2 : d_1+d_2= \\mathtt{N} \\ (\\mathrm{digit}(\\mathtt{X},d_1)\\land \\mathrm{digit}(\\mathtt{Y},d_2))\n", "$$\n", "and use the satisfaction of this query as the output of $\\mathtt{addition(X,Y,N)}$ .\n", "\n", "\n", "The challenge is the following:\n", "- We provide, in the data, pairs of images $\\mathtt{X}$, $\\mathtt{Y}$ and the result of the addition $\\mathtt{N}$ (final label),\n", "- We do **not** provide the intermediate labels, the correct digits for $d_1$, $d_2$.\n", "\n", "Regardless, it is possible to use the equation above as background knowledge to train $\\mathrm{digit}$ with LTN.\n", "In contrast, a standard neural network baseline cannot incorporate such intermediate components as nicely." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import ltn\n", "import baselines, data, commons\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data\n", "\n", "Dataset of images for the digits X and Y, and their label Z s.t. X+Y=Z." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Metal device set to: Apple M1\n", "\n", "systemMemory: 16.00 GB\n", "maxCacheSize: 5.33 GB\n", "\n", "Result label is 2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-03-30 16:40:28.884429: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.\n", "2023-03-30 16:40:28.884615: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: )\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ds_train, ds_test = data.get_mnist_op_dataset(\n", " count_train=3000,\n", " count_test=1000,\n", " buffer_size=3000,\n", " batch_size=16,\n", " n_operands=2,\n", " op=lambda args: args[0]+args[1])\n", "\n", "# Visualize one example\n", "x, y, z = next(ds_train.as_numpy_iterator())\n", "plt.subplot(121)\n", "plt.imshow(x[0][:,:,0])\n", "plt.subplot(122)\n", "plt.imshow(y[0][:,:,0])\n", "print(\"Result label is %i\" % z[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LTN" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "logits_model = baselines.SingleDigit(inputs_as_a_list=True)\n", "Digit = ltn.Predicate.FromLogits(logits_model, activation_function=\"softmax\")\n", "\n", "d1 = ltn.Variable(\"digits1\", range(10))\n", "d2 = ltn.Variable(\"digits2\", range(10))\n", "\n", "Not = ltn.Wrapper_Connective(ltn.fuzzy_ops.Not_Std())\n", "And = ltn.Wrapper_Connective(ltn.fuzzy_ops.And_Prod())\n", "Or = ltn.Wrapper_Connective(ltn.fuzzy_ops.Or_ProbSum())\n", "Implies = ltn.Wrapper_Connective(ltn.fuzzy_ops.Implies_Reichenbach())\n", "Forall = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMeanError(),semantics=\"forall\")\n", "Exists = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMean(),semantics=\"exists\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice the use of `Diag`: when grounding $x$,$y$,$n$ with three sequences of values, the $i$-th examples of each variable are matching. \n", "That is, `(images_x[i],images_y[i],labels[i])` is a tuple from our dataset of valid additions.\n", "Using the diagonal quantification, LTN aggregates pairs of images and their corresponding result, rather than any combination of images and results. \n", " \n", "Notice also the guarded quantification: by quantifying only on the \"intermediate labels\" (not given during training) that could add up to the result label (given during training), we incorporate symbolic information into the system." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-03-30 16:40:38.033631: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n", "2023-03-30 16:40:38.035507: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# mask\n", "add = ltn.Function.Lambda(lambda inputs: inputs[0]+inputs[1])\n", "equals = ltn.Predicate.Lambda(lambda inputs: inputs[0] == inputs[1])\n", "\n", "### Axioms\n", "@tf.function\n", "def axioms(images_x, images_y, labels_z, p_schedule=tf.constant(2.)):\n", " images_x = ltn.Variable(\"x\", images_x)\n", " images_y = ltn.Variable(\"y\", images_y)\n", " labels_z = ltn.Variable(\"z\", labels_z)\n", " axiom = Forall(\n", " ltn.diag(images_x,images_y,labels_z),\n", " Exists(\n", " (d1,d2),\n", " And(Digit([images_x,d1]),Digit([images_y,d2])),\n", " mask=equals([add([d1,d2]), labels_z]),\n", " p=p_schedule\n", " ),\n", " p=2\n", " )\n", " sat = axiom.tensor\n", " return sat\n", "\n", "images_x, images_y, labels_z = next(ds_train.as_numpy_iterator())\n", "axioms(images_x, images_y, labels_z)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Optimizer, training steps and metrics" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam(0.001)\n", "metrics_dict = {\n", " 'train_loss': tf.keras.metrics.Mean(name=\"train_loss\"),\n", " 'train_accuracy': tf.keras.metrics.Mean(name=\"train_accuracy\"),\n", " 'test_loss': tf.keras.metrics.Mean(name=\"test_loss\"),\n", " 'test_accuracy': tf.keras.metrics.Mean(name=\"test_accuracy\") \n", "}\n", "\n", "@tf.function\n", "def train_step(images_x, images_y, labels_z, **parameters):\n", " # loss\n", " with tf.GradientTape() as tape:\n", " loss = 1.- axioms(images_x, images_y, labels_z, **parameters)\n", " gradients = tape.gradient(loss, logits_model.trainable_variables)\n", " optimizer.apply_gradients(zip(gradients, logits_model.trainable_variables))\n", " metrics_dict['train_loss'](loss)\n", " # accuracy\n", " predictions_x = tf.argmax(logits_model([images_x]),axis=-1)\n", " predictions_y = tf.argmax(logits_model([images_y]),axis=-1)\n", " predictions_z = predictions_x + predictions_y\n", " match = tf.equal(predictions_z,tf.cast(labels_z,predictions_z.dtype))\n", " metrics_dict['train_accuracy'](tf.reduce_mean(tf.cast(match,tf.float32)))\n", " \n", "@tf.function\n", "def test_step(images_x, images_y, labels_z, **parameters):\n", " # loss\n", " loss = 1.- axioms(images_x, images_y, labels_z, **parameters)\n", " metrics_dict['test_loss'](loss)\n", " # accuracy\n", " predictions_x = tf.argmax(logits_model([images_x]),axis=-1)\n", " predictions_y = tf.argmax(logits_model([images_y]),axis=-1)\n", " predictions_z = predictions_x + predictions_y\n", " match = tf.equal(predictions_z,tf.cast(labels_z,predictions_z.dtype))\n", " metrics_dict['test_accuracy'](tf.reduce_mean(tf.cast(match,tf.float32)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "\n", "scheduled_parameters = defaultdict(lambda: {})\n", "for epoch in range(0,4):\n", " scheduled_parameters[epoch] = {\"p_schedule\":tf.constant(1.)}\n", "for epoch in range(4,8):\n", " scheduled_parameters[epoch] = {\"p_schedule\":tf.constant(2.)}\n", "for epoch in range(8,12):\n", " scheduled_parameters[epoch] = {\"p_schedule\":tf.constant(4.)}\n", "for epoch in range(12,20):\n", " scheduled_parameters[epoch] = {\"p_schedule\":tf.constant(6.)}\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-03-30 16:40:54.797507: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", "2023-03-30 16:41:06.041586: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", "2023-03-30 16:41:06.798203: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", "2023-03-30 16:41:08.252734: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0, train_loss: 0.9356, train_accuracy: 0.3793, test_loss: 0.8768, test_accuracy: 0.6567\n", "Epoch 1, train_loss: 0.8566, train_accuracy: 0.8321, test_loss: 0.8480, test_accuracy: 0.8403\n", "Epoch 2, train_loss: 0.8443, train_accuracy: 0.8989, test_loss: 0.8430, test_accuracy: 0.8562\n", "Epoch 3, train_loss: 0.8383, train_accuracy: 0.9249, test_loss: 0.8442, test_accuracy: 0.8492\n", "Epoch 4, train_loss: 0.6416, train_accuracy: 0.9279, test_loss: 0.6523, test_accuracy: 0.8730\n", "Epoch 5, train_loss: 0.6285, train_accuracy: 0.9458, test_loss: 0.6476, test_accuracy: 0.8859\n", "Epoch 6, train_loss: 0.6237, train_accuracy: 0.9525, test_loss: 0.6345, test_accuracy: 0.9077\n", "Epoch 7, train_loss: 0.6179, train_accuracy: 0.9624, test_loss: 0.6329, test_accuracy: 0.9147\n", "Epoch 8, train_loss: 0.4284, train_accuracy: 0.9525, test_loss: 0.4674, test_accuracy: 0.8889\n", "Epoch 9, train_loss: 0.4167, train_accuracy: 0.9618, test_loss: 0.4652, test_accuracy: 0.8929\n", "Epoch 10, train_loss: 0.4127, train_accuracy: 0.9651, test_loss: 0.4669, test_accuracy: 0.8869\n", "Epoch 11, train_loss: 0.4054, train_accuracy: 0.9697, test_loss: 0.4590, test_accuracy: 0.9028\n", "Epoch 12, train_loss: 0.3186, train_accuracy: 0.9688, test_loss: 0.3705, test_accuracy: 0.9147\n", "Epoch 13, train_loss: 0.3156, train_accuracy: 0.9697, test_loss: 0.3740, test_accuracy: 0.9167\n", "Epoch 14, train_loss: 0.3204, train_accuracy: 0.9664, test_loss: 0.4056, test_accuracy: 0.8869\n", "Epoch 15, train_loss: 0.3228, train_accuracy: 0.9661, test_loss: 0.3548, test_accuracy: 0.9286\n", "Epoch 16, train_loss: 0.3112, train_accuracy: 0.9731, test_loss: 0.3741, test_accuracy: 0.9127\n", "Epoch 17, train_loss: 0.3071, train_accuracy: 0.9747, test_loss: 0.3577, test_accuracy: 0.9276\n", "Epoch 18, train_loss: 0.3042, train_accuracy: 0.9771, test_loss: 0.3682, test_accuracy: 0.9167\n", "Epoch 19, train_loss: 0.2983, train_accuracy: 0.9801, test_loss: 0.3588, test_accuracy: 0.9246\n" ] } ], "source": [ "commons.train(\n", " 20,\n", " metrics_dict,\n", " ds_train,\n", " ds_test,\n", " train_step,\n", " test_step,\n", " scheduled_parameters=scheduled_parameters\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "12eaedf9b9a64329743e8900a3192e3d75dbaaa78715534825922e4a4f7d9137" }, "kernelspec": { "display_name": "ltn", "language": "python", "name": "ltn" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 }