Submit
Path:
~
/
/
opt
/
alt
/
python35
/
share
/
doc
/
alt-python35-scikit-learn-0.18.1
/
examples
/
linear_model
/
File Content:
plot_robust_fit.py
""" Robust linear estimator fitting =============================== Here a sine function is fit with a polynomial of order 3, for values close to zero. Robust fitting is demoed in different situations: - No measurement errors, only modelling errors (fitting a sine with a polynomial) - Measurement errors in X - Measurement errors in y The median absolute deviation to non corrupt new data is used to judge the quality of the prediction. What we can see that: - RANSAC is good for strong outliers in the y direction - TheilSen is good for small outliers, both in direction X and y, but has a break point above which it performs worse than OLS. - The scores of HuberRegressor may not be compared directly to both TheilSen and RANSAC because it does not attempt to completely filter the outliers but lessen their effect. """ from matplotlib import pyplot as plt import numpy as np from sklearn.linear_model import ( LinearRegression, TheilSenRegressor, RANSACRegressor, HuberRegressor) from sklearn.metrics import mean_squared_error from sklearn.preprocessing import PolynomialFeatures from sklearn.pipeline import make_pipeline np.random.seed(42) X = np.random.normal(size=400) y = np.sin(X) # Make sure that it X is 2D X = X[:, np.newaxis] X_test = np.random.normal(size=200) y_test = np.sin(X_test) X_test = X_test[:, np.newaxis] y_errors = y.copy() y_errors[::3] = 3 X_errors = X.copy() X_errors[::3] = 3 y_errors_large = y.copy() y_errors_large[::3] = 10 X_errors_large = X.copy() X_errors_large[::3] = 10 estimators = [('OLS', LinearRegression()), ('Theil-Sen', TheilSenRegressor(random_state=42)), ('RANSAC', RANSACRegressor(random_state=42)), ('HuberRegressor', HuberRegressor())] colors = {'OLS': 'turquoise', 'Theil-Sen': 'gold', 'RANSAC': 'lightgreen', 'HuberRegressor': 'black'} linestyle = {'OLS': '-', 'Theil-Sen': '-.', 'RANSAC': '--', 'HuberRegressor': '--'} lw = 3 x_plot = np.linspace(X.min(), X.max()) for title, this_X, this_y in [ ('Modeling Errors Only', X, y), ('Corrupt X, Small Deviants', X_errors, y), ('Corrupt y, Small Deviants', X, y_errors), ('Corrupt X, Large Deviants', X_errors_large, y), ('Corrupt y, Large Deviants', X, y_errors_large)]: plt.figure(figsize=(5, 4)) plt.plot(this_X[:, 0], this_y, 'b+') for name, estimator in estimators: model = make_pipeline(PolynomialFeatures(3), estimator) model.fit(this_X, this_y) mse = mean_squared_error(model.predict(X_test), y_test) y_plot = model.predict(x_plot[:, np.newaxis]) plt.plot(x_plot, y_plot, color=colors[name], linestyle=linestyle[name], linewidth=lw, label='%s: error = %.3f' % (name, mse)) legend_title = 'Error of Mean\nAbsolute Deviation\nto Non-corrupt Data' legend = plt.legend(loc='upper right', frameon=False, title=legend_title, prop=dict(size='x-small')) plt.xlim(-4, 10.2) plt.ylim(-2, 10.2) plt.title(title) plt.show()
Submit
FILE
FOLDER
Name
Size
Permission
Action
README.txt
135 bytes
0644
lasso_dense_vs_sparse_data.py
1862 bytes
0644
plot_ard.py
2828 bytes
0644
plot_bayesian_ridge.py
2733 bytes
0644
plot_huber_vs_ridge.py
2206 bytes
0644
plot_iris_logistic.py
1679 bytes
0644
plot_lasso_and_elasticnet.py
2074 bytes
0644
plot_lasso_coordinate_descent_path.py
2945 bytes
0644
plot_lasso_lars.py
1080 bytes
0644
plot_lasso_model_selection.py
5431 bytes
0644
plot_logistic.py
1568 bytes
0644
plot_logistic_l1_l2_sparsity.py
2601 bytes
0644
plot_logistic_multinomial.py
2480 bytes
0644
plot_logistic_path.py
1195 bytes
0644
plot_multi_task_lasso_support.py
2319 bytes
0644
plot_ols.py
1936 bytes
0644
plot_ols_3d.py
2040 bytes
0644
plot_ols_ridge_variance.py
2060 bytes
0644
plot_omp.py
2263 bytes
0644
plot_polynomial_interpolation.py
2088 bytes
0644
plot_ransac.py
1859 bytes
0644
plot_ridge_coeffs.py
2785 bytes
0644
plot_ridge_path.py
2138 bytes
0644
plot_robust_fit.py
3050 bytes
0644
plot_sgd_comparison.py
1819 bytes
0644
plot_sgd_iris.py
2202 bytes
0644
plot_sgd_loss_functions.py
1232 bytes
0644
plot_sgd_penalties.py
1877 bytes
0644
plot_sgd_separating_hyperplane.py
1221 bytes
0644
plot_sgd_weighted_samples.py
1458 bytes
0644
plot_sparse_recovery.py
7486 bytes
0644
plot_theilsen.py
3846 bytes
0644
N4ST4R_ID | Naxtarrr