{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from scipy.sparse.linalg import svds\n", "from sklearn.datasets import load_iris\n", "from sklearn.decomposition import TruncatedSVD as skTruncatedSVD" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class TruncatedSVD():\n", " def __init__(self, n_components=2):\n", " self.n_components = n_components\n", "\n", " def fit(self, X):\n", " U, Sigma, VT = svds(X, k=self.n_components)\n", " U, Sigma, VT = U[:, ::-1], Sigma[::-1], VT[::-1]\n", " self.components_ = VT\n", " self.singular_values_ = Sigma\n", " self.explained_variance_ = np.var(U * Sigma, axis=0)\n", " self.explained_variance_ratio_ = self.explained_variance_ / np.var(X, axis=0).sum()\n", " return self\n", "\n", " def transform(self, X):\n", " return np.dot(X, self.components_.T)\n", "\n", " def inverse_transform(self, X):\n", " return np.dot(X, self.components_)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "X, _ = load_iris(return_X_y=True)\n", "svd1 = TruncatedSVD().fit(X)\n", "Xt1 = svd1.transform(X)\n", "Xinv1 = svd1.inverse_transform(Xt1)\n", "svd2 = skTruncatedSVD(n_components=2).fit(X)\n", "Xt2 = svd2.transform(X)\n", "Xinv2 = svd2.inverse_transform(Xt2)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "for i in range(svd1.components_.shape[0]):\n", " assert np.allclose(svd1.components_[i], svd2.components_[i]) or np.allclose(svd1.components_[i], -svd2.components_[i])\n", "assert np.allclose(svd1.singular_values_, svd2.singular_values_)\n", "assert np.allclose(svd1.explained_variance_, svd2.explained_variance_)\n", "assert np.allclose(svd1.explained_variance_ratio_, svd2.explained_variance_ratio_)\n", "for i in range(Xt1.shape[1]):\n", " assert np.allclose(Xt1[:, i], Xt2[:, i]) or np.allclose(Xt1[:, i], -Xt2[:, i])\n", "assert np.allclose(Xinv1, Xinv2)" ] } ], "metadata": { "kernelspec": { "display_name": "dev", "language": "python", "name": "dev" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }