{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.linalg import lstsq\n",
    "from scipy.special import expit\n",
    "from sklearn.datasets import load_breast_cancer, load_iris\n",
    "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as skLinearDiscriminantAnalysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementation 1\n",
    "- similar to scikit-learn solver='lsqr'\n",
    "- reference: Pattern Recognition and Machine Learning Section 4.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearDiscriminantAnalysis():\n",
    "    def fit(self, X, y):\n",
    "        self.classes_ = np.unique(y)\n",
    "        n_features = X.shape[1]\n",
    "        n_classes = len(self.classes_)\n",
    "        self.priors_ = np.zeros(n_classes)\n",
    "        self.means_ = np.zeros((n_classes, n_features))\n",
    "        self.covariance_ = np.zeros((n_features, n_features))\n",
    "        for i, c in enumerate(self.classes_):\n",
    "            X_c = X[y == c]\n",
    "            self.priors_[i] = X_c.shape[0] / X.shape[0]\n",
    "            self.means_[i] = np.mean(X_c, axis=0)\n",
    "            self.covariance_ += self.priors_[i] * np.cov(X_c.T, bias=True)\n",
    "        self.coef_ = lstsq(self.covariance_, self.means_.T)[0].T\n",
    "        self.intercept_ = (-0.5 * np.diag(np.dot(self.means_, self.coef_.T)) +\n",
    "                           np.log(self.priors_))\n",
    "        if len(self.classes_) == 2:\n",
    "            self.coef_ = np.atleast_2d(self.coef_[1] - self.coef_[0])\n",
    "            self.intercept_ = np.atleast_1d(self.intercept_[1] - self.intercept_[0])\n",
    "        return self\n",
    "\n",
    "    def decision_function(self, X):\n",
    "        scores = np.dot(X, self.coef_.T) + self.intercept_\n",
    "        if scores.shape[1] == 1:\n",
    "            return scores.ravel()\n",
    "        else:\n",
    "            return scores\n",
    "\n",
    "    def predict(self, X):\n",
    "        scores = self.decision_function(X)\n",
    "        if len(scores.shape) == 1:\n",
    "            indices = (scores > 0).astype(int)\n",
    "        else:\n",
    "            indices = np.argmax(scores, axis=1)\n",
    "        return self.classes_[indices]\n",
    "\n",
    "    def predict_proba(self, X):\n",
    "        scores = self.decision_function(X)\n",
    "        if len(scores.shape) == 1:\n",
    "            prob = expit(scores)\n",
    "            prob = np.vstack((1 - prob, prob)).T\n",
    "        else:\n",
    "            scores -= np.max(scores, axis=1)[:, np.newaxis]\n",
    "            prob = np.exp(scores)\n",
    "            prob /= np.sum(prob, axis=1)[:, np.newaxis]\n",
    "        return prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = load_breast_cancer(return_X_y=True)\n",
    "clf1 = LinearDiscriminantAnalysis().fit(X, y)\n",
    "clf2 = skLinearDiscriminantAnalysis(solver='lsqr').fit(X, y)\n",
    "assert np.allclose(clf1.priors_, clf2.priors_)\n",
    "assert np.allclose(clf1.means_, clf2.means_)\n",
    "assert np.allclose(clf1.covariance_, clf2.covariance_)\n",
    "assert np.allclose(clf1.coef_, clf2.coef_)\n",
    "assert np.allclose(clf1.intercept_, clf2.intercept_)\n",
    "prob1 = clf1.decision_function(X)\n",
    "prob2 = clf2.decision_function(X)\n",
    "assert np.allclose(prob1, prob2)\n",
    "prob1 = clf1.predict_proba(X)\n",
    "prob2 = clf2.predict_proba(X)\n",
    "assert np.allclose(prob1, prob2)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.array_equal(pred1, pred2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = load_iris(return_X_y=True)\n",
    "clf1 = LinearDiscriminantAnalysis().fit(X, y)\n",
    "clf2 = skLinearDiscriminantAnalysis(solver='lsqr').fit(X, y)\n",
    "assert np.allclose(clf1.priors_, clf2.priors_)\n",
    "assert np.allclose(clf1.means_, clf2.means_)\n",
    "assert np.allclose(clf1.covariance_, clf2.covariance_)\n",
    "assert np.allclose(clf1.coef_, clf2.coef_)\n",
    "assert np.allclose(clf1.intercept_, clf2.intercept_)\n",
    "prob1 = clf1.decision_function(X)\n",
    "prob2 = clf2.decision_function(X)\n",
    "assert np.allclose(prob1, prob2)\n",
    "prob1 = clf1.predict_proba(X)\n",
    "prob2 = clf2.predict_proba(X)\n",
    "assert np.allclose(prob1, prob2)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.array_equal(pred1, pred2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}