{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Relational Learning with TensorFlow\n", "Brian Jones, FireEye Inc.
\n", "March 2016
\n", "Accompanying code: [GitHub site](http://github.com/fireeye/tf_rl_tutorial)\n", "
\n", "\n", "## Table of Contents\n", "* [Relational Learning](#relational-learning)\n", "* [Loading the Data](#loading-the-data)\n", "* [Preprocessing](#preprocessing)\n", "* [Background: Matrix Factorization](#background-matrix-factorization)\n", "* [Tensor Decomposition](#tensor-decomposition)\n", "* [Optimization](#optimization-and-negative-sampling)\n", "* [Max-Norm Regularization](#max-norm-regularization)\n", "* [Loss Functions](#loss-functions)\n", "* [Bilinear Model (RESCAL)](#bilinear-model-rescal)\n", "* [TransE](#transe)\n", "* [Take It For a Spin](#take-it-for-a-spin)\n", "* [Visualization](#visualization)\n", "* [References](#references)\n", "\n", "
\n", "Like many machine learning researchers, I recently decided to take some time to learn Google's new [TensorFlow](http://www.tensorflow.org/) framework. For those who have not taken a look yet, TensorFlow provides a well-designed API in both C and Python for building and training graph-based models like deep neural networks. Most of the existing introductory material focuses on computer vision and sequence modeling, so I thought I'd take a different route and use it to build statistical relational learning models. I turned that experience into this tutorial in case it is useful for others who are also learning the framework or interested in the topic.\n", "\n", "In this tutorial we'll build three different types of learning models, implement a max-norm constraint regularization technique, and evaluate them on a popular dataset\n", "\n", "Before diving in, make sure you've gone over the [background material](http://www.tensorflow.org/how_tos) provided on the official site, as I won't be covering the absolute basics. In addition to TensorFlow, we'll also be using some standard Python numerical libraries: [NumPy](http://www.numpy.org/) and [Pandas](http://pandas.pydata.org/). Finally, we'll make use of some utilities that can be found in the [accompanying code](http://github.com/fireeye/tf_rl_tutorial) for this tutorial." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Relational Learning\n", "\n", "Relational learning models try to predict unseen or future relationships between entities based on past observations. Item recommendation is a common example, where the model predicts how much a person will enjoy a particular item based on past rating (or consumption) data collected from a user base. A similar task is link prediction, where the goal is to predict new edges, or edge values, in a graph that are likely to be observed in the future. When applied to a knowledge database like [WordNet](http://wordnet.princeton.edu) or [FreeBase](http://www.freebase.com), it is sometimes used to suggest missing facts for \"knowledge base completion\".\n", "\n", "![link prediction](pics/link.svg)\n", "\n", "Many approaches have been developed for relational learning, and here we will focus on a subset called latent factor models. While TensorFlow is primarily geared towards neural network construction with layers of weights and nonlinear activations, it is also a nice framework for building latent factor models and even offers some specific functionality for them with its \"embedding\" methods.\n", "\n", "What makes frameworks like [TensorFlow](http://www.tensorflow.org/) , [Theano](http://deeplearning.net/software/theano/), and [Torch](http://torch.ch/) so great for productivity is their automatic differentiation feature, which calculates all of the partial derivatives necessary for running optimization procedures like stochastic gradient descent [(SGD)](http://en.wikipedia.org/wiki/Stochastic_gradient_descent). One simply needs to declare the computational graph structure and the framework takes care of all of the messy details for backpropagation. While calculating derivatives by hand can be fun (for some!), having to do it for every model variation you are developing is incredibly time-consuming.\n", "\n", "Let's get started!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Loading the Data\n", "\n", "For this tutorial, we'll be using a subset of the WordNet database that is popular for benchmarking relational learning models [[Bordes14]](#Bordes14). You can download it [here](https://everest.hds.utc.fr/doku.php?id=en:transe). WordNet is a large knowledge base of English words, their definitions, and their relationships to each other. This particular dataset contains one file with word definitions, and three files with relationships for standard train / validate / test purposes. The code below will load the data from text files into Pandas DataFrames objects, and also replaces the identifier numbers with actual words to make them more readable." ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train shape: (141442, 3)\n", "Validation shape: (5000, 3)\n", "Test shape: (5000, 3)\n", "Training entity count: 40943\n", "Training relationship type count: 18\n", "Example training triples:\n" ] }, { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
headreltail
9565__incinerate_VB_2_verb_group__incinerate_VB_1
132173__outside_door_NN_1_has_part__panelling_NN_1
108291__negative_JJ_2_derivationally_related_form__negativity_NN_2
82581__typography_NN_1_derivationally_related_form__typographer_NN_1
127266__dominican_republic_NN_1_part_of__hispaniola_NN_1
\n", "
" ], "text/plain": [ " head rel \\\n", "9565 __incinerate_VB_2 _verb_group \n", "132173 __outside_door_NN_1 _has_part \n", "108291 __negative_JJ_2 _derivationally_related_form \n", "82581 __typography_NN_1 _derivationally_related_form \n", "127266 __dominican_republic_NN_1 _part_of \n", "\n", " tail \n", "9565 __incinerate_VB_1 \n", "132173 __panelling_NN_1 \n", "108291 __negativity_NN_2 \n", "82581 __typographer_NN_1 \n", "127266 __hispaniola_NN_1 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import os\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "from IPython.display import display\n", "\n", "from tf_rl_tutorial import util\n", "\n", "def read_and_replace(fpath, def_df):\n", " df = pd.read_table(fpath, names=['head', 'rel', 'tail'])\n", " df['head'] = def_df.loc[df['head']]['word'].values\n", " df['tail'] = def_df.loc[df['tail']]['word'].values\n", " return df\n", "\n", "data_dir = '~/data/wordnet-mlj12' # change to where you extracted the data\n", "definitions = pd.read_table(os.path.join(data_dir, 'wordnet-mlj12-definitions.txt'), \n", " index_col=0, names=['word', 'definition'])\n", "train = read_and_replace(os.path.join(data_dir, 'wordnet-mlj12-train.txt'), definitions)\n", "val = read_and_replace(os.path.join(data_dir, 'wordnet-mlj12-valid.txt'), definitions)\n", "test = read_and_replace(os.path.join(data_dir, 'wordnet-mlj12-test.txt'), definitions)\n", "\n", "print('Train shape:', train.shape)\n", "print('Validation shape:', val.shape)\n", "print('Test shape:', test.shape)\n", "all_train_entities = set(train['head']).union(train['tail'])\n", "print('Training entity count: {}'.format(len(all_train_entities)))\n", "print('Training relationship type count: {}'.format(len(set(train['rel']))))\n", "print('Example training triples:')\n", "display(train.sample(5))" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "As we can see, this dataset contains relationship triples where the left-hand side (head) and right-hand side (tail) entries are things (or concepts) and the middle entry is a relationship between them. It contains 18 different relationship types, making this a multi-relational dataset. The extra suffix on each word serves to differentiate between homographs (e.g. fish the animal vs. fish the verb)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Preprocessing\n", "\n", "The validation and test sets also contain true statements in the form shown above. One thing I noticed is that most of the triples in these two sets are mirror images of triples in the training set. This is due to most relationships having an opposite counterpart, for example \"fishing_rod has_part reel\" and \"reel part_of fishing_rod\". I think this makes the problem a bit too easy, so let's remove the mirror-image triples from the training set. We'll also remove triples in the validation and test sets where the head or tail entity is not present in that position in the training set, because our first model will not be able to handle those." ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train shape: (131974, 3)\n", "Validation shape: (4389, 3)\n", "Test shape: (4394, 3)\n" ] } ], "source": [ "from collections import defaultdict\n", "\n", "mask = np.zeros(len(train)).astype(bool)\n", "lookup = defaultdict(list)\n", "for idx,h,r,t in train.itertuples():\n", " lookup[(h,t)].append(idx) \n", "for h,r,t in pd.concat((val,test)).itertuples(index=False):\n", " mask[lookup[(h,t)]] = True\n", " mask[lookup[(t,h)]] = True\n", "train = train.loc[~mask]\n", "heads, tails = set(train['head']), set(train['tail'])\n", "val = val.loc[val['head'].isin(heads) & val['tail'].isin(tails)]\n", "test = test.loc[test['head'].isin(heads) & test['tail'].isin(tails)]\n", "print('Train shape:', train.shape)\n", "print('Validation shape:', val.shape)\n", "print('Test shape:', test.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's add some false statements to the validation and test sets to make this a classification problem, as was done in [[Socher13]](#Socher13). For each true statement, we'll corrupt it by replacing either the head or tail entity with a random one (checking to make sure that the resulting triple isn't in our data). Some code to create these true/false pairs is provided in the *util* module function *create_tf_pairs()*." ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Validation shape: (8778, 4)\n", "Test shape: (8788, 4)\n" ] } ], "source": [ "rng = np.random.RandomState(42)\n", "combined_df = pd.concat((train, val, test))\n", "val = util.create_tf_pairs(val, combined_df, rng)\n", "test = util.create_tf_pairs(test, combined_df, rng)\n", "print('Validation shape:', val.shape)\n", "print('Test shape:', test.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To see what kind of prediction task we're up against, let's examine the training and test data for relationships involving the entity \"brain cell\":" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train:\n" ] }, { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
headreltail
1409__encephalon_NN_1_has_part__brain_cell_NN_1
22089__brain_cell_NN_1_hyponym__golgi_cell_NN_1
44047__golgi_cell_NN_1_hypernym__brain_cell_NN_1
59003__brain_cell_NN_1_part_of__encephalon_NN_1
\n", "
" ], "text/plain": [ " head rel tail\n", "1409 __encephalon_NN_1 _has_part __brain_cell_NN_1\n", "22089 __brain_cell_NN_1 _hyponym __golgi_cell_NN_1\n", "44047 __golgi_cell_NN_1 _hypernym __brain_cell_NN_1\n", "59003 __brain_cell_NN_1 _part_of __encephalon_NN_1" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Test:\n" ] }, { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
headreltailtruth_flag
4320__neuron_NN_1_hyponym__brain_cell_NN_1True
4321__lobbyist_NN_1_hyponym__brain_cell_NN_1False
4505__spice_up_VB_2_derivationally_related_form__brain_cell_NN_1False
\n", "
" ], "text/plain": [ " head rel tail \\\n", "4320 __neuron_NN_1 _hyponym __brain_cell_NN_1 \n", "4321 __lobbyist_NN_1 _hyponym __brain_cell_NN_1 \n", "4505 __spice_up_VB_2 _derivationally_related_form __brain_cell_NN_1 \n", "\n", " truth_flag \n", "4320 True \n", "4321 False \n", "4505 False " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "example_entity = '__brain_cell_NN_1'\n", "example_train_rows = (train['head'] == example_entity) | (train['tail'] == example_entity)\n", "print('Train:')\n", "display(train.loc[example_train_rows])\n", "example_test_rows = (test['head'] == example_entity) | (test['tail'] == example_entity)\n", "print('Test:')\n", "display(test.loc[example_test_rows])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We want to build models that can learn \"a neuron **is** a type of brain cell\", and \"a lobbyist **is not** a type of brain cell\", having never seen those statements before." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Background: Matrix Factorization\n", "\n", "Before describing the multi-relational models we will build, it is useful to consider a popular single-relationship modeling technique called matrix factorization. This approach became very popular during the \\$1M Netflix Prize movie recommendation contest, as it was one of the primary techniques used by the top teams [[Koren09]](#Koren09). \n", "\n", "In a factorization model, we will represent our training data for a single relationship type as a matrix, $\\textbf{X}$. Each row in $\\textbf{X}$ corresponds to an entity appearing in the head of the relationships, each column a tail entity, and the matrix entries contain the values for all possible pairings.\n", "\n", "For movie ratings each row would be a person, each column a movie, the values integers between 1-5, and the matrix only partially observed (each person has not rated every movie). Our training data consists of true statements only, which raises the issue of how we will represent all other possible statements. This is known as the [open](https://en.wikipedia.org/wiki/Open-world_assumption) vs. [closed world](https://en.wikipedia.org/wiki/Closed-world_assumption) assumption, and for now we will go with the latter and treat all other statements as false. This is clearly not perfect, as some true statements have been intentionally withheld in the validation and test sets, and knowledge bases like WordNet are seldom complete. In some previous work the unobserved entries are weighted less to account for this [[Hu08]](#Hu08).\n", "\n", "A single matrix is sufficient for datasets with only one relationship type, but for multi-relational data like we are handling with WordNet we would need to do this for each of the 18 relationship types. Here is what part of our matrix looks like for the relationship type *has_part*:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tail__acculturate_VB_1__national_weather_service_NN_1__rod_cell_NN_1__singapore_NN_2__vascular_tissue_NN_1
head
__noaa_NN_101000
__retina_NN_100100
__vascular_plant_NN_100001
\n", "
" ], "text/plain": [ "tail __acculturate_VB_1 __national_weather_service_NN_1 \\\n", "head \n", "__noaa_NN_1 0 1 \n", "__retina_NN_1 0 0 \n", "__vascular_plant_NN_1 0 0 \n", "\n", "tail __rod_cell_NN_1 __singapore_NN_2 \\\n", "head \n", "__noaa_NN_1 0 0 \n", "__retina_NN_1 1 0 \n", "__vascular_plant_NN_1 0 0 \n", "\n", "tail __vascular_tissue_NN_1 \n", "head \n", "__noaa_NN_1 0 \n", "__retina_NN_1 0 \n", "__vascular_plant_NN_1 1 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "has_part_triples = val.loc[val['rel'] == '_has_part']\n", "query_entities = ['__noaa_NN_1', '__vascular_plant_NN_1', '__retina_NN_1']\n", "has_part_example = has_part_triples.loc[has_part_triples['head'].isin(query_entities)]\n", "matrix_view = pd.pivot_table(has_part_example, 'truth_flag', 'head', 'tail', \n", " fill_value=False).astype(int)\n", "display(matrix_view)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The factorization model approximates the observation matrix $\\textbf{X}$ as the product of two matrices. We'll call them $\\textbf{H}$ and $\\textbf{T}$ for \"head\" and \"tail\". If $\\textbf{X}$ contains $I$ rows and $J$ columns, then $\\textbf{H} = [h_1, h_2, ..., h_I] \\in \\unicode{x211D}^{DxI}$ and $\\textbf{T} = [t_1, t_2, ..., t_J]\\in \\unicode{x211D}^{DxJ}$. $D$ is typically much smaller than $I$ or $J$.\n", "\n", "$$ \\textbf{X} \\approx \\textbf{H}^T \\textbf{T} $$\n", "\n", "![matrix factorization](pics/mf1.svg)\n", "\n", "For relational learning, one can think of the matrix $\\textbf{H}$ as containing latent factor vectors in its columns for the entities appearing in the head entry of our relationship triples and likewise for $\\textbf{T}$ and the tail entities. This is the same concept that has become popular in [word2vec](http://www.tensorflow.org/tutorials/word2vec/index.html) models [[Mikolov13]](#Mikolov13), where the vector representation is instead called an embedding.\n", "\n", "The model output for the relationship between head entity $i$ and tail entity $j$ is the dot product of their latent vectors:\n", "\n", "$$ f(i,j) = h_i^T t_j = \\sum_{d=1}^D \\textbf{H}_{di} \\textbf{T}_{dj} $$\n", "\n", "The figure below contains a simple geometric illustration in a *D*=2 space for some of the words shown in the *has_part* matrix above. Since $ h_i^T t_j = \\lVert h_i \\rVert \\lVert t_j \\rVert cos \\left( \\theta_{ij} \\right)$, the model score for (retina, has_part, singapore) will be 0 due to the vectors being orthogonal, whereas the score for (retina, has_part, rod_cell) will be positive and based on the length of each vector and their angle $\\theta$.\n", "\n", "![dot product](pics/mf_dot.svg)\n", "\n", "A common optimization objective for this model is to find latent vectors for all head and tail entities such that the total squared approximation error is minimized. This is similar to a truncated [SVD](https://en.wikipedia.org/wiki/Singular_value_decomposition) but without orthogonality constraints:\n", "\n", "$$ min_{\\textbf{H},\\textbf{T}} \\ \\lVert \\textbf{X} - \\textbf{H}^T \\textbf{T} \\rVert_{Fro}^{2} \\ = \\ min_{\\textbf{H},\\textbf{T}} \\ \\sum_{ij} \\left( \\textbf{X}_{ij} - h_i^T t_j \\right) ^2 $$\n", "\n", "Note that squared error is not the only possible choice, and regularization is often important to control model complexity. We will explore both of these topics as we move through the tutorial. \n", "\n", "Implementing matrix factorization in TensorFlow is pretty simple, but we're going to hold off because we'd rather address multi-relational learning, and it is also a special case of our next model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Tensor Decomposition\n", "\n", "A major limitation of the matrix factorization described above is that it only applies to a single relationship type at a time. For a multi-relational dataset like WordNet, this results in 18 independent models, and anything one model \"learns\" cannot influence the others.\n", "\n", "One way to encourage knowledge transfer across relationships is to stack their matrices into a 3D array called a [tensor](https://en.wikipedia.org/wiki/Tensor). We can then build one factorization model for this tensor which shares the head and tail entity representations across all relationships.\n", "\n", "A natural fit for what we have described is the CANDECOMP / PARAFAC (CP) tensor decomposition [[Hitchcock27]](#Hitchcock27) [[Kolda09]](#Kolda09). We will add a new factor matrix for our relationship types, $\\textbf{R} = [r_1, r_2, ..., r_K] \\in \\unicode{x211D}^{DxK}$, and CP then approximates the target tensor $\\bf{X}$ as the sum of $D$ rank-one tensors, each formed by the outer product of one dimension from the factor matrices. \n", "\n", "![CP](pics/tf1.svg)\n", "\n", "$$ \\textbf{X} \\approx \\sum_{d=1}^{D} \\textbf{H}_{d,:} \\circ \\textbf{T}_{d,:} \\circ \\textbf{R}_{d,:} $$\n", "\n", "where $\\circ$ denotes the outer product, and the subscript $d,:$ denotes a matrix row in Matlab/NumPy slice syntax. The model output for a single cell is similar to matrix factorization, but now each entry in the dot product between head and tail is also multiplied by the corresponding entry in the relationship's vector:\n", "\n", "$$ f(i,j,k) = \\sum_{d=1}^{D} \\textbf{H}_{di} \\textbf{T}_{dj} \\textbf{R}_{dk} $$\n", "\n", "or as a matrix multiplication:\n", "\n", "$$ f(i,j,k) = h_i^T diag(r_k) \\ t_j $$\n", "\n", "where $diag(r_k)$ denotes a $D \\ x \\ D$ matrix with the elements from vector $r_k$ on the diagonal.\n", "\n", "Also as before, when the decomposition is not full-rank a common objective is least-squares approximation:\n", "\n", "$$ min_{\\textbf{H},\\textbf{T},\\textbf{R}} \\ \\sum_{ijk} \\left( \\textbf{X}_{ijk} - f(i,j,k) \\right)^2 $$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Defining a model like this in TensorFlow is straightforward. First, we create placeholders for the triples that the model takes as input, as well as the target (desired output) placeholder for each input triple during training. In our case the target values are boolean and will be represented by 1 and 0. Using *None* as the dimension size means that it will be determined at runtime by the size of the mini-batch passed in for each training step." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "graph = tf.Graph()\n", "with graph.as_default():\n", " # input and target placeholders\n", " head_input = tf.placeholder(tf.int32, shape=[None])\n", " rel_input = tf.placeholder(tf.int32, shape=[None])\n", " tail_input = tf.placeholder(tf.int32, shape=[None])\n", " target = tf.placeholder(tf.float32, shape=[None])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "CP decomposition requires two separate embeddings for the head and tail entities, even if the same entity appears in both sets (other models do not require this, as we'll see). One way of implementing the input and embedding layer would be to one-hot encode each of our three inputs into long vectors where only the entry for the correpsonding item is set to 1. These input vectors would then be connected to the embedding layer in a fully-connected fashion as is typical in neural networks, and the weight matrix would contain the embedding for each item. Instead, TensorFlow provides a special method for this use case, *tf.nn.embedding_lookup()*, which will accomplish the same goal with less code. We'll also need to convert our strings into integers in order to feed them into the input placeholders." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "embedding_size = 20\n", "head_cnt = len(set(train['head']))\n", "rel_cnt = len(set(train['rel']))\n", "tail_cnt = len(set(train['tail']))\n", " \n", "with graph.as_default():\n", " # embedding variables\n", " init_sd = 1.0 / np.sqrt(embedding_size)\n", " head_embedding_vars = tf.Variable(tf.truncated_normal([head_cnt, embedding_size], \n", " stddev=init_sd))\n", " rel_embedding_vars = tf.Variable(tf.truncated_normal([rel_cnt, embedding_size], \n", " stddev=init_sd))\n", " tail_embedding_vars = tf.Variable(tf.truncated_normal([tail_cnt, embedding_size], \n", " stddev=init_sd))\n", " # embedding layer for the (h, r, t) triple being fed in as input\n", " head_embed = tf.nn.embedding_lookup(head_embedding_vars, head_input)\n", " rel_embed = tf.nn.embedding_lookup(rel_embedding_vars, rel_input)\n", " tail_embed = tf.nn.embedding_lookup(tail_embedding_vars, tail_input)\n", " # CP model output\n", " output = tf.reduce_sum(tf.mul(tf.mul(head_embed, rel_embed), tail_embed), 1)\n", " \n", "# TensorFlow requires integer indices\n", "field_categories = (set(train['head']), set(train['rel']), set(train['tail']))\n", "train, train_idx_array = util.make_categorical(train, field_categories)\n", "val, val_idx_array = util.make_categorical(val, field_categories)\n", "test, test_idx_array = util.make_categorical(test, field_categories)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### Optimization and Negative Sampling\n", "\n", "With our model structure defined, we now need to determine how to optimize it on the training data. A popular method for the squared-loss CP model above is alternating least squares, which can efficiently handle the huge number of zero entries typical in a relationship tensor [[Kolda09]](#Kolda09). We'd like to use the nice SGD optimizers that come with TensorFlow, however, and may also want to use other objective functions where alternating least squares is not appropriate. To do so, we must decide how to form our random training mini-batches.\n", "\n", "For this tutorial, we will use an approach that has become popular for training embedding models called negative sampling [[Mikolov13]](#Mikolov13) [[Bordes11]](#Bordes11) [[Socher13]](#Socher13). We will augment each positive example in the mini-batch by creating a negative counterpart with a corruption operation. Corruption is accomplished by replacing either the head or the tail (but not the relationship type) with a random entity, just as we did to create the negative examples in the validation and test sets. While this method is usually used for logistic loss or pairwise loss functions (which we'll examine soon), we will first give it a shot on squared loss.\n", "\n", "Note that we have deviated a bit from the original CP formulation above, as these mini-batches are not a random sample from the overall objective function. Our model will therefore not be exactly the CP decomposition described above, but instead \"CP-like\".\n", "\n", "For brevity I have omitted the *ContrastiveTrainingProvider* class which provides these training mini-batches, but it can be found in the sample code. Because this model creates a separate embedding for head and tail entities, we will tell the batch provider to honor this by setting *separate_head_tail=True*. Let's take a look at an example mini-batch it provides:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[35150, 6, 11760],\n", " [35150, 6, 9035],\n", " [ 4505, 17, 27691],\n", " [ 4505, 17, 10438],\n", " [ 739, 10, 22320],\n", " [ 739, 10, 26591]])" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "which encodes:\n" ] }, { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
headreltaillabel
0__tractable_JJ_1_also_see__manageable_JJ_11
1__tractable_JJ_1_also_see__true_flycatcher_NN_10
2__trash_VB_2_hypernym__pick_at_VB_31
3__trash_VB_2_hypernym__factor_NN_30
4__wish_well_VB_1_derivationally_related_form__wish_NN_21
5__wish_well_VB_1_derivationally_related_form__athletics_NN_20
\n", "
" ], "text/plain": [ " head rel tail \\\n", "0 __tractable_JJ_1 _also_see __manageable_JJ_1 \n", "1 __tractable_JJ_1 _also_see __true_flycatcher_NN_1 \n", "2 __trash_VB_2 _hypernym __pick_at_VB_3 \n", "3 __trash_VB_2 _hypernym __factor_NN_3 \n", "4 __wish_well_VB_1 _derivationally_related_form __wish_NN_2 \n", "5 __wish_well_VB_1 _derivationally_related_form __athletics_NN_2 \n", "\n", " label \n", "0 1 \n", "1 0 \n", "2 1 \n", "3 0 \n", "4 1 \n", "5 0 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from tf_rl_tutorial.util import ContrastiveTrainingProvider\n", "\n", "batch_provider = ContrastiveTrainingProvider(train_idx_array, batch_pos_cnt=3, \n", " separate_head_tail=True)\n", "batch_triples, batch_labels = batch_provider.next_batch()\n", "batch_df = pd.DataFrame()\n", "batch_df['head'] = pd.Categorical.from_codes(batch_triples[:,0], train['head'].cat.categories)\n", "batch_df['rel'] = pd.Categorical.from_codes(batch_triples[:,1], train['rel'].cat.categories)\n", "batch_df['tail'] = pd.Categorical.from_codes(batch_triples[:,2], train['tail'].cat.categories)\n", "batch_df['label'] = batch_labels\n", "display(batch_triples)\n", "print('which encodes:')\n", "display(batch_df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are now ready to train. We'll monitor progress by periodically printing the average loss on both the current training mini-batch and the validation set. We'll also track how many of the contrastive pairs in the validation set are ranked properly, which means that the positive triple receives a higher score than the negative triple. This is accomplished using *util.pair_ranking_accuracy()*." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iter 0, batch loss: 0.5, val loss: 0.5, val pair ranking acc: 0.51\n", "iter 2000, batch loss: 0.5, val loss: 0.54, val pair ranking acc: 0.49\n", "iter 4000, batch loss: 0.12, val loss: 0.55, val pair ranking acc: 0.55\n", "iter 6000, batch loss: 0.11, val loss: 0.55, val pair ranking acc: 0.56\n", "iter 8000, batch loss: 0.077, val loss: 0.53, val pair ranking acc: 0.57\n", "iter 10000, batch loss: 0.072, val loss: 0.52, val pair ranking acc: 0.58\n", "iter 12000, batch loss: 0.069, val loss: 0.51, val pair ranking acc: 0.58\n", "iter 14000, batch loss: 0.077, val loss: 0.51, val pair ranking acc: 0.59\n", "iter 16000, batch loss: 0.061, val loss: 0.5, val pair ranking acc: 0.6\n", "iter 18000, batch loss: 0.055, val loss: 0.5, val pair ranking acc: 0.6\n", "iter 20000, batch loss: 0.036, val loss: 0.49, val pair ranking acc: 0.6\n", "iter 22000, batch loss: 0.043, val loss: 0.49, val pair ranking acc: 0.61\n", "iter 24000, batch loss: 0.05, val loss: 0.48, val pair ranking acc: 0.62\n", "iter 26000, batch loss: 0.059, val loss: 0.48, val pair ranking acc: 0.62\n", "iter 28000, batch loss: 0.039, val loss: 0.48, val pair ranking acc: 0.62\n", "iter 29999, batch loss: 0.047, val loss: 0.47, val pair ranking acc: 0.63\n" ] } ], "source": [ "max_iter = 30000\n", "\n", "batch_provider = ContrastiveTrainingProvider(train_idx_array, batch_pos_cnt=100, \n", " separate_head_tail=True)\n", "opt = tf.train.AdagradOptimizer(1.0)\n", "\n", "sess = tf.Session(graph=graph)\n", "with graph.as_default():\n", " loss = tf.reduce_sum(tf.square(output - target))\n", " train_step = opt.minimize(loss)\n", " sess.run(tf.initialize_all_variables())\n", "\n", "# feed dict for monitoring progress on validation set\n", "val_labels = np.array(val['truth_flag'], dtype=np.float)\n", "val_feed_dict = {head_input: val_idx_array[:,0],\n", " rel_input: val_idx_array[:,1],\n", " tail_input: val_idx_array[:,2],\n", " target: val_labels}\n", "\n", "for i in range(max_iter):\n", " batch_triples, batch_labels = batch_provider.next_batch()\n", " feed_dict = {head_input: batch_triples[:,0],\n", " rel_input: batch_triples[:,1],\n", " tail_input: batch_triples[:,2],\n", " target: batch_labels}\n", " if i % 2000 == 0 or i == (max_iter-1):\n", " batch_avg_loss = sess.run(loss, feed_dict) / len(batch_labels)\n", " val_output, val_loss = sess.run((output,loss), val_feed_dict)\n", " val_avg_loss = val_loss / len(val_labels)\n", " val_pair_ranking_acc = util.pair_ranking_accuracy(val_output)\n", " msg = 'iter {}, batch loss: {:.2}, val loss: {:.2}, val pair ranking acc: {:.2}'\n", " print(msg.format(i, batch_avg_loss, val_avg_loss, val_pair_ranking_acc))\n", " sess.run(train_step, feed_dict)\n", "\n", "sess.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Not so good. It is achieving a very tight fit on the training set but making little progress on the validation set, a sign of overfitting.\n", "\n", "\n", "\n", "### Max-Norm Regularization\n", "\n", "There are a variety of things we could do here, for example lowering the number of model parameters by reducing the embedding dimension. A popular approach for latent factor models is instead to add a regularization term to our objective function that operates on the model parameters $\\Omega$:\n", "\n", "$$ min_{\\Omega} \\ \\ Loss+ \\lambda Reg(\\Omega) $$\n", "\n", "The $L_2$ (ridge) penalty is a common choice which encourages the sum of the squares of the embeddings to be small:\n", "\n", "$$ Reg_{L_{2}} ( \\textbf{H},\\textbf{T},\\textbf{R} ) = \\lVert \\textbf{H} \\rVert_{Fro}^2 + \\lVert \\textbf{T} \\rVert_{Fro}^2 + \\lVert \\textbf{R} \\rVert_{Fro}^2 $$\n", "\n", "We could easily add this form of regularization to our CP model with a single line of code, but it creates a bit of a nuisance: it destroys the nice sparse updates we get in SGD and can slow down the optimization process significantly. This is because SGD subsamples the overall loss using a subset of training examples, and doing this creates 0 gradient and no updates for embeddings not present in the mini-batch (**note**: this is not true for all optimizers, see caveat below). The regularization term, however, is data-independent and creates a gradient for all embeddings on every SGD step. There are some clever approaches to handle this issue [[Bottou12]](#Bottou12) [[Carpenter08]](#Carpenter08), but they would add a bit of extra complexity to our code. Other researchers have opted to only apply the regularization gradient to parameters active in the mini-batch [[Koren09]](#Koren09).\n", "\n", "An alternative way to control model complexity, which will also preserve the sparsity of the SGD gradients, is to instead constrain all embedding vectors to lie within an $L_2$ ball of radius $C$. This is called max-norm regularization, and has been shown to be effective in a variety of settings including relational learning models [[Sbrero04]](#Sbrero04) [[Srivastava14]](#Srivastava14) [[Bordes13]](#Bordes13). Our optimization problem is now:\n", "\n", "$$ min_{\\textbf{H},\\textbf{T},\\textbf{R}} \\ Loss \\\\\n", " s.t. \\\\\n", " \\forall i \\ \\ \\lVert h_i \\rVert <= C, \\ \\ \n", " \\forall j \\ \\ \\lVert t_j \\rVert <= C, \\ \\\n", " \\forall k \\ \\ \\lVert r_k \\rVert <= C\n", "$$\n", "\n", "We can enforce these constraints by using projected gradient descent: after each step, the embedding vectors that have moved outside of the $L_2$ ball are projected back onto it. If the optimizer only updates embedding vectors with nonzero gradient, the ones present in each mini-batch, we only need to check this small subset after each step. \n", "![MaxNorm](pics/l2_maxnorm.svg)\n", "This form of constraint is not directly supported by TensorFlow, but we can implement it by tracking which rows in the embedding matrices could have changed, and using the *scatter_update()* function to apply sparse projection updates. I'm not sure if this is the best way to accomplish the goal, it's just the one I came up with when learning the TensorFlow API.\n", "\n", "**Caveat**: As of when this tutorial was written, all of the optimizers I tried in TensorFlow ignore variables with zero gradient except *AdamOptimizer*, which can apply momentum updates to them. The full code example therefore uses a dense projection operation when Adam is used, and an efficient sparse update otherwise." ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "collapsed": true }, "outputs": [], "source": [ "maxnorm = 1.5\n", "\n", "def sparse_maxnorm_update(var_matrix, indices, maxnorm=1.0):\n", " selected_rows = tf.nn.embedding_lookup(var_matrix, indices)\n", " row_norms = tf.sqrt(tf.reduce_sum(tf.square(selected_rows), 1))\n", " scaling = maxnorm / tf.maximum(row_norms, maxnorm)\n", " scaled = selected_rows * tf.expand_dims(scaling, 1)\n", " return tf.scatter_update(var_matrix, indices, scaled)\n", "\n", "def dense_maxnorm_update(var_matrix, maxnorm=1.0):\n", " row_norms = tf.sqrt(tf.reduce_sum(tf.square(var_matrix), 1))\n", " scaling = maxnorm / tf.maximum(row_norms, maxnorm)\n", " scaled = var_matrix * tf.expand_dims(scaling, 1)\n", " return tf.assign(var_matrix, scaled)\n", "\n", "with graph.as_default():\n", " # tf.unique used to gather the head, tail, and rel indices that are active in the minibatch\n", " head_constraint = sparse_maxnorm_update(head_embedding_vars, \n", " tf.unique(head_input)[0], maxnorm)\n", " rel_constraint = sparse_maxnorm_update(rel_embedding_vars, \n", " tf.unique(rel_input)[0], maxnorm)\n", " tail_constraint = sparse_maxnorm_update(tail_embedding_vars, \n", " tf.unique(tail_input)[0], maxnorm)\n", " postprocess_step = [head_constraint, rel_constraint, tail_constraint]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now run the max-norm constraint operation after each training step and see if we get a better result:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iter 0, batch loss: 0.5, val loss: 0.5, val pair ranking acc: 0.5, max norm: 1.35\n", "iter 2000, batch loss: 0.39, val loss: 0.5, val pair ranking acc: 0.55, max norm: 1.5\n", "iter 4000, batch loss: 0.24, val loss: 0.43, val pair ranking acc: 0.65, max norm: 1.5\n", "iter 6000, batch loss: 0.12, val loss: 0.4, val pair ranking acc: 0.68, max norm: 1.5\n", "iter 8000, batch loss: 0.15, val loss: 0.38, val pair ranking acc: 0.71, max norm: 1.5\n", "iter 10000, batch loss: 0.1, val loss: 0.37, val pair ranking acc: 0.73, max norm: 1.5\n", "iter 12000, batch loss: 0.11, val loss: 0.36, val pair ranking acc: 0.73, max norm: 1.5\n", "iter 14000, batch loss: 0.088, val loss: 0.35, val pair ranking acc: 0.75, max norm: 1.5\n", "iter 16000, batch loss: 0.11, val loss: 0.34, val pair ranking acc: 0.76, max norm: 1.5\n", "iter 18000, batch loss: 0.091, val loss: 0.34, val pair ranking acc: 0.76, max norm: 1.5\n", "iter 20000, batch loss: 0.11, val loss: 0.33, val pair ranking acc: 0.77, max norm: 1.5\n", "iter 22000, batch loss: 0.081, val loss: 0.33, val pair ranking acc: 0.77, max norm: 1.5\n", "iter 24000, batch loss: 0.096, val loss: 0.33, val pair ranking acc: 0.77, max norm: 1.5\n", "iter 26000, batch loss: 0.076, val loss: 0.33, val pair ranking acc: 0.76, max norm: 1.5\n", "iter 28000, batch loss: 0.085, val loss: 0.32, val pair ranking acc: 0.77, max norm: 1.5\n", "iter 29999, batch loss: 0.085, val loss: 0.32, val pair ranking acc: 0.76, max norm: 1.5\n" ] } ], "source": [ "max_iter = 30000\n", "\n", "sess = tf.Session(graph=graph)\n", "with graph.as_default():\n", " # init variables and ensure constraints are initially satisfied\n", " sess.run(tf.initialize_all_variables())\n", " sess.run(dense_maxnorm_update(head_embedding_vars, maxnorm))\n", " sess.run(dense_maxnorm_update(rel_embedding_vars, maxnorm))\n", " sess.run(dense_maxnorm_update(tail_embedding_vars, maxnorm))\n", "\n", "for i in range(max_iter):\n", " batch_triples, batch_labels = batch_provider.next_batch()\n", " feed_dict = {head_input: batch_triples[:,0],\n", " rel_input: batch_triples[:,1],\n", " tail_input: batch_triples[:,2],\n", " target: batch_labels}\n", " if i % 2000 == 0 or i == (max_iter-1):\n", " batch_avg_loss = sess.run(loss, feed_dict) / len(batch_labels)\n", " val_output, val_loss = sess.run((output,loss), val_feed_dict)\n", " val_avg_loss = val_loss / len(val_labels)\n", " val_pair_ranking_acc = util.pair_ranking_accuracy(val_output)\n", " # check on embedding norm constraints\n", " all_embed = np.vstack(sess.run([head_embedding_vars, \n", " rel_embedding_vars, \n", " tail_embedding_vars]))\n", " norms = np.linalg.norm(all_embed, axis=1)\n", " msg = 'iter {}, batch loss: {:.2}, val loss: {:.2}, val pair ranking acc: {:.2}, max norm: {:.3}'\n", " print(msg.format(i, batch_avg_loss, val_avg_loss, val_pair_ranking_acc, np.max(norms)))\n", "\n", " sess.run(train_step, feed_dict)\n", " sess.run(postprocess_step, feed_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "While there is still a large gap between training and validation loss, this variant is faring better than its unregularized counterpart. Let's see how it does on the test set.\n", "\n", "For evaluation on the test set we will not just report the pair ranking accuracy, but will also turn our model into a boolean classifier and measure prediction accuracy. To do this we'll need to threshold the real-valued model output. As is done in [[Socher13]](#Socher13), the thresholds for each relationship type will be found on the validation set, and this code is in *util.threshold_and_eval()*." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test set accuracy: 0.71\n" ] } ], "source": [ "test_feed_dict = {head_input: test_idx_array[:,0],\n", " rel_input: test_idx_array[:,1],\n", " tail_input: test_idx_array[:,2]}\n", "\n", "test_scores = sess.run(output, test_feed_dict)\n", "val_scores = sess.run(output, val_feed_dict)\n", "acc, pred, scores, thresh_map = util.threshold_and_eval(test, test_scores, \n", " val, val_scores)\n", "print('Test set accuracy: {:.2}'.format(acc))\n", "sess.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Is that the best it can do? Probably not, as there are a number of things we could experiment with like the max-norm constraint, embedding dimension, optimizer type, optimizer parameters, alternate loss functions, how many negatives we sample per positive, etc. For now, let's move on and look at some alternative loss functions and more models." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Loss Functions\n", "\n", "### Per-Instance Loss\n", "So far we have used squared loss as the objective for model training which implies a normally-distributed likelihood. For binary data, a Bernoulli model is a natural choice and only requires two changes: we squash the output with a sigmoid function, and switch to cross-entropy loss. Another popular choice is the hinge loss from SVM and other max-margin models. \n", "\n", "### Pairwise Ranking Margin\n", "The loss functions above operate on individual positive and negative training triples, and encourage the model to map known positive examples to 1 and all others to 0. The contrastive sampling we are utilizing suggests an alternative objective function we can optimize. Instead of trying to map each positive instance to 1 and its corrupted partner to 0, we could instead ask the model to just try to score the positive instance higher by a certain margin. If this could be done for all contrastive pairs, then our positive training examples would all be ranked above the others. \n", "\n", "Instead of attempting to satisfy this constraint exactly, we can optimize the so-called soft margin of hinge-loss penalties in a similar fashion as RankSVM [[Joachims06]](#Joachims06).\n", "\n", "$$ RankingLoss\\left((i,j,k), (i',j',k)\\right) = [\\gamma - f(i,j,k) + f(i',j',k) ]_{+} $$\n", "\n", "where $(i,j,k)$ is the positive example, $(i',j',k)$ is the corrupted triple where either the head $i$ or the tail $j$ has been modified, $\\gamma$ is the margin, and $[x]_{+} = max(0, x)$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is where automatic differentiation really shines, as we can easily experiment with these loss function variations with just a few lines of code. Since we've already used squared loss in our first attempt, we'll switch to others in the upcoming models." ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def least_squares_objective(output, target, add_bias=True):\n", " y = output\n", " if add_bias:\n", " bias = tf.Variable([0.0])\n", " y = output + bias\n", " loss = tf.reduce_sum(tf.square(y - target))\n", " return y, loss\n", "\n", "def logistic_objective(output, target, add_bias=True):\n", " y = output\n", " if add_bias:\n", " bias = tf.Variable([0.0])\n", " y = output + bias\n", " squashed_y = tf.clip_by_value(tf.sigmoid(y), 0.001, 0.999) # avoid NaNs\n", " loss = -tf.reduce_sum(target*tf.log(squashed_y) + (1-target)*tf.log(1-squashed_y))\n", " return squashed_y, loss\n", "\n", "def ranking_margin_objective(output, margin=1.0):\n", " ''' This only works when given model output on alternating\n", " positive/negative pairs: [pos,neg,pos,neg,...] '''\n", " y_pairs = tf.reshape(output, [-1,2]) # fold: 1 x n -> [n/2 x 2]\n", " pos_scores, neg_scores = tf.split(1, 2, y_pairs) # separate the pairs\n", " hinge_losses = tf.nn.relu(margin - pos_scores + neg_scores)\n", " total_hinge_loss = tf.reduce_sum(hinge_losses)\n", " return output, total_hinge_loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Bilinear Model (RESCAL)\n", "\n", "RESCAL [[Nickel11]](#Nickel11) employs two changes to the basic CP decomposition:\n", "\n", "* Entities have a single embedding space; each entity is mapped to the same embedding vector no matter its position in the relationship triple\n", "* Each relationship type is embedded into a larger $D^2$ space which operates as a $D \\ x \\ D$ bilinear operator on the entity embeddings\n", "\n", "The figure below is based on one from the original paper that I think does a nice job depicting this decomposition:\n", "\n", "![RESCAL](pics/rescal.svg)\n", "\n", "The two separate $\\textbf{H}$ and $\\textbf{T}$ matrices from before have been replaced with single embedding matrix for all entities, $\\textbf{E} = [e_1,e_2,...,e_N] \\in \\unicode{x211D}^{DxN}$, and the model output for an input triple is:\n", "\n", "$$ f(i,j,k) = e_i^T \\textbf{R}_k e_j $$\n", "\n", "where $\\textbf{R}_k$ is a $D \\ x \\ D$ matrix for relationship type $k$.\n", "\n", "This is a more flexible tensor decomposition than CP, as the relationship matrix introduces interaction terms between each entry in the embedding vectors. As shown before, CP can be seen as a special case of this formulation where each relationship type operates as a diagonal scaling matrix instead of a full bilinear one.\n", "\n", "Implementing a bilinear model requires only a few changes to our code above. We are going to start doing things a little differently, however, and use a common base class for our models which contains reusable code. Each model only has to implement the *_create_model()* method to define its structure, training step, and postprocessing step.\n", "\n", "We will also deviate a bit from the original RESCAL paper which proposes $L_2$ regularization, and instead stick to our max-norm constraints. A bit of experimentation has shown that it can help to use a larger max-norm for the relationship embeddings, which can be specified using *rel_maxnorm_mult*. Finally, we'll switch to logistic loss instead of squared loss, and again utilize a contrastive negative sampling approach." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from tf_rl_tutorial.models import BaseModel, dense_maxnorm \n", "\n", "class Bilinear(BaseModel):\n", " \n", " def __init__(self, embedding_size, maxnorm=1.0, rel_maxnorm_mult=3.0, \n", " batch_pos_cnt=100, max_iter=1000, \n", " model_type='least_squares', add_bias=True, opt=None):\n", " super(Bilinear, self).__init__(\n", " embedding_size=embedding_size,\n", " maxnorm=maxnorm,\n", " batch_pos_cnt=batch_pos_cnt,\n", " max_iter=max_iter,\n", " model_type=model_type,\n", " opt=opt)\n", " self.rel_maxnorm_mult = rel_maxnorm_mult\n", " \n", " def _create_model(self, train_triples):\n", " # Count unique items to determine embedding matrix sizes\n", " entity_cnt = len(set(train_triples[:,0]).union(train_triples[:,2]))\n", " rel_cnt = len(set(train_triples[:,1]))\n", " init_sd = 1.0 / np.sqrt(self.embedding_size)\n", " # Embedding variables for all entities and relationship types\n", " entity_embedding_shape = [entity_cnt, self.embedding_size]\n", " # Relationship embeddings will be stored in flattened format to make \n", " # applying maxnorm constraints easier\n", " rel_embedding_shape = [rel_cnt, self.embedding_size * self.embedding_size]\n", " entity_init = tf.truncated_normal(entity_embedding_shape, stddev=init_sd)\n", " rel_init = tf.truncated_normal(rel_embedding_shape, stddev=init_sd)\n", " if self.maxnorm is not None:\n", " # Ensure maxnorm constraints are initially satisfied\n", " entity_init = dense_maxnorm(entity_init, self.maxnorm)\n", " rel_init = dense_maxnorm(rel_init, self.maxnorm)\n", " self.entity_embedding_vars = tf.Variable(entity_init)\n", " self.rel_embedding_vars = tf.Variable(rel_init)\n", " # Embedding layer for each (head, rel, tail) triple being fed in as input\n", " head_embed = tf.nn.embedding_lookup(self.entity_embedding_vars, self.head_input)\n", " tail_embed = tf.nn.embedding_lookup(self.entity_embedding_vars, self.tail_input)\n", " rel_embed = tf.nn.embedding_lookup(self.rel_embedding_vars, self.rel_input)\n", " # Reshape rel_embed into square D x D matrices\n", " rel_embed_square = tf.reshape(rel_embed, (-1, self.embedding_size, self.embedding_size))\n", " # Reshape head_embed and tail_embed to be suitable for the matrix multiplication\n", " head_embed_row = tf.expand_dims(head_embed, 1) # embeddings as row vectors\n", " tail_embed_col = tf.expand_dims(tail_embed, 2) # embeddings as column vectors\n", " head_rel_mult = tf.batch_matmul(head_embed_row, rel_embed_square)\n", " # Output needs a squeeze into a 1d vector\n", " raw_output = tf.squeeze(tf.batch_matmul(head_rel_mult, tail_embed_col)) \n", " self.output, self.loss = self._create_output_and_loss(raw_output)\n", " # Optimization\n", " self.train_step = self.opt.minimize(self.loss)\n", " if self.maxnorm is not None:\n", " # Post-processing to limit embedding vars to L2 ball\n", " rel_maxnorm = self.maxnorm * self.rel_maxnorm_mult\n", " unique_ent_indices = tf.unique(tf.concat(0, [self.head_input, self.tail_input]))[0]\n", " unique_rel_indices = tf.unique(self.rel_input)[0]\n", " entity_constraint = self._norm_constraint_op(self.entity_embedding_vars, \n", " unique_ent_indices, \n", " self.maxnorm)\n", " rel_constraint = self._norm_constraint_op(self.rel_embedding_vars, \n", " unique_rel_indices, \n", " rel_maxnorm)\n", " self.post_step = [entity_constraint, rel_constraint]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Because the rest of our models will use a single embedding space for all entities, we need to redo the categorical mappings:" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "collapsed": true }, "outputs": [], "source": [ "all_entities = set(train['head']).union(set(train['tail']))\n", "field_categories = (all_entities, set(train['rel']), all_entities)\n", "train, train_idx_array = util.make_categorical(train, field_categories)\n", "val, val_idx_array = util.make_categorical(val, field_categories)\n", "test, test_idx_array = util.make_categorical(test, field_categories)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "itr 0, batch loss: 0.69, val loss: 0.69, val pair ranking acc: 0.5\n", "itr 2000, batch loss: 0.66, val loss: 0.7, val pair ranking acc: 0.53\n", "itr 4000, batch loss: 0.58, val loss: 0.68, val pair ranking acc: 0.61\n", "itr 6000, batch loss: 0.53, val loss: 0.66, val pair ranking acc: 0.65\n", "itr 8000, batch loss: 0.48, val loss: 0.65, val pair ranking acc: 0.69\n", "itr 10000, batch loss: 0.52, val loss: 0.64, val pair ranking acc: 0.7\n", "itr 12000, batch loss: 0.38, val loss: 0.64, val pair ranking acc: 0.71\n", "itr 14000, batch loss: 0.43, val loss: 0.62, val pair ranking acc: 0.72\n", "itr 16000, batch loss: 0.43, val loss: 0.62, val pair ranking acc: 0.74\n", "itr 18000, batch loss: 0.39, val loss: 0.61, val pair ranking acc: 0.73\n", "itr 20000, batch loss: 0.4, val loss: 0.61, val pair ranking acc: 0.76\n", "itr 22000, batch loss: 0.39, val loss: 0.6, val pair ranking acc: 0.78\n", "itr 24000, batch loss: 0.4, val loss: 0.59, val pair ranking acc: 0.78\n", "itr 26000, batch loss: 0.4, val loss: 0.58, val pair ranking acc: 0.78\n", "itr 28000, batch loss: 0.35, val loss: 0.58, val pair ranking acc: 0.79\n", "itr 29999, batch loss: 0.35, val loss: 0.56, val pair ranking acc: 0.8\n", "Test set accuracy: 0.75\n" ] } ], "source": [ "bilinear = Bilinear(embedding_size=20,\n", " maxnorm=1.0,\n", " rel_maxnorm_mult=6.0,\n", " batch_pos_cnt=100, \n", " max_iter=30000,\n", " model_type='logistic')\n", "\n", "val_feed_dict = bilinear.create_feed_dict(val_idx_array, val_labels)\n", "\n", "def train_step_callback(itr, batch_feed_dict):\n", " if (itr % 2000) == 0 or (itr == (bilinear.max_iter-1)):\n", " batch_size = len(batch_feed_dict[bilinear.target])\n", " batch_avg_loss = bilinear.sess.run(bilinear.loss, batch_feed_dict) / batch_size\n", " val_output, val_loss = bilinear.sess.run((bilinear.output, bilinear.loss), \n", " val_feed_dict)\n", " val_avg_loss = val_loss / len(val_labels)\n", " val_pair_ranking_acc = util.pair_ranking_accuracy(val_output)\n", " msg = 'itr {}, batch loss: {:.2}, val loss: {:.2}, val pair ranking acc: {:.2}'\n", " print(msg.format(itr, batch_avg_loss, val_avg_loss, val_pair_ranking_acc))\n", " return True\n", "\n", "bilinear.fit(train_idx_array, train_step_callback)\n", "\n", "acc, pred, scores, thresh_map = util.model_threshold_and_eval(bilinear, test, val)\n", "print('Test set accuracy: {:.2}'.format(acc))\n", "bilinear.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## TransE\n", "\n", "The TransE model [[Bordes13]](#Bordes13) is different than the ones we have built so far because it is not based on tensor decomposition. It also embeds entities into a vector space, but instead treats relationship types as translations between the entity vectors. The model scores a triple based on a dissimilarity function between the head and tail entity after applying the relationship's translation to the head. This formulation was inspired by previous work on word embedding models, where researchers discovered after the fact that translations in the embedding space captured certain types of relationships between them.\n", "\n", "TransE is shown below in a 2D embedding space for both a perfect match (left) and a depiction of what *has_part* might look like (right):\n", "\n", "![TransE](pics/transe.svg)\n", "\n", "The model output for a triple is:\n", "\n", "$$ f(i,j,k) = -d(e_i+r_k, e_j) $$\n", "\n", "and for Euclidean distance:\n", "\n", "$$ d(e_i+r_k, e_j) = \\lVert e_i + r_k - e_j \\rVert $$\n", "\n", "Other distances can be used, such as Manhattan, squared Euclidean, etc. The distance is negated in order to keep with our convention that models should output higher scores for true statements, here 0 being the best possible score. For training, the pairwise ranking loss described above is used with the margin $\\gamma$ as a configurable hyperparameter. The loss across all training triples $S$ and their corrupted counterparts $S'$ is:\n", "\n", "$$ Loss = \\sum_{(i,j,k) \\in S} \\ \\sum_{(i',j',k) \\in S_{(i,j,k)}'} [\\gamma + d(e_i+r_k, e_j) - d(e_{i'}+r_{k'}, e_{j'})]_{+} $$\n", "\n", "which is optimized by sampling positive and negative pairs in mini-batch SGD, just as we have been doing with our previous two models. In the paper, the authors propose constraining the embedding vectors to lie on the unit sphere: $\\forall i \\ ||e_i|| = 1$. Here we will continue to use a max-norm constraint, which is less restrictive and allows them to instead lie within the unit ball. The relationship translation vectors are left unconstrained." ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class TransE(BaseModel):\n", "\n", " def __init__(self, embedding_size, batch_pos_cnt=100, \n", " max_iter=1000, dist='euclidean', \n", " margin=1.0, opt=None):\n", " super(TransE, self).__init__(embedding_size=embedding_size,\n", " maxnorm=1.0,\n", " batch_pos_cnt=batch_pos_cnt,\n", " max_iter=max_iter,\n", " model_type='ranking_margin',\n", " opt=opt)\n", " self.dist = dist\n", " self.margin = margin\n", " self.EPS = 1e-3 # for sqrt gradient when dist='euclidean'\n", " \n", " def _create_model(self, train_triples):\n", " # Count unique items to determine embedding matrix sizes\n", " entity_cnt = len(set(train_triples[:,0]).union(train_triples[:,2]))\n", " rel_cnt = len(set(train_triples[:,1]))\n", " init_sd = 1.0 / np.sqrt(self.embedding_size)\n", " # Embedding variables\n", " entity_var_shape = [entity_cnt, self.embedding_size]\n", " rel_var_shape = [rel_cnt, self.embedding_size]\n", " entity_init = tf.truncated_normal(entity_var_shape, stddev=init_sd)\n", " rel_init = tf.truncated_normal(rel_var_shape, stddev=init_sd)\n", " # Ensure maxnorm constraints are initially satisfied\n", " entity_init = dense_maxnorm(entity_init, self.maxnorm)\n", " self.entity_embedding_vars = tf.Variable(entity_init)\n", " self.rel_embedding_vars = tf.Variable(rel_init)\n", " # Embedding layer for each (head, rel, tail) triple being fed in as input\n", " head_embed = tf.nn.embedding_lookup(self.entity_embedding_vars, self.head_input)\n", " tail_embed = tf.nn.embedding_lookup(self.entity_embedding_vars, self.tail_input)\n", " rel_embed = tf.nn.embedding_lookup(self.rel_embedding_vars, self.rel_input)\n", " # Relationship vector acts as a translation in entity embedding space\n", " diff_vec = tail_embed - (head_embed + rel_embed)\n", " # negative dist so higher scores are better (important for pairwise loss)\n", " if self.dist == 'manhattan':\n", " raw_output = -tf.reduce_sum(tf.abs(diff_vec), 1)\n", " elif self.dist == 'euclidean':\n", " # +eps because gradients can misbehave for small values in sqrt\n", " raw_output = -tf.sqrt(tf.reduce_sum(tf.square(diff_vec), 1) + self.EPS)\n", " elif self.dist == 'sqeuclidean':\n", " raw_output = -tf.reduce_sum(tf.square(diff_vec), 1)\n", " else:\n", " raise Exception('Unknown distance type')\n", " # Model output\n", " self.output, self.loss = ranking_margin_objective(raw_output, self.margin)\n", " # Optimization with postprocessing to limit embedding vars to L2 ball\n", " self.train_step = self.opt.minimize(self.loss)\n", " unique_ent_indices = tf.unique(tf.concat(0, [self.head_input, self.tail_input]))[0]\n", " self.post_step = self._norm_constraint_op(self.entity_embedding_vars, \n", " unique_ent_indices, \n", " self.maxnorm)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "itr 0, batch loss: 0.48, val loss: 0.5, val pair ranking acc: 0.5\n", "itr 2000, batch loss: 0.36, val loss: 0.41, val pair ranking acc: 0.63\n", "itr 4000, batch loss: 0.28, val loss: 0.38, val pair ranking acc: 0.69\n", "itr 6000, batch loss: 0.24, val loss: 0.34, val pair ranking acc: 0.77\n", "itr 8000, batch loss: 0.15, val loss: 0.31, val pair ranking acc: 0.82\n", "itr 10000, batch loss: 0.17, val loss: 0.29, val pair ranking acc: 0.86\n", "itr 12000, batch loss: 0.1, val loss: 0.27, val pair ranking acc: 0.88\n", "itr 14000, batch loss: 0.12, val loss: 0.25, val pair ranking acc: 0.89\n", "itr 16000, batch loss: 0.094, val loss: 0.24, val pair ranking acc: 0.91\n", "itr 18000, batch loss: 0.082, val loss: 0.23, val pair ranking acc: 0.91\n", "itr 20000, batch loss: 0.11, val loss: 0.22, val pair ranking acc: 0.92\n", "itr 22000, batch loss: 0.1, val loss: 0.22, val pair ranking acc: 0.92\n", "itr 24000, batch loss: 0.089, val loss: 0.21, val pair ranking acc: 0.92\n", "itr 26000, batch loss: 0.094, val loss: 0.21, val pair ranking acc: 0.93\n", "itr 28000, batch loss: 0.096, val loss: 0.2, val pair ranking acc: 0.92\n", "itr 29999, batch loss: 0.083, val loss: 0.2, val pair ranking acc: 0.92\n", "Test set accuracy: 0.87\n" ] } ], "source": [ "transE = TransE(embedding_size=20,\n", " margin=1.0,\n", " dist='euclidean',\n", " batch_pos_cnt=100, \n", " max_iter=30000)\n", "\n", "val_feed_dict = transE.create_feed_dict(val_idx_array, val_labels)\n", "\n", "def train_step_callback(itr, batch_feed_dict):\n", " if (itr % 2000) == 0 or (itr == (transE.max_iter-1)):\n", " batch_size = len(batch_feed_dict[transE.target])\n", " batch_avg_loss = transE.sess.run(transE.loss, batch_feed_dict) / batch_size\n", " val_output, val_loss = transE.sess.run((transE.output, transE.loss), val_feed_dict)\n", " val_avg_loss = val_loss / len(val_labels)\n", " val_pair_ranking_acc = util.pair_ranking_accuracy(val_output)\n", " msg = 'itr {}, batch loss: {:.2}, val loss: {:.2}, val pair ranking acc: {:.2}'\n", " print(msg.format(itr, batch_avg_loss, val_avg_loss, val_pair_ranking_acc))\n", " return True\n", "\n", "transE.fit(train_idx_array, train_step_callback)\n", "\n", "acc, pred, scores, thresh_map = util.model_threshold_and_eval(transE, test, val)\n", "print('Test set accuracy: {:.2}'.format(acc))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is not too bad, 87% is quite a bit better than our previous attempts." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Take It For a Spin\n", "\n", "Let's take a look at the top 10 scoring triples for a single head entity and relationship type. We'll use (brain cell, part of, ?):" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "26308 ['__brain_cell_NN_1']\n", "4 ['_part_of']\n", "\n", "Top 10 matches:\n", "['__brain_cell_NN_1' '__encephalon_NN_1' '__temporal_lobe_NN_1'\n", " '__funiculus_NN_2' '__neural_structure_NN_1' '__pallium_NN_1'\n", " '__systema_nervosum_NN_1' '__external_body_part_NN_1'\n", " '__male_reproductive_system_NN_1' '__hypothalamus_NN_1']\n", "\n", "Training triples for entity:\n" ] }, { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
headreltail
1325__encephalon_NN_1_has_part__brain_cell_NN_1
20593__brain_cell_NN_1_hyponym__golgi_cell_NN_1
41067__golgi_cell_NN_1_hypernym__brain_cell_NN_1
55004__brain_cell_NN_1_part_of__encephalon_NN_1
\n", "
" ], "text/plain": [ " head rel tail\n", "1325 __encephalon_NN_1 _has_part __brain_cell_NN_1\n", "20593 __brain_cell_NN_1 _hyponym __golgi_cell_NN_1\n", "41067 __golgi_cell_NN_1 _hypernym __brain_cell_NN_1\n", "55004 __brain_cell_NN_1 _part_of __encephalon_NN_1" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "entity = '__brain_cell_NN_1'\n", "relationship_type = '_part_of'\n", "entity_cats = train['head'].cat.categories\n", "rel_cats = train['rel'].cat.categories\n", "entity_code = (entity_cats == entity).argmax()\n", "has_part_code = (rel_cats == relationship_type).argmax()\n", "\n", "print(entity_code, pd.Categorical.from_codes([entity_code], entity_cats).astype('str'))\n", "print(has_part_code, pd.Categorical.from_codes([has_part_code], rel_cats).astype('str'))\n", "\n", "query_triples = np.zeros((len(entity_cats), 3))\n", "query_triples[:,0] = entity_code\n", "query_triples[:,1] = has_part_code\n", "query_triples[:,2] = range(len(entity_cats))\n", "\n", "scores = transE.predict(query_triples)\n", "sorted_idxs = np.argsort(scores)[::-1]\n", "print()\n", "print('Top 10 matches:')\n", "print(np.array(entity_cats)[sorted_idxs][:10])\n", "\n", "example_train_rows = (train['head'] == entity) | (train['tail'] == entity)\n", "print('\\nTraining triples for entity:')\n", "display(train.loc[example_train_rows])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some good predictions in there, especially considering how few training examples included the word directly (shown at the bottom). One error is the object itself, and the other errors are at least body parts. Not all queries I've tried perform as well as this one, however!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Visualization\n", "\n", "Another useful feature of these relational learning models is that the embedding vectors can be used for clustering and visualization. Let's take a look at some of them laid out in a 2D scatter using Barnes-Hut T-SNE [[VanDerMaaten14]](#VanDerMaaten14). \n", "This requires [scikit-learn](http://scikit-learn.org/) 0.17 or above and the [bokeh](http://bokeh.pydata.org/) plotting library.\n", "\n", "We will color each embedding point by the word's top-level category in WordNet, called its \"lexname\". I have included a file with these in the data directory. You should be able to mouse over the plot and see the words, which lets you get a feel for how things are being clustered in embedding space. Note that we are only showing 10% of the data here in order to save some time and keep the web browser happy. You should be able to run it on all of the embeddings if you would like, as Barnes-Hut T-SNE is an $O(n \\ log(n))$ algorithm." ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " Loading BokehJS ...\n", "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "(function(global) {\n", " function now() {\n", " return new Date();\n", " }\n", "\n", " if (typeof (window._bokeh_onload_callbacks) === \"undefined\") {\n", " window._bokeh_onload_callbacks = [];\n", " }\n", "\n", " function run_callbacks() {\n", " window._bokeh_onload_callbacks.forEach(function(callback) { callback() });\n", " delete window._bokeh_onload_callbacks\n", " console.info(\"Bokeh: all callbacks have finished\");\n", " }\n", "\n", " function load_libs(js_urls, callback) {\n", " window._bokeh_onload_callbacks.push(callback);\n", " if (window._bokeh_is_loading > 0) {\n", " console.log(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n", " return null;\n", " }\n", " if (js_urls == null || js_urls.length === 0) {\n", " run_callbacks();\n", " return null;\n", " }\n", " console.log(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n", " window._bokeh_is_loading = js_urls.length;\n", " for (var i = 0; i < js_urls.length; i++) {\n", " var url = js_urls[i];\n", " var s = document.createElement('script');\n", " s.src = url;\n", " s.async = false;\n", " s.onreadystatechange = s.onload = function() {\n", " window._bokeh_is_loading--;\n", " if (window._bokeh_is_loading === 0) {\n", " console.log(\"Bokeh: all BokehJS libraries loaded\");\n", " run_callbacks()\n", " }\n", " };\n", " s.onerror = function() {\n", " console.warn(\"failed to load library \" + url);\n", " };\n", " console.log(\"Bokeh: injecting script tag for BokehJS library: \", url);\n", " document.getElementsByTagName(\"head\")[0].appendChild(s);\n", " }\n", " };\n", "\n", " var js_urls = ['https://cdn.pydata.org/bokeh/release/bokeh-0.11.1.min.js', 'https://cdn.pydata.org/bokeh/release/bokeh-widgets-0.11.1.min.js', 'https://cdn.pydata.org/bokeh/release/bokeh-compiler-0.11.1.min.js'];\n", "\n", " var inline_js = [\n", " function(Bokeh) {\n", " Bokeh.set_log_level(\"info\");\n", " },\n", " \n", " function(Bokeh) {\n", " Bokeh.$(\"#768dc2d8-3d5d-4711-9bdb-0be12c4fe65f\").text(\"BokehJS successfully loaded\");\n", " },\n", " function(Bokeh) {\n", " console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-0.11.1.min.css\");\n", " Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-0.11.1.min.css\");\n", " console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-widgets-0.11.1.min.css\");\n", " Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-widgets-0.11.1.min.css\");\n", " }\n", " ];\n", "\n", " function run_inline_js() {\n", " for (var i = 0; i < inline_js.length; i++) {\n", " inline_js[i](window.Bokeh);\n", " }\n", " }\n", "\n", " if (window._bokeh_is_loading === 0) {\n", " console.log(\"Bokeh: BokehJS loaded, going straight to plotting\");\n", " run_inline_js();\n", " } else {\n", " load_libs(js_urls, function() {\n", " console.log(\"Bokeh: BokehJS plotting callback run at\", now());\n", " run_inline_js();\n", " });\n", " }\n", "}(this));" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Embeddings shape: (39996, 20)\n", "Using subset: (4000, 20)\n", "Running T-SNE, may take a while...\n", "Plotting...\n" ] }, { "data": { "text/html": [ "\n", "\n", "
\n", "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.manifold import TSNE\n", "from bokeh.models import HoverTool, BoxSelectTool\n", "from bokeh.plotting import figure, output_notebook, show, ColumnDataSource, reset_output\n", "from bokeh.palettes import Spectral11\n", "\n", "subsample = 10\n", "\n", "reset_output()\n", "output_notebook()\n", "\n", "lexnames = pd.read_table(os.path.join(data_dir, 'wordnet_lexnames.txt'), index_col=0)\n", "entity_embeddings = transE.sess.run(transE.entity_embedding_vars)\n", "entity_cats = train['head'].cat.categories\n", "entity_names = pd.Categorical.from_codes(range(len(entity_embeddings)), \n", " entity_cats).astype(str)\n", "entity_lexnames = lexnames.loc[entity_names].values\n", "\n", "# Run on just a subset of the data to save some time\n", "emb_subset = entity_embeddings[::subsample, :] \n", "emb_subset_names = entity_names[::subsample]\n", "emb_subset_lexnames = entity_lexnames[::subsample]\n", "\n", "print('Embeddings shape:', entity_embeddings.shape)\n", "print('Using subset:', emb_subset.shape)\n", "print('Running T-SNE, may take a while...')\n", "tsne = TSNE(n_iter=1000, method='barnes_hut')\n", "lowdim = tsne.fit_transform(emb_subset)\n", "\n", "print('Plotting...')\n", "source = ColumnDataSource(\n", " data=dict(x=lowdim[:,0],\n", " y=lowdim[:,1],\n", " name=emb_subset_names,\n", " lexname=emb_subset_lexnames)\n", ")\n", "colormap = {}\n", "for i,ln in enumerate(set(emb_subset_lexnames.flat)):\n", " colormap[ln] = Spectral11[i % len(Spectral11)]\n", "colors = [colormap[ln] for ln in emb_subset_lexnames.flat]\n", "tools = 'pan,wheel_zoom,box_zoom,reset,resize,hover'\n", "fig = figure(title=\"T-SNE of WordNet TransE Embeddings\", \n", " plot_width=800, plot_height=600, tools=tools)\n", "fig.scatter('x', 'y', source=source, alpha=0.5, fill_color=colors, line_color=None)\n", "hover = fig.select(dict(type=HoverTool))\n", "hover.tooltips = [('','@name, @lexname')]\n", "h = show(fig)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Concluding Thoughts\n", "\n", "This tutorial has covered three relatively simple models for relational learning: two based on the CP and RESCAL tensor decompositions, and TransE which treats relationships as translations in embedding space. We trained each with contrastive negative sampling and max-norm regularization, which of course are not the only possibilities.\n", "\n", "TransE performed the best for me on this particular dataset, although in fairness not much time was spent optimizing hyperparameters for each model. Based on the brief time I have spent with it on the WordNet data, it appears to cluster things well, but does not necessarily possess high precision for all of the fine-grained relationships.\n", "\n", "Many other models have also been proposed in the literature, for example TransH [[Wang14]](#Wang14), Neural Tensor Network (NTN) [[Socher13]](#Socher13), and models using a more traditional neural network architecture [Dong14](#Dong14). I believe that all of these are good candidates for TensorFlow and it would be fun to add them to the framework we have developed here.\n", "\n", "One thing to note is that this evaluation is measuring classification accuracy on a balanced test set with an equal number of positive and negative instances. Other performance measures have been proposed which take more of an information-retrieval approach, for example comparing the score of each positive test triple against all of its possible corruptions and reporting things like mean rank and precision@k. I did not use these for this tutorial because they are a bit more involved to code up, and evaluation takes much longer. They may be a better metric than accuracy on a balanced set, however, depending on application.\n", "\n", "Lastly, this is a relatively small dataset for the complex task the model is attempting to learn, with only about 3 triples per entity on average. Many recent publications on knowledge-base completion are using much larger data sets, but I decided to go with this particular one to keep training times reasonable for a tutorial.\n", "\n", "I hope you have enjoyed this tutorial and found it useful! I know I had a lot of fun putting it together." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## References\n", "\n", " **[Bordes11]** Bordes, Antoine, et al. \"Learning structured embeddings of knowledge bases.\" Conference on Artificial Intelligence. No. EPFL-CONF-192344. 2011.\n", "\n", " **[Bordes13]** Bordes, Antoine, et al. \"Translating embeddings for modeling multi-relational data.\" Advances in Neural Information Processing Systems. 2013.\n", "\n", " **[Bordes14]** Bordes, Antoine, et al. \"A semantic matching energy function for learning with multi-relational data.\" Machine Learning 94.2 (2014): 233-259.\n", "\n", " **[Bottou12]** Bottou, Léon. \"Stochastic gradient descent tricks.\" Neural Networks: Tricks of the Trade. Springer Berlin Heidelberg, 2012. 421-436.\n", "\n", " **[Carpenter08]** Carpenter, Bob. \"Lazy sparse stochastic gradient descent for regularized multinomial logistic regression.\" Alias-i, Inc., Tech. Rep (2008): 1-20.\n", "\n", " **[Dong14]** Dong, Xin, et al. \"Knowledge vault: A web-scale approach to probabilistic knowledge fusion.\" Proceedings of the 20th ACM SIGKDD international conference on Knowledge discovery and data mining. ACM, 2014.\n", "\n", " **[Gutmann10]** Gutmann, Michael, and Aapo Hyvärinen. \"Noise-contrastive estimation: A new estimation principle for unnormalized statistical models.\" International Conference on Artificial Intelligence and Statistics. 2010.\n", "\n", " **[Hitchcock27]** Hitchcock, Frank L. \"The expression of a tensor or a polyadic as a sum of products.\" Journal of Mathematics and Physics 6.1 (1927): 164-189.\n", "\n", " **[Hu08]** Hu, Yifan, Yehuda Koren, and Chris Volinsky. \"Collaborative filtering for implicit feedback datasets.\" Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on. Ieee, 2008.\n", "\n", " **[Joachims06]** Joachims, Thorsten. \"Training linear SVMs in linear time.\" Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining. ACM, 2006.\n", "\n", " **[Kolda09]** Kolda, Tamara G., and Brett W. Bader. \"Tensor decompositions and applications.\" SIAM review 51.3 (2009): 455-500.\n", "\n", " **[Koren09]** Koren, Yehuda, Robert Bell, and Chris Volinsky. \"Matrix factorization techniques for recommender systems.\" Computer 8 (2009): 30-37.\n", "\n", " **[Mikolov13]** Mikolov, Tomas, et al. \"Distributed representations of words and phrases and their compositionality.\" Advances in neural information processing systems. 2013.\n", "\n", " **[Mnih13]** Mnih, Andriy, and Koray Kavukcuoglu. \"Learning word embeddings efficiently with noise-contrastive estimation.\" Advances in Neural Information Processing Systems. 2013.\n", "\n", " **[Nickel11]** Nickel, Maximilian, Volker Tresp, and Hans-Peter Kriegel. \"A three-way model for collective learning on multi-relational data.\" Proceedings of the 28th international conference on machine learning (ICML-11). 2011.\n", "\n", " **[Nickel15]** Nickel, Maximilian, et al. \"A review of relational machine learning for knowledge graphs: From multi-relational link prediction to automated knowledge graph construction.\" arXiv preprint arXiv:1503.00759 (2015).\n", "\n", " **[Rendle04]** Rendle, Steffen, et al. \"BPR: Bayesian personalized ranking from implicit feedback.\" Proceedings of the twenty-fifth conference on uncertainty in artificial intelligence. AUAI Press, 2009.\n", "\n", " **[Sbrero04]** Srebro, Nathan, Jason Rennie, and Tommi S. Jaakkola. \"Maximum-margin matrix factorization.\" Advances in neural information processing systems. 2004.\n", "\n", " **[Socher13]** Socher, Richard, et al. \"Reasoning with neural tensor networks for knowledge base completion.\" Advances in Neural Information Processing Systems. 2013.\n", "\n", " **[Srivastava14]** Srivastava, Nitish, et al. \"Dropout: A simple way to prevent neural networks from overfitting.\" The Journal of Machine Learning Research 15.1 (2014): 1929-1958.\n", "\n", " **[VanDerMaaten14]** Van Der Maaten, Laurens. \"Accelerating t-sne using tree-based algorithms.\" The Journal of Machine Learning Research 15.1 (2014): 3221-3245.\n", "\n", " **[Wang14]** Wang, Zhen, et al. \"Knowledge Graph Embedding by Translating on Hyperplanes.\" AAAI. 2014." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.1" } }, "nbformat": 4, "nbformat_minor": 0 }