{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.special import logsumexp\n",
    "from sklearn.datasets import fetch_20newsgroups\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.naive_bayes import MultinomialNB as skMultinomialNB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultinomialNB():\n",
    "    def __init__(self, alpha=1.0):\n",
    "        self.alpha = alpha\n",
    "\n",
    "    def _encode(self, y):\n",
    "        classes = np.unique(y)\n",
    "        y_train = np.zeros((y.shape[0], len(classes)))\n",
    "        for i, c in enumerate(classes):\n",
    "            y_train[y == c, i] = 1\n",
    "        return classes, y_train\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        self.classes_, y_train = self._encode(y)\n",
    "        self.feature_count_ = np.dot(y_train.T, X)\n",
    "        self.class_count_ = y_train.sum(axis=0)\n",
    "        smoothed_fc = self.feature_count_ + self.alpha\n",
    "        smoothed_cc = smoothed_fc.sum(axis=1)\n",
    "        self.feature_log_prob_ = (np.log(smoothed_fc) -\n",
    "                                  np.log(smoothed_cc.reshape(-1, 1)))\n",
    "        self.class_log_prior_ = np.log(self.class_count_) - np.log(self.class_count_.sum())\n",
    "        return self\n",
    "\n",
    "    def _joint_log_likelihood(self, X):\n",
    "        return np.dot(X, self.feature_log_prob_.T) + self.class_log_prior_\n",
    "\n",
    "    def predict(self, X):\n",
    "        joint_log_likelihood = self._joint_log_likelihood(X)\n",
    "        return self.classes_[np.argmax(joint_log_likelihood, axis=1)]\n",
    "\n",
    "    def predict_proba(self, X):\n",
    "        joint_log_likelihood = self._joint_log_likelihood(X)\n",
    "        log_prob = joint_log_likelihood - logsumexp(joint_log_likelihood, axis=1)[:, np.newaxis]\n",
    "        return np.exp(log_prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_train = fetch_20newsgroups()\n",
    "X, y = data_train.data, data_train.target\n",
    "# convert to dense since we do not support sparse very well\n",
    "X = CountVectorizer(min_df=0.001).fit_transform(X).toarray()\n",
    "clf1 = MultinomialNB().fit(X, y)\n",
    "clf2 = skMultinomialNB().fit(X, y)\n",
    "assert np.allclose(clf1.feature_log_prob_, clf2.feature_log_prob_)\n",
    "assert np.allclose(clf1.class_log_prior_, clf2.class_log_prior_)\n",
    "prob1 = clf1._joint_log_likelihood(X)\n",
    "prob2 = clf2._joint_log_likelihood(X)\n",
    "assert np.allclose(prob1, prob2)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.array_equal(pred1, pred2)\n",
    "prob1 = clf1.predict_proba(X)\n",
    "prob2 = clf2.predict_proba(X)\n",
    "assert np.allclose(prob1, prob2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}