{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Hashing feature transformation using Totally Random Trees\n\nRandomTreesEmbedding provides a way to map data to a\nvery high-dimensional, sparse representation, which might\nbe beneficial for classification.\nThe mapping is completely unsupervised and very efficient.\n\nThis example visualizes the partitions given by several\ntrees and shows how the transformation can also be used for\nnon-linear dimensionality reduction or non-linear classification.\n\nPoints that are neighboring often share the same leaf of a tree and therefore\nshare large parts of their hashed representation. This allows to\nseparate two concentric circles simply based on the principal components\nof the transformed data with truncated SVD.\n\nIn high-dimensional spaces, linear classifiers often achieve\nexcellent accuracy. For sparse binary data, BernoulliNB\nis particularly well-suited. The bottom row compares the\ndecision boundary obtained by BernoulliNB in the transformed\nspace with an ExtraTreesClassifier forests learned on the\noriginal data.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom sklearn.datasets import make_circles\nfrom sklearn.decomposition import TruncatedSVD\nfrom sklearn.ensemble import ExtraTreesClassifier, RandomTreesEmbedding\nfrom sklearn.naive_bayes import BernoulliNB\n\n# make a synthetic dataset\nX, y = make_circles(factor=0.5, random_state=0, noise=0.05)\n\n# use RandomTreesEmbedding to transform data\nhasher = RandomTreesEmbedding(n_estimators=10, random_state=0, max_depth=3)\nX_transformed = hasher.fit_transform(X)\n\n# Visualize result after dimensionality reduction using truncated SVD\nsvd = TruncatedSVD(n_components=2)\nX_reduced = svd.fit_transform(X_transformed)\n\n# Learn a Naive Bayes classifier on the transformed data\nnb = BernoulliNB()\nnb.fit(X_transformed, y)\n\n\n# Learn an ExtraTreesClassifier for comparison\ntrees = ExtraTreesClassifier(max_depth=3, n_estimators=10, random_state=0)\ntrees.fit(X, y)\n\n\n# scatter plot of original and reduced data\nfig = plt.figure(figsize=(9, 8))\n\nax = plt.subplot(221)\nax.scatter(X[:, 0], X[:, 1], c=y, s=50, edgecolor=\"k\")\nax.set_title(\"Original Data (2d)\")\nax.set_xticks(())\nax.set_yticks(())\n\nax = plt.subplot(222)\nax.scatter(X_reduced[:, 0], X_reduced[:, 1], c=y, s=50, edgecolor=\"k\")\nax.set_title(\n \"Truncated SVD reduction (2d) of transformed data (%dd)\" % X_transformed.shape[1]\n)\nax.set_xticks(())\nax.set_yticks(())\n\n# Plot the decision in original space. For that, we will assign a color\n# to each point in the mesh [x_min, x_max]x[y_min, y_max].\nh = 0.01\nx_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5\ny_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5\nxx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))\n\n# transform grid using RandomTreesEmbedding\ntransformed_grid = hasher.transform(np.c_[xx.ravel(), yy.ravel()])\ny_grid_pred = nb.predict_proba(transformed_grid)[:, 1]\n\nax = plt.subplot(223)\nax.set_title(\"Naive Bayes on Transformed data\")\nax.pcolormesh(xx, yy, y_grid_pred.reshape(xx.shape))\nax.scatter(X[:, 0], X[:, 1], c=y, s=50, edgecolor=\"k\")\nax.set_ylim(-1.4, 1.4)\nax.set_xlim(-1.4, 1.4)\nax.set_xticks(())\nax.set_yticks(())\n\n# transform grid using ExtraTreesClassifier\ny_grid_pred = trees.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]\n\nax = plt.subplot(224)\nax.set_title(\"ExtraTrees predictions\")\nax.pcolormesh(xx, yy, y_grid_pred.reshape(xx.shape))\nax.scatter(X[:, 0], X[:, 1], c=y, s=50, edgecolor=\"k\")\nax.set_ylim(-1.4, 1.4)\nax.set_xlim(-1.4, 1.4)\nax.set_xticks(())\nax.set_yticks(())\n\nplt.tight_layout()\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }