{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.preprocessing import KernelCenterer as skKernelCenterer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\\tilde{K}_{ij} = (\\varphi(x_i) - \\frac{1}{n}\\sum_{r=1}^n{\\varphi(x_r)})^T(\\varphi(x_j) - \\frac{1}{n}\\sum_{r=1}^n{\\varphi(x_r)})$$\n",
    "$$= \\varphi(x_i)^T\\varphi(x_j) - \\frac{1}{n}\\sum_{r=1}^n{\\varphi(x_i)^T\\varphi(x_r)}\n",
    "- \\frac{1}{n}\\sum_{r=1}^n{\\varphi(x_r)^T\\varphi(x_j)} + \\frac{1}{n^2}\\sum_{r,s=1}^n{\\varphi(x_r)^T\\varphi(x_s)}$$\n",
    "$$= K_{ij} - \\frac{1}{n}\\sum_{r=1}^n{K_{ir}} - \\frac{1}{n}\\sum_{r=1}^n{K_{rj}} + \\frac{1}{n^2}\\sum_{r,s=1}^n{K_{rs}}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class KernelCenterer():\n",
    "    def fit(self, K):\n",
    "        n_samples = K.shape[0]\n",
    "        self.K_fit_rows_ = np.sum(K, axis=0) / n_samples\n",
    "        self.K_fit_all_ = self.K_fit_rows_.sum() / n_samples\n",
    "        return self\n",
    "\n",
    "    def transform(self, K):\n",
    "        Kt = (K - (np.sum(K, axis=1) / K.shape[1])[:, np.newaxis]\n",
    "              - self.K_fit_rows_ + self.K_fit_all_)\n",
    "        return Kt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linear_kernel(X, Y=None):\n",
    "    if Y is None:\n",
    "        Y = X\n",
    "    K = np.dot(X, Y.T)\n",
    "    return K\n",
    "\n",
    "def rbf_kernel(X, Y=None, gamma=None):\n",
    "    if Y is None:\n",
    "        Y = X\n",
    "    if gamma is None:\n",
    "        gamma = 1 / X.shape[1]\n",
    "    K = np.zeros((X.shape[0], Y.shape[0]))\n",
    "    for i in range(X.shape[0]):\n",
    "        for j in range(Y.shape[0]):\n",
    "            K[i, j] = np.exp(-gamma * np.sum(np.square(X[i] - Y[j])))\n",
    "    return K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# definition\n",
    "X, _ = load_iris(return_X_y=True)\n",
    "X_train, X_test = X[:100], X[100:]\n",
    "X_mean = np.mean(X_train, axis=0)\n",
    "Xt_train = X_train - X_mean\n",
    "Xt_test = X_test - X_mean\n",
    "K_train = linear_kernel(X_train)\n",
    "K_test = linear_kernel(X_test, X_train)\n",
    "Kt_train = linear_kernel(Xt_train)\n",
    "Kt_test = linear_kernel(Xt_test, Xt_train)\n",
    "trans = KernelCenterer().fit(K_train)\n",
    "assert np.allclose(trans.transform(K_train), Kt_train)\n",
    "assert np.allclose(trans.transform(K_test), Kt_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# linear kernel\n",
    "X, _ = load_iris(return_X_y=True)\n",
    "X_train, X_test = X[:100], X[100:]\n",
    "K_train = linear_kernel(X_train)\n",
    "K_test = linear_kernel(X_test, X_train)\n",
    "trans1 = KernelCenterer().fit(K_train)\n",
    "trans2 = skKernelCenterer().fit(K_train)\n",
    "Kt1 = trans1.transform(K_train)\n",
    "Kt2 = trans2.transform(K_train)\n",
    "assert np.allclose(Kt1, Kt2)\n",
    "Kt1 = trans1.transform(K_test)\n",
    "Kt2 = trans2.transform(K_test)\n",
    "assert np.allclose(Kt1, Kt2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# rbf kernel\n",
    "X, _ = load_iris(return_X_y=True)\n",
    "X_train, X_test = X[:100], X[100:]\n",
    "K_train = rbf_kernel(X_train)\n",
    "K_test = rbf_kernel(X_test, X_train)\n",
    "trans1 = KernelCenterer().fit(K_train)\n",
    "trans2 = skKernelCenterer().fit(K_train)\n",
    "Kt1 = trans1.transform(K_train)\n",
    "Kt2 = trans2.transform(K_train)\n",
    "assert np.allclose(Kt1, Kt2)\n",
    "Kt1 = trans1.transform(K_test)\n",
    "Kt2 = trans2.transform(K_test)\n",
    "assert np.allclose(Kt1, Kt2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}