{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_boston\n",
    "from sklearn.dummy import DummyRegressor as skDummyRegressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DummyRegressor():\n",
    "    def __init__(self, strategy=\"mean\", constant=None, quantile=None):\n",
    "        self.strategy = strategy\n",
    "        self.constant = constant\n",
    "        self.quantile = quantile\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        if self.strategy == \"mean\":\n",
    "            self.constant_ = np.mean(y)\n",
    "        elif self.strategy == \"median\":\n",
    "            self.constant_ = np.median(y)\n",
    "        elif self.strategy == \"quantile\":\n",
    "            self.constant_ = np.quantile(y, quantile);\n",
    "        elif self.strategy == \"constant\":\n",
    "            self.constant_ = self.constant\n",
    "        # keep consistent with scikit-learn\n",
    "        self.constant_ = np.reshape(self.constant_, (1, -1))\n",
    "        return self\n",
    "\n",
    "    def predict(self, X):\n",
    "        return np.full(X.shape[0], self.constant_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = load_boston(return_X_y=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf1 = DummyRegressor(strategy=\"mean\").fit(X, y)\n",
    "clf2 = skDummyRegressor(strategy=\"mean\").fit(X, y)\n",
    "assert np.allclose(clf1.constant_, clf2.constant_)\n",
    "pred1 = clf2.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.allclose(pred1, pred2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf1 = DummyRegressor(strategy=\"median\").fit(X, y)\n",
    "clf2 = skDummyRegressor(strategy=\"median\").fit(X, y)\n",
    "assert np.allclose(clf1.constant_, clf2.constant_)\n",
    "pred1 = clf2.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.allclose(pred1, pred2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf1 = DummyRegressor(strategy=\"constant\", constant=0).fit(X, y)\n",
    "clf2 = skDummyRegressor(strategy=\"constant\", constant=0).fit(X, y)\n",
    "assert np.allclose(clf1.constant_, clf2.constant_)\n",
    "pred1 = clf2.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.allclose(pred1, pred2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "for quantile in [0.25, 0.75]:\n",
    "    clf1 = DummyRegressor(strategy=\"quantile\", quantile=quantile).fit(X, y)\n",
    "    clf2 = skDummyRegressor(strategy=\"quantile\", quantile=quantile).fit(X, y)\n",
    "    assert np.allclose(clf1.constant_, clf2.constant_)\n",
    "    pred1 = clf2.predict(X)\n",
    "    pred2 = clf2.predict(X)\n",
    "    assert np.allclose(pred1, pred2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}