{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from scipy.stats.mstats import mode\n", "from sklearn.datasets import load_iris\n", "from sklearn.impute import SimpleImputer as skSimpleImputer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class SimpleImputer():\n", " def __init__(self, strategy='mean', fill_value=None):\n", " self.strategy = strategy\n", " self.fill_value = fill_value # only used when strategy == 'constant'\n", "\n", " def fit(self, X):\n", " mask = np.isnan(X)\n", " masked_X = np.ma.masked_array(X, mask=mask)\n", " if self.strategy == \"mean\":\n", " self.statistics_ = np.array(np.ma.mean(masked_X, axis=0))\n", " elif self.strategy == \"median\":\n", " self.statistics_ = np.array(np.ma.median(masked_X, axis=0))\n", " elif self.strategy == \"most_frequent\":\n", " self.statistics_ = np.array(mode(masked_X, axis=0)[0])\n", " elif self.strategy == \"constant\":\n", " self.statistics_ = np.full(X.shape[1], self.fill_value)\n", " return self\n", "\n", " def transform(self, X):\n", " mask = np.isnan(X)\n", " n_missing = np.sum(mask, axis=0)\n", " values = np.repeat(self.statistics_, n_missing)\n", " coordinates = np.where(mask.T)[::-1]\n", " Xt = X.copy()\n", " Xt[coordinates] = values\n", " return Xt" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "X, _ = load_iris(return_X_y=True)\n", "rng = np.random.RandomState(0)\n", "missing_samples = np.arange(X.shape[0])\n", "missing_features = rng.choice(X.shape[1], X.shape[0])\n", "X[missing_samples, missing_features] = np.nan" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "est1 = SimpleImputer(strategy=\"mean\").fit(X)\n", "est2 = skSimpleImputer(strategy=\"mean\").fit(X)\n", "assert np.allclose(est1.statistics_, est2.statistics_)\n", "Xt1 = est1.transform(X)\n", "Xt2 = est2.transform(X)\n", "assert np.allclose(Xt1, Xt2)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "est1 = SimpleImputer(strategy=\"median\").fit(X)\n", "est2 = skSimpleImputer(strategy=\"median\").fit(X)\n", "assert np.allclose(est1.statistics_, est2.statistics_)\n", "Xt1 = est1.transform(X)\n", "Xt2 = est2.transform(X)\n", "assert np.allclose(Xt1, Xt2)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "est1 = SimpleImputer(strategy=\"most_frequent\").fit(X)\n", "est2 = skSimpleImputer(strategy=\"most_frequent\").fit(X)\n", "assert np.allclose(est1.statistics_, est2.statistics_)\n", "Xt1 = est1.transform(X)\n", "Xt2 = est2.transform(X)\n", "assert np.allclose(Xt1, Xt2)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "est1 = SimpleImputer(strategy=\"constant\", fill_value=0).fit(X)\n", "est2 = skSimpleImputer(strategy=\"constant\", fill_value=0).fit(X)\n", "assert np.allclose(est1.statistics_, est2.statistics_)\n", "Xt1 = est1.transform(X)\n", "Xt2 = est2.transform(X)\n", "assert np.allclose(Xt1, Xt2)" ] } ], "metadata": { "kernelspec": { "display_name": "dev", "language": "python", "name": "dev" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }