{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.linear_model import SGDClassifier as skSGDClassifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementation 1\n",
    "- scikit-learn loss = \"hinge\", penalty=\"l2\"/\"none\"\n",
    "- similar to sklearn.svm.LinearSVC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _loss(x, y, coef, intercept):\n",
    "    p = np.dot(x, coef) + intercept\n",
    "    z = p * y\n",
    "    if z <= 1:\n",
    "        return 1 - z\n",
    "    else:\n",
    "        return 0\n",
    "\n",
    "def _grad(x, y, coef, intercept):\n",
    "    p = np.dot(x, coef) + intercept\n",
    "    z = p * y\n",
    "    if z <= 1:\n",
    "        dloss = -y\n",
    "    else:\n",
    "        dloss = 0\n",
    "    # clip gradient (consistent with scikit-learn)\n",
    "    dloss = np.clip(dloss, -1e12, 1e12)\n",
    "    coef_grad = dloss * x\n",
    "    intercept_grad = dloss\n",
    "    return coef_grad, intercept_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SGDClassifier():\n",
    "    def __init__(self, penalty=\"l2\", alpha=0.0001, max_iter=1000, tol=1e-3,\n",
    "                 shuffle=True, random_state=0,\n",
    "                 # use learning_rate = 'invscaling' for simplicity\n",
    "                 eta0=0, power_t=0.5, n_iter_no_change=5):\n",
    "        self.penalty = penalty\n",
    "        self.alpha = alpha\n",
    "        self.max_iter = max_iter\n",
    "        self.tol = tol\n",
    "        self.shuffle = shuffle\n",
    "        self.random_state = random_state\n",
    "        self.eta0 = eta0\n",
    "        self.power_t = power_t\n",
    "        self.n_iter_no_change = n_iter_no_change\n",
    "\n",
    "    def _encode(self, y):\n",
    "        classes = np.unique(y)\n",
    "        y_train = np.full((y.shape[0], len(classes)), -1)\n",
    "        for i, c in enumerate(classes):\n",
    "            y_train[y == c, i] = 1\n",
    "        if len(classes) == 2:\n",
    "            y_train = y_train[:, 1].reshape(-1, 1)\n",
    "        return classes, y_train\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        self.classes_, y_train = self._encode(y)\n",
    "        if len(self.classes_) == 2:\n",
    "            coef = np.zeros((1, X.shape[1]))\n",
    "            intercept = np.zeros(1)\n",
    "        else:\n",
    "            coef = np.zeros((len(self.classes_), X.shape[1]))\n",
    "            intercept = np.zeros(len(self.classes_))\n",
    "        n_iter = 0\n",
    "        rng = np.random.RandomState(self.random_state)\n",
    "        for class_ind in range(y_train.shape[1]):\n",
    "            cur_y = y_train[:, class_ind]\n",
    "            cur_coef = np.zeros(X.shape[1])\n",
    "            cur_intercept = 0\n",
    "            best_loss = np.inf\n",
    "            no_improvement_count = 0\n",
    "            t = 1\n",
    "            for epoch in range(self.max_iter):\n",
    "                # different from how data is shuffled in scikit-learn\n",
    "                if self.shuffle:\n",
    "                    ind = rng.permutation(X.shape[0])\n",
    "                    X, cur_y = X[ind], cur_y[ind]\n",
    "                sumloss = 0\n",
    "                for i in range(X.shape[0]):\n",
    "                    sumloss += _loss(X[i], cur_y[i], cur_coef, cur_intercept)\n",
    "                    eta = self.eta0 / np.power(t, self.power_t)\n",
    "                    coef_grad, intercept_grad = _grad(X[i], cur_y[i], cur_coef, cur_intercept)\n",
    "                    if self.penalty == \"l2\":\n",
    "                        cur_coef *= 1 - eta * self.alpha\n",
    "                    cur_coef -= eta * coef_grad\n",
    "                    cur_intercept -= eta * intercept_grad\n",
    "                    t += 1\n",
    "                if sumloss > best_loss - self.tol * X.shape[0]:\n",
    "                    no_improvement_count += 1\n",
    "                else:\n",
    "                    no_improvement_count = 0\n",
    "                if no_improvement_count == self.n_iter_no_change:\n",
    "                    break\n",
    "                if sumloss < best_loss:\n",
    "                    best_loss = sumloss\n",
    "            coef[class_ind] = cur_coef\n",
    "            intercept[class_ind] = cur_intercept\n",
    "            n_iter = max(n_iter, epoch + 1)\n",
    "        self.coef_ = coef\n",
    "        self.intercept_ = intercept\n",
    "        self.n_iter_ = n_iter\n",
    "        return self\n",
    "\n",
    "    def decision_function(self, X):\n",
    "        scores = np.dot(X, self.coef_.T) + self.intercept_\n",
    "        if scores.shape[1] == 1:\n",
    "            return scores.ravel()\n",
    "        else:\n",
    "            return scores\n",
    "\n",
    "    def predict(self, X):\n",
    "        scores = self.decision_function(X)\n",
    "        if len(scores.shape) == 1:\n",
    "            indices = (scores > 0).astype(int)\n",
    "        else:\n",
    "            indices = np.argmax(scores, axis=1)\n",
    "        return self.classes_[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# binary classification\n",
    "X, y = load_iris(return_X_y=True)\n",
    "X, y = X[y != 2], y[y != 2]\n",
    "X = StandardScaler().fit_transform(X)\n",
    "clf1 = SGDClassifier(eta0=0.1, shuffle=False).fit(X, y)\n",
    "clf2 = skSGDClassifier(learning_rate='invscaling', eta0=0.1, shuffle=False).fit(X, y)\n",
    "assert np.allclose(clf1.coef_, clf2.coef_)\n",
    "assert np.allclose(clf1.intercept_, clf2.intercept_)\n",
    "prob1 = clf1.decision_function(X)\n",
    "prob2 = clf2.decision_function(X)\n",
    "assert np.allclose(prob1, prob2)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.array_equal(pred1, pred2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shuffle=False penalty=\"none\"\n",
    "X, y = load_iris(return_X_y=True)\n",
    "X = StandardScaler().fit_transform(X)\n",
    "clf1 = SGDClassifier(eta0=0.1, shuffle=False).fit(X, y)\n",
    "clf2 = skSGDClassifier(learning_rate='invscaling', eta0=0.1, shuffle=False).fit(X, y)\n",
    "assert np.allclose(clf1.coef_, clf2.coef_)\n",
    "assert np.allclose(clf1.intercept_, clf2.intercept_)\n",
    "prob1 = clf1.decision_function(X)\n",
    "prob2 = clf2.decision_function(X)\n",
    "assert np.allclose(prob1, prob2)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.array_equal(pred1, pred2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shuffle=False penalty=\"l2\"\n",
    "for alpha in [0.1, 1, 10]:\n",
    "    X, y = load_iris(return_X_y=True)\n",
    "    X = StandardScaler().fit_transform(X)\n",
    "    clf1 = SGDClassifier(eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)\n",
    "    clf2 = skSGDClassifier(learning_rate='invscaling', eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)\n",
    "    assert np.allclose(clf1.coef_, clf2.coef_)\n",
    "    assert np.allclose(clf1.intercept_, clf2.intercept_)\n",
    "    prob1 = clf1.decision_function(X)\n",
    "    prob2 = clf2.decision_function(X)\n",
    "    assert np.allclose(prob1, prob2)\n",
    "    pred1 = clf1.predict(X)\n",
    "    pred2 = clf2.predict(X)\n",
    "    assert np.array_equal(pred1, pred2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementation 2\n",
    "- scikit-learn loss = \"squared_hinge\", penalty=\"l2\"/\"none\"\n",
    "- similar to sklearn.svm.LinearSVC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _loss(x, y, coef, intercept):\n",
    "    p = np.dot(x, coef) + intercept\n",
    "    z = 1 - p * y\n",
    "    if z > 0:\n",
    "        return z * z\n",
    "    else:\n",
    "        return 0\n",
    "\n",
    "def _grad(x, y, coef, intercept):\n",
    "    p = np.dot(x, coef) + intercept\n",
    "    z = 1 - p * y\n",
    "    if z > 0:\n",
    "        dloss = -2 * y * z\n",
    "    else:\n",
    "        dloss = 0\n",
    "    # clip gradient (consistent with scikit-learn)\n",
    "    dloss = np.clip(dloss, -1e12, 1e12)\n",
    "    coef_grad = dloss * x\n",
    "    intercept_grad = dloss\n",
    "    return coef_grad, intercept_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shuffle=False penalty=\"none\"\n",
    "X, y = load_iris(return_X_y=True)\n",
    "X = StandardScaler().fit_transform(X)\n",
    "clf1 = SGDClassifier(eta0=0.1, shuffle=False).fit(X, y)\n",
    "clf2 = skSGDClassifier(loss=\"squared_hinge\", learning_rate='invscaling', eta0=0.1, shuffle=False).fit(X, y)\n",
    "assert np.allclose(clf1.coef_, clf2.coef_)\n",
    "assert np.allclose(clf1.intercept_, clf2.intercept_)\n",
    "prob1 = clf1.decision_function(X)\n",
    "prob2 = clf2.decision_function(X)\n",
    "assert np.allclose(prob1, prob2)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.array_equal(pred1, pred2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shuffle=False penalty=\"l2\"\n",
    "for alpha in [0.1, 1, 10]:\n",
    "    X, y = load_iris(return_X_y=True)\n",
    "    X = StandardScaler().fit_transform(X)\n",
    "    clf1 = SGDClassifier(eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)\n",
    "    clf2 = skSGDClassifier(loss=\"squared_hinge\", learning_rate='invscaling', eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)\n",
    "    assert np.allclose(clf1.coef_, clf2.coef_)\n",
    "    assert np.allclose(clf1.intercept_, clf2.intercept_)\n",
    "    prob1 = clf1.decision_function(X)\n",
    "    prob2 = clf2.decision_function(X)\n",
    "    assert np.allclose(prob1, prob2)\n",
    "    pred1 = clf1.predict(X)\n",
    "    pred2 = clf2.predict(X)\n",
    "    assert np.array_equal(pred1, pred2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementation 3\n",
    "- scikit-learn loss = \"modified_huber\", penalty=\"l2\"/\"none\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _loss(x, y, coef, intercept):\n",
    "    p = np.dot(x, coef) + intercept\n",
    "    z = p * y\n",
    "    if z > 1:\n",
    "        return 0\n",
    "    elif z > -1:\n",
    "        return (1 - z) * (1 - z)\n",
    "    else:\n",
    "        return -4 * z\n",
    "\n",
    "def _grad(x, y, coef, intercept):\n",
    "    p = np.dot(x, coef) + intercept\n",
    "    z = p * y\n",
    "    if z > 1:\n",
    "        dloss = 0\n",
    "    elif z > -1:\n",
    "        dloss = -2 * (1 - z) * y\n",
    "    else:\n",
    "        dloss = -4 * y\n",
    "    # clip gradient (consistent with scikit-learn)\n",
    "    dloss = np.clip(dloss, -1e12, 1e12)\n",
    "    coef_grad = dloss * x\n",
    "    intercept_grad = dloss\n",
    "    return coef_grad, intercept_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shuffle=False penalty=\"none\"\n",
    "X, y = load_iris(return_X_y=True)\n",
    "X = StandardScaler().fit_transform(X)\n",
    "clf1 = SGDClassifier(eta0=0.1, shuffle=False).fit(X, y)\n",
    "clf2 = skSGDClassifier(loss=\"modified_huber\", learning_rate='invscaling', eta0=0.1, shuffle=False).fit(X, y)\n",
    "assert np.allclose(clf1.coef_, clf2.coef_)\n",
    "assert np.allclose(clf1.intercept_, clf2.intercept_)\n",
    "prob1 = clf1.decision_function(X)\n",
    "prob2 = clf2.decision_function(X)\n",
    "assert np.allclose(prob1, prob2)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.array_equal(pred1, pred2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shuffle=False penalty=\"l2\"\n",
    "for alpha in [0.1, 1, 10]:\n",
    "    X, y = load_iris(return_X_y=True)\n",
    "    X = StandardScaler().fit_transform(X)\n",
    "    clf1 = SGDClassifier(eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)\n",
    "    clf2 = skSGDClassifier(loss=\"modified_huber\", learning_rate='invscaling', eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)\n",
    "    assert np.allclose(clf1.coef_, clf2.coef_)\n",
    "    assert np.allclose(clf1.intercept_, clf2.intercept_)\n",
    "    prob1 = clf1.decision_function(X)\n",
    "    prob2 = clf2.decision_function(X)\n",
    "    assert np.allclose(prob1, prob2)\n",
    "    pred1 = clf1.predict(X)\n",
    "    pred2 = clf2.predict(X)\n",
    "    assert np.array_equal(pred1, pred2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementation 4\n",
    "- scikit-learn loss = \"log\", penalty=\"l2\"/\"none\"\n",
    "- similar to sklearn.linear_model.LogisticRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _loss(x, y, coef, intercept):\n",
    "    p = np.dot(x, coef) + intercept\n",
    "    z = p * y\n",
    "    # follow scikit-learn\n",
    "    if z > 18:\n",
    "        return np.exp(-z)\n",
    "    elif z < -18:\n",
    "        return -z\n",
    "    else:\n",
    "        return np.log(1 + np.exp(-z))\n",
    "\n",
    "def _grad(x, y, coef, intercept):\n",
    "    p = np.dot(x, coef) + intercept\n",
    "    z = p * y\n",
    "    if z > 18:\n",
    "        dloss = -np.exp(-z) * y\n",
    "    elif z < -18:\n",
    "        dloss =  -y\n",
    "    else:\n",
    "        dloss = -y / (1 + np.exp(z))\n",
    "    # clip gradient (consistent with scikit-learn)\n",
    "    dloss = np.clip(dloss, -1e12, 1e12)\n",
    "    coef_grad = dloss * x\n",
    "    intercept_grad = dloss\n",
    "    return coef_grad, intercept_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shuffle=False penalty=\"none\"\n",
    "X, y = load_iris(return_X_y=True)\n",
    "X = StandardScaler().fit_transform(X)\n",
    "clf1 = SGDClassifier(eta0=0.1, shuffle=False).fit(X, y)\n",
    "clf2 = skSGDClassifier(loss=\"log\", learning_rate='invscaling', eta0=0.1, shuffle=False).fit(X, y)\n",
    "assert np.allclose(clf1.coef_, clf2.coef_)\n",
    "assert np.allclose(clf1.intercept_, clf2.intercept_)\n",
    "prob1 = clf1.decision_function(X)\n",
    "prob2 = clf2.decision_function(X)\n",
    "assert np.allclose(prob1, prob2)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.array_equal(pred1, pred2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shuffle=False penalty=\"l2\"\n",
    "for alpha in [0.1, 1, 10]:\n",
    "    X, y = load_iris(return_X_y=True)\n",
    "    X = StandardScaler().fit_transform(X)\n",
    "    clf1 = SGDClassifier(eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)\n",
    "    clf2 = skSGDClassifier(loss=\"log\", learning_rate='invscaling', eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)\n",
    "    assert np.allclose(clf1.coef_, clf2.coef_)\n",
    "    assert np.allclose(clf1.intercept_, clf2.intercept_)\n",
    "    prob1 = clf1.decision_function(X)\n",
    "    prob2 = clf2.decision_function(X)\n",
    "    assert np.allclose(prob1, prob2)\n",
    "    pred1 = clf1.predict(X)\n",
    "    pred2 = clf2.predict(X)\n",
    "    assert np.array_equal(pred1, pred2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}