{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# SCat Algorithm\n", "\n", "I have pushed multiple iterations of a simple sequence labeling algorithm called SCat. This iteration is the first version that will be implemented in Kadot to be part of the *Bot Engine*.\n", "\n", "The final algorithm use a mono-directional (because we read in only one way) LSTM to encode features from the text and we decode the feature using a simple layer to produce the final classification.\n", "\n", "![SCat Algorithm](https://raw.githubusercontent.com/the-new-sky/Kadot/1.0dev/RaD/SCat.jpg \"SCat Algorithm\")\n", "\n", "Let's start by defining our small training dataset." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Find the city in a weather related query\n", "\n", "train_x = [\n", " \"What is the weather like in Paris ?\",\n", " \"What kind of weather will it do in London ?\",\n", " \"Give me the weather forecast in Berlin please .\",\n", " \"Tell me the forecast in New York !\",\n", " \"Give me the weather in San Francisco ...\",\n", " \"I want the forecast in Dublin .\"\n", "]\n", "\n", "train_y = [\n", " ('Paris',),\n", " ('London',),\n", " ('Berlin',),\n", " ('New', 'York'),\n", " ('San', 'Francisco'),\n", " ('Dublin',)\n", "]" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "tokenizer = lambda x: x.split(' ')\n", "vocabulary = sorted(set(tokenizer(' '.join(train_x) + ' ')))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now import all the modules we will need." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.autograd import Variable\n", "from torch import nn, optim\n", "import torch.nn.functional as F\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's now time to write our network as described before." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class SCatNetwork(nn.Module):\n", " def __init__(self, vocabulary_size, embedding_dimension, hidden_size):\n", " super(SCatNetwork, self).__init__()\n", " \n", " self.embeddings = nn.Embedding(vocabulary_size, embedding_dimension)\n", " self.encoder = nn.LSTM( # a LSTM layer to encode features\n", " embedding_dimension,\n", " hidden_size,\n", " batch_first=True,\n", " )\n", " self.decoder = nn.Linear(hidden_size, 2)\n", " \n", " def forward(self, inputs):\n", " hc = (Variable(torch.ones(1, 1, 10)), Variable(torch.ones(1, 1, 10)))\n", " \n", " outputs = self.embeddings(inputs)\n", " outputs, _ = self.encoder(outputs, hc) \n", " outputs = self.decoder(outputs)\n", " return outputs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will also need a function to transform a sentence into a list (LongTensor) of index." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def vectorizer(sentence):\n", " tokens = tokenizer(sentence)\n", " vector = [vocabulary.index(token) if token in vocabulary else vocabulary.index('') for token in tokens]\n", " \n", " return torch.LongTensor(vector) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before training, let's test the whole network." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "What is the weather like in Paris ?\n", "Variable containing:\n", " 15 20 27 29 23 19 12 4\n", "[torch.LongTensor of size 1x8]\n", "\n", "SCatNetwork(\n", " (embeddings): Embedding(31, 20)\n", " (encoder): LSTM(20, 10, batch_first=True)\n", " (decoder): Linear(in_features=10, out_features=2, bias=True)\n", ")\n", "\n" ] }, { "data": { "text/plain": [ "Variable containing:\n", "(0 ,.,.) = \n", " 0.1471 0.0173\n", " 0.0750 0.0669\n", " -0.0199 -0.0658\n", " -0.0773 -0.2239\n", " 0.0530 -0.2327\n", " -0.1172 -0.2223\n", " -0.0828 -0.2067\n", " 0.0318 -0.1589\n", "[torch.FloatTensor of size 1x8x2]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(train_x[0])\n", "\n", "test_vec = Variable(vectorizer(train_x[0]).view(1, len(tokenizer(train_x[0]))))\n", "test_net = SCatNetwork(len(vocabulary), 20, 10)\n", "\n", "print(test_vec)\n", "print(str(test_net))\n", "print()\n", "\n", "test_net(test_vec)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's train the network with those hyperparameters :\n", "\n", " - **epoch** : 800\n", " - **learning rate** : 0.01\n", " - **optimizer** : Adam\n", " - **Loss function** : Cross-Entropy\n", " - **embedding dimension** : 20\n", " - **hidden layer size** : 10\n", " - **min probability tolerance** : 0.0001\n", " - **mean output ratio** : 75%\n", " - **N-gram range**: 1 - 10" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0 - Loss : 0.638474315404892\n", "Epoch 80 - Loss : 0.0003119074899586849\n", "Epoch 160 - Loss : 9.471087954201114e-05\n", "Epoch 240 - Loss : 4.483968010996856e-05\n", "Epoch 320 - Loss : 2.5974186125192016e-05\n", "Epoch 400 - Loss : 1.6693455563654425e-05\n", "Epoch 480 - Loss : 1.1401202148893693e-05\n", "Epoch 560 - Loss : 8.116763865473331e-06\n", "Epoch 640 - Loss : 5.943054247836699e-06\n", "Epoch 720 - Loss : 4.4373889143874594e-06\n", "Epoch 800 - Loss : 3.360752922768976e-06\n" ] } ], "source": [ "n_epoch = 801\n", "learning_rate = 0.01\n", "mean_ratio = 0.75\n", "min_tolerance = 1e-04\n", "n_range = (1, 10+1)\n", "\n", "model = SCatNetwork(len(vocabulary), 20, 10)\n", "optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n", "criterion = nn.CrossEntropyLoss()\n", "\n", "for epoch in range(n_epoch):\n", " epoch_losses = []\n", " \n", " for sentence, goal in zip(train_x, train_y):\n", " sentence_lenght = len(tokenizer(sentence))\n", " goal = [1 if word in goal else 0 for word in tokenizer(sentence)]\n", " \n", " x = Variable(vectorizer(sentence).view(1, sentence_lenght))\n", " y = Variable(torch.LongTensor(goal))\n", " \n", " model.zero_grad()\n", " \n", " preds = model(x)[0]\n", " loss = criterion(preds, y)\n", " epoch_losses.append(float(loss))\n", " \n", " loss.backward()\n", " optimizer.step()\n", " \n", " if epoch % 80 == 0:\n", " mean_loss = torch.FloatTensor(epoch_losses).mean()\n", " print(\"Epoch {} - Loss : {}\".format(epoch, float(mean_loss)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's test the model with this sentence : *Give me the latest weather forecast in Los Angeles*" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "x = \"Give me the latest weather forecast in Los Angeles\"\n", "tokens = tokenizer(x)\n", "x_vec = Variable(vectorizer(x).view(1, len(tokens)))\n", "\n", "pred = F.softmax(model(x_vec), dim=2)[0, :,1].data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Visualize the output of the network and apply a special transformation to it." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEACAYAAACj0I2EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAADwFJREFUeJzt3G2spGV9x/Hvb1mhIj5VW2l3AatosbxBbZCotWMg3dUSNk36wNpGaoxt0lKNjQ3ENxzSV9g01kYTNaVUBV0j1kobrKulY2oqsBUouu6yq0TYBSQVn9KamhX+fTE329nj7MwszDn37LXfT7I5M3Oumfufs2e/e881Z06qCknS8W9D3wNIkhbDoEtSIwy6JDXCoEtSIwy6JDXCoEtSI2YGPcm1SR5OcveUNX+dZH+Su5Kct9gRJUnzmOcM/Tpgy9E+meR1wAur6kXAHwLvX9BskqRjMDPoVfVF4LtTlmwDPtytvQ14ZpLnLWY8SdK8FrGHvgk4MHb9ge42SdI6WkTQM+E2f5+AJK2zjQt4jIPAGWPXNwMPTlqYxNBL0hNQVZNOno8w7xl6mHwmDnAT8EaAJBcA36uqh6cMtVR/rrrqqt5nOB5mWta5nOn4nakrQg9/pndoeb9Ws808Q0/yUWAAPCfJ/cBVwMmjr0l9sKpuTvL6JF8H/gd409xHlyQtzMygV9Ub5lhz+WLGkSQ9USf8O0UHg0HfI/yEZZwJlnMuZ5rPMs60rI7nr1WOZX/mSR8sqfU8nqTllYR+fiAux7QvvQySUAt8UVSStOQMuiQ1wqBLUiMMuiSNOf3055NkXf+cfvrzFzK7L4pK6sWyvijaz1yzZ/JFUUk6gRh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRswV9CRbk+xNsi/JFRM+f0aSW5LckeSuJK9b/KiSpGlSVdMXJBuAfcCFwIPALuDSqto7tuYDwB1V9YEkLwFurqpfmPBYNet4kk4MSYA+ehCmdaifuWbPVFWZ9SjznKGfD+yvqvuq6hCwA9i2as1jwDO6y88CHpjjcSVJC7RxjjWbgANj1w8yivy4q4GdSd4KnApctJjxJEnzmifok07zVz832A5cV1XvTnIBcD1w7qQHW1lZOXx5MBgwGAzmGlSSThTD4ZDhcHjM95tnD/0CYKWqtnbXrwSqqq4ZW/NVYEtVPdBd/wbwiqr69qrHcg9dEuAe+qqjrtse+i7g7CRnJTkZuBS4adWa++i2WboXRU9ZHXNJ0tqaGfSqehS4HNgJ7AZ2VNWeJFcnubhb9g7gLUnuAm4ALlurgSVJk83cclnowdxykdRxy+WIo67blosk6Thg0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhoxV9CTbE2yN8m+JFccZc1vJ9md5CtJrl/smJKkWVJV0xckG4B9wIXAg8Au4NKq2ju25mzg48Brq+oHSZ5bVd+e8Fg163iSTgxJgD56EKZ1qJ+5Zs9UVZn1KPOcoZ8P7K+q+6rqELAD2LZqzVuA91XVDwAmxVyStLbmCfom4MDY9YPdbeNeDPxiki8m+fckWxY1oCRpPhvnWDPpNH/1c4ONwNnAa4AzgX9Lcu7jZ+ySpLU3T9APMor04zYz2ktfveZLVfUY8M0k9wAvAr68+sFWVlYOXx4MBgwGg2ObWJIaNxwOGQ6Hx3y/eV4UPQm4h9GLog8BtwPbq2rP2Jot3W2/n+S5jEJ+XlV9d9Vj+aKoJMAXRVcddX1eFK2qR4HLgZ3AbmBHVe1JcnWSi7s1nwUeSbIb+BfgHatjLklaWzPP0Bd6MM/QJXU8Qz/iqOv2Y4uSpOOAQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRswV9CRbk+xNsi/JFVPW/WaSx5K8bHEjSpLmMTPoSTYA7wW2AOcC25OcM2HdacCfALcuekhJ0mzznKGfD+yvqvuq6hCwA9g2Yd2fA9cAP1rgfJKkOc0T9E3AgbHrB7vbDktyHrC5qm5e4GySpGOwcY41mXBbHf5kEuDdwGUz7iNJWkPzBP0gcObY9c3Ag2PXn85ob33Yxf104NNJLqmqO1Y/2MrKyuHLg8GAwWBw7FNLUsOGwyHD4fCY75eqmr4gOQm4B7gQeAi4HdheVXuOsv5fgT+tqjsnfK5mHU/SiWF0/tdHD8K0DvUz1+yZqmrmzsfMPfSqehS4HNgJ7AZ2VNWeJFcnuXjSXXDLRZLW3cwz9IUezDN0SR3P0I846vqcoUuSjg8GXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaMVfQk2xNsjfJviRXTPj825PsTnJXks8lOWPxo0qSppkZ9CQbgPcCW4Bzge1Jzlm17A7g5VV1HvBJ4C8WPagkabp5ztDPB/ZX1X1VdQjYAWwbX1BVX6iq/+2u3gpsWuyYkqRZ5gn6JuDA2PWDTA/2m4HPPJmhJEnHbuMcazLhtpq4MPk94OXArx7twVZWVg5fHgwGDAaDOUaQpBPHcDhkOBwe8/1SNbHN/78guQBYqaqt3fUrgaqqa1atuwh4D/CaqnrkKI9Vs44n6cSQhKOcG671kZnWoX7mmj1TVU06uT7CPFsuu4Czk5yV5GTgUuCmVQd7KfB+4JKjxVyStLZmBr2qHgUuB3YCu4EdVbUnydVJLu6WvQt4GvCJJHcm+Yc1m1iSNNHMLZeFHswtF0kdt1yOOOq6bblIko4DBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRcwU9ydYke5PsS3LFhM+fnGRHkv1JvpTkzMWPKkmaZmbQk2wA3gtsAc4Ftic5Z9WyNwPfqaoXAX8FvGvRg66V4XDY9wg/YRlnguWcy5nms4wzafHmOUM/H9hfVfdV1SFgB7Bt1ZptwIe6yzcCFy5uxLW1jN/oyzgTLOdczjSfZZxJizdP0DcBB8auH+xum7imqh4FvpfkpxcyoSRpLvMEPRNuqxlrMmGNJGkNpWp6d5NcAKxU1dbu+pVAVdU1Y2s+0625LclJwENV9bMTHsvIS9ITUFWTTq6PsHGOx9kFnJ3kLOAh4FJg+6o1/whcBtwG/BZwyxMdSJL0xMwMelU9muRyYCejLZprq2pPkquBXVX1T8C1wEeS7AceYRR9SdI6mrnlIkk6PqzbO0VnvTlpvSW5NsnDSe7ue5bHJdmc5JYkX0vylSRvXYKZTklyW5I7u5mu6numxyXZkOSOJDf1PcvjknwzyX92X6/b+54HIMkzk3wiyZ4ku5O8oud5Xtx9fe7oPn5/Sb7X357kq0nuTnJDkpOXYKa3df/u5utBVa35H0b/cXwdOAt4CnAXcM56HHvKTK8GzgPu7nOOVTOdDpzXXT4NuKfvr1M3y6ndx5OAW4Hz+56pm+ftwPXATX3PMjbTvcCz+55j1Ux/B7ypu7wReEbfM43NtgF4EDij5zl+vvu7O7m7/nHgjT3PdC5wN3BK92/vc8ALp91nvc7Q53lz0rqqqi8C3+1zhtWq6ltVdVd3+b+BPfzkz/yvu6r6YXfxFEZB6H2fLslm4PXA3/Q9yyphiX5HUpKnA79SVdcBVNWPq+oHPY817iLgG1V1YObKtXcS8LQkG4FTGf1H06eXALdW1Y9q9P6eLwC/Me0O6/WNN8+bkzQmyfMZPYO4rd9JDm9t3Al8C/hcVe3qeybg3cCfsQT/uaxSwGeT7Erylr6HAV4AfDvJdd0WxweTPLXvocb8DvCxvoeoqgeBvwTuBx4AvldVn+93Kr4KvCbJs5OcyugE5oxpd1ivoM/z5iR1kpzG6FcovK07U+9VVT1WVS8FNgOvSPJLfc6T5NeBh7tnM2Hy91dfXllVv8zoH98fJ3l1z/NsBF4GvK+qXgb8ELiy35FGkjwFuAT4xBLM8ixGuwZnMdp+OS3JG/qcqar2AtcAnwduZrRV/eNp91mvoB8Exn8D42b6fzqzlLqnezcCH6mqT/c9z7juqfoQ2NrzKK8CLklyL6Ozu9cm+XDPMwGjbbPu438Bn2K03ding8CBqvqP7vqNjAK/DF4HfLn7WvXtIuDeqvpOt73x98Are56Jqrquql5eVQNGW8T7p61fr6AffnNS98rxpcAy/GTCsp3dAfwt8LWqek/fgwAkeW6SZ3aXn8roG39vnzNV1Tur6syqegGj76VbquqNfc4EkOTU7tkVSZ4G/Bqjp829qaqHgQNJXtzddCHwtR5HGredJdhu6dwPXJDkp5KE0ddpT88zkeRnuo9nMto/n/r1muedok9aHeXNSetx7KNJ8lFgADwnyf3AVY+/cNTjTK8Cfhf4SrdnXcA7q+qfexzr54APdb9GeQPw8aq6ucd5ltnzgE91v+JiI3BDVe3seSaAtwI3dFsc9wJv6nme8ZODP+h7FoCquj3JjcCdwKHu4wf7nQqAT3a/6PAQ8EdV9f1pi31jkSQ1Yml+vEqS9OQYdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqxP8BAAaXAafVOMIAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEACAYAAACwB81wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAD2RJREFUeJzt3X+s3Xddx/HnayubjunkZyvt1sov+RHNADMqoF4DkW0kqxgJIAkwjRBlGcGEgISk5U9MDD/EBKdzbAgOmPzYcMrAeTSobBVWO1wLxcHWMlp/bMPAjJb69o/z7bzr7u09t+fb+/2Wz/ORnPT8+NzzeeXce1/ncz7nfk9TVUiSvv+dNnQASdLasPAlqREWviQ1wsKXpEZY+JLUCAtfkhoxd+En2ZTk5iR3JLk9yeXLjHtvkn1JdiU5f955JUmrs66H+/ge8FtVtSvJ2cAXk9xUVXuPDkhyEfCkqnpKkucC7we29jC3JGlGc6/wq+pgVe3qzn8H2ANsPGbYNuCabswtwDlJ1s87tyRpdr3u4SfZApwP3HLMTRuB/Ysuf5OHPylIkk6i3gq/2865Dnhjt9J/yM1LfImf6SBJa6iPPXySrGNa9h+sqk8tMeQAcO6iy5uAe5a4H58EJOkEVNVSC+uH6GuF/8fAHVX1nmVuvx54NUCSrcD9VXVoqYFVNbrT9u3bB89gJjO1mMtMs51mNfcKP8nzgVcBtye5jelWzduAzdP+riuq6sYkFyf5GvBd4NJ555Ukrc7chV9VfwecPsO4y+adS5J04jzSdgYLCwtDR3gYM83GTLMbYy4z9Sur2f852ZLUmPJIGtaGDVs4dOiuNZ93/frNHDz4jTWf90QloWZ409bClzRaSRjmL7izqjdDhzZr4bulI0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SVqFDRu2kGTNTxs2bJk7uwdeSRqtMR54NdZMHnglSXqQhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSI3op/CRXJjmUZPcyt/9ckvuTfKk7vb2PeSVJs1vX0/1cBfwecM1xxvxtVV3S03ySpFXqZYVfVZ8H7lthWPqYS5J0YtZyD39rktuS/HmSZ6zhvJIk+tvSWckXgc1V9UCSi4BPAk9dauCOHTsePL+wsMDCwsJa5JOkU8ZkMmEymaz661JVvQRIshm4oap+coaxXweeU1X3HnN99ZVH0qkvCTBEJ4TlumismapqxW3zPrd0wjL79EnWLzp/AdMnmnuXGitJOjl62dJJ8mFgAXhMkruB7cAZQFXVFcAvJ/kN4DDwX8DL+5hXkjS73rZ0+uCWjqTFxrp9MsZMa72lI0kaMQtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiN6KfwkVyY5lGT3cca8N8m+JLuSnN/HvJKk2fW1wr8KePFyNya5CHhSVT0FeD3w/p7mlSTNqJfCr6rPA/cdZ8g24Jpu7C3AOUnW9zG3JGk2a7WHvxHYv+jyN7vrJElrZK0KP0tcV2s0tyQJWLdG8xwAzl10eRNwz1IDd+zY8eD5hYUFFhYWTmYuSTrlTCYTJpPJqr8uVf0stJNsAW6oqp9Y4raLgTdU1UuSbAXeXVVblxhXfeWRdOpLwjCbAWG5LhprpqpaaiflIXpZ4Sf5MLAAPCbJ3cB24AygquqKqroxycVJvgZ8F7i0j3klSbPrbYXfB1f4khYb62p6jJlmWeF7pK0kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGtFL4Se5MMneJF9N8pYlbn9Nkn9N8qXu9Kt9zCtJmt26ee8gyWnA+4AXAvcAO5N8qqr2HjP02qq6fN75JEknpo8V/gXAvqq6q6oOA9cC25YYlx7mkiSdoD4KfyOwf9HlA911x/qlJLuSfDTJph7mlSStQh+Fv9TKvY65fD2wparOB/4KuLqHeSVJqzD3Hj7TFf15iy5vYrqX/6Cqum/RxT8E3rncne3YsePB8wsLCywsLPQQUZK+f0wmEyaTyaq/LlXHLsZXeQfJ6cBXmL5p+y3gVuCVVbVn0ZgNVXWwO/9S4M1V9bwl7qvmzSPp+0cSHr5hsCYzs1wXjTVTVa34PuncK/yqOpLkMuAmpltEV1bVniTvAHZW1aeBy5NcAhwG7gVeO++8kqTVmXuF3ydX+JIWG+tqeoyZZlnhe6StJDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhrRS+EnuTDJ3iRfTfKWJW4/I8m1SfYl+Yck5/UxryRpdnMXfpLTgPcBLwaeCbwyydOOGfZrwL1V9RTg3cDvzDuvJGl1+ljhXwDsq6q7quowcC2w7Zgx24Cru/PXAS/sYV5J0ir0Ufgbgf2LLh/orltyTFUdAe5P8uge5pYkzWhdD/eRJa6rFcZkiTHTG7LU3Z1c69dv5uDBbyx7+4YNWzh06K61C8Q4M8Hxc5np/51qmWCcP+fr12/m0KFhOuF4tw2daTKZMJlMVn0fqVqyd2e/g2QrsKOqLuwuvxWoqnrnojF/0Y25JcnpwLeq6vFL3Fct8zxwkoXjPQ7TJ6G1zjXGTHC8XGZ6yMynVCYY58+5ZpOEqlrxWaiPLZ2dwJOTbE5yBvAK4PpjxtwAvKY7/zLg5h7mlSStwtxbOlV1JMllwE1Mn0CurKo9Sd4B7KyqTwNXAh9Msg/4D6ZPCpKkNTT3lk6f3NJ5yKwjzASn2laFmR4y8wh/ptzS6cNabulIkk4BFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaMcLCz5qf1q/ffNxE09vNtFIuM526mYbKtVIm9StVNXSGByWpMeWRpFNBEqoqK40b4QpfknQyWPiS1AgLX5IaYeFLUiPmKvwkj0pyU5KvJPlMknOWGXckyZeS3Jbkk/PMKUk6MfOu8N8KfK6qfhy4GfjtZcZ9t6qeXVXPqqpfnHPONTeZTIaO8DBmmo2ZZjfGXGbq17yFvw24ujt/NbBcma/450JjNsZvsJlmY6bZjTGXmfo1b+E/vqoOAVTVQeBxy4w7M8mtSf4+ybY555QknYB1Kw1I8llg/eKrgALevop5zquqg0l+DLg5ye6q+vrqokqS5jHXkbZJ9gALVXUoyQbgr6vq6St8zVXADVX18SVu8zBbSToBsxxpu+IKfwXXA68F3gm8BvjUsQOS/AjwQFX9T5LHAs/rxj/MLIElSSdm3hX+o4GPAucCdwMvq6r7kzwHeH1VvS7JTwN/ABxh+p7Bu6rqA3MnlyStyqg+PE2SdPKM5kjbJBcm2Zvkq0neMnQegCRXJjmUZPfQWQCSbEpyc5I7ktye5PKhMwEkOTPJLd2Bdbcn2T50pqOSnNYd9Hf90FkAknwjyT91j9WtQ+cBSHJOko8l2ZPkn5M8d+A8T+0en6MHa357RD/rb0ry5SS7k3woyRkjyPTG7vdu5U6oqsFPTJ94vgZsBh4B7AKeNoJcLwDOB3YPnaXLswE4vzt/NvCVMTxOXZ6zun9PB74AXDB0pi7Pm4A/Aa4fOkuX507gUUPnOCbTB4BLu/PrgB8eOtOibKcB9wDnjiDLE7rv3xnd5Y8Arx440zOB3cCZ3e/eZ4EnLTd+LCv8C4B9VXVXVR0GrmV6UNegqurzwH1D5ziqqg5W1a7u/HeAPcDGYVNNVdUD3dkzmZbG4HuFSTYBFwN/NHSWRcK4Xln/EPAzVXUVQFV9r6r+c+BYi70I+Jeq2j90kM7pwCOTrAPOYvpkNKSnA1+oqv+uqiPA3wAvXW7wWH7wNgKLv6EHGEmRjVWSLUxffdwybJKpbuvkNuAg8Nmq2jl0JuBdwJsZwZPPIgV8JsnOJL8+dBjgicC/J7mq20K5IskPDh1qkZcDfzp0CICqugf4XaZ/oPJN4P6q+tywqfgy8LPd55qdxXSBc+5yg8dS+Ev9OeaYfklHJcnZwHXAG7uV/uCq6n+r6lnAJuC5SZ4xZJ4kLwEOda+Ijv6femPwvKr6Kaa/mG9I8oKB86wDng38flU9G3iA6WdkDS7JI4BLgI8NnQUe/BPzbUy3np8AnJ3kV4bMVFV7mf6Z++eAG5luh39vufFjKfwDwHmLLm9i+JdKo9S9lLwO+GBVPey4h6F12wET4MKBozwfuCTJnUxXiD+f5JqBM1HTjyChqv4N+ATT7cwhHQD2V9U/dpevY/oEMAYXAV/sHqsxeBFwZ1Xd222ffJzpcUWDqqqrquo5VbXAdAt633Jjx1L4O4EnJ9ncvev9CqYHdY3BmFaHAH8M3FFV7xk6yFFJHnv0o7G77YAXAXuHzFRVb6uq86rqiUx/nm6uqlcPmSnJWd2rM5I8EvgFpi/JB1PTz8Lan+Sp3VUvBO4YMNJir2Qk2zmdu4GtSX4gSZg+VnsGzkSSx3X/nsd0/37Zx2zeI217UVVHklwG3MT0SejKqhrDA/lhYAF4TJK7ge1H39waKM/zgVcBt3f75QW8rar+cqhMnR8Frk5yGtPv30eq6saBM43ReuAT3UeIrAM+VFU3DZwJ4HLgQ90Wyp3ApQPnWbxweN3QWY6qqluTXAfcBhzu/r1i2FQA/Fl3EOxh4Der6tvLDfTAK0lqxFi2dCRJJ5mFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSI/4Pd5ssr+xQBXYAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.figure()\n", "plt.bar(range(len(pred)), pred.tolist())\n", "\n", "# Apply the special transformation\n", "pred = (pred - min_tolerance - mean_ratio * pred.mean()) / pred.std()\n", "\n", "plt.figure()\n", "plt.bar(range(len(pred)), pred.tolist())" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('Give', -0.37811943888664246),\n", " ('me', -0.37811994552612305),\n", " ('the', -0.37812015414237976),\n", " ('latest', -0.37789979577064514),\n", " ('weather', -0.3781139552593231),\n", " ('forecast', -0.37811869382858276),\n", " ('in', -0.37810349464416504),\n", " ('Los', 1.9355530738830566),\n", " ('Angeles', 1.8426263332366943)]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "word_with_scores = list(zip(tokens, pred.tolist()))\n", "word_with_scores" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, find the most scored n-gram in the sentence." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('Los', 1.9355530738830566), ('Angeles', 1.8426263332366943)]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grams_with_scores = sum([list(zip(*[word_with_scores[i:] for i in range(n)])) for n in range(*n_range)], [])\n", "grams_with_scores.append([('', 0)])\n", "\n", "summed_gram_scores = [sum(list(zip(*gram))[1]) for gram in grams_with_scores]\n", "\n", "best_gram = list(grams_with_scores[summed_gram_scores.index(max(summed_gram_scores))])\n", "best_gram" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }