{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Regression with LTN\n", "\n", "This is a basic example in which we use LTN for training a regression estimator. The essence of regression is to approximate a function $f(x)=y$ where $y$ can be any real value or tensor (as opposed to classifiers which project values to $[0,1]$). We are given examples of this function $x_i$, $y_i$ with $f(x_i)=y_i$. From the examples, we need to estimate a function $f^\\ast$ that approximates $f$.\n", "\n", "In LTN we can directly model this by defining $f^\\ast$ as a learnable/trainable function with some parameters that are constrained through data. Additionnally, we need a notion of $=$. Here, we use an euclidian distance/similarity to get a smooth $=$ function. We define the following language and theory:\n", "- a set of points $x_i$ and $y_i$\n", "- a definition of a predicate for equality, modeled as a smooth equality function $\\mathrm{eq}(x,y)=\\exp\\bigg(-\\sqrt{\\sum_j (x_j-y_j)^2}\\bigg)$\n", "- a learnable function $f^\\ast$ for which the following constraint hold: $\\mathrm{eq}(f^\\ast(x_i),y_i)$. The function $f^\\ast$ approximates $f(x_i)=y_i$\n", "\n", "Here, $f^\\ast$ is modeled using a simple MLP." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Init Plugin\n", "Init Graph Optimizer\n", "Init Kernel\n" ] } ], "source": [ "import logging; logging.basicConfig(level=logging.INFO)\n", "import tensorflow as tf\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import ltn\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data\n", "\n", "Load the [real estate dataset](https://www.kaggle.com/quantbruce/real-estate-price-prediction): \n", "- 414 samples,\n", "- 6 float features: transaction date (converted to float, eg. `2012.917` is equivalent to december), house age, distance to station, # of convenience stores nearby, lat, long,\n", "- y: house price per unit area" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | No | \n", "X1 transaction date | \n", "X2 house age | \n", "X3 distance to the nearest MRT station | \n", "X4 number of convenience stores | \n", "X5 latitude | \n", "X6 longitude | \n", "Y house price of unit area | \n", "
---|---|---|---|---|---|---|---|---|
126 | \n", "127 | \n", "2013.083 | \n", "38.6 | \n", "804.68970 | \n", "4 | \n", "24.97838 | \n", "121.53477 | \n", "62.9 | \n", "
124 | \n", "125 | \n", "2012.917 | \n", "9.9 | \n", "279.17260 | \n", "7 | \n", "24.97528 | \n", "121.54541 | \n", "57.4 | \n", "
264 | \n", "265 | \n", "2013.167 | \n", "32.6 | \n", "493.65700 | \n", "7 | \n", "24.96968 | \n", "121.54522 | \n", "40.6 | \n", "
403 | \n", "404 | \n", "2012.667 | \n", "30.9 | \n", "161.94200 | \n", "9 | \n", "24.98353 | \n", "121.53966 | \n", "39.7 | \n", "
192 | \n", "193 | \n", "2013.167 | \n", "43.8 | \n", "57.58945 | \n", "7 | \n", "24.96750 | \n", "121.54069 | \n", "42.7 | \n", "