{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.special import logsumexp\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.naive_bayes import GaussianNB as skGaussianNB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GaussianNB():\n",
    "    def fit(self, X, y):\n",
    "        self.classes_ = np.unique(y)\n",
    "        n_features = X.shape[1]\n",
    "        n_classes = len(self.classes_)\n",
    "        self.theta_ = np.zeros((n_classes, n_features))\n",
    "        self.sigma_ = np.zeros((n_classes, n_features))\n",
    "        self.class_count_ = np.zeros(n_classes)\n",
    "        for i, c in enumerate(self.classes_):\n",
    "            X_c = X[y == c]\n",
    "            self.theta_[i] = np.mean(X_c, axis=0)\n",
    "            self.sigma_[i] = np.var(X_c, axis=0)\n",
    "            self.class_count_[i] = X_c.shape[0]\n",
    "        self.class_prior_ = self.class_count_ / np.sum(self.class_count_)\n",
    "        return self\n",
    "\n",
    "    def _joint_log_likelihood(self, X):\n",
    "        joint_log_likelihood = np.zeros((X.shape[0], len(self.classes_)))\n",
    "        for i in range(len(self.classes_)):\n",
    "            p1 = np.log(self.class_prior_[i])\n",
    "            p2 = -0.5 * np.log(2 * np.pi * self.sigma_[i]) - 0.5 * (X - self.theta_[i]) ** 2 / self.sigma_[i]\n",
    "            joint_log_likelihood[:, i] = p1 + np.sum(p2, axis=1)\n",
    "        return joint_log_likelihood\n",
    "\n",
    "    def predict(self, X):\n",
    "        joint_log_likelihood = self._joint_log_likelihood(X)\n",
    "        return self.classes_[np.argmax(joint_log_likelihood, axis=1)]\n",
    "\n",
    "    def predict_proba(self, X):\n",
    "        joint_log_likelihood = self._joint_log_likelihood(X)\n",
    "        log_prob = joint_log_likelihood - logsumexp(joint_log_likelihood, axis=1)[:, np.newaxis]\n",
    "        return np.exp(log_prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = load_iris(return_X_y = True)\n",
    "clf1 = GaussianNB().fit(X, y)\n",
    "clf2 = skGaussianNB().fit(X, y)\n",
    "assert np.allclose(clf1.theta_, clf2.theta_)\n",
    "assert np.allclose(clf1.sigma_, clf2.sigma_)\n",
    "assert np.allclose(clf1.class_count_, clf2.class_count_)\n",
    "assert np.allclose(clf1.class_prior_, clf2.class_prior_)\n",
    "prob1 = clf1._joint_log_likelihood(X)\n",
    "prob2 = clf2._joint_log_likelihood(X)\n",
    "assert np.allclose(prob1, prob2)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.array_equal(pred1, pred2)\n",
    "prob1 = clf1.predict_proba(X)\n",
    "prob2 = clf2.predict_proba(X)\n",
    "assert np.allclose(prob1, prob2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}