{ "cells": [ { "cell_type": "markdown", "metadata": { "toc": true }, "source": [ "

Table of Contents

\n", "
" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "" ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# code for loading the format for the notebook\n", "import os\n", "\n", "# path : store the current path to convert back to it later\n", "path = os.getcwd()\n", "os.chdir(os.path.join('..', '..', 'notebook_format'))\n", "\n", "from formats import load_style\n", "load_style(css_style='custom2.css', plot_style=False)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ethen 2018-08-21 18:38:17 \n", "\n", "CPython 3.6.4\n", "IPython 6.4.0\n", "\n", "numpy 1.14.1\n", "sklearn 0.19.1\n", "matplotlib 2.2.2\n", "tqdm 4.24.0\n", "nmslib 1.7.3.4\n" ] } ], "source": [ "os.chdir(path)\n", "\n", "# 1. magic for inline plot\n", "# 2. magic to print version\n", "# 3. magic so that the notebook will reload external python modules\n", "# 4. magic to enable retina (high resolution) plots\n", "# https://gist.github.com/minrk/3301035\n", "%matplotlib inline\n", "%load_ext watermark\n", "%load_ext autoreload\n", "%autoreload 2\n", "%config InlineBackend.figure_format='retina'\n", "\n", "import time\n", "import nmslib # pip install nmslib>=1.7.3.2 pybind11>=2.2.3\n", "import zipfile\n", "import requests\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from tqdm import trange\n", "from joblib import dump, load\n", "from sklearn.preprocessing import normalize\n", "from sklearn.neighbors import NearestNeighbors\n", "from sklearn.model_selection import train_test_split\n", "\n", "# change default style figure and font size\n", "plt.rcParams['figure.figsize'] = 8, 6\n", "plt.rcParams['font.size'] = 12\n", "\n", "%watermark -a 'Ethen' -d -t -v -p numpy,sklearn,matplotlib,tqdm,nmslib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Approximate Nearest Neighbor Search" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Approximate nearest neighbor (ANN) search is useful when we have a large dataset with hundred thousands/millions/billions of data-points, and for a given data point we wish to find its nearest neighbors. There are many use case for this type of methods and the one we'll be focusing on here is finding similar vector representations, so think algorithms such as matrix factorization or word2vec that compresses our original data into embeddings, or so called latent factors. And throughout the notebook, the notion of similar here will be referring to two vectors' cosine distance.\n", "\n", "There are many open-source implementations already that we can use to see whether it solves our problem, but the question is always which one is better? The following github repo contains a thorough benchmarks of various open-sourced implementations. [Github: Benchmarking nearest neighbors](https://github.com/erikbern/ann-benchmarks).\n", "\n", "The goal of this notebook shows how to run a quicker benchmark ourselves without all the complexity. The repo listed above benchmarks multiple algorithms on multiple datasets using multiple hyperparameters, which can take a really long time. We will pick one of the open-source implementation that has been identified as a solid choice and walk through step-by-step of the process using one dataset." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting Up the Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first step is to get our hands on some data and split it into training and test set, here we'll be using the glove vector representation trained on twitter dataset." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def download(url, filename):\n", " with open(filename, 'wb') as file:\n", " response = requests.get(url)\n", " file.write(response.content)\n", "\n", "# we'll download the data to DATA_DIR location\n", "DATA_DIR = './datasets/'\n", "URL = 'http://nlp.stanford.edu/data/glove.twitter.27B.zip'\n", "filename = os.path.join(DATA_DIR, 'glove.twitter.27B.zip')\n", "if not os.path.exists(DATA_DIR):\n", " os.makedirs(DATA_DIR)\n", "\n", "if not os.path.exists(filename):\n", " download(URL, filename)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training data shape: (954811, 25)\n", "testing data shape: (238703, 25)\n" ] } ], "source": [ "def get_train_test_data(filename, dimension=25, test_size=0.2, random_state=1234):\n", " \"\"\"\n", " dimension : int, {25, 50, 100, 200}, default 25\n", " The dataset contains embeddings of different size.\n", " \"\"\"\n", " with zipfile.ZipFile(filename) as f:\n", " X = []\n", " zip_filename = 'glove.twitter.27B.{}d.txt'.format(dimension)\n", " for line in f.open(zip_filename):\n", " # remove the first index, id field and only get the vectors\n", " vector = np.array([float(x) for x in line.strip().split()[1:]])\n", " X.append(vector)\n", "\n", " X_train, X_test = train_test_split(\n", " np.array(X), test_size=test_size, random_state=random_state)\n", "\n", " # we can downsample for experimentation purpose\n", " # X_train = X_train[:50000]\n", " # X_test = X_test[:10000]\n", " return X_train, X_test\n", "\n", "\n", "X_train, X_test = get_train_test_data(filename)\n", "print('training data shape: ', X_train.shape)\n", "print('testing data shape: ', X_test.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Benchmarking an approximate nearest neighbor method involves looking at how much faster it is compared to exact nearest neighbor methods and how much precision/recall are we losing for the speed that was gained. To measure this, we first need to use an exact nearest neighbor methods to see how long it takes and store the ground truth. e.g. if out exact nearest neighbor methods, thinks that for data point 1, its top 3 nearest neighbors excluding itself are [2, 4, 1], and our approximate nearest neighbor method returns [2, 1, 5], then our precision/recall depending on which way we're looking at it would be 66%, since 2 and 1 are both in the ground truth set whereas 5 is not." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class BruteForce:\n", " \"\"\"\n", " Brute force way of computing cosine distance, this\n", " is more of clarifying what we're trying to accomplish,\n", " don't actually use it as it will take extremely long.\n", " \"\"\"\n", "\n", " def __init__(self):\n", " pass\n", "\n", " def fit(self, X):\n", " lens = (X ** 2).sum(axis=-1)\n", " index = X / np.sqrt(lens)[:, np.newaxis]\n", " self.index_ = np.ascontiguousarray(index, dtype=np.float32)\n", " return self\n", "\n", " def query(self, vector, topn):\n", " \"\"\"Find indices of most similar vectors for a given query vector.\"\"\"\n", "\n", " # argmax_a dot(a, b) / |a||b| = argmin_a -dot(a, b)\n", " dists = -np.dot(self.index_, vector)\n", " indices = np.argpartition(dists, topn)[:topn]\n", " return sorted(indices, key=lambda index: dists[index])\n", "\n", "\n", "class KDTree:\n", "\n", " def __init__(self, topn=10, n_jobs=-1):\n", " self.topn = topn\n", " self.n_jobs = n_jobs\n", "\n", " def fit(self, X):\n", "\n", " # cosine distance is proportional to normalized euclidean distance,\n", " # thus we normalize the item vectors and use euclidean metric so\n", " # we can use the more efficient kd-tree for nearest neighbor search\n", " X_normed = normalize(X)\n", " index = NearestNeighbors(\n", " n_neighbors=self.topn, metric='euclidean', n_jobs=self.n_jobs)\n", " index.fit(X_normed)\n", " self.index_ = index\n", " return self\n", "\n", " def query_batch(self, X):\n", " X_normed = normalize(X)\n", " _, indices = self.index_.kneighbors(X_normed)\n", " return indices\n", "\n", " def query(self, vector):\n", " vector_normed = normalize(vector.reshape(1, -1))\n", " _, indices = self.index_.kneighbors(vector_normed)\n", " return indices.ravel()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def get_ground_truth(X_train, X_test, kdtree_params):\n", " \"\"\"\n", " Compute the ground truth or so called golden standard, during\n", " which we'll compute the time to build the index using the\n", " training set, time to query the nearest neighbors for all\n", " the data points in the test set. The ground_truth returned\n", " will be of type list[(ndarray, ndarray)], where the first\n", " ndarray will be the query vector, and the second ndarray will\n", " be the corresponding nearest neighbors.\n", " \"\"\"\n", " start = time.time()\n", " kdtree = KDTree(**kdtree_params)\n", " kdtree.fit(X_train)\n", " build_time = time.time() - start\n", "\n", " start = time.time()\n", " indices = kdtree.query_batch(X_test)\n", " query_time = time.time() - start\n", "\n", " ground_truth = [(vector, index) for vector, index in zip(X_test, indices)]\n", " return build_time, query_time, ground_truth" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ground truth filepath: model/ground_truth.pkl\n", "build time: 5.02460503578186\n", "query time: 5105.871987104416\n" ] }, { "data": { "text/plain": [ "(array([ 0.84227, 0.19005, 1.5346 , 0.88995, -1.6548 , -0.60046,\n", " -1.3206 , -1.5521 , -0.30763, -0.56361, 1.5054 , 3.2881 ,\n", " 1.7582 , -0.63313, -0.48781, 2.0016 , -2.5334 , 1.0601 ,\n", " -0.19666, -0.38252, 0.65653, 0.89475, 2.7882 , 2.4109 ,\n", " -0.72981]),\n", " array([213945, 566700, 232533, 673941, 79801, 932371, 59183, 318977,\n", " 649659, 871934]))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# we'll compute the ground truth for the first time and\n", "# store it on disk to prevent computing it over and over again\n", "MODEL_DIR = 'model'\n", "if not os.path.exists(MODEL_DIR):\n", " os.makedirs(MODEL_DIR)\n", "\n", "ground_truth_filename = 'ground_truth.pkl'\n", "ground_truth_filepath = os.path.join(MODEL_DIR, ground_truth_filename)\n", "print('ground truth filepath: ', ground_truth_filepath)\n", "\n", "if os.path.exists(ground_truth_filepath):\n", " ground_truth = load(ground_truth_filepath)\n", "else:\n", " # using a setting of kdtree_params = {'topn': 10, 'n_jobs': -1},\n", " # it took at least 1 hour to finish on a 8 core machine\n", " kdtree_params = {'topn': 10, 'n_jobs': -1}\n", " build_time, query_time, ground_truth = get_ground_truth(X_train, X_test, kdtree_params)\n", " print('build time: ', build_time)\n", " print('query time: ', query_time)\n", " dump(ground_truth, ground_truth_filepath)\n", "\n", "ground_truth[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Benchmarking ANN Methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The library that we'll be leveraging here is [`nmslib`](https://github.com/nmslib/nmslib), specifically the algorithm HNSW (Hierarchical Navigable Small World), a graph-based approximate nearest neighborhood search method, we will only be using the library and will not be introducing the details of the algorithm in this notebook." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class Hnsw:\n", "\n", " def __init__(self, space='cosinesimil', index_params=None,\n", " query_params=None, print_progress=True):\n", " self.space = space\n", " self.index_params = index_params\n", " self.query_params = query_params\n", " self.print_progress = print_progress\n", "\n", " def fit(self, X):\n", " index_params = self.index_params\n", " if index_params is None:\n", " index_params = {'M': 16, 'post': 0, 'efConstruction': 400}\n", "\n", " query_params = self.query_params\n", " if query_params is None:\n", " query_params = {'ef': 90}\n", "\n", " # this is the actual nmslib part, hopefully the syntax should\n", " # be pretty readable, the documentation also has a more verbiage\n", " # introduction: https://nmslib.github.io/nmslib/quickstart.html\n", " index = nmslib.init(space=self.space, method='hnsw')\n", " index.addDataPointBatch(X)\n", " index.createIndex(index_params, print_progress=self.print_progress)\n", " index.setQueryTimeParams(query_params)\n", "\n", " self.index_ = index\n", " self.index_params_ = index_params\n", " self.query_params_ = query_params\n", " return self\n", "\n", " def query(self, vector, topn):\n", " # the knnQuery returns indices and corresponding distance\n", " # we will throw the distance away for now\n", " indices, _ = self.index_.knnQuery(vector, k=topn)\n", " return indices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Like a lot of machine learning algorithms, there are hyperparameters that we can tune. We will pick a random one for now and look at the influence of each hyperparameters in later section." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "42.73225116729736" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "index_params = {'M': 5, 'post': 0, 'efConstruction': 100}\n", "\n", "start = time.time()\n", "hnsw = Hnsw(index_params=index_params)\n", "hnsw.fit(X_train)\n", "build_time = time.time() - start\n", "build_time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "we'll first use the first element from the ground truth to show-case what we'll be doing before scaling it to all the data points." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "query time: 0.0002560615539550781\n", "correct indices: [213945 566700 232533 673941 79801 932371 59183 318977 649659 871934]\n", "found indices: [213945 566700 232533 673941 79801 318977 871934 221617 107727 705332]\n" ] } ], "source": [ "topn = 10\n", "\n", "query_vector, correct_indices = ground_truth[0]\n", "start = time.time()\n", "\n", "# use the query_vector to find its corresponding\n", "# approximate nearest neighbors\n", "found_indices = hnsw.query(query_vector, topn)\n", "\n", "query_time = time.time() - start\n", "print('query time:', query_time)\n", "\n", "print('correct indices: ', correct_indices)\n", "print('found indices: ', found_indices)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# compute the proportion of data points that overlap between the\n", "# two sets\n", "precision = len(set(found_indices).intersection(correct_indices)) / topn\n", "precision" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def run_algo(X_train, X_test, topn, ground_truth, algo_type='hnsw', algo_params=None):\n", " \"\"\"\n", " We can extend this benchmark across multiple algorithm or algorithm's hyperparameter\n", " by adding more algo_type options. The algo_params can be a dictionary that is passed\n", " to the algorithm's __init__ method.\n", " Here only 1 method is included.\n", " \"\"\"\n", "\n", " if algo_type == 'hnsw':\n", " algo = Hnsw()\n", " if algo_params is not None:\n", " algo = Hnsw(**algo_params)\n", "\n", " start = time.time()\n", " algo.fit(X_train)\n", " build_time = time.time() - start\n", "\n", " total_correct = 0\n", " total_query_time = 0.0\n", " n_queries = len(ground_truth)\n", " for i in trange(n_queries):\n", " query_vector, correct_indices = ground_truth[i]\n", "\n", " start = time.time()\n", " found_indices = algo.query(query_vector, topn)\n", " query_time = time.time() - start\n", " total_query_time += query_time\n", "\n", " n_correct = len(set(found_indices).intersection(correct_indices))\n", " total_correct += n_correct\n", "\n", " avg_query_time = total_query_time / n_queries\n", " avg_precision = total_correct / (n_queries * topn)\n", " return build_time, avg_query_time, avg_precision" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The next few code chunks experiments with different parameters to see which one works better for this use-case. \n", "\n", "Recommended by the author of package, the most influential parameters are `M` and `efConstruction`.\n", "\n", "- `efConstruction`: Increasing this value improves the quality of the constructed graph and leads to a higher search accuracy, at the cost of longer indexing time. The same idea applies to the `ef` or `efSearch` parameter that we can pass to `query_params`. Reasonable range for this parameter is 100-2000.\n", "- `M`: This parameter controls the maximum number of neighbors for each layer. Increasing the values of this parameters (to a certain degree) leads to better recall and shorter retrieval times (at the expense of longer indexing time). Reasonable range for this parameter is 5-100.\n", "\n", "Other parameters include `indexThreadQty` (we can explicitly set the number of threads) and `post`. The `post` parameter controls the amount of post-processing done to the graph. 0, which means no post-processing. Additional options are 1 and 2 (2 means more post-processing)." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 238703/238703 [00:42<00:00, 5642.16it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "build time: 96.66552662849426\n", "average search time: 0.00014930271766456666\n", "average precision: 0.971047284701072\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# we will be running four combinations, higher/lower\n", "# efConstruction/M parameters and comparing the performance\n", "algo_type = 'hnsw'\n", "algo_params = {\n", " 'index_params': {'M': 16, 'post': 0, 'efConstruction': 100}\n", "}\n", "\n", "build_time1, avg_query_time1, avg_precision1 = run_algo(\n", " X_train, X_test, topn, ground_truth, algo_type, algo_params)\n", "\n", "print('build time: ', build_time1)\n", "print('average search time: ', avg_query_time1)\n", "print('average precision: ', avg_precision1)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 238703/238703 [00:45<00:00, 5257.06it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "build time: 312.3543019294739\n", "average search time: 0.0001598479527504984\n", "average precision: 0.9770271006229498\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "algo_params = {\n", " 'index_params': {'M': 16, 'post': 0, 'efConstruction': 400}\n", "}\n", "\n", "build_time2, avg_query_time2, avg_precision2 = run_algo(\n", " X_train, X_test, topn, ground_truth, algo_type, algo_params)\n", "\n", "print('build time: ', build_time2)\n", "print('average search time: ', avg_query_time2)\n", "print('average precision: ', avg_precision2)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 238703/238703 [00:23<00:00, 10106.25it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "build time: 40.743391036987305\n", "average search time: 7.756965745461889e-05\n", "average precision: 0.7929644788712333\n" ] } ], "source": [ "algo_params = {\n", " 'index_params': {'M': 5, 'post': 0, 'efConstruction': 100}\n", "}\n", "\n", "build_time3, avg_query_time3, avg_precision3 = run_algo(\n", " X_train, X_test, topn, ground_truth, algo_type, algo_params)\n", "\n", "print('build time: ', build_time3)\n", "print('average search time: ', avg_query_time3)\n", "print('average precision: ', avg_precision3)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 238703/238703 [00:24<00:00, 9667.32it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "build time: 135.51102113723755\n", "average search time: 8.058855207119328e-05\n", "average precision: 0.8155624353275828\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "algo_params = {\n", " 'index_params': {'M': 5, 'post': 0, 'efConstruction': 400}\n", "}\n", "\n", "build_time4, avg_query_time4, avg_precision4 = run_algo(\n", " X_train, X_test, topn, ground_truth, algo_type, algo_params)\n", "\n", "print('build time: ', build_time4)\n", "print('average search time: ', avg_query_time4)\n", "print('average precision: ', avg_precision4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Based on the result, we can see that larger values of parameters `M` and `efConstruction` does give better precision scores. Another observation is that the result for `efConstruction` = 100 is on-par with `efConstruction` = 400 and only one third of the time to build the index." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Reference" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- [Blog: New approximate nearest neighbor benchmarks](https://erikbern.com/2018/06/17/new-approximate-nearest-neighbor-benchmarks.html)\n", "- [Nmslib Documentation: Basic parameter tuning for NMSLIB](https://github.com/nmslib/nmslib/blob/master/manual/methods.md)\n", "- [Nmslib Documentation: Jupyter Notebook - Search Dense Vector](http://nbviewer.jupyter.org/github/nmslib/nmslib/blob/master/python_bindings/notebooks/search_vector_dense_optim.ipynb)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "280px" }, "toc_section_display": true, "toc_window_display": true }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }