{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Robust linear model estimation using RANSAC\n\nIn this example, we see how to robustly fit a linear model to faulty data using\nthe `RANSAC ` algorithm.\n\nThe ordinary linear regressor is sensitive to outliers, and the fitted line can\neasily be skewed away from the true underlying relationship of data.\n\nThe RANSAC regressor automatically splits the data into inliers and outliers,\nand the fitted line is determined only by the identified inliers.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport numpy as np\nfrom matplotlib import pyplot as plt\n\nfrom sklearn import datasets, linear_model\n\nn_samples = 1000\nn_outliers = 50\n\n\nX, y, coef = datasets.make_regression(\n n_samples=n_samples,\n n_features=1,\n n_informative=1,\n noise=10,\n coef=True,\n random_state=0,\n)\n\n# Add outlier data\nnp.random.seed(0)\nX[:n_outliers] = 3 + 0.5 * np.random.normal(size=(n_outliers, 1))\ny[:n_outliers] = -3 + 10 * np.random.normal(size=n_outliers)\n\n# Fit line using all data\nlr = linear_model.LinearRegression()\nlr.fit(X, y)\n\n# Robustly fit linear model with RANSAC algorithm\nransac = linear_model.RANSACRegressor()\nransac.fit(X, y)\ninlier_mask = ransac.inlier_mask_\noutlier_mask = np.logical_not(inlier_mask)\n\n# Predict data of estimated models\nline_X = np.arange(X.min(), X.max())[:, np.newaxis]\nline_y = lr.predict(line_X)\nline_y_ransac = ransac.predict(line_X)\n\n# Compare estimated coefficients\nprint(\"Estimated coefficients (true, linear regression, RANSAC):\")\nprint(coef, lr.coef_, ransac.estimator_.coef_)\n\nlw = 2\nplt.scatter(\n X[inlier_mask], y[inlier_mask], color=\"yellowgreen\", marker=\".\", label=\"Inliers\"\n)\nplt.scatter(\n X[outlier_mask], y[outlier_mask], color=\"gold\", marker=\".\", label=\"Outliers\"\n)\nplt.plot(line_X, line_y, color=\"navy\", linewidth=lw, label=\"Linear regressor\")\nplt.plot(\n line_X,\n line_y_ransac,\n color=\"cornflowerblue\",\n linewidth=lw,\n label=\"RANSAC regressor\",\n)\nplt.legend(loc=\"lower right\")\nplt.xlabel(\"Input\")\nplt.ylabel(\"Response\")\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }