{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Smokes/Friends/Cancer with Logic Tensor Networks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a classic example from relational learning implemented in LTN and grounded in a domain $\\mathbb{R}^n$." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import logging;logging.basicConfig(level=logging.INFO)\n", "import tensorflow as tf\n", "\n", "# adding parent directory to sys.path (ltnw is in parent dir)\n", "import sys, os\n", "sys.path.insert(0,os.path.normpath(os.path.join(os.path.abspath(''),os.path.pardir)))\n", "import logictensornetworks as ltn\n", "import logictensornetworks_wrapper as ltnw" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some initial settings. Includes the dimensionality of the embedding/grounding space" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "ltn.LAYERS = 4\n", "ltn.BIAS_factor = 1e-3\n", "\n", "max_epochs=1000\n", "embedding_size = 10 # embedding space dimensionality" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use three predicates. Two unary predicates and one binary predicate" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "ltnw.predicate('Smokes',embedding_size);\n", "ltnw.predicate('Friends',embedding_size*2);\n", "ltnw.predicate('Cancer',embedding_size);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generate constants a to n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "g1,g2='abcdefgh','ijklmn'\n", "g=g1+g2\n", "for l in g:\n", " ltnw.constant(l,min_value=[0.]*embedding_size,max_value=[1.]*embedding_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we define the facts known about the world, i.e. who we know is friends, smokes and has cancer" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cancer(a)\n", "Cancer(e)\n", "Friends(a,b)\n", "Friends(a,e)\n", "Friends(a,f)\n", "Friends(a,g)\n", "Friends(b,c)\n", "Friends(c,d)\n", "Friends(e,f)\n", "Friends(g,h)\n", "Friends(i,j)\n", "Friends(j,m)\n", "Friends(k,l)\n", "Friends(m,n)\n", "Smokes(a)\n", "Smokes(e)\n", "Smokes(f)\n", "Smokes(g)\n", "Smokes(j)\n", "Smokes(n)\n", "~Cancer(b)\n", "~Cancer(c)\n", "~Cancer(d)\n", "~Cancer(f)\n", "~Cancer(g)\n", "~Cancer(h)\n", "~Cancer(i)\n", "~Cancer(j)\n", "~Cancer(k)\n", "~Cancer(l)\n", "~Cancer(m)\n", "~Cancer(n)\n", "~Friends(a,c)\n", "~Friends(a,d)\n", "~Friends(a,h)\n", "~Friends(b,d)\n", "~Friends(b,e)\n", "~Friends(b,f)\n", "~Friends(b,g)\n", "~Friends(b,h)\n", "~Friends(c,e)\n", "~Friends(c,f)\n", "~Friends(c,g)\n", "~Friends(c,h)\n", "~Friends(d,e)\n", "~Friends(d,f)\n", "~Friends(d,g)\n", "~Friends(d,h)\n", "~Friends(e,g)\n", "~Friends(e,h)\n", "~Friends(f,g)\n", "~Friends(f,h)\n", "~Friends(i,k)\n", "~Friends(i,l)\n", "~Friends(i,m)\n", "~Friends(i,n)\n", "~Friends(j,k)\n", "~Friends(j,l)\n", "~Friends(j,n)\n", "~Friends(k,m)\n", "~Friends(k,n)\n", "~Friends(l,m)\n", "~Friends(l,n)\n", "~Smokes(b)\n", "~Smokes(c)\n", "~Smokes(d)\n", "~Smokes(h)\n", "~Smokes(i)\n", "~Smokes(k)\n", "~Smokes(l)\n", "~Smokes(m)\n" ] } ], "source": [ "friends = [('a','b'),('a','e'),('a','f'),('a','g'),('b','c'),('c','d'),('e','f'),('g','h'),\n", " ('i','j'),('j','m'),('k','l'),('m','n')]\n", "[ltnw.formula(\"Friends(%s,%s)\" %(x,y)) for (x,y) in friends]\n", "[ltnw.formula(\"~Friends(%s,%s)\" %(x,y)) for x in g1 for y in g1 if (x,y) not in friends and x < y]\n", "[ltnw.formula(\"~Friends(%s,%s)\" %(x,y)) for x in g2 for y in g2 if (x,y) not in friends and x < y]\n", "\n", "smokes = ['a','e','f','g','j','n']\n", "[ltnw.formula(\"Smokes(%s)\" % x) for x in smokes]\n", "[ltnw.formula(\"~Smokes(%s)\" % x) for x in g if x not in smokes]\n", "\n", "cancer = ['a','e']\n", "[ltnw.formula(\"Cancer(%s)\" % x) for x in cancer]\n", "[ltnw.formula(\"~Cancer(%s)\" % x) for x in g if x not in cancer]\n", "\n", "print(\"\\n\".join(sorted(ltnw.FORMULAS.keys())))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we add knowledge that we know holds for each constant/individual. For this we use two variables, that go over the all samples" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "forall p,q:Friends(p,q) -> (Smokes(p)->Smokes(q))\n", "forall p,q:Friends(p,q) -> Friends(q,p)\n", "forall p: Smokes(p) -> Cancer(p)\n", "forall p: exists q: Friends(p,q)\n", "forall p: ~Friends(p,p)\n" ] } ], "source": [ "ltnw.variable(\"p\",tf.concat(list(ltnw.CONSTANTS.values()),axis=0))\n", "ltnw.variable(\"q\",tf.concat(list(ltnw.CONSTANTS.values()),axis=0))\n", "\n", "ltnw.formula(\"forall p: ~Friends(p,p)\")\n", "ltnw.formula(\"forall p,q:Friends(p,q) -> Friends(q,p)\")\n", "ltnw.formula(\"forall p: exists q: Friends(p,q)\")\n", "ltnw.formula(\"forall p,q:Friends(p,q) -> (Smokes(p)->Smokes(q))\")\n", "ltnw.formula(\"forall p: Smokes(p) -> Cancer(p)\")\n", "print(\"\\n\".join(sorted(filter(lambda x: x.startswith(\"forall\"), ltnw.FORMULAS.keys()))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initialize the knowledgebase and optimize" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:logictensornetworks_wrapper:Initializing knowledgebase\n", "INFO:logictensornetworks_wrapper:Initializing optimizer\n", "INFO:logictensornetworks_wrapper:Assembling feed dict\n", "INFO:logictensornetworks_wrapper:Initializing Tensorflow session\n", "INFO:logictensornetworks_wrapper:INITIALIZED with sat level = 0.4619353\n", "INFO:logictensornetworks_wrapper:TRAINING 0 sat level -----> 0.4619353\n", "INFO:logictensornetworks_wrapper:TRAINING finished after 999 epochs with sat level 0.9683184\n" ] } ], "source": [ "optimizer=tf.train.RMSPropOptimizer(learning_rate=.01,decay=.9)\n", "formula_aggregator=lambda *x: tf.reduce_mean(tf.concat(x,axis=0))+ltn.BIAS\n", "ltnw.initialize_knowledgebase(initial_sat_level_threshold=.1,optimizer=optimizer,formula_aggregator=formula_aggregator)\n", "ltnw.train(track_sat_levels=1000,sat_level_epsilon=.99,max_epochs=max_epochs);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check some formulas" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a = [ 0.40694648 0.13683926 1.1569124 1.1818535 0.14294973 0.9392728\n", " 1.7961706 -0.82798356 0.61106527 1.3495604 ]\n", "Cancer(a): 1.00\n", "Smokes(a): 1.00\n", "Friends(a,a): 0.00\n", "Friends(a,b): 0.96\n", "Friends(a,c): 0.00\n", "Friends(a,d): 0.00\n", "Friends(a,e): 1.00\n", "Friends(a,f): 1.00\n", "Friends(a,g): 1.00\n", "Friends(a,h): 0.00\n", "Friends(a,i): 0.00\n", "Friends(a,j): 1.00\n", "Friends(a,k): 0.96\n", "Friends(a,l): 0.00\n", "Friends(a,m): 0.26\n", "Friends(a,n): 0.26\n", "b = [ 0.9147028 1.0019453 0.5234545 0.3568315 0.9653716 -0.7044046\n", " 0.29647025 1.2177726 0.7527934 0.6989008 ]\n", "Cancer(b): 0.00\n", "Smokes(b): 0.00\n", "Friends(b,a): 0.26\n", "Friends(b,b): 0.00\n", "Friends(b,c): 0.96\n", "Friends(b,d): 0.00\n", "Friends(b,e): 0.00\n", "Friends(b,f): 0.00\n", "Friends(b,g): 0.00\n", "Friends(b,h): 0.00\n", "Friends(b,i): 0.00\n", "Friends(b,j): 0.00\n", "Friends(b,k): 0.00\n", "Friends(b,l): 0.00\n", "Friends(b,m): 0.00\n", "Friends(b,n): 0.00\n", "c = [ 0.17684561 0.10682261 -0.3228023 0.18780398 2.2169807 0.07184859\n", " 0.67084223 0.1279161 0.03808275 0.2801455 ]\n", "Cancer(c): 0.00\n", "Smokes(c): 0.00\n", "Friends(c,a): 0.00\n", "Friends(c,b): 0.00\n", "Friends(c,c): 0.00\n", "Friends(c,d): 1.00\n", "Friends(c,e): 0.00\n", "Friends(c,f): 0.00\n", "Friends(c,g): 0.00\n", "Friends(c,h): 0.00\n", "Friends(c,i): 0.00\n", "Friends(c,j): 0.00\n", "Friends(c,k): 0.00\n", "Friends(c,l): 0.00\n", "Friends(c,m): 0.00\n", "Friends(c,n): 0.00\n", "d = [ 1.8529487 1.5305904 -0.00983319 -0.43959376 1.1610166 -0.19935457\n", " 0.8084609 0.8635864 -0.37106708 -0.21966477]\n", "Cancer(d): 0.00\n", "Smokes(d): 0.00\n", "Friends(d,a): 0.00\n", "Friends(d,b): 0.00\n", "Friends(d,c): 0.96\n", "Friends(d,d): 0.00\n", "Friends(d,e): 0.00\n", "Friends(d,f): 0.00\n", "Friends(d,g): 0.00\n", "Friends(d,h): 0.00\n", "Friends(d,i): 0.00\n", "Friends(d,j): 0.00\n", "Friends(d,k): 0.96\n", "Friends(d,l): 0.00\n", "Friends(d,m): 0.00\n", "Friends(d,n): 0.00\n", "e = [0.6416602 0.66157794 0.12425633 0.67886037 0.09569763 1.7586013\n", " 1.1241968 0.92352915 0.47239357 0.8050211 ]\n", "Cancer(e): 1.00\n", "Smokes(e): 1.00\n", "Friends(e,a): 1.00\n", "Friends(e,b): 0.00\n", "Friends(e,c): 0.00\n", "Friends(e,d): 0.00\n", "Friends(e,e): 0.00\n", "Friends(e,f): 1.00\n", "Friends(e,g): 0.00\n", "Friends(e,h): 0.00\n", "Friends(e,i): 0.00\n", "Friends(e,j): 0.96\n", "Friends(e,k): 0.00\n", "Friends(e,l): 0.00\n", "Friends(e,m): 0.26\n", "Friends(e,n): 0.00\n", "f = [-0.08516412 1.3672957 0.72731215 0.5337677 0.24218766 -0.17402795\n", " 0.7505686 1.0381958 -0.35125494 1.0763618 ]\n", "Cancer(f): 0.00\n", "Smokes(f): 1.00\n", "Friends(f,a): 1.00\n", "Friends(f,b): 0.00\n", "Friends(f,c): 0.00\n", "Friends(f,d): 0.00\n", "Friends(f,e): 1.00\n", "Friends(f,f): 0.00\n", "Friends(f,g): 0.00\n", "Friends(f,h): 0.00\n", "Friends(f,i): 0.00\n", "Friends(f,j): 0.00\n", "Friends(f,k): 0.00\n", "Friends(f,l): 0.00\n", "Friends(f,m): 0.26\n", "Friends(f,n): 0.04\n", "g = [ 0.8539798 0.761432 0.75761944 0.69546074 0.30572608 0.59105855\n", " -0.343103 1.1347185 0.53520316 -0.08765783]\n", "Cancer(g): 0.00\n", "Smokes(g): 1.00\n", "Friends(g,a): 1.00\n", "Friends(g,b): 0.00\n", "Friends(g,c): 0.00\n", "Friends(g,d): 0.00\n", "Friends(g,e): 0.00\n", "Friends(g,f): 0.00\n", "Friends(g,g): 0.00\n", "Friends(g,h): 0.96\n", "Friends(g,i): 1.00\n", "Friends(g,j): 0.00\n", "Friends(g,k): 0.00\n", "Friends(g,l): 0.00\n", "Friends(g,m): 0.26\n", "Friends(g,n): 0.00\n", "h = [ 0.33865392 0.84384173 0.70018744 0.1345554 1.1172969 0.09839492\n", " 0.81501096 -0.15736319 0.56537557 0.7202723 ]\n", "Cancer(h): 0.00\n", "Smokes(h): 0.00\n", "Friends(h,a): 0.00\n", "Friends(h,b): 0.00\n", "Friends(h,c): 0.00\n", "Friends(h,d): 0.00\n", "Friends(h,e): 0.00\n", "Friends(h,f): 0.00\n", "Friends(h,g): 0.00\n", "Friends(h,h): 0.00\n", "Friends(h,i): 0.00\n", "Friends(h,j): 0.00\n", "Friends(h,k): 0.00\n", "Friends(h,l): 0.00\n", "Friends(h,m): 0.00\n", "Friends(h,n): 0.00\n", "i = [ 0.8514193 1.4208581 0.40041497 -0.19588503 -0.200309 -0.2671704\n", " 1.292986 -0.57946455 1.1746347 0.6758825 ]\n", "Cancer(i): 0.00\n", "Smokes(i): 0.00\n", "Friends(i,a): 0.00\n", "Friends(i,b): 0.00\n", "Friends(i,c): 0.00\n", "Friends(i,d): 0.00\n", "Friends(i,e): 0.00\n", "Friends(i,f): 0.00\n", "Friends(i,g): 0.96\n", "Friends(i,h): 0.00\n", "Friends(i,i): 0.00\n", "Friends(i,j): 0.96\n", "Friends(i,k): 0.00\n", "Friends(i,l): 0.00\n", "Friends(i,m): 0.00\n", "Friends(i,n): 0.00\n", "j = [ 0.66317457 0.4312353 0.76770747 0.26138163 -0.5081676 0.71022123\n", " 0.3749792 0.8656672 1.5517639 0.17871857]\n", "Cancer(j): 0.00\n", "Smokes(j): 1.00\n", "Friends(j,a): 1.00\n", "Friends(j,b): 0.00\n", "Friends(j,c): 0.00\n", "Friends(j,d): 0.00\n", "Friends(j,e): 0.96\n", "Friends(j,f): 0.00\n", "Friends(j,g): 0.00\n", "Friends(j,h): 0.00\n", "Friends(j,i): 0.96\n", "Friends(j,j): 0.00\n", "Friends(j,k): 0.00\n", "Friends(j,l): 0.00\n", "Friends(j,m): 1.00\n", "Friends(j,n): 0.00\n", "k = [ 0.8996744 0.8115943 0.30769414 0.25508854 1.6205107 -1.1135386\n", " -0.31606957 0.27544066 1.0825695 0.51878643]\n", "Cancer(k): 0.00\n", "Smokes(k): 0.00\n", "Friends(k,a): 0.26\n", "Friends(k,b): 0.00\n", "Friends(k,c): 0.00\n", "Friends(k,d): 0.96\n", "Friends(k,e): 0.00\n", "Friends(k,f): 0.00\n", "Friends(k,g): 0.00\n", "Friends(k,h): 0.00\n", "Friends(k,i): 0.00\n", "Friends(k,j): 0.00\n", "Friends(k,k): 0.00\n", "Friends(k,l): 1.00\n", "Friends(k,m): 0.00\n", "Friends(k,n): 0.00\n", "l = [ 1.0598446 0.3733673 1.0997576 -0.74925214 0.9516005 0.44306272\n", " 0.15078382 -0.45237944 0.21813329 0.13396993]\n", "Cancer(l): 0.00\n", "Smokes(l): 0.00\n", "Friends(l,a): 0.00\n", "Friends(l,b): 0.00\n", "Friends(l,c): 0.00\n", "Friends(l,d): 0.00\n", "Friends(l,e): 0.00\n", "Friends(l,f): 0.00\n", "Friends(l,g): 0.00\n", "Friends(l,h): 0.00\n", "Friends(l,i): 0.00\n", "Friends(l,j): 0.00\n", "Friends(l,k): 0.96\n", "Friends(l,l): 0.00\n", "Friends(l,m): 0.00\n", "Friends(l,n): 0.00\n", "m = [-0.2774569 1.4202287 1.1293546 0.9845873 0.48729423 0.41340557\n", " 1.2946917 0.7260294 0.6903299 1.0546893 ]\n", "Cancer(m): 0.00\n", "Smokes(m): 1.00\n", "Friends(m,a): 0.26\n", "Friends(m,b): 0.00\n", "Friends(m,c): 0.00\n", "Friends(m,d): 0.00\n", "Friends(m,e): 0.26\n", "Friends(m,f): 0.26\n", "Friends(m,g): 0.26\n", "Friends(m,h): 0.00\n", "Friends(m,i): 0.00\n", "Friends(m,j): 1.00\n", "Friends(m,k): 0.00\n", "Friends(m,l): 0.00\n", "Friends(m,m): 0.26\n", "Friends(m,n): 1.00\n", "n = [0.50134546 0.961617 0.84858876 0.2242519 0.9684452 1.192804\n", " 0.06254435 0.4865449 0.0592797 0.5465894 ]\n", "Cancer(n): 0.00\n", "Smokes(n): 1.00\n", "Friends(n,a): 0.26\n", "Friends(n,b): 0.00\n", "Friends(n,c): 0.00\n", "Friends(n,d): 0.00\n", "Friends(n,e): 0.00\n", "Friends(n,f): 0.00\n", "Friends(n,g): 0.00\n", "Friends(n,h): 0.00\n", "Friends(n,i): 0.00\n", "Friends(n,j): 0.00\n", "Friends(n,k): 0.00\n", "Friends(n,l): 0.00\n", "Friends(n,m): 0.26\n", "Friends(n,n): 0.00\n" ] } ], "source": [ "for x in g:\n", " print(x,\" = \",ltnw.ask(x).squeeze())\n", " print(\"Cancer(\"+x+\"): %.2f\" % ltnw.ask(\"Cancer(%s)\" % x).squeeze())\n", " print(\"Smokes(\"+x+\"): %.2f\" % ltnw.ask(\"Smokes(%s)\" % x).squeeze())\n", "\n", " for y in g:\n", " print(\"Friends(\"+x+\",\"+y+\"): %.2f\" % ltnw.ask(\"Friends(%s,%s)\" % (x,y)).squeeze())\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check some general axioms" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "forall p: ~Friends(p,p) : 0.97\n", "forall p,q: Friends(p,q) -> Friends(q,p) : 0.77\n", "forall p: exists q: Friends(p,q) : 0.99\n", "forall p,q: Friends(p,q) -> (Smokes(p)->Smokes(q)) : 0.00\n" ] } ], "source": [ "for formula in [\"forall p: ~Friends(p,p)\",\n", " \"forall p,q: Friends(p,q) -> Friends(q,p)\",\n", " \"forall p: exists q: Friends(p,q)\",\n", " \"forall p,q: Friends(p,q) -> (Smokes(p)->Smokes(q))\"]:\n", " print(formula,\": %.2f\" % ltnw.ask(formula).squeeze())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 4 }