{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Language Model Exercises\n",
"In these exercises you will extend and develop language models. We will use the code from the notes, but within a python package [`lm`](http://localhost:8888/edit/statnlpbook/lm.py)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup 1: Load Libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2016-10-21T16:59:18.569772",
"start_time": "2016-10-21T16:59:18.552156"
},
"run_control": {
"frozen": false,
"read_only": false
}
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline\n",
"import sys, os\n",
"_snlp_book_dir = \"..\"\n",
"sys.path.append(_snlp_book_dir) \n",
"from statnlpbook.lm import *\n",
"from statnlpbook.ohhla import *\n",
"# %cd .. \n",
"import sys\n",
"sys.path.append(\"..\")\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"matplotlib.rcParams['figure.figsize'] = (10.0, 6.0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"$$\n",
"\\newcommand{\\prob}{p}\n",
"\\newcommand{\\vocab}{V}\n",
"\\newcommand{\\params}{\\boldsymbol{\\theta}}\n",
"\\newcommand{\\param}{\\theta}\n",
"\\DeclareMathOperator{\\perplexity}{PP}\n",
"\\DeclareMathOperator{\\argmax}{argmax}\n",
"\\newcommand{\\train}{\\mathcal{D}}\n",
"\\newcommand{\\counts}[2]{\\#_{#1}(#2) }\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup 2: Load Data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2016-10-21T16:59:18.613748",
"start_time": "2016-10-21T16:59:18.575886"
},
"run_control": {
"frozen": false,
"read_only": false
}
},
"outputs": [],
"source": [
"docs = load_all_songs(\"../data/ohhla/train/www.ohhla.com/anonymous/j_live/\")\n",
"assert len(docs) == 50, \"Your ohhla corpus is corrupted, please download it again!\"\n",
"trainDocs, testDocs = docs[:len(docs)//2], docs[len(docs)//2:] \n",
"train = words(trainDocs)\n",
"test = words(testDocs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Task 1: Optimal Pseudo Count \n",
"\n",
"Plot the perplexity for laplace smoothing on the given data as a function of alpha in the interval [0.001, 0.1] in steps by 0.001. Is it fair to assume that this is a convex function? Write a method that finds the optimal pseudo count `alpha` number for [laplace smoothing](https://github.com/uclmr/stat-nlp-book/blob/python/statnlpbook/lm.py#L180) for the given data up to some predefined numerical precision `epsilon` under the assumption that the perplexity is a convex function of alpha. How often did you have to call `perplexity` to find the optimum?\n",
"\n",
"Tips:\n",
"\n",
"You don't need 1st or 2nd order derivatives in this case, only the gradient descent direction. Think about recursively slicing up the problem.\n",
""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2016-10-21T16:59:19.151308",
"start_time": "2016-10-21T16:59:18.615252"
},
"run_control": {
"frozen": false,
"read_only": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0 0.0\n"
]
},
{
"data": {
"text/plain": [
"0.0"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmUAAAFpCAYAAADdpV/BAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEgdJREFUeJzt3F+I5fd53/HP012LxvlTOUix17tSd9suDZtSsJiqag29\niJQiKY43lxIkdlzCYqiC3SYY2b4ovSgEWtJgKiwWx0UmboVxXLI1GxTHyWVlNPIfuWtF8VZJLMnr\naBOoHSqoIvz0Yo7LZD3SjPacnXlm9vWCg87v/L6/mWf4sjNvnXNmqrsDAMDe+ht7PQAAAKIMAGAE\nUQYAMIAoAwAYQJQBAAwgygAABhBlAAADiDIAgAFEGQDAAKIMAGCAw3s9wNW46aab+vjx43s9BgDA\ntp588sk/7+6bt1u3L6Ps+PHjWV9f3+sxAAC2VVV/upN1Xr4EABhAlAEADCDKAAAGEGUAAAOIMgCA\nAUQZAMAAogwAYABRBgAwgCgDABhAlAEADCDKAAAGEGUAAAOIMgCAAUQZAMAAogwAYABRBgAwgCgD\nABhAlAEADCDKAAAGEGUAAAOIMgCAAUQZAMAAogwAYABRBgAwgCgDABhAlAEADCDKAAAGEGUAAAOI\nMgCAAUQZAMAAogwAYABRBgAwgCgDABhAlAEADCDKAAAGEGUAAAOIMgCAAUQZAMAAK4myqrq7qp6p\nqotV9eAW56uqPrI4/1RV3XbF+UNV9aWq+uwq5gEA2G+WjrKqOpTkoST3JDmV5P6qOnXFsnuSnFzc\nziT56BXn35fk6WVnAQDYr1bxTNntSS5297Pd/XKSR5OcvmLN6SSf6A2PJ7mxqo4kSVUdS/LTST62\nglkAAPalVUTZ0STPbTp+fvHYTtf8epIPJPnuCmYBANiX9vSN/lX1jiQvdveTO1h7pqrWq2r98uXL\nuzAdAMDuWUWUvZDklk3HxxaP7WTN25O8s6r+JBsve/5kVf3mVp+ku89291p3r918880rGBsAYI5V\nRNkTSU5W1YmquiHJfUnOXbHmXJJ3LX4L844k3+7uS939we4+1t3HF9f9fnf/3ApmAgDYVw4v+wG6\n+5WqeiDJY0kOJfl4d1+oqvcuzj+c5HySe5NcTPJSkvcs+3kBAA6S6u69nuF1W1tb6/X19b0eAwBg\nW1X1ZHevbbfOX/QHABhAlAEADCDKAAAGEGUAAAOIMgCAAUQZAMAAogwAYABRBgAwgCgDABhAlAEA\nDCDKAAAGEGUAAAOIMgCAAUQZAMAAogwAYABRBgAwgCgDABhAlAEADCDKAAAGEGUAAAOIMgCAAUQZ\nAMAAogwAYABRBgAwgCgDABhAlAEADCDKAAAGEGUAAAOIMgCAAUQZAMAAogwAYABRBgAwgCgDABhA\nlAEADCDKAAAGEGUAAAOIMgCAAUQZAMAAogwAYABRBgAwgCgDABhAlAEADCDKAAAGEGUAAAOIMgCA\nAUQZAMAAogwAYABRBgAwgCgDABhAlAEADLCSKKuqu6vqmaq6WFUPbnG+quoji/NPVdVti8dvqao/\nqKqvVdWFqnrfKuYBANhvlo6yqjqU5KEk9yQ5leT+qjp1xbJ7kpxc3M4k+eji8VeS/HJ3n0pyR5J/\nucW1AAAH3iqeKbs9ycXufra7X07yaJLTV6w5neQTveHxJDdW1ZHuvtTdX0yS7v7LJE8nObqCmQAA\n9pVVRNnRJM9tOn4+3x9W266pquNJ3pbkCyuYCQBgXxnxRv+q+qEkv5Xk/d39nVdZc6aq1qtq/fLl\ny7s7IADANbaKKHshyS2bjo8tHtvRmqp6QzaC7JPd/ZlX+yTdfba717p77eabb17B2AAAc6wiyp5I\ncrKqTlTVDUnuS3LuijXnkrxr8VuYdyT5dndfqqpK8htJnu7uX1vBLAAA+9LhZT9Ad79SVQ8keSzJ\noSQf7+4LVfXexfmHk5xPcm+Si0leSvKexeVvT/LzSb5aVV9ePPah7j6/7FwAAPtJdfdez/C6ra2t\n9fr6+l6PAQCwrap6srvXtls34o3+AADXO1EGADCAKAMAGECUAQAMIMoAAAYQZQAAA4gyAIABRBkA\nwACiDABgAFEGADCAKAMAGECUAQAMIMoAAAYQZQAAA4gyAIABRBkAwACiDABgAFEGADCAKAMAGECU\nAQAMIMoAAAYQZQAAA4gyAIABRBkAwACiDABgAFEGADCAKAMAGECUAQAMIMoAAAYQZQAAA4gyAIAB\nRBkAwACiDABgAFEGADCAKAMAGECUAQAMIMoAAAYQZQAAA4gyAIABRBkAwACiDABgAFEGADCAKAMA\nGECUAQAMIMoAAAYQZQAAA4gyAIABRBkAwACiDABgAFEGADCAKAMAGGAlUVZVd1fVM1V1saoe3OJ8\nVdVHFuefqqrbdnotAMD1YOkoq6pDSR5Kck+SU0nur6pTVyy7J8nJxe1Mko++jmsBAA68VTxTdnuS\ni939bHe/nOTRJKevWHM6ySd6w+NJbqyqIzu8FgDgwDu8go9xNMlzm46fT/KPd7Dm6A6v3XX/9r9f\nyNe++Z29HgMAuMZOvfVH8m9+5if2eowk++iN/lV1pqrWq2r98uXLez0OAMBKreKZsheS3LLp+Nji\nsZ2secMOrk2SdPfZJGeTZG1trZcb+bVNKWYA4PqximfKnkhysqpOVNUNSe5Lcu6KNeeSvGvxW5h3\nJPl2d1/a4bUAAAfe0s+UdfcrVfVAkseSHEry8e6+UFXvXZx/OMn5JPcmuZjkpSTvea1rl50JAGC/\nqe5r+krgNbG2ttbr6+t7PQYAwLaq6snuXttu3b55oz8AwEEmygAABhBlAAADiDIAgAFEGQDAAKIM\nAGAAUQYAMIAoAwAYQJQBAAwgygAABhBlAAADiDIAgAFEGQDAAKIMAGAAUQYAMIAoAwAYQJQBAAwg\nygAABhBlAAADiDIAgAFEGQDAAKIMAGAAUQYAMIAoAwAYQJQBAAwgygAABhBlAAADiDIAgAFEGQDA\nAKIMAGAAUQYAMIAoAwAYQJQBAAwgygAABhBlAAADiDIAgAFEGQDAAKIMAGAAUQYAMIAoAwAYQJQB\nAAwgygAABhBlAAADiDIAgAFEGQDAAKIMAGAAUQYAMIAoAwAYQJQBAAwgygAABlgqyqrqR6vqc1X1\n9cV/3/Qq6+6uqmeq6mJVPbjp8X9fVX9YVU9V1X+rqhuXmQcAYL9a9pmyB5N8vrtPJvn84vivqapD\nSR5Kck+SU0nur6pTi9OfS/IPuvsfJvmjJB9cch4AgH1p2Sg7neSRxf1HkvzsFmtuT3Kxu5/t7peT\nPLq4Lt39u939ymLd40mOLTkPAMC+tGyUvbm7Ly3ufyvJm7dYczTJc5uOn188dqV/keR3lpwHAGBf\nOrzdgqr6vSRv2eLUhzcfdHdXVV/NEFX14SSvJPnka6w5k+RMktx6661X82kAAMbaNsq6+65XO1dV\nf1ZVR7r7UlUdSfLiFsteSHLLpuNji8e+9zF+Ick7ktzZ3a8add19NsnZJFlbW7uq+AMAmGrZly/P\nJXn34v67k/z2FmueSHKyqk5U1Q1J7ltcl6q6O8kHkryzu19achYAgH1r2Sj71SQ/VVVfT3LX4jhV\n9daqOp8kizfyP5DksSRPJ/lUd19YXP+fkvxwks9V1Zer6uEl5wEA2Je2ffnytXT3XyS5c4vHv5nk\n3k3H55Oc32Ld31vm8wMAHBT+oj8AwACiDABgAFEGADCAKAMAGECUAQAMIMoAAAYQZQAAA4gyAIAB\nRBkAwACiDABgAFEGADCAKAMAGECUAQAMIMoAAAYQZQAAA4gyAIABRBkAwACiDABgAFEGADCAKAMA\nGECUAQAMIMoAAAYQZQAAA4gyAIABRBkAwACiDABgAFEGADCAKAMAGECUAQAMIMoAAAYQZQAAA4gy\nAIABRBkAwACiDABgAFEGADCAKAMAGECUAQAMIMoAAAYQZQAAA4gyAIABRBkAwACiDABgAFEGADCA\nKAMAGECUAQAMIMoAAAYQZQAAA4gyAIABRBkAwACiDABggKWirKp+tKo+V1VfX/z3Ta+y7u6qeqaq\nLlbVg1uc/+Wq6qq6aZl5AAD2q2WfKXswyee7+2SSzy+O/5qqOpTkoST3JDmV5P6qOrXp/C1J/nmS\nbyw5CwDAvrVslJ1O8sji/iNJfnaLNbcnudjdz3b3y0keXVz3Pf8xyQeS9JKzAADsW8tG2Zu7+9Li\n/reSvHmLNUeTPLfp+PnFY6mq00le6O6vLDkHAMC+dni7BVX1e0nessWpD28+6O6uqh0/21VVb0zy\noWy8dLmT9WeSnEmSW2+9daefBgBgX9g2yrr7rlc7V1V/VlVHuvtSVR1J8uIWy15Icsum42OLx/5u\nkhNJvlJV33v8i1V1e3d/a4s5ziY5myRra2te6gQADpRlX748l+Tdi/vvTvLbW6x5IsnJqjpRVTck\nuS/Jue7+anf/WHcf7+7j2XhZ87atggwA4KBbNsp+NclPVdXXk9y1OE5VvbWqzidJd7+S5IEkjyV5\nOsmnuvvCkp8XAOBA2fbly9fS3X+R5M4tHv9mkns3HZ9Pcn6bj3V8mVkAAPYzf9EfAGAAUQYAMIAo\nAwAYQJQBAAwgygAABhBlAAADiDIAgAFEGQDAAKIMAGAAUQYAMIAoAwAYQJQBAAwgygAABhBlAAAD\niDIAgAFEGQDAAKIMAGAAUQYAMIAoAwAYQJQBAAwgygAABhBlAAADiDIAgAFEGQDAAKIMAGAAUQYA\nMIAoAwAYQJQBAAwgygAABhBlAAADiDIAgAFEGQDAAKIMAGAAUQYAMIAoAwAYQJQBAAwgygAABhBl\nAAADiDIAgAFEGQDAANXdez3D61ZVl5P86Qo/5E1J/nyFH4/Vsj9z2ZvZ7M9c9ma2Ve/P3+7um7db\ntC+jbNWqar271/Z6DrZmf+ayN7PZn7nszWx7tT9evgQAGECUAQAMIMo2nN3rAXhN9mcuezOb/ZnL\n3sy2J/vjPWUAAAN4pgwAYIADH2VVdXdVPVNVF6vqwS3OV1V9ZHH+qaq6bafXspyr3ZuquqWq/qCq\nvlZVF6rqfbs//cG3zL+dxflDVfWlqvrs7k19fVjy+9qNVfXpqvrDqnq6qv7J7k5/8C25P/9q8X3t\nf1bVf62qv7m70x9sO9ibH6+q/1FV/7eqfuX1XLsS3X1gb0kOJflfSf5OkhuSfCXJqSvW3Jvkd5JU\nkjuSfGGn17rt2d4cSXLb4v4PJ/kjezNnfzad/9dJ/kuSz+7113OQbsvuTZJHkvzi4v4NSW7c66/p\nIN2W/N52NMkfJ/mBxfGnkvzCXn9NB+W2w735sST/KMm/S/Irr+faVdwO+jNltye52N3PdvfLSR5N\ncvqKNaeTfKI3PJ7kxqo6ssNruXpXvTfdfam7v5gk3f2XSZ7OxjczVmeZfzupqmNJfjrJx3Zz6OvE\nVe9NVf2tJP8syW8kSXe/3N3/ezeHvw4s9W8nyeEkP1BVh5O8Mck3d2vw68C2e9PdL3b3E0n+6vVe\nuwoHPcqOJnlu0/Hz+f4f3q+2ZifXcvWW2Zv/r6qOJ3lbki+sfMLr27L78+tJPpDku9dqwOvYMntz\nIsnlJP958dLyx6rqB6/lsNehq96f7n4hyX9I8o0kl5J8u7t/9xrOer1Z5uf6rjTBQY8yDrCq+qEk\nv5Xk/d39nb2ehw1V9Y4kL3b3k3s9C9/ncJLbkny0u9+W5P8k8X7ZIarqTdl49uVEkrcm+cGq+rm9\nnYrddNCj7IUkt2w6PrZ4bCdrdnItV2+ZvUlVvSEbQfbJ7v7MNZzzerXM/rw9yTur6k+y8RT/T1bV\nb167Ua87y+zN80me7+7vPbP86WxEGquzzP7cleSPu/tyd/9Vks8k+afXcNbrzTI/13elCQ56lD2R\n5GRVnaiqG5Lcl+TcFWvOJXnX4rdh7sjG08WXdngtV++q96aqKhvviXm6u39td8e+blz1/nT3B7v7\nWHcfX1z3+93t//ZXZ5m9+VaS56rq7y/W3Znka7s2+fVhmZ8730hyR1W9cfF97s5svGeW1Vjm5/qu\nNMHhVX/ASbr7lap6IMlj2fjNiY9394Wqeu/i/MNJzmfjN2EuJnkpyXte69o9+DIOpGX2JhvPxPx8\nkq9W1ZcXj32ou8/v5tdwkC25P1xDK9ibX0ryycUPlmdj31ZqyZ87X6iqTyf5YpJXknwp/vL/yuxk\nb6rqLUnWk/xIku9W1fuz8VuW39mNJvAX/QEABjjoL18CAOwLogwAYABRBgAwgCgDABhAlAEADCDK\nAAAGEGUAAAOIMgCAAf4fo1Ld896nlp8AAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"oov_train = inject_OOVs(train)\n",
"oov_vocab = set(oov_train)\n",
"oov_test = replace_OOVs(oov_vocab, test)\n",
"bigram = NGramLM(oov_train,2)\n",
"\n",
"interval = [x/1000.0 for x in range(1, 100, 1)]\n",
"perplexity_at_1 = perplexity(LaplaceLM(bigram, alpha=1.0), oov_test)\n",
"\n",
"def plot_perplexities(interval):\n",
" \"\"\"Plots the perplexity of LaplaceLM for every alpha in interval.\"\"\"\n",
" perplexities = [0.0 for alpha in interval] # todo\n",
" plt.plot(interval, perplexities)\n",
" \n",
"def find_optimal(low, high, epsilon=1e-6):\n",
" \"\"\"Returns the optimal pseudo count alpha within the interval [low, high] and the perplexity.\"\"\"\n",
" print(high, low)\n",
" if high - low < epsilon:\n",
" return 0.0 # todo\n",
" else:\n",
" return 0.0 # todo\n",
"\n",
"plot_perplexities(interval) \n",
"find_optimal(0.0, 1.0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Task 2: Sanity Check LM\n",
"Implement a method that tests whether a language model provides a valid probability distribution."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2016-10-21T16:59:19.237379",
"start_time": "2016-10-21T16:59:19.153304"
},
"run_control": {
"frozen": false,
"read_only": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0647115579930904\n"
]
}
],
"source": [
"def sanity_check(lm, *history):\n",
" \"\"\"Throws an AssertionError if lm does not define a valid probability distribution for all words \n",
" in the vocabulary.\"\"\" \n",
" probability_mass = 1.0 # todo\n",
" assert abs(probability_mass - 1.0) < 1e-6, probability_mass\n",
"\n",
"unigram = NGramLM(oov_train,1)\n",
"stupid = StupidBackoff(bigram, unigram, 0.1)\n",
"print(sum([stupid.probability(word, 'the') for word in stupid.vocab]))\n",
"sanity_check(stupid, 'the')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Task 3: Subtract Count LM\n",
"Develop and implement a language model that subtracts a count $d\\in[0,1]$ from each non-zero count in the training set. Let's first formalise this:\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\\begin{align}\n",
"\\#_{w=0}(h_n) &= \\sum_{w \\in V} \\mathbf{1}[\\counts{\\train}{h_n,w} = 0]\\\\\n",
"\\#_{w>0}(h_n) &= \\sum_{w \\in V} \\mathbf{1}[\\counts{\\train}{h_n,w} > 0]\\\\\n",
"\\prob(w|h_n) &= \n",
"\\begin{cases}\n",
"\\frac{\\counts{\\train}{h_n,w} - d}{\\counts{\\train}{h_n}} & \\mbox{if }\\counts{\\train}{h_n,w} > 0 \\\\\\\\\n",
"\\frac{???}{\\counts{\\train}{h_n}} & \\mbox{otherwise}\n",
"\\end{cases}\n",
"\\end{align}"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2016-10-21T16:59:19.337884",
"start_time": "2016-10-21T16:59:19.240468"
},
"run_control": {
"frozen": false,
"read_only": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0\n"
]
},
{
"data": {
"text/plain": [
"inf"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class SubtractCount(CountLM): \n",
" def __init__(self, base_lm, d):\n",
" super().__init__(base_lm.vocab, base_lm.order)\n",
" self.base_lm = base_lm\n",
" self.d = d \n",
" self._counts = base_lm._counts # not good style since it is a protected member\n",
" self.vocab = base_lm.vocab\n",
"\n",
" def counts(self, word_and_history):\n",
" if self._counts[word_and_history] > 0:\n",
" return 0.0 # todo\n",
" else:\n",
" return 0.0 # todo\n",
"\n",
" def norm(self, history):\n",
" return self.base_lm.norm(history) \n",
" \n",
"subtract_lm = SubtractCount(unigram, 0.1)\n",
"oov_prob = subtract_lm.probability(OOV, 'the')\n",
"rest_prob = sum([subtract_lm.probability(word, 'the') for word in subtract_lm.vocab])\n",
"print(oov_prob + rest_prob)\n",
"sanity_check(subtract_lm, 'the')\n",
"perplexity(subtract_lm, oov_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Task 4: Normalisation of Stupid LM\n",
"Develop and implement a version of the [stupid language model](https://github.com/uclmr/stat-nlp-book/blob/python/statnlpbook/lm.py#L205) that provides probabilities summing up to 1."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2016-10-21T16:59:19.398354",
"start_time": "2016-10-21T16:59:19.339446"
},
"run_control": {
"frozen": false,
"read_only": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0\n"
]
},
{
"data": {
"text/plain": [
"inf"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class StupidBackoffNormalized(LanguageModel):\n",
" def __init__(self, main, backoff, alpha):\n",
" super().__init__(main.vocab, main.order)\n",
" self.main = main\n",
" self.backoff = backoff\n",
" self.alpha = alpha \n",
"\n",
" def probability(self, word, *history):\n",
" return 0.0 # todo\n",
" \n",
"less_stupid = StupidBackoffNormalized(bigram, unigram, 0.1)\n",
"print(sum([less_stupid.probability(word, 'the') for word in less_stupid.vocab]))\n",
"sanity_check(less_stupid, 'the')\n",
"perplexity(less_stupid, oov_test)"
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}