{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MNIST Digit Addition Problem\n", "\n", "Consider a task where one needs to learn a classifier $\\mathtt{addition(X,Y,N)}$ where $\\mathtt{X}$ and $\\mathtt{Y}$ are images of digits (the MNIST data set will be used), and $\\mathtt{N}$ is a natural number corresponding to the sum of these digits. The classifier should return an estimate of the validity of the addition ($0$ is invalid, $1$ is valid). \n", "\n", "For instance, if $\\mathtt{X}$ is an image of a 0 and $\\mathtt{Y}$ is an image of a 9:\n", "- if $\\mathtt{N} = 9$, then the addition is valid; \n", "- if $\\mathtt{N} = 4$, then the addition is not valid. \n", "\n", "A natural approach is to seek to first 1) learn a single digit classifier, then 2) benefit from knowledge readily available about the properties of addition.\n", "For instance, suppose that a predicate $\\mathrm{digit}(x,d)$ gives the likelihood of an image $x$ being of digit $d$, one could query with LTN: \n", "$$\n", "\\exists d_1,d_2 : d_1+d_2= \\mathtt{N} \\ (\\mathrm{digit}(\\mathtt{X},d_1)\\land \\mathrm{digit}(\\mathtt{Y},d_2))\n", "$$\n", "and use the satisfaction of this query as the output of $\\mathtt{addition(X,Y,N)}$ .\n", "\n", "\n", "The challenge is the following:\n", "- We provide, in the data, pairs of images $\\mathtt{X}$, $\\mathtt{Y}$ and the result of the addition $\\mathtt{N}$ (final label),\n", "- We do **not** provide the intermediate labels, the correct digits for $d_1$, $d_2$.\n", "\n", "Regardless, it is possible to use the equation above as background knowledge to train $\\mathrm{digit}$ with LTN.\n", "In contrast, a standard neural network baseline cannot incorporate such intermediate components as nicely." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import ltn\n", "import baselines, data, commons\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data\n", "\n", "Dataset of images for the digits X and Y, and their label Z s.t. X+Y=Z." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Metal device set to: Apple M1\n", "\n", "systemMemory: 16.00 GB\n", "maxCacheSize: 5.33 GB\n", "\n", "Result label is 2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-03-30 16:40:28.884429: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.\n", "2023-03-30 16:40:28.884615: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: )\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAEOCAYAAAApP3VyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYFklEQVR4nO3df3DU9Z3H8dcGyBIg2RgwWfYINqJCRwr2EGLEMiA5QpxhADPWH20PnF5taWAGch3bzClW60wKvWupmsI/HdCZIkgrMHA2HRpMGGoSjwjDMNYc4VKJB4mKJhsCWWL2c394brMSvskmm8/uJs/HzHfG/b6/u/v2y/Dmlc9+812XMcYIAADAkqRYNwAAAEYXwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAqrGxbuDLgsGgLly4oNTUVLlcrli3A4xKxhh1dHTI5/MpKSkxfkZhdgCxFdHcMMPkpZdeMrfccotxu91mwYIFpq6ubkDPa25uNpLY2NjiYGtubh6uEdGnwc4NY5gdbGzxsg1kbgzLysfevXtVUlKiHTt2KDc3V9u2bVNBQYEaGhqUmZnp+NzU1FRJ0n16QGM1bjjaA9CPz9St43oj9PfRhqHMDYnZAcRaJHPDZUz0v1guNzdX8+fP10svvSTp8+XQ7OxsbdiwQT/5yU8cn+v3++XxeLRYKzXWxQABYuEz060qHVR7e7vS0tKsvOdQ5obE7ABiLZK5EfUPc69du6b6+nrl5+f//U2SkpSfn6+amprrjg8EAvL7/WEbgNEl0rkhMTuARBb18PHxxx+rp6dHWVlZYfuzsrLU0tJy3fFlZWXyeDyhLTs7O9otAYhzkc4NidkBJLKYX8ZeWlqq9vb20Nbc3BzrlgAkAGYHkLiifsHplClTNGbMGLW2tobtb21tldfrve54t9stt9sd7TYAJJBI54bE7AASWdRXPpKTkzVv3jxVVlaG9gWDQVVWViovLy/abwdgBGBuAKPLsPyqbUlJidasWaO7775bCxYs0LZt29TZ2anHH398ON4OwAjA3ABGj2EJHw8//LA++ugjbd68WS0tLbrrrrtUUVFx3cVkAPAF5gYwegzLfT6Ggt/VB2IvFvf5GCpmBxBbMb3PBwAAgBPCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAq8bGugGgL01leY714C1XHeszHjsVxW4A2HL21/c41v/noR2O9ZyDTzjW71j3dsQ9IfqivvLx05/+VC6XK2ybNWtWtN8GwAjC3ABGl2FZ+bjzzjv15z//+e9vMpYFFgDOmBvA6DEsf7vHjh0rr9c7HC8NYIRibgCjx7BccHr27Fn5fD7deuut+ta3vqXz58/f8NhAICC/3x+2ARh9IpkbErMDSGRRDx+5ubnatWuXKioqtH37djU1Nekb3/iGOjo6+jy+rKxMHo8ntGVnZ0e7JQBxLtK5ITE7gEQW9fBRWFiohx56SHPmzFFBQYHeeOMNtbW16bXXXuvz+NLSUrW3t4e25ubmaLcEIM5FOjckZgeQyIb9iq709HTdcccdamxs7LPudrvldruHuw0ACaS/uSExO4BENuzh4/Llyzp37py+853vDPdbIYG4+vlHY1n+O471C1fTHOudEXeEeMLcGL0euPekY73HBB3r37zH+T4epyJtCMMi6h+7/OhHP1J1dbX+9re/6a233tLq1as1ZswYPfroo9F+KwAjBHMDGF2ivvLxwQcf6NFHH9WlS5d0880367777lNtba1uvvnmaL8VgBGCuQGMLlEPH3v27In2SwIY4ZgbwOjCF8sBAACrCB8AAMAqwgcAALCK8AEAAKziayMRE2N8zl8g9mvffsf6Q+cKotkOAEvG3DnTsb7IU2GpE8QSKx8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAq7jJGBLS6eZpjvUZ+shSJwAi8emcmxzrRRM/tdQJYomVDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWcZ8PxETrUp9jfYzLORffvuWqYz0YcUcARoI/vPt1x/oMnbTUCZyw8gEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKu7zgZj4ZK7znTh6jHM9ePq9aLYDYITIPOSOdQsYgIhXPo4dO6YVK1bI5/PJ5XLpwIEDYXVjjDZv3qypU6cqJSVF+fn5Onv2bLT6BZCAmBsAeos4fHR2dmru3LkqLy/vs75161a98MIL2rFjh+rq6jRx4kQVFBSoq6tryM0CSEzMDQC9RfyxS2FhoQoLC/usGWO0bds2PfXUU1q5cqUk6ZVXXlFWVpYOHDigRx555LrnBAIBBQKB0GO/3x9pSwDiXLTnhsTsABJZVC84bWpqUktLi/Lz80P7PB6PcnNzVVNT0+dzysrK5PF4Qlt2dnY0WwIQ5wYzNyRmB5DIoho+WlpaJElZWVlh+7OyskK1LystLVV7e3toa25ujmZLAOLcYOaGxOwAElnMf9vF7XbL7ebqZACRYXYAiSuqKx9er1eS1NraGra/tbU1VAOA3pgbwOgT1ZWPnJwceb1eVVZW6q677pL0+UVgdXV1WrduXTTfCnHOdfdsx/pfVv6HYz3v1OOO9ZvEr2GOFMwNYPSJOHxcvnxZjY2NocdNTU06deqUMjIyNH36dG3cuFHPP/+8br/9duXk5Ojpp5+Wz+fTqlWrotk3gATC3ADQW8Th48SJE1qyZEnocUlJiSRpzZo12rVrl5588kl1dnbqiSeeUFtbm+677z5VVFRo/Pjx0esaQEJhbgDoLeLwsXjxYhljblh3uVx67rnn9Nxzzw2pMQAjB3MDQG98sRwAALCK8AEAAKwifAAAAKsIHwAAwKqY3+EUI1NPyjjHeuaYCY71cS9P7ucduM8HMBo9/7HzPYTSj/y3Y70nms1g0Fj5AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV9/nAsPjwX7sc638JOOfeSa/VRrMdACPEhYDHsd5z6RNLnWAoWPkAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBX3+cCguO6e7Vj/z3/c4Vj/l8aH+3mH/42wIwCJwJ/Dz7xg5QMAAFhG+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVdznA4PinzHJsZ6aNMax3vXvPse6m/t8ACPSyoeOx7oFxIGIVz6OHTumFStWyOfzyeVy6cCBA2H1tWvXyuVyhW3Lly+PVr8AEhBzA0BvEYePzs5OzZ07V+Xl5Tc8Zvny5bp48WJoe/XVV4fUJIDExtwA0FvEH7sUFhaqsLDQ8Ri32y2v1zvopgCMLMwNAL0NywWnVVVVyszM1MyZM7Vu3TpdunTphscGAgH5/f6wDcDoE8nckJgdQCKLevhYvny5XnnlFVVWVmrLli2qrq5WYWGhenp6+jy+rKxMHo8ntGVnZ0e7JQBxLtK5ITE7gEQW9d92eeSRR0L//bWvfU1z5szRjBkzVFVVpaVLl153fGlpqUpKSkKP/X4/QwQYZSKdGxKzA0hkw36fj1tvvVVTpkxRY2Njn3W32620tLSwDcDo1t/ckJgdQCIb9vt8fPDBB7p06ZKmTp063G8Fiy4uCTrWX/zk64519xv/Fc12MMIwNxJXoHC+Y31txrZ+XiHFsfrW751ni09v9fP6iAcRh4/Lly+H/TTS1NSkU6dOKSMjQxkZGXr22WdVVFQkr9erc+fO6cknn9Rtt92mgoKCqDYOIHEwNwD0FnH4OHHihJYsWRJ6/MVnrmvWrNH27dt1+vRpvfzyy2pra5PP59OyZcv0s5/9TG63O3pdA0gozA0AvUUcPhYvXixjzA3rf/rTn4bUEICRh7kBoDe+WA4AAFhF+AAAAFYRPgAAgFWEDwAAYNWw3+cDI9OW+/c61hu7+IIwYDS6kuX8z8qMsc738ehPeuONb7mPxMHKBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACruM8HBqVo4qeO9ZkHH3Osz1BNNNsBECeSLwcd65eCVx3rk5OGdh8QJAZWPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYxX0+0Ke2f87r54h3HKu37e1wrJsI+wGQGCb+vs6xvu3fnGfLzzJPRbEbxCtWPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYxX0+0KdAusuxXtU1zrGe9H6rY70n4o4AACNFRCsfZWVlmj9/vlJTU5WZmalVq1apoaEh7Jiuri4VFxdr8uTJmjRpkoqKitTa6vwPEYCRjdkBoLeIwkd1dbWKi4tVW1urI0eOqLu7W8uWLVNnZ2fomE2bNunQoUPat2+fqqurdeHCBT344INRbxxA4mB2AOgtoo9dKioqwh7v2rVLmZmZqq+v16JFi9Te3q7f/va32r17t+6//35J0s6dO/XVr35VtbW1uueee6LXOYCEwewA0NuQLjhtb2+XJGVkZEiS6uvr1d3drfz8/NAxs2bN0vTp01VTU9PnawQCAfn9/rANwMjG7ABGt0GHj2AwqI0bN2rhwoWaPXu2JKmlpUXJyclKT08POzYrK0stLS19vk5ZWZk8Hk9oy87OHmxLABIAswPAoMNHcXGxzpw5oz179gypgdLSUrW3t4e25ubmIb0egPjG7AAwqF+1Xb9+vQ4fPqxjx45p2rRpof1er1fXrl1TW1tb2E8wra2t8nq9fb6W2+2W2+0eTBsAEgyzA4AUYfgwxmjDhg3av3+/qqqqlJOTE1afN2+exo0bp8rKShUVFUmSGhoadP78eeXl5UWvawy7K1ONY337hSWO9Z6PPopmO0hwzI7RI2nCBMf6hDGfWOoE8Syi8FFcXKzdu3fr4MGDSk1NDX0W6/F4lJKSIo/Ho+9+97sqKSlRRkaG0tLStGHDBuXl5XG1OjCKMTsA9BZR+Ni+fbskafHixWH7d+7cqbVr10qSfvWrXykpKUlFRUUKBAIqKCjQb37zm6g0CyAxMTsA9Bbxxy79GT9+vMrLy1VeXj7opgCMLMwOAL3xxXIAAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwKpB3eEUI9/shY2O9aBxWeoEQCK59M25jvXSyfw2E1j5AAAAlhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV9/nAoHwamOBYT7bUB4D4MvlUu2P9yNUUx/o/pVyNZjuIU6x8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK+3xgUD499A+O9Sy9b6kTAPEkeOpdx/q6I2sc68/f/wfHempDm2O9x7GKeMHKBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrIrrPR1lZmV5//XW99957SklJ0b333qstW7Zo5syZoWMWL16s6urqsOd9//vf144dO6LTMazoXPSRYz1LznWgN2YHvnDHurcd668ou59XaIheM4iZiFY+qqurVVxcrNraWh05ckTd3d1atmyZOjs7w4773ve+p4sXL4a2rVu3RrVpAImF2QGgt4hWPioqKsIe79q1S5mZmaqvr9eiRYtC+ydMmCCv1xudDgEkPGYHgN6GdM1He3u7JCkjIyNs/+9+9ztNmTJFs2fPVmlpqa5cuXLD1wgEAvL7/WEbgJGN2QGMboP+bpdgMKiNGzdq4cKFmj17dmj/Y489pltuuUU+n0+nT5/Wj3/8YzU0NOj111/v83XKysr07LPPDrYNAAmG2QHAZYwxg3niunXr9Mc//lHHjx/XtGnTbnjc0aNHtXTpUjU2NmrGjBnX1QOBgAKBQOix3+9Xdna2FmulxrrGDaY1AEP0melWlQ6qvb1daWlpUX1tZgcwMkUyNwa18rF+/XodPnxYx44dcxwekpSbmytJNxwgbrdbbrd7MG0ASDDMDgBShOHDGKMNGzZo//79qqqqUk5OTr/POXXqlCRp6tSpg2oQQOJjdgDoLaLwUVxcrN27d+vgwYNKTU1VS0uLJMnj8SglJUXnzp3T7t279cADD2jy5Mk6ffq0Nm3apEWLFmnOnDnD8j8AIP4xOwD0FtE1Hy6Xq8/9O3fu1Nq1a9Xc3Kxvf/vbOnPmjDo7O5Wdna3Vq1frqaeeGvDnxn6/Xx6Ph89tgRiK9jUfzA5g5Bu2az76yynZ2dnX3aEQAJgdAHrju10AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWRfTFcjZ88QVUn6lbGvD37QKIps/ULan/L4SLJ8wOILYimRtxFz46OjokScf1Row7AdDR0SGPxxPrNgaE2QHEh4HMDZeJsx9tgsGgLly4oNTUVLlcLvn9fmVnZ6u5uVlpaWmxbi8hcQ6HZjSeP2OMOjo65PP5lJSUGJ/OMjuii/M3dKPtHEYyN+Ju5SMpKUnTpk27bn9aWtqo+MMbTpzDoRlt5y9RVjy+wOwYHpy/oRtN53CgcyMxfqQBAAAjBuEDAABYFffhw+1265lnnpHb7Y51KwmLczg0nL/ExJ/b0HD+ho5zeGNxd8EpAAAY2eJ+5QMAAIwshA8AAGAV4QMAAFhF+AAAAFYRPgAAgFVxHz7Ky8v1la98RePHj1dubq7efvvtWLcUt44dO6YVK1bI5/PJ5XLpwIEDYXVjjDZv3qypU6cqJSVF+fn5Onv2bGyajUNlZWWaP3++UlNTlZmZqVWrVqmhoSHsmK6uLhUXF2vy5MmaNGmSioqK1NraGqOOcSPMjYFjbgwNc2Nw4jp87N27VyUlJXrmmWf0zjvvaO7cuSooKNCHH34Y69biUmdnp+bOnavy8vI+61u3btULL7ygHTt2qK6uThMnTlRBQYG6urosdxqfqqurVVxcrNraWh05ckTd3d1atmyZOjs7Q8ds2rRJhw4d0r59+1RdXa0LFy7owQcfjGHX+DLmRmSYG0PD3BgkE8cWLFhgiouLQ497enqMz+czZWVlMewqMUgy+/fvDz0OBoPG6/WaX/ziF6F9bW1txu12m1dffTUGHca/Dz/80Egy1dXVxpjPz9e4cePMvn37Qsf89a9/NZJMTU1NrNrElzA3Bo+5MXTMjYGJ25WPa9euqb6+Xvn5+aF9SUlJys/PV01NTQw7S0xNTU1qaWkJO58ej0e5ubmczxtob2+XJGVkZEiS6uvr1d3dHXYOZ82apenTp3MO4wRzI7qYG5FjbgxM3IaPjz/+WD09PcrKygrbn5WVpZaWlhh1lbi+OGecz4EJBoPauHGjFi5cqNmzZ0v6/BwmJycrPT097FjOYfxgbkQXcyMyzI2BGxvrBoB4VFxcrDNnzuj48eOxbgVAgmBuDFzcrnxMmTJFY8aMue6K4NbWVnm93hh1lbi+OGecz/6tX79ehw8f1ptvvqlp06aF9nu9Xl27dk1tbW1hx3MO4wdzI7qYGwPH3IhM3IaP5ORkzZs3T5WVlaF9wWBQlZWVysvLi2FniSknJ0derzfsfPr9ftXV1XE+/58xRuvXr9f+/ft19OhR5eTkhNXnzZuncePGhZ3DhoYGnT9/nnMYJ5gb0cXc6B9zY5BifcWrkz179hi322127dpl3n33XfPEE0+Y9PR009LSEuvW4lJHR4c5efKkOXnypJFkfvnLX5qTJ0+a999/3xhjzM9//nOTnp5uDh48aE6fPm1WrlxpcnJyzNWrV2PceXxYt26d8Xg8pqqqyly8eDG0XblyJXTMD37wAzN9+nRz9OhRc+LECZOXl2fy8vJi2DW+jLkRGebG0DA3Bieuw4cxxrz44otm+vTpJjk52SxYsMDU1tbGuqW49eabbxpJ121r1qwxxnz+a3NPP/20ycrKMm632yxdutQ0NDTEtuk40te5k2R27twZOubq1avmhz/8obnpppvMhAkTzOrVq83Fixdj1zT6xNwYOObG0DA3BsdljDH21lkAAMBoF7fXfAAAgJGJ8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACr/g9DhBV48dmDIwAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ds_train, ds_test = data.get_mnist_op_dataset(\n", " count_train=3000,\n", " count_test=1000,\n", " buffer_size=3000,\n", " batch_size=16,\n", " n_operands=2,\n", " op=lambda args: args[0]+args[1])\n", "\n", "# Visualize one example\n", "x, y, z = next(ds_train.as_numpy_iterator())\n", "plt.subplot(121)\n", "plt.imshow(x[0][:,:,0])\n", "plt.subplot(122)\n", "plt.imshow(y[0][:,:,0])\n", "print(\"Result label is %i\" % z[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LTN" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "logits_model = baselines.SingleDigit(inputs_as_a_list=True)\n", "Digit = ltn.Predicate.FromLogits(logits_model, activation_function=\"softmax\")\n", "\n", "d1 = ltn.Variable(\"digits1\", range(10))\n", "d2 = ltn.Variable(\"digits2\", range(10))\n", "\n", "Not = ltn.Wrapper_Connective(ltn.fuzzy_ops.Not_Std())\n", "And = ltn.Wrapper_Connective(ltn.fuzzy_ops.And_Prod())\n", "Or = ltn.Wrapper_Connective(ltn.fuzzy_ops.Or_ProbSum())\n", "Implies = ltn.Wrapper_Connective(ltn.fuzzy_ops.Implies_Reichenbach())\n", "Forall = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMeanError(),semantics=\"forall\")\n", "Exists = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMean(),semantics=\"exists\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice the use of `Diag`: when grounding $x$,$y$,$n$ with three sequences of values, the $i$-th examples of each variable are matching. \n", "That is, `(images_x[i],images_y[i],labels[i])` is a tuple from our dataset of valid additions.\n", "Using the diagonal quantification, LTN aggregates pairs of images and their corresponding result, rather than any combination of images and results. \n", " \n", "Notice also the guarded quantification: by quantifying only on the \"intermediate labels\" (not given during training) that could add up to the result label (given during training), we incorporate symbolic information into the system." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-03-30 16:40:38.033631: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n", "2023-03-30 16:40:38.035507: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# mask\n", "add = ltn.Function.Lambda(lambda inputs: inputs[0]+inputs[1])\n", "equals = ltn.Predicate.Lambda(lambda inputs: inputs[0] == inputs[1])\n", "\n", "### Axioms\n", "@tf.function\n", "def axioms(images_x, images_y, labels_z, p_schedule=tf.constant(2.)):\n", " images_x = ltn.Variable(\"x\", images_x)\n", " images_y = ltn.Variable(\"y\", images_y)\n", " labels_z = ltn.Variable(\"z\", labels_z)\n", " axiom = Forall(\n", " ltn.diag(images_x,images_y,labels_z),\n", " Exists(\n", " (d1,d2),\n", " And(Digit([images_x,d1]),Digit([images_y,d2])),\n", " mask=equals([add([d1,d2]), labels_z]),\n", " p=p_schedule\n", " ),\n", " p=2\n", " )\n", " sat = axiom.tensor\n", " return sat\n", "\n", "images_x, images_y, labels_z = next(ds_train.as_numpy_iterator())\n", "axioms(images_x, images_y, labels_z)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Optimizer, training steps and metrics" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam(0.001)\n", "metrics_dict = {\n", " 'train_loss': tf.keras.metrics.Mean(name=\"train_loss\"),\n", " 'train_accuracy': tf.keras.metrics.Mean(name=\"train_accuracy\"),\n", " 'test_loss': tf.keras.metrics.Mean(name=\"test_loss\"),\n", " 'test_accuracy': tf.keras.metrics.Mean(name=\"test_accuracy\") \n", "}\n", "\n", "@tf.function\n", "def train_step(images_x, images_y, labels_z, **parameters):\n", " # loss\n", " with tf.GradientTape() as tape:\n", " loss = 1.- axioms(images_x, images_y, labels_z, **parameters)\n", " gradients = tape.gradient(loss, logits_model.trainable_variables)\n", " optimizer.apply_gradients(zip(gradients, logits_model.trainable_variables))\n", " metrics_dict['train_loss'](loss)\n", " # accuracy\n", " predictions_x = tf.argmax(logits_model([images_x]),axis=-1)\n", " predictions_y = tf.argmax(logits_model([images_y]),axis=-1)\n", " predictions_z = predictions_x + predictions_y\n", " match = tf.equal(predictions_z,tf.cast(labels_z,predictions_z.dtype))\n", " metrics_dict['train_accuracy'](tf.reduce_mean(tf.cast(match,tf.float32)))\n", " \n", "@tf.function\n", "def test_step(images_x, images_y, labels_z, **parameters):\n", " # loss\n", " loss = 1.- axioms(images_x, images_y, labels_z, **parameters)\n", " metrics_dict['test_loss'](loss)\n", " # accuracy\n", " predictions_x = tf.argmax(logits_model([images_x]),axis=-1)\n", " predictions_y = tf.argmax(logits_model([images_y]),axis=-1)\n", " predictions_z = predictions_x + predictions_y\n", " match = tf.equal(predictions_z,tf.cast(labels_z,predictions_z.dtype))\n", " metrics_dict['test_accuracy'](tf.reduce_mean(tf.cast(match,tf.float32)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "\n", "scheduled_parameters = defaultdict(lambda: {})\n", "for epoch in range(0,4):\n", " scheduled_parameters[epoch] = {\"p_schedule\":tf.constant(1.)}\n", "for epoch in range(4,8):\n", " scheduled_parameters[epoch] = {\"p_schedule\":tf.constant(2.)}\n", "for epoch in range(8,12):\n", " scheduled_parameters[epoch] = {\"p_schedule\":tf.constant(4.)}\n", "for epoch in range(12,20):\n", " scheduled_parameters[epoch] = {\"p_schedule\":tf.constant(6.)}\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-03-30 16:40:54.797507: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", "2023-03-30 16:41:06.041586: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", "2023-03-30 16:41:06.798203: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", "2023-03-30 16:41:08.252734: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0, train_loss: 0.9356, train_accuracy: 0.3793, test_loss: 0.8768, test_accuracy: 0.6567\n", "Epoch 1, train_loss: 0.8566, train_accuracy: 0.8321, test_loss: 0.8480, test_accuracy: 0.8403\n", "Epoch 2, train_loss: 0.8443, train_accuracy: 0.8989, test_loss: 0.8430, test_accuracy: 0.8562\n", "Epoch 3, train_loss: 0.8383, train_accuracy: 0.9249, test_loss: 0.8442, test_accuracy: 0.8492\n", "Epoch 4, train_loss: 0.6416, train_accuracy: 0.9279, test_loss: 0.6523, test_accuracy: 0.8730\n", "Epoch 5, train_loss: 0.6285, train_accuracy: 0.9458, test_loss: 0.6476, test_accuracy: 0.8859\n", "Epoch 6, train_loss: 0.6237, train_accuracy: 0.9525, test_loss: 0.6345, test_accuracy: 0.9077\n", "Epoch 7, train_loss: 0.6179, train_accuracy: 0.9624, test_loss: 0.6329, test_accuracy: 0.9147\n", "Epoch 8, train_loss: 0.4284, train_accuracy: 0.9525, test_loss: 0.4674, test_accuracy: 0.8889\n", "Epoch 9, train_loss: 0.4167, train_accuracy: 0.9618, test_loss: 0.4652, test_accuracy: 0.8929\n", "Epoch 10, train_loss: 0.4127, train_accuracy: 0.9651, test_loss: 0.4669, test_accuracy: 0.8869\n", "Epoch 11, train_loss: 0.4054, train_accuracy: 0.9697, test_loss: 0.4590, test_accuracy: 0.9028\n", "Epoch 12, train_loss: 0.3186, train_accuracy: 0.9688, test_loss: 0.3705, test_accuracy: 0.9147\n", "Epoch 13, train_loss: 0.3156, train_accuracy: 0.9697, test_loss: 0.3740, test_accuracy: 0.9167\n", "Epoch 14, train_loss: 0.3204, train_accuracy: 0.9664, test_loss: 0.4056, test_accuracy: 0.8869\n", "Epoch 15, train_loss: 0.3228, train_accuracy: 0.9661, test_loss: 0.3548, test_accuracy: 0.9286\n", "Epoch 16, train_loss: 0.3112, train_accuracy: 0.9731, test_loss: 0.3741, test_accuracy: 0.9127\n", "Epoch 17, train_loss: 0.3071, train_accuracy: 0.9747, test_loss: 0.3577, test_accuracy: 0.9276\n", "Epoch 18, train_loss: 0.3042, train_accuracy: 0.9771, test_loss: 0.3682, test_accuracy: 0.9167\n", "Epoch 19, train_loss: 0.2983, train_accuracy: 0.9801, test_loss: 0.3588, test_accuracy: 0.9246\n" ] } ], "source": [ "commons.train(\n", " 20,\n", " metrics_dict,\n", " ds_train,\n", " ds_test,\n", " train_step,\n", " test_step,\n", " scheduled_parameters=scheduled_parameters\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "12eaedf9b9a64329743e8900a3192e3d75dbaaa78715534825922e4a4f7d9137" }, "kernelspec": { "display_name": "ltn", "language": "python", "name": "ltn" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 }