{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "From this blog: [How to Implement a Beam Search Decoder for Natural Language Processing](https://machinelearningmastery.com/beam-search-decoder-natural-language-processing/)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import math\n", "import numpy as np\n", "from pprint import pprint" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data = [[0.1, 0.2, 0.3, 0.4, 0.5],\n", " [0.5, 0.4, 0.3, 0.2, 0.1],\n", " [0.1, 0.2, 0.3, 0.4, 0.5],\n", " [0.5, 0.4, 0.3, 0.2, 0.1],\n", " [0.1, 0.2, 0.3, 0.4, 0.5],\n", " [0.5, 0.4, 0.3, 0.2, 0.1],\n", " [0.1, 0.2, 0.3, 0.4, 0.5],\n", " [0.5, 0.4, 0.3, 0.2, 0.1],\n", " [0.1, 0.2, 0.3, 0.4, 0.5],\n", " [0.5, 0.4, 0.3, 0.2, 0.1]]\n", "data = np.array(data)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def greedy_decoder(data):\n", " return [np.argmax(s) for s in data]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# beam search\n", "def beam_search_decoder(data, k):\n", " sequences = [[list(), 1.0]]\n", " for row in data:\n", " all_candidates = list()\n", " for i in range(len(sequences)):\n", " seq, score = sequences[i]\n", " for j in range(len(row)):\n", " candidate = [seq + [j], score * -math.log(row[j]+1e-100)]\n", " all_candidates.append(candidate)\n", "# pprint((\"all candidates: \", all_candidates))\n", " ordered = sorted(all_candidates, key=lambda tup:tup[1])\n", "# pprint((\"ordered: \", ordered))\n", " print(\"=\"*50)\n", " sequences = ordered[:k]\n", " pprint((\"sequence: \", sequences))\n", " return sequences" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[4, 0, 4, 0, 4, 0, 4, 0, 4, 0]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "greedy_decoder(data)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==================================================\n", "('sequence: ',\n", " [[[4], 0.6931471805599453],\n", " [[3], 0.916290731874155],\n", " [[2], 1.2039728043259361]])\n", "==================================================\n", "('sequence: ',\n", " [[[4, 0], 0.4804530139182014],\n", " [[4, 1], 0.6351243373717793],\n", " [[3, 0], 0.6351243373717793]])\n", "==================================================\n", "('sequence: ',\n", " [[[4, 0, 4], 0.33302465198892944],\n", " [[4, 0, 3], 0.4402346437542523],\n", " [[4, 1, 4], 0.4402346437542523]])\n", "==================================================\n", "('sequence: ',\n", " [[[4, 0, 4, 0], 0.23083509858308343],\n", " [[4, 0, 3, 0], 0.3051474021030719],\n", " [[4, 1, 4, 0], 0.3051474021030719]])\n", "==================================================\n", "('sequence: ',\n", " [[[4, 0, 4, 0, 4], 0.1600026977571413],\n", " [[4, 0, 3, 0, 4], 0.21151206142293622],\n", " [[4, 1, 4, 0, 4], 0.21151206142293622]])\n", "==================================================\n", "('sequence: ',\n", " [[[4, 0, 4, 0, 4, 0], 0.11090541883234757],\n", " [[4, 0, 4, 0, 4, 1], 0.1466089890297302],\n", " [[4, 0, 3, 0, 4, 0], 0.1466089890297302]])\n", "==================================================\n", "('sequence: ',\n", " [[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158],\n", " [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145],\n", " [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]])\n", "==================================================\n", "('sequence: ',\n", " [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184],\n", " [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441],\n", " [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441]])\n", "==================================================\n", "('sequence: ',\n", " [[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901],\n", " [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468],\n", " [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]])\n", "==================================================\n", "('sequence: ',\n", " [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],\n", " [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],\n", " [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]])\n" ] }, { "data": { "text/plain": [ "[[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],\n", " [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],\n", " [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "beam_search_decoder(data, 3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": 