{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Language Model Exercises\n", "In these exercises you will extend and develop language models. We will use the code from the notes, but within a python package [`lm`](http://localhost:8888/edit/statnlpbook/lm.py)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## <font color='green'>Setup 1</font>: Load Libraries" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2016-10-21T16:57:44.612226", "start_time": "2016-10-21T16:57:44.592461" }, "run_control": { "frozen": false, "read_only": false } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import sys, os\n", "_snlp_book_dir = \"..\"\n", "sys.path.append(_snlp_book_dir) \n", "from statnlpbook.lm import *\n", "from statnlpbook.ohhla import *\n", "# %cd .. \n", "import sys\n", "sys.path.append(\"..\")\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "matplotlib.rcParams['figure.figsize'] = (10.0, 6.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<!---\n", "Latex Macros\n", "-->\n", "$$\n", "\\newcommand{\\prob}{p}\n", "\\newcommand{\\vocab}{V}\n", "\\newcommand{\\params}{\\boldsymbol{\\theta}}\n", "\\newcommand{\\param}{\\theta}\n", "\\DeclareMathOperator{\\perplexity}{PP}\n", "\\DeclareMathOperator{\\argmax}{argmax}\n", "\\newcommand{\\train}{\\mathcal{D}}\n", "\\newcommand{\\counts}[2]{\\#_{#1}(#2) }\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## <font color='green'>Setup 2</font>: Load Data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2016-10-21T16:57:44.672863", "start_time": "2016-10-21T16:57:44.627958" }, "run_control": { "frozen": false, "read_only": false } }, "outputs": [], "source": [ "docs = load_all_songs(\"../data/ohhla/train/www.ohhla.com/anonymous/j_live/\")\n", "assert len(docs) == 50, \"Your ohhla corpus is corrupted, please download it again!\"\n", "trainDocs, testDocs = docs[:len(docs)//2], docs[len(docs)//2:] \n", "train = words(trainDocs)\n", "test = words(testDocs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## <font color='blue'>Task 1</font>: Optimal Pseudo Count \n", "\n", "Plot the perplexity for laplace smoothing on the given data as a function of alpha in the interval [0.001, 0.1] in steps by 0.001. Is it fair to assume that this is a convex function? Write a method that finds the optimal pseudo count `alpha` number for [laplace smoothing](https://github.com/uclmr/stat-nlp-book/blob/python/statnlpbook/lm.py#L180) for the given data up to some predefined numerical precision `epsilon` under the assumption that the perplexity is a convex function of alpha. How often did you have to call `perplexity` to find the optimum?\n", "\n", "Tips:\n", "<font color='white'>\n", "You don't need 1st or 2nd order derivatives in this case, only the gradient descent direction. Think about recursively slicing up the problem.\n", "</font>" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2016-10-21T16:57:55.326750", "start_time": "2016-10-21T16:57:44.674646" }, "run_control": { "frozen": false, "read_only": false } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0 0.0\n", "0.5 0.0\n", "0.25 0.0\n", "0.125 0.0\n", "0.0625 0.0\n", "0.03125 0.0\n", "0.03125 0.015625\n", "0.0234375 0.015625\n", "0.0234375 0.01953125\n", "0.0234375 0.021484375\n", "0.0234375 0.0224609375\n", "0.02294921875 0.0224609375\n", "0.022705078125 0.0224609375\n", "0.0225830078125 0.0224609375\n", "0.02252197265625 0.0224609375\n", "0.022491455078125 0.0224609375\n", "0.022491455078125 0.0224761962890625\n", "0.02248382568359375 0.0224761962890625\n", "0.02248382568359375 0.022480010986328125\n", "0.02248382568359375 0.022481918334960938\n", "0.022482872009277344 0.022481918334960938\n" ] }, { "data": { "text/plain": [ "(0.022482872009277344, 65.3568849083981)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlMAAAFpCAYAAAC4SK2+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl0nHd97/HPV7NIM1pGthZrdWTZjnfJSYxJ0iQUQkkJ\nCQHapsmBNFAgbstlu6W9lHtuoae3PbQnXC7ltvSY7aQtSymEkqYphZsWKPQS4iTeHe+OLVnWYluj\nfRnN7/4xI1k2TjTWrHrm/TpnjqyZZ6yvzmNbH/9+3+f7mHNOAAAAWJySfBcAAACwlBGmAAAA0kCY\nAgAASANhCgAAIA2EKQAAgDQQpgAAANJAmAIAAEgDYQoAACANKYUpM/ugme03swNm9qHkc58ws24z\n25183J3dUgEAAAqPf6EDzGyzpPdK2i5pStJ3zezJ5Mufds49msX6AAAACtqCYUrSBknPOOfGJMnM\nfijpbYv5YrW1ta6trW0xbwUAAMip5557bsA5V7fQcamEqf2S/sTMaiSNS7pb0i5J5yW938x+I/n5\n7zrnLr7Sb9TW1qZdu3al8CUBAADyy8xeSuW4BXumnHOHJP2ZpO9J+q6k3ZJmJH1OUrukrZJ6JH3q\nZQp5xMx2mdmu/v7+1KoHAABYIlJqQHfOfdE5d5Nz7g5JFyUdcc71OudmnHNxSZ9Xoqfqau/d6Zzb\n5pzbVle34EoZAADAkpLq1Xz1yY8rleiX+qqZNc475K1KbAcCAAAUlVR6piTpW8meqWlJ73PODZrZ\nZ81sqyQn6ZSkHVmqEQAAoGClFKacc7df5bmHMl8OAADA0sIEdAAAgDQQpgAAANJAmAIAAEgDYQoA\nACANhCkAAIA0EKYAAADS4KkwdebCmJ4+1KvYTDzfpQAAgCLhqTD1rwfO6d2P7dLo5Ey+SwEAAEXC\nU2EqHEzMIB2bjuW5EgAAUCw8FqZ8kqSxKVamAABAbngqTIWSYWqcMAUAAHLEU2GKlSkAAJBrHg1T\n9EwBAIDc8FSYCgUSDehs8wEAgFzxVJiaXZkanyZMAQCA3PBkmKJnCgAA5IqnwhRX8wEAgFzzVJia\nG9pJmAIAADniqTDlKzEF/SVMQAcAADnjqTAlJfqm2OYDAAC54rkwFQr42OYDAAA5470wxcoUAADI\nIc+FqXDQxwR0AACQM94LUwE/23wAACBnPBemQkEfE9ABAEDOeC5MJbb5CFMAACA3PBemaEAHAAC5\n5LkwRQM6AADIJQ+GKRrQAQBA7nguTIUCPk3G4pqJu3yXAgAAioDnwlQ46JMkTXBFHwAAyAHPhim2\n+gAAQC54LkyFgn5J4oo+AACQE54LU3MrU9Nc0QcAALLPc2EqxDYfAADIIc+FqXAgEabY5gMAALng\nvTCV7JliZQoAAOSC58LUpW0+eqYAAED2eTZMsc0HAABywXNharZnim0+AACQC54LU3MrU0xABwAA\nOeC5MFXqL1GJ0TMFAAByw3NhyswUDvrZ5gMAADmRUpgysw+a2X4zO2BmH0o+t9zMvm9mR5Mfl2W3\n1NSFgj4a0AEAQE4sGKbMbLOk90raLqlT0j1mtkbSRyU97ZxbK+np5OcFIRz0sTIFAAByIpWVqQ2S\nnnHOjTnnYpJ+KOltku6T9FjymMckvSU7JV67UIAwBQAAciOVMLVf0u1mVmNmYUl3S2qVtMI515M8\n5pykFVmq8ZqFgz5NcDUfAADIAf9CBzjnDpnZn0n6nqRRSbslzVxxjDMzd7X3m9kjkh6RpJUrV6Zd\ncCoSDehczQcAALIvpQZ059wXnXM3OefukHRR0hFJvWbWKEnJj30v896dzrltzrltdXV1mar7FYXo\nmQIAADmS6tV89cmPK5Xol/qqpCckPZw85GFJ38lGgYsRDvoY2gkAAHJiwW2+pG+ZWY2kaUnvc84N\nmtknJX3DzN4t6SVJ92eryGvF1XwAACBXUgpTzrnbr/LceUl3ZryiDAgF/MyZAgAAOeG5CejS7MpU\nTM5dtSceAAAgYzwZpkJBn+JOmozF810KAADwOE+GqXDQJ0ls9QEAgKzzdJga44o+AACQZZ4MU2WB\n2ZUpBncCAIDs8mSYCgcTFykyHgEAAGSbR8NUcpuPMAUAALLMk2EqRAM6AADIEU+GKVamAABArngz\nTAVme6ZoQAcAANnlyTA1u803wWgEAACQZZ4MU2zzAQCAXPFkmAoFCFMAACA3PBmmSkpMZYESjbPN\nBwAAssyTYUpKDO6kAR0AAGSbZ8NUKOBjmw8AAGSdZ8NUOOhjaCcAAMg6T4cpVqYAAEC2eTZMhViZ\nAgAAOeDZMBUO+jU2TQM6AADILs+GqRDbfAAAIAe8G6YCbPMBAIDs82yYogEdAADkgmfDFA3oAAAg\nFzwbpsIBv6Zm4orNxPNdCgAA8DDvhqlg8mbH3J8PAABkkWfDVCgZptjqAwAA2eTZMBUmTAEAgBzw\nfJjiij4AAJBNng1ToaBfkjTOFHQAAJBFng1TrEwBAIBc8GyYCgUIUwAAIPs8G6ZoQAcAALng4TCV\n6JliZQoAAGSTZ8NUaK5nigZ0AACQPZ4NU2zzAQCAXPBsmAr4ShTwGbeTAQAAWeXZMCUlruhjZQoA\nAGSTp8NUOOinZwoAAGSVp8NUKOjjaj4AAJBV3g5TbPMBAIAs83SYCrMyBQAAsszTYSoU9HE1HwAA\nyKqUwpSZfdjMDpjZfjP7mpmVmdknzKzbzHYnH3dnu9hrFQ76NMHKFAAAyCL/QgeYWbOkD0ja6Jwb\nN7NvSHog+fKnnXOPZrPAdISDfo1NczUfAADInlS3+fySQmbmlxSWdDZ7JWVOKEgDOgAAyK4Fw5Rz\nrlvSo5JOS+qRFHXOfS/58vvNbK+ZfcnMlmWxzkUJB2hABwAA2bVgmEqGpPskrZLUJKnczN4h6XOS\n2iVtVSJkfepl3v+Ime0ys139/f0ZKzwV4aBP49Mzcs7l9OsCAIDikco23+slnXTO9TvnpiU9LulW\n51yvc27GOReX9HlJ26/2ZufcTufcNufctrq6usxVnoJQ0C/npInpeE6/LgAAKB6phKnTkm42s7CZ\nmaQ7JR0ys8Z5x7xV0v5sFJiOcNAnSdxSBgAAZM2CV/M5554xs29Kel5STNILknZK+oKZbZXkJJ2S\ntCOLdS5KaC5Mzagmz7UAAABvWjBMSZJz7uOSPn7F0w9lvpzMml2ZGmdwJwAAyBJPT0APz1uZAgAA\nyAZPh6lQILHwRs8UAADIFk+HqbltPlamAABAlhRFmGKbDwAAZIunw1RZgJUpAACQXZ4OU8yZAgAA\n2ebxMJVsQGc0AgAAyBJPh6myQInMpAm2+QAAQJZ4OkyZmUIBHw3oAAAgazwdpqRE3xTbfAAAIFs8\nH6ZCQR9X8wEAgKzxfJgKB/xczQcAALLG82EqFKRnCgAAZI/nw1SYbT4AAJBFRRGmWJkCAADZ4vkw\nFQr6Nc7VfAAAIEs8H6bCAR8N6AAAIGs8H6ZoQAcAANnk+TBFAzoAAMimoghTsbjTVCye71IAAIAH\neT5MhYJ+SWJ1CgAAZIXnw1Q46JMkjU3ThA4AADLP82EqFEiGKVamAABAFng/TCVXptjmAwAA2eD5\nMFWe7JkanWSbDwAAZJ7nw1RNRVCSdGF0Ks+VAAAAL/J8mKqrLJUk9Q1P5rkSAADgRZ4PU8vCQflK\nTP2EKQAAkAWeD1O+ElNNeZAwBQAAssLzYUpKbPX1jxCmAABA5hVFmKqvLGVlCgAAZEVRhKk6whQA\nAMiSoglTAyOTisddvksBAAAeUxxhqqJUsbjTxTFmTQEAgMwqjjBVWSZJNKEDAICMK5IwlRjcSd8U\nAADINMIUAABAGghTAAAAaSiKMFVR6lc46CNMAQCAjCuKMCUxBR0AAGRH8YSpCgZ3AgCAzCueMFVZ\nqj7CFAAAyLCiClOsTAEAgExLKUyZ2YfN7ICZ7Tezr5lZmZktN7Pvm9nR5Mdl2S42HXUVpYqOT2sy\nNpPvUgAAgIcsGKbMrFnSByRtc85tluST9ICkj0p62jm3VtLTyc8L1ux4hIERbikDAAAyJ9VtPr+k\nkJn5JYUlnZV0n6THkq8/JuktmS8vc+qrmDUFAAAyb8Ew5ZzrlvSopNOSeiRFnXPfk7TCOdeTPOyc\npBVZqzID6iqS9+cjTAEAgAxKZZtvmRKrUKskNUkqN7N3zD/GOeckuZd5/yNmtsvMdvX392eg5MVh\nCjoAAMiGVLb5Xi/ppHOu3zk3LelxSbdK6jWzRklKfuy72pudczudc9ucc9vq6uoyVfc1q6kISpL6\nhifyVgMAAPCeVMLUaUk3m1nYzEzSnZIOSXpC0sPJYx6W9J3slJgZAV+JlpcHWZkCAAAZ5V/oAOfc\nM2b2TUnPS4pJekHSTkkVkr5hZu+W9JKk+7NZaCYwBR0AAGTagmFKkpxzH5f08SuenlRilWrJ4P58\nAAAg04pmArrEFHQAAJB5RRWm6pNhKnHxIQAAQPqKKkzVVZZqMhbX8GQs36UAAACPKLowJUl9Q2z1\nAQCAzCiuMFXB4E4AAJBZxRWmZqegc0UfAADIkOIMU6xMAQCADCmqMBUJBRTwGWEKAABkTFGFKTNj\nCjoAAMioogpTklRXVUbPFAAAyJjiC1OsTAEAgAwqvjBVWar+4Yl8lwEAADyiKMPU+dEpxWbi+S4F\nAAB4QFGGKeekC6NT+S4FAAB4QPGFqeQU9D76pgAAQAYUX5hiCjoAAMigogtT9UxBBwAAGVR0YYpb\nygAAgEwqujBVFvCpssxPmAIAABlRdGFKmp01RZgCAADpK84wxRR0AACQIcUZpipLuZoPAABkRPGG\nKVamAABABhRtmBqZjGlsKpbvUgAAwBJXlGGqvrJMkjQwzC1lAABAeooyTF2agj6R50oAAMBSV5xh\navb+fEP0TQEAgPQUZZiqr0qEqZ4oK1MAACA9RRmmasqDqizz68TASL5LAQAAS1xRhikz05r6Ch3r\nI0wBAID0FGWYkqQ1dRU61jea7zIAAMASV7xhqr5CAyOTio5N57sUAACwhBV1mJKkY/3Dea4EAAAs\nZYQp+qYAAEAaijZMtSwLK+gvIUwBAIC0FG2Y8pWY2mvLCVMAACAtRRumpMRW37F+whQAAFi8og9T\nXRfHNTE9k+9SAADAElX0Yco56TirUwAAYJGKPkxJXNEHAAAWr6jD1KracpWYdJwwBQAAFqmow1Sp\n36eVy8M0oQMAgEUr6jAliRseAwCAtBR9mFpdX6GTA6OKzcTzXQoAAFiCFgxTZrbOzHbPewyZ2YfM\n7BNm1j3v+btzUXCmramr0PSM0+kLY/kuBQAALEELhinn3GHn3Fbn3FZJN0kak/Tt5Mufnn3NOfdU\nNgvNFq7oAwAA6bjWbb47JR13zr2UjWLyYfVsmKIJHQAALMK1hqkHJH1t3ufvN7O9ZvYlM1t2tTeY\n2SNmtsvMdvX39y+60GypKgtoRVUpK1MAAGBRUg5TZhaU9GZJ/5B86nOS2iVtldQj6VNXe59zbqdz\nbptzbltdXV2a5WbHmvoKZk0BAIBFuZaVqTdKet451ytJzrle59yMcy4u6fOStmejwFxYU1eh4/2j\ncs7luxQAALDEXEuYelDztvjMrHHea2+VtD9TReXamvoKjUzGdG5oIt+lAACAJcafykFmVi7plyTt\nmPf0n5vZVklO0qkrXltSVs+7oq8xEspzNQAAYClJKUw550Yl1Vzx3ENZqSgP5o9HuH1tYfZ1AQCA\nwlT0E9Alqa6iVFVlfq7oAwAA14wwJcnMuEcfAABYFMJU0pr6xBV9AAAA14IwlbSmvkIDI5OKjk3n\nuxQAALCEEKaSVtfN3lZmOM+VAACApYQwlcQNjwEAwGIQppJaloVV6i/RkV7CFAAASB1hKslXYtrc\nHNELpy/muxQAALCEEKbm2b5qufZ2RTU+NZPvUgAAwBJBmJpne9tyxeKO1SkAAJAywtQ8N7Utk5n0\ns1MX8l0KAABYIghT81SVBbSxsUo/O0mYAgAAqSFMXWH7quV6/vRFTcXi+S4FAAAsAYSpK7x61XJN\nTMe1rzua71IAAMASQJi6wra25ZLEVh8AAEgJYeoKtRWlWl1XrmdpQgcAACkgTF3F9lU1evbUBc3E\nXb5LAQAABY4wdRWvXrVcwxMxvXhuKN+lAACAAkeYuortq+ibAgAAqSFMXUVTdUjN1SHCFAAAWBBh\n6mW8etVyPXvqgpyjbwoAALw8wtTL2L5quQZGpnRiYDTfpQAAgKTo2LR+fHRA50cm813KHH++CyhU\n8/umVtdV5LkaAACKz/DEtPZ3D2lf96D2dkW1rzuql86PSZI+88BW3be1Oc8VJhCmXsaq2nLVVpTq\nZycv6MHtK/NdDgAAnjY2FdPBs0Pa2xXV3q5B7e2O6kT/pd2h5uqQOloi+vVXtaqjuVqdrZE8Vns5\nwtTLMDNtX7WMJnQAADJsYnpGh3qGtK87mlhx6orqaN+wZsc7NlSVaUtLRG/Z2qwtLRF1NEdUU1Ga\n36JfAWHqFWxvW66n9p1T18UxtSwL57scAACWnKlYXEd6h5PbdIntusPnhhVLJqea8qA6WiK6a3OD\nOpoj6miJqL6qLM9VXxvC1CvYvqpGkvTsqQuEKQAAFjATdzrWN6I9XYPa1xXV3u6oDvUMaSoWlyRF\nQgF1tES04zXt2tIcUUdLtRojZTKzPFeeHsLUK1jXUKmqMr9+dvKC3npDS77LAQCgYMTjTqfOjyZ7\nnBKrTvu7hzQ+PSNJqij1a3Nzld51a5s2N0fU2VKt1uWhJR+croYw9Qp8JaZXt9foh4f7FY87lZR4\n7w8AAAALcc6p6+J4Ijh1D2rvmaj2d0c1PBmTJJUFSrSpKdEc3tka0ZbmarXXlhfNz03C1ALu3tKg\n7x/s1QtnLuqm65bnuxwAALKub2hi7qq6PcmRBBdGpyRJAZ9pQ2OV7t3apM6WxFbd2voK+X3FO7qS\nMLWA129YoaC/RP+0p4cwBQDwnIujU8mr6gbntuzODU1IkkpMun5Fpe5cX6+O1mp1tkS0rqFSpX5f\nnqsuLISpBVSWBfTadXV6al+P/sc9G+UrkiVLAID3jEzGtL973opTV1SnL4zNvd5eW65Xty9XR0si\nOG1qiigUJDgthDCVgns6mvSvB3r17KkLurm9Jt/lAACwoNlZTnu7otqTXHU63j+i2VvOzg7BfHD7\nykRwao4oEgrkt+glijCVgjs31CsU8OnJvWcJUwCAghObieto38jcitPerkG92HNpllNtRVCdLdW6\nt6NJHS0RbWmJqLaAh2AuNYSpFISDfr1uQ73+Zd85feLeTUXdZAcAyK/5IwlmV5wOnI1qYjoxy6my\nzK+Olojee0f7XIO4F2Y5FTLCVIru2dKof97bo5+euKDb1tbmuxwAQBFwzqknOnHZitPerqiGJy4f\nSZDYqqtWR0tEbTXFM5KgUBCmUvTa9fUqDya2+ghTAIBsuDg6NbfatLdrULvPRDUwMilJ8peY1jdW\n6t5ORhIUGsJUisoCPr1+4wp998A5/fFbNivAH14AQBpGJ2NzIwlmV53OXBiXJJlJq+sqdMf1tXMr\nThsaq1QW4Mq6QkSYugb3dDTpO7vP6sfHBvTadfX5LgcAsERMxmb0Ys/wZcHpWN+I4vOurOtsjejt\nr74u0SDeHFFlGVfWLRWEqWtwx/W1qizz68k9PYQpAMBVzcSdTvSPaPeZS9t1h3qGNTWTaBCvKQ+q\noyWiN25u1NbWaq6s8wDC1DUo9fv0ho0N+t7Bc5qMbWYCLAAUucvuWdc1qN1nBrW/O6rRqStu9ntb\n29x2XXO1N2/2W8wIU9fons5Gfev5Lv3oyIB+aeOKfJcDAMih8yOT2tsVTa46JVaezifvWRf0lWhD\nU5V+5aYWdbRUa2trRO21FVxZVwQIU9fotjW1qg4H9OTes4QpAPCwkcmY9s0bR7D7zKC6By81iK+t\nr9Br19ersyWiztZqrW+oUtDPxUnFaMEwZWbrJP39vKfaJf2hpL9JPt8m6ZSk+51zFzNfYmEJ+Er0\nxs0N+scXzio6Nq1ImAZBAFjq5jeI7z6TbBCfd+uVlmUhbV1ZrYdvvU4dLdXa3BxRRSnrEUhY8E+C\nc+6wpK2SZGY+Sd2Svi3po5Keds590sw+mvz8v2Wx1oLx0M1t+trPzujvnnlJ73vtmnyXAwC4BvG4\n04mBEe05k5ggvqcrqkNnhy5rEO9srdabOhrV2VqtjuaIamgQxyu41lh9p6TjzrmXzOw+Sb+YfP4x\nST9QkYSpjU1Vun1trb78k1N6922rmPsBAAVqdoL4njOJ0LQn2SA+PJmYIF4e9Glzc0Tv/IU2bW2l\nQRyLc61h6gFJX0v+eoVzrif563OSiqqBaMcdq/WOLz6jf3yhWw9sX5nvcgAAkgbHpuZC05UTxAM+\n0/qGKt13Q5M6W6rV2Vqt1XUV8tEgjjSlHKbMLCjpzZL+4MrXnHPOzNzLvO8RSY9I0sqV3gkdv7Cm\nRhsbq7TzP07o/m2tXK0BADk2PjWjA2ejl4WnU+fH5l5fXVc+N0G8s7VaGxorGWmDrLiWlak3Snre\nOdeb/LzXzBqdcz1m1iip72pvcs7tlLRTkrZt23bVwLUUmZl2vKZdH/z6bj39Yh9X9gFAFsVm4jrS\nO5KcIJ5YcTrSO6yZ5AjxhqoybW2t1v2vatXWlmptbomoigniyJFrCVMP6tIWnyQ9IelhSZ9MfvxO\nButaEu7e0qg//+5h7fzRccIUAGSIc05nLoxrd9eg9p5JhKd93VFNTCcaxKvK/Opsrdad61ers7Va\nnS0R1VeV5blqFLOUwpSZlUv6JUk75j39SUnfMLN3S3pJ0v2ZL6+wBXwl+s3bVumPnzyo509f1I0r\nl+W7JABYcgZGJuf6m2a36y6OTUuSSv0l2twc0YPbV2pra7U6W6p1XU2YBnEUlJTClHNuVFLNFc+d\nV+LqvqL2wKta9RdPH9XOH57QXz90U77LAYCCNjoZ077u2dB0+SDMEpOuX1GpN2xsUEdrRJ0t1VrX\nUKmAj0GYKGxMHEtTealf77h5pf7qB8d1cmBUq2rL810SABSEqVhch88NX7Zdd7Tv0iDM1uUh3bCy\nWu+8tU0dLRFtbo6onEGYWIL4U5sBD9/aps//6KQ+/x8n9Kdv3ZLvcgAg5+Jxp5PnRxMN4mcSK04H\ne4Y0Fbt8EObdWxrnrq5bXh7Mc9VAZhCmMqC+skxvu7FZ33yuSx943Vo1RGiEBOBtvUMT2n1mcG67\nbk/XoIYnEoMww7ODMJMrTp0t1WpZxiBMeBdhKkPe99o1evyFbv3JU4f02QdvyHc5AJAx0fFp7UsG\npj3J7breocQgTH+JaV1Dpe7tbFJnS0RbW5dpTT2DMFFcCFMZ0ro8rN9+zWp95umjenB7q25dXZvv\nkgDgmk1Mz+hgz9ClFaczgzoxMDr3+qract3SXqOO5FbdpqYqbqmFokeYyqDf/sXVevyFLn38Owf0\n1Adv5woUAAVtJu50rG9kbsVpb1dUh3qGFEsOwqyvLFVna7XedmNz8oa/1YqEGYQJXIkwlUFlAZ/+\n8J5Neu/f7NJj/3lK77m9Pd8lAYCkxCDMrovj2tN1aSTB/u6oxqZmJEmVpX5taYnovXe0q7OlWltb\nq+n/BFJEmMqw12+o12vX1enT3z+iezubtIKpvADyYHYQ5p4z0eQtWKK6MDolSQr6S7SxsUr3b2tN\nNIi3VmtVTTn3GAUWiTCVYWamj9+7SW/49I/0p08d0mceoBkdQHaNTMa0rys6d9+6PWeic4MwzaS1\n9RW6c3198tYriUGYQT9tCECmEKayoK22XDte067P/tsxPbh9pW5ur1n4TQCQgsnYjF7sGZ5bbdpz\nZlDH+i8NwmxZFtJWBmECOcXfsCz5nV9co8ef79bHv3NAT37gNprRAVyzmbjT8f6RuXEEe7uierFn\nWFMzlw/CfFNHY7JBPKKaitI8Vw0UH8JUloSCPv3hvRu142+f02f+71F95K51+S4JQAG7skF8T7JB\nfDTZIF5R6teW5ojedVubOluq1dESUXM1gzCBQkCYyqK7NjXo17e16v/8+zHdeF21Xrd+Rb5LAlAg\n+oYntHdec/i+7nkN4r4SbWiq0q/e1JKc5xRRe20FDeJAgSJMZdkf3bdJe7uj+vDf79E/f+A2tSwL\n57skADkWHZvWvu5octUpsfLUE52QJJWYdP2KSr1+Q30iONEgDiw5hKksKwv49Lm336h7P/tj/c5X\nntc//NYtKvUzLRjwqrGpmPZ3D82Fpr1dgzp1fmzu9baasF7VtnxuJMGmpiqFg/xTDCxl/A3Ogbba\ncj16f6d2/O1z+p9PHtIfv2VzvksCkAGTsRkd6hnWvtmtuq6ojvYNKzlAXI2RMnW0RPRryXlOTBAH\nvIkwlSN3bWrQI3e0a+ePTmhb2zLdt7U53yUBuAbTM3Ed6R1OzHPqTqw4HT43rOmZRHKqKQ+qoyWi\nuzY3qLMloi0tEdVXMrQXKAaEqRz6vbvW6YXTF/XRb+3T+oYqrWuozHdJAK5idiTB3q6o9nUNam93\nVAfPDmkylhhJUFnmV0dLRO+5vV0dzRF1tFarKVLGlXVAkTI3O+ktB7Zt2+Z27dqVs69XiHqHJnTP\nZ38sk/SNHbeorbY83yUBRS0edzp1flT7uqPJ8BTV/rOX7lkXDvq0uSmx0tTRElFHS7WuWx7myjqg\nCJjZc865bQseR5jKvSO9w3pg509V5i/R3++4Ra3LucIPyAXnnF46P6a93VHtT27VHege0vBkTJJU\n6i/RpqYqbWlOhKaOloja6yrkIzgBRYkwVeAOnI3qwZ0/VSQc0Dd23KLGSCjfJQGe4pzT6Qtj2ted\nWG3alwxQQxOJ4BT0lWhDY6U2N0fU2VKtLS0Rra2vkJ+7FQBIIkwtAXvODOrtX3hG9ZWl+vqOm2lW\nBRZpdsVpNjDtP5sIUPOD07qG2eCUuF/d9SuY5QTglRGmlohdpy7oN770MzVXh/T1R27mvlrAAuJx\np5cuzAtOV6w4BXym9Q1V2twcSW7XEZwALA5hagn5z+MDeteXn1VzdUhfeHib2usq8l0SUBBm4k4n\nB0YvhaYiQEcaAAAM+0lEQVSz0ct6nIK+Eq1PbtVtST4ITgAyhTC1xPzs5AX91t89p9hMXH/59ht1\n+9q6fJcE5NT0TFxHe0d04GxUB84OaX93VAd7huauqiv1l2hDY6I5fHNz1dxWXYAeJwBZQphags5c\nGNN7HtulY/0j+h9v2qCHb21jbg08aWJ6Rod6hnTg7NBceHrx3LCmknOcwkGfNjYmAtPmZHhaU0dz\nOIDcSjVMMbSzgLQuD+tbv3OrPvT13frEPx3U4d4R/dGbN7FlgSUtOjatA2cTq0yz4elY38jcLVci\noYA2NVXpnbe2aVNTIkC11ZQzjgDAkkGYKjAVpX7tfOgmPfq9w/qrHxzXkd5hferXOhnuiYLnnFP3\n4LgOnh3SwZ4hHTybCE/dg+NzxzRUlWlTU5V+eVODNjYlVpyaq0OswAJY0tjmK2BP7Dmr//7tfZqe\niev37lqvd93axtRlFISpWFxH+4Z18OyQDvUM62BPVId6hhUdn5YkmUmrasu1sbFKm5oi2tRUpU1N\nVVytCmBJoWfKI85FJ/Sxb+/Tv73Yp1e1LdOf/2qnVrFKhRw6PzKpQz3DOtQzpEM9iVWn4/0jczf4\nLQuUaF1DlTY2VmpjMjitb6hUOMjCN4CljTDlIc45Pf58t/7onw5oMhbXR96wTg/f2kYvFTJqKhbX\niYERvTgbnM4N68WeIfUNT84ds6KqVOsbqrSxqUobG6u0obFKq2rpbwLgTYQpD+odmtDHHt+np1/s\n08rlYf3uG67XvR1NbP3hmjjn1BOd0OFzw3rx3LBePDekw+eGL1ttCvpKtKa+QusbK+dC0/qGSrbp\nABQVwpRHOef0wyP9+rPvHtahniFtaqrS7//yet2xtpYmXvyci6NTOtw7rCO9wzp8LvnoHdZwclq4\nJDVFyrSuoVLrGqq0obFybrWJ+U0Aih1hyuPicacn9pzVp75/WGcujOuW9hrteE277lhbx0pVEYqO\nT+tY37CO9I7o8LlhHe0b1uFzIxoYubRFV1Xm1/qGKl3fUKF1Kyq1vrFK16+oVCQUyGPlAFC4CFNF\nYioW11efeUl/+YPj6h+eVHtdud55a5t+5cYWlZfSAOw150cmdaxvRMf6R3S0d0TH+kZ0tG9YvUOX\nQlM46NPa+gqtXVGpdSsqtXZFhdY3VGlFVSmrlwBwDQhTRWYqFtdT+3r05Z+c1J6uqCrL/Lp/W6t+\nbVuL1jdU5bs8XIOZuFPXxTEd7x/Rif5RHe9PhKZjfSO6ODY9d9xsaFpTX6nrV1RoTX2Frl9Rqebq\nEKuTAJABhKki9vzpi/ryT07pX/b1KBZ3un5Fhd7c2aR7O5t0XQ1jFQqBc04XRqd0cmBUJwZGdXJg\nVCf7R3ViYESnBsY0NROfO3Z5eVBr6iq0uj4RmGYfjVVlhCYAyCLCFDQwMqmn9vXoid1nteuli5Kk\nztZqvWHjCt2xtk6bmqr4YZxFzjkNjEzp9IVRvXR+TKcGRnXy/JheOp8IT/ObwAM+08rlYa2qLdfq\nuorEo75c7bUVWlYezON3AQDFizCFy3QPjuvJPWf15N4e7euOSkqseNy2pla3r63VLatruK3HIoxN\nxdR1cVxnLozNfTw97zE2NTN3bIlJzctCaqspTzxqy9VeW672unI1V4e4iS8AFBjCFF5W3/CEfnJs\nQD86MqD/ONqvgZEpSVJdZam2tlZra2u1bmit1paWiCrLivdKr5m40/mRSfVEJ9QTHVf34ITODo6r\n++K4zkYTH8+PTl32nlJ/iVYuD+u6mrBal4d13fKwrqsp18qasFqXhRm0CgBLCGEKKYnHnQ6dG9Ku\nUxe1+8ygdp8Z1MmB0bnXGyNlWruiMnF1WLJXp2VZWPWVpUt2izA2E9eFsSmdH5lS//Ckeocm1Dc8\nqf7hSfUNT6h3aFLnohPqHZpQLH75349QwKem6jI1LwurubpMLcvCalkWUuvyxMe6Cq6YAwCvSDVM\nce18kSspseSNaCN6OPncxdEp7e4a1MGzQzrWN6IjvcN65sR5TcYuNUUHfKam6pCaq0Nqqg6ptqJU\ntRVB1VQEtby8VDXlQVWVBVRe6lNFmV+lfl9G656KxTU+PaPxqRkNT0xraCKm4YlpjUzGNDwR0+DY\ntAbHpjQ4Nq2LyY+JADV52RVx81WW+VVfWaoVVWV6dftyNUbK1BAJqaGqTI2RMjVVh7QsHCAsAQAu\nw8oUUjITd+q+OK7jAyPqvjiurovj6h4cV/fFMZ0dnND50cm5W5FcTcBnKi/1q8zvU9BfooDPFPCV\nKOgvUYmZrswn8bjT9IzT9ExcsbjTVCyuqZm4JqZmND4983MrRlcT9JWoOhzQsnBQkXBANeVB1VaU\nqqYiqJqKUtWWB1VbWaoVlWWqqyxVKJjZwAcAWNoyujJlZtWSviBpsyQn6Tcl3SXpvZL6k4d9zDn3\n1OLKRaHzlZhW1oS1siZ81dedcxqaiOnCaGL15/zolIYnYhqdjGlk9jER02RsRtMzTlMzcU3H4nNh\n6Uollghbs6HL7zMFfSUKBX0KBXwKB30qC/gUCvpUWRZQZZlflaX+uV9XhwMKBXysIgEAsi7Vbb7P\nSPquc+5XzSwoKaxEmPq0c+7RrFWHJcPMFAkFFAkFtKqWWVYAgOKxYJgys4ikOyS9U5Kcc1OSpvgf\nPwAAgJTKddqrlNjK+7KZvWBmXzCz2aWH95vZXjP7kpkty16ZAAAAhSmVMOWXdKOkzznnbpA0Kumj\nkj4nqV3SVkk9kj51tTeb2SNmtsvMdvX391/tEAAAgCUrlTDVJanLOfdM8vNvSrrROdfrnJtxzsUl\nfV7S9qu92Tm30zm3zTm3ra6uLjNVAwAAFIgFw5Rz7pykM2a2LvnUnZIOmlnjvMPeKml/FuoDAAAo\naKlezfd+SV9JXsl3QtK7JP2FmW1VYlTCKUk7slIhAABAAUspTDnndku6cmjVQ5kvBwAAYGnhrqsA\nAABpIEwBAACkgTAFAACQBsIUAABAGghTAAAAaSBMAQAApMGcc7n7Ymb9kl7K4G9ZK2kgg78fMovz\nU7g4N4WN81O4ODeFLdPn5zrn3IK3b8lpmMo0M9vlnLty/hUKBOencHFuChvnp3Bxbgpbvs4P23wA\nAABpIEwBAACkYamHqZ35LgCviPNTuDg3hY3zU7g4N4UtL+dnSfdMAQAA5NtSX5kCAADIq4INU2b2\ny2Z22MyOmdlHr/K6mdlfJF/fa2Y3pvpepGex58bMWs3s383soJkdMLMP5r5670vn707ydZ+ZvWBm\nT+au6uKQ5r9r1Wb2TTN70cwOmdktua3e+9I8Px9O/ru238y+ZmZlua3e21I4N+vN7P+Z2aSZfeRa\n3psRzrmCe0jySTouqV1SUNIeSRuvOOZuSf8iySTdLOmZVN/LI2/nplHSjclfV0o6wrkpnPMz7/X/\nKumrkp7M9/fjpUe650bSY5Lek/x1UFJ1vr8nLz3S/LetWdJJSaHk59+Q9M58f09eeaR4buolvUrS\nn0j6yLW8NxOPQl2Z2i7pmHPuhHNuStLXJd13xTH3Sfobl/BTSdVm1pjie7F4iz43zrke59zzkuSc\nG5Z0SIl/hJA56fzdkZm1SHqTpC/ksugisehzY2YRSXdI+qIkOeemnHODuSy+CKT1d0eSX1LIzPyS\nwpLO5qrwIrDguXHO9TnnnpU0fa3vzYRCDVPNks7M+7xLP/9D9+WOSeW9WLx0zs0cM2uTdIOkZzJe\nYXFL9/z8b0m/LymerQKLWDrnZpWkfklfTm7BfsHMyrNZbBFa9PlxznVLelTSaUk9kqLOue9lsdZi\nk87P9ZxkgkINU/AwM6uQ9C1JH3LODeW7HiSY2T2S+pxzz+W7Fvwcv6QbJX3OOXeDpFFJ9IMWCDNb\npsRqxypJTZLKzewd+a0KuVSoYapbUuu8z1uSz6VyTCrvxeKlc25kZgElgtRXnHOPZ7HOYpXO+fkF\nSW82s1NKLIW/zsz+LnulFp10zk2XpC7n3OxK7jeVCFfInHTOz+slnXTO9TvnpiU9LunWLNZabNL5\nuZ6TTFCoYepZSWvNbJWZBSU9IOmJK455QtJvJK+uuFmJZdWeFN+LxVv0uTEzU6Ln45Bz7n/ltuyi\nsejz45z7A+dci3OuLfm+f3PO8b/rzEnn3JyTdMbM1iWPu1PSwZxVXhzS+blzWtLNZhZO/jt3pxI9\nociMdH6u5yQT+DP9G2aCcy5mZv9F0r8q0Yn/JefcATP7reTrfy3pKSWurDgmaUzSu17pvXn4Njwp\nnXOjxMrHQ5L2mdnu5HMfc849lcvvwcvSPD/Iogycm/dL+kryB8IJcd4yKs2fO8+Y2TclPS8pJukF\nMSk9Y1I5N2bWIGmXpCpJcTP7kBJX7Q3lIhMwAR0AACANhbrNBwAAsCQQpgAAANJAmAIAAEgDYQoA\nACANhCkAAIA0EKYAAADSQJgCAABIA2EKAAAgDf8fr+L5vk/vEOIAAAAASUVORK5CYII=\n", "text/plain": [ "<matplotlib.figure.Figure at 0x7fc9a8c0de48>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "oov_train = inject_OOVs(train)\n", "oov_vocab = set(oov_train)\n", "oov_test = replace_OOVs(oov_vocab, test)\n", "bigram = NGramLM(oov_train,2)\n", "\n", "interval = [x/1000.0 for x in range(1, 100, 1)]\n", "perplexit_at_1 = perplexity(LaplaceLM(bigram, alpha=1.0), oov_test)\n", "\n", "def plot_perplexities(interval):\n", " \"\"\"Plots the perplexity of LaplaceLM for every alpha in interval.\"\"\"\n", " perplexities = [perplexity(LaplaceLM(bigram, alpha), oov_test) for alpha in interval] # todo\n", " plt.plot(interval, perplexities)\n", " \n", "def find_optimal(low, high, epsilon=1e-6):\n", " \"\"\"Returns the optimal pseudo count alpha within the interval [low, high] and the perplexity.\"\"\"\n", " print(high, low)\n", " if high - low < epsilon:\n", " return high, perplexity(LaplaceLM(bigram, high), oov_test)\n", " else:\n", " mid = (high+low) / 2.0\n", " left = perplexity(LaplaceLM(bigram, mid-epsilon), oov_test)\n", " right = perplexity(LaplaceLM(bigram, mid+epsilon), oov_test)\n", " if left < right:\n", " return find_optimal(low, mid, epsilon)\n", " else:\n", " return find_optimal(mid, high, epsilon)\n", "\n", "plot_perplexities(interval) \n", "find_optimal(0.0, 1.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## <font color='blue'>Task 2</font>: Sanity Check LM\n", "Implement a method that tests whether a language model provides a valid probability distribution." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2016-10-21T16:57:55.397116", "start_time": "2016-10-21T16:57:55.328580" }, "run_control": { "frozen": false, "read_only": false } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0647115579930901\n" ] }, { "ename": "AssertionError", "evalue": "1.0647115579930901", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m<ipython-input-8-9bd0c3fbec6b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstupid\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprobability\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mword\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'the'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mword\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mstupid\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvocab\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mOOV\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mstupid\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvocab\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0msanity_check\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstupid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'the'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m<ipython-input-8-9bd0c3fbec6b>\u001b[0m in \u001b[0;36msanity_check\u001b[0;34m(lm, *history)\u001b[0m\n\u001b[1;32m 3\u001b[0m in the vocabulary.\"\"\" \n\u001b[1;32m 4\u001b[0m \u001b[0mprobability_mass\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprobability\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mword\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mword\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvocab\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprobability_mass\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m1e-6\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprobability_mass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0munigram\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNGramLM\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moov_train\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mAssertionError\u001b[0m: 1.0647115579930901" ] } ], "source": [ "def sanity_check(lm, *history):\n", " \"\"\"Throws an AssertionError if lm does not define a valid probability distribution for all words \n", " in the vocabulary.\"\"\" \n", " probability_mass = sum([lm.probability(word, *history) for word in lm.vocab])\n", " assert abs(probability_mass - 1.0) < 1e-6, probability_mass\n", "\n", "unigram = NGramLM(oov_train,1)\n", "stupid = StupidBackoff(bigram, unigram, 0.1)\n", "print(sum([stupid.probability(word, 'the') for word in stupid.vocab]))\n", "assert OOV in stupid.vocab\n", "sanity_check(stupid, 'the')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## <font color='blue'>Task 3</font>: Subtract Count LM\n", "Develop and implement a language model that subtracts a count $d\\in[0,1]$ from each non-zero count in the training set. Let's first formalise this:\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\\begin{align}\n", "\\#_{w=0}(h_n) &= \\sum_{w \\in V} \\mathbf{1}[\\counts{\\train}{h_n,w} = 0]\\\\\n", "\\#_{w>0}(h_n) &= \\sum_{w \\in V} \\mathbf{1}[\\counts{\\train}{h_n,w} > 0]\\\\\n", "\\prob(w|h_n) &= \n", "\\begin{cases}\n", "\\frac{\\counts{\\train}{h_n,w} - d}{\\counts{\\train}{h_n}} & \\mbox{if }\\counts{\\train}{h_n,w} > 0 \\\\\\\\\n", "\\frac{d*\\#_{w>0}(h_n)/\\#_{w=0}(h_n)}{\\counts{\\train}{h_n}} & \\mbox{otherwise}\n", "\\end{cases}\n", "\\end{align}" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2016-10-21T16:58:08.278572", "start_time": "2016-10-21T16:57:58.756197" }, "run_control": { "frozen": false, "read_only": false } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.1742331911436041\n" ] }, { "data": { "text/plain": [ "91.4414922652717" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class SubtractCount(CountLM): \n", " def __init__(self, base_lm, d):\n", " super().__init__(base_lm.vocab, base_lm.order)\n", " self.base_lm = base_lm\n", " self.d = d \n", " self._counts = base_lm._counts # not good style since it is a protected member\n", " self.vocab = base_lm.vocab\n", "\n", " def counts(self, word_and_history):\n", " # todo: this can be chached and does not have to be called at every call of counts \n", " history = word_and_history[1:] \n", " num_non_zero_histories = len([x for x in self.vocab if self._counts[(x, ) + history] > 0])\n", " num_zero_histories = len(self.vocab) - num_non_zero_histories \n", " if num_zero_histories == 0:\n", " return self._counts[word_and_history]\n", " else: \n", " if self._counts[word_and_history] > 0:\n", " return self._counts[word_and_history] - self.d\n", " else: \n", " return self.d * num_non_zero_histories / num_zero_histories\n", "\n", " def norm(self, history):\n", " return self.base_lm.norm(history) \n", " \n", "subtract_lm = SubtractCount(unigram, 0.1)\n", "oov_prob = subtract_lm.probability(OOV, 'the')\n", "rest_prob = sum([subtract_lm.probability(word, 'the') for word in subtract_lm.vocab])\n", "print(oov_prob + rest_prob)\n", "sanity_check(subtract_lm, 'the')\n", "perplexity(subtract_lm, oov_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## <font color='blue'>Task 4</font>: Normalisation of Stupid LM\n", "Develop and implement a version of the [stupid language model](https://github.com/uclmr/stat-nlp-book/blob/python/statnlpbook/lm.py#L205) that provides probabilities summing up to 1." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2016-10-21T16:58:08.402515", "start_time": "2016-10-21T16:58:08.280467" }, "run_control": { "frozen": false, "read_only": false } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.000000000000002\n" ] }, { "data": { "text/plain": [ "60.032236179798886" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class StupidBackoffNormalized(LanguageModel):\n", " def __init__(self, main, backoff, alpha):\n", " super().__init__(main.vocab, main.order)\n", " self.main = main\n", " self.backoff = backoff\n", " self.alpha = alpha \n", "\n", " def probability(self, word, *history):\n", " main_counts = self.main.counts((word,)+tuple(history))\n", " main_norm = self.main.norm(history) \n", " backoff_order_diff = self.main.order - self.backoff.order\n", " backoff_counts = self.backoff.counts((word,)+tuple(history[:-backoff_order_diff]))\n", " backoff_norm = self.backoff.norm(history[:-backoff_order_diff]) \n", " counts = main_counts + self.alpha * backoff_counts\n", " norm = main_norm + self.alpha * backoff_norm\n", " return counts / norm\n", " \n", "less_stupid = StupidBackoffNormalized(bigram, unigram, 0.1)\n", "print(sum([less_stupid.probability(word, 'the') for word in less_stupid.vocab]))\n", "sanity_check(less_stupid, 'the')\n", "perplexity(less_stupid, oov_test)" ] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.2" } }, "nbformat": 4, "nbformat_minor": 1 }