{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import itertools\n", "from collections import Counter" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "traind = [\n", " \"just plain boring\".split(),\n", " \"entirely predictable and lacks energy\".split(),\n", " \"no surprises and very few laughs\".split(),\n", " \"very powerful\".split(),\n", " \"the most fun film of the summer\".split()]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "yd = [0, 0, 0, 1, 1]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "testd = [ \"predictable\", \"with\", \"no\", \"fun\"]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def train_naive_bayes(X, y):\n", " Ndoc = len(X)\n", " logpc = {}\n", " bigdoc = {}\n", " logpwc = {}\n", " V = set(itertools.chain(*X))\n", " for i, c in enumerate(list(set(y))):\n", " cindex = [_i for _i,_c in enumerate(y) if _c == c]\n", " Nc = len(cindex)\n", " logpc[c] = np.log(Nc/Ndoc)\n", " bigdoc[c] = list(itertools.chain(*[X[_i] for _i in cindex]))\n", " for w in V:\n", " countc = Counter(bigdoc[c])\n", " count_wc = countc[w]\n", " logpwc[w,c] = np.log((count_wc+1)/(sum(countc.values())+len(V)))\n", " return logpc, logpwc, V" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "logpc, logpwc, V = train_naive_bayes(traind, yd)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def test_naive_bayes(testd, logpc, logpwc, y ,V):\n", " sum_c = {}\n", " for c in set(y):\n", " sum_c[c] = logpc[c]\n", " for i in range(len(testd)):\n", " w = testd[i]\n", " if w in V:\n", " sum_c[c] = sum_c[c] + logpwc[w,c]\n", " sortres = sorted(sum_c.items(), key=lambda x: x[1])\n", " return sortres[0][0]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_naive_bayes(testd, logpc, logpwc, yd, V)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.2" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }