.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/model_selection.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_model_selection.py: Model Selection =============== This example demonstrates a model selection plot using cross validation. Code has been adapted from the `machinelearningmastery example `_ .. GENERATED FROM PYTHON SOURCE LINES 11-30 .. code-block:: default import logging from typing import Dict import numpy as np import pandas import pandas as pd import plotly from sklearn.datasets import load_diabetes from sklearn.linear_model import LinearRegression from sklearn.pipeline import make_pipeline, Pipeline from sklearn.preprocessing import StandardScaler from elphick.sklearn_viz.model_selection import ModelSelection, plot_model_selection, metrics from elphick.sklearn_viz.model_selection.models import Models logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(module)s - %(funcName)s: %(message)s', datefmt='%Y-%m-%dT%H:%M:%S%z') .. GENERATED FROM PYTHON SOURCE LINES 31-35 Load Data --------- Once loaded we'll create the train-test split for a classification problem. .. GENERATED FROM PYTHON SOURCE LINES 35-44 .. code-block:: default url = "https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv" names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class'] dataframe = pandas.read_csv(url, names=names) array = dataframe.values x = pd.DataFrame(array[:, 0:8], columns=names[0:8]) y = pd.Series(array[:, 8], name=names[8]) xy: pd.DataFrame = pd.concat([x, y], axis=1) .. GENERATED FROM PYTHON SOURCE LINES 45-49 Instantiate ----------- Create an optional pre-processor as a sklearn Pipeline. .. GENERATED FROM PYTHON SOURCE LINES 49-55 .. code-block:: default np.random.seed(1234) pp: Pipeline = make_pipeline(StandardScaler()) models_to_test: Dict = Models().fast_classifiers() pp .. raw:: html
Pipeline(steps=[('standardscaler', StandardScaler())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 56-61 Plot using the function ----------------------- The box colors are scaled to provide a relative indication of performance based on the score (Kudos to `Shah Newaz Khan `_) .. GENERATED FROM PYTHON SOURCE LINES 61-66 .. code-block:: default fig = plot_model_selection(estimators=models_to_test, datasets=xy, target='class', pre_processor=pp) fig.update_layout(height=600) fig .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 67-74 Plot using the object --------------------- The alternative to using the function is to instantiate a ModelSelection object. This has the advantage of persisting the data, which provides greater flexibility and faster re-plotting. If metrics as provided additional subplots are provided - however since metrics have no concept of "greater-is-good" like a scorer, they are not coloured. .. GENERATED FROM PYTHON SOURCE LINES 74-82 .. code-block:: default ms: ModelSelection = ModelSelection(estimators=models_to_test, datasets=xy, target='class', pre_processor=pp, k_folds=30, verbosity=0) fig = ms.plot(title='Model Selection', metrics='f1') fig.update_layout(height=600) # noinspection PyTypeChecker plotly.io.show(fig) # this call to show will set the thumbnail for use in the gallery .. raw:: html :file: images/sphx_glr_model_selection_001.html .. GENERATED FROM PYTHON SOURCE LINES 83-84 View the data .. GENERATED FROM PYTHON SOURCE LINES 84-87 .. code-block:: default ms.results .. rst-class:: sphx-glr-script-out .. code-block:: none {'dataset': {'LR': CrossValidationResult(test_scores=array([0.88461538, 0.76923077, 0.73076923, 0.73076923, 0.76923077, 0.84615385, 0.69230769, 0.80769231, 0.92307692, 0.73076923, 0.76923077, 0.80769231, 0.69230769, 0.92307692, 0.73076923, 0.65384615, 0.80769231, 0.73076923, 0.8 , 0.8 , 0.56 , 0.68 , 0.84 , 0.72 , 0.84 , 0.76 , 0.84 , 0.88 , 0.72 , 0.8 ]), train_scores=array([0.77897574, 0.78167116, 0.78301887, 0.78167116, 0.78436658, 0.78167116, 0.78301887, 0.78032345, 0.77628032, 0.78436658, 0.78436658, 0.78167116, 0.78301887, 0.77628032, 0.78167116, 0.7884097 , 0.78167116, 0.77762803, 0.78061911, 0.7833109 , 0.78734859, 0.7846568 , 0.78061911, 0.7846568 , 0.77792732, 0.78196501, 0.77792732, 0.77792732, 0.78600269, 0.78061911]), fit_times=array([0.00261998, 0.00266814, 0.00245738, 0.00247359, 0.00245833, 0.00245118, 0.00250554, 0.00248766, 0.00245142, 0.00247431, 0.00245357, 0.00256872, 0.00251532, 0.00249863, 0.00247025, 0.00249982, 0.0024519 , 0.00243473, 0.00250983, 0.00249267, 0.0024817 , 0.00278544, 0.00246525, 0.00253177, 0.00252366, 0.00251293, 0.00245667, 0.00244617, 0.00246143, 0.0024724 ]), score_times=array([0.00135255, 0.00126386, 0.00125623, 0.00126529, 0.00127745, 0.00124955, 0.00123978, 0.00124335, 0.00123501, 0.0012455 , 0.00125885, 0.00126028, 0.0012393 , 0.00125027, 0.00123453, 0.00125718, 0.00127292, 0.00125909, 0.00124621, 0.00124645, 0.00124764, 0.00126696, 0.00123858, 0.00123858, 0.00124431, 0.00125837, 0.00125289, 0.00127816, 0.00126195, 0.0012424 ]), estimator=[LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression()], metrics={'f1': [0.8235294117647058, 0.75, 0.5882352941176471, 0.5333333333333333, 0.5714285714285714, 0.6, 0.6363636363636364, 0.7619047619047619, 0.875, 0.631578947368421, 0.7, 0.8, 0.42857142857142855, 0.875, 0.46153846153846156, 0.3076923076923077, 0.7058823529411765, 0.46153846153846156, 0.7058823529411765, 0.5454545454545454, 0.47619047619047616, 0.2, 0.5, 0.5882352941176471, 0.6666666666666666, 0.5, 0.5, 0.8695652173913043, 0.5882352941176471, 0.7058823529411765]}, metrics_group={}), 'LDA': CrossValidationResult(test_scores=array([0.80769231, 0.84615385, 0.80769231, 0.88461538, 0.61538462, 0.80769231, 0.73076923, 0.61538462, 0.84615385, 0.65384615, 0.84615385, 0.73076923, 0.76923077, 0.80769231, 0.76923077, 0.65384615, 0.73076923, 0.76923077, 0.76 , 0.8 , 0.76 , 0.72 , 0.88 , 0.88 , 0.76 , 0.84 , 0.8 , 0.72 , 0.8 , 0.8 ]), train_scores=array([0.77762803, 0.78032345, 0.78167116, 0.77628032, 0.78975741, 0.78301887, 0.78167116, 0.78436658, 0.77762803, 0.78571429, 0.78301887, 0.78301887, 0.77493261, 0.78167116, 0.78436658, 0.78167116, 0.78032345, 0.78167116, 0.7833109 , 0.7833109 , 0.7833109 , 0.77927322, 0.7833109 , 0.77927322, 0.77927322, 0.77927322, 0.77927322, 0.78196501, 0.78061911, 0.77927322]), fit_times=array([0.00188708, 0.00180793, 0.00179863, 0.00180387, 0.00179124, 0.00178361, 0.00179172, 0.00178409, 0.00177813, 0.00175285, 0.00184703, 0.00185871, 0.00179529, 0.00180244, 0.00181794, 0.00180173, 0.00179625, 0.00178432, 0.00177479, 0.00179648, 0.00176167, 0.00185084, 0.0017817 , 0.00182533, 0.00178695, 0.00178742, 0.00179029, 0.00184774, 0.00179434, 0.00177097]), score_times=array([0.00128984, 0.00123215, 0.00124812, 0.00124788, 0.00131893, 0.0012393 , 0.00125313, 0.00127029, 0.00125337, 0.00125813, 0.00126934, 0.00127673, 0.00125289, 0.00126362, 0.00124788, 0.00126815, 0.00124836, 0.00125289, 0.0012908 , 0.00126147, 0.00128865, 0.00125313, 0.00126815, 0.00125289, 0.00126553, 0.00126076, 0.00125456, 0.00127029, 0.00123405, 0.00126123]), estimator=[LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis()], metrics={'f1': [0.6153846153846154, 0.75, 0.6666666666666666, 0.8, 0.16666666666666666, 0.7058823529411765, 0.6956521739130435, 0.2857142857142857, 0.8181818181818182, 0.5263157894736842, 0.7142857142857143, 0.631578947368421, 0.7272727272727273, 0.7058823529411765, 0.6666666666666666, 0.47058823529411764, 0.631578947368421, 0.4, 0.4, 0.7058823529411765, 0.25, 0.6956521739130435, 0.6666666666666666, 0.7692307692307693, 0.7272727272727273, 0.75, 0.7368421052631579, 0.5882352941176471, 0.5454545454545454, 0.5454545454545454]}, metrics_group={}), 'KNN': CrossValidationResult(test_scores=array([0.73076923, 0.76923077, 0.88461538, 0.65384615, 0.80769231, 0.61538462, 0.84615385, 0.65384615, 0.96153846, 0.73076923, 0.57692308, 0.69230769, 0.80769231, 0.69230769, 0.84615385, 0.73076923, 0.73076923, 0.80769231, 0.56 , 0.64 , 0.68 , 0.72 , 0.64 , 0.72 , 0.76 , 0.72 , 0.76 , 0.84 , 0.72 , 0.6 ]), train_scores=array([0.83153639, 0.83153639, 0.8328841 , 0.82884097, 0.82614555, 0.83153639, 0.82210243, 0.83018868, 0.82884097, 0.83692722, 0.83692722, 0.83018868, 0.82614555, 0.83423181, 0.82614555, 0.82884097, 0.82884097, 0.82614555, 0.82907133, 0.83176312, 0.82368775, 0.83580081, 0.82907133, 0.82772544, 0.83041723, 0.83580081, 0.82772544, 0.82772544, 0.83310902, 0.8371467 ]), fit_times=array([0.00173783, 0.00167918, 0.00174212, 0.00166988, 0.00168896, 0.001688 , 0.00166416, 0.00168133, 0.00220537, 0.00174189, 0.00175905, 0.0016942 , 0.00170112, 0.00169015, 0.0017159 , 0.00168204, 0.00166464, 0.00173283, 0.00167823, 0.00168204, 0.00166988, 0.00164223, 0.00167704, 0.00167942, 0.00167394, 0.00181675, 0.00170994, 0.00171351, 0.00170016, 0.00171971]), score_times=array([0.00317693, 0.00309372, 0.00301814, 0.00300336, 0.00305367, 0.00303578, 0.00301814, 0.00317574, 0.00420833, 0.00310493, 0.00311255, 0.00308704, 0.00307727, 0.00308394, 0.00308323, 0.00308442, 0.00318766, 0.0030694 , 0.00304437, 0.00304627, 0.0030458 , 0.00301719, 0.00301886, 0.00301075, 0.00299764, 0.00310397, 0.0030601 , 0.00305057, 0.00303197, 0.00307846]), estimator=[KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier()], metrics={'f1': [0.36363636363636365, 0.625, 0.7272727272727273, 0.4, 0.7619047619047619, 0.5454545454545454, 0.7777777777777778, 0.4, 0.9473684210526315, 0.5882352941176471, 0.47619047619047616, 0.5555555555555556, 0.6153846153846154, 0.5555555555555556, 0.7777777777777778, 0.5882352941176471, 0.5333333333333333, 0.6666666666666666, 0.47619047619047616, 0.4, 0.5, 0.46153846153846156, 0.18181818181818182, 0.631578947368421, 0.5, 0.46153846153846156, 0.625, 0.8, 0.631578947368421, 0.4444444444444444]}, metrics_group={}), 'CART': CrossValidationResult(test_scores=array([0.57692308, 0.73076923, 0.73076923, 0.61538462, 0.57692308, 0.73076923, 0.65384615, 0.84615385, 0.65384615, 0.57692308, 0.73076923, 0.46153846, 0.76923077, 0.76923077, 0.61538462, 0.73076923, 0.73076923, 0.65384615, 0.68 , 0.6 , 0.8 , 0.8 , 0.68 , 0.72 , 0.64 , 0.8 , 0.76 , 0.68 , 0.8 , 0.84 ]), train_scores=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), fit_times=array([0.00459576, 0.0042994 , 0.00424004, 0.00427985, 0.00428796, 0.0041132 , 0.00411606, 0.00422359, 0.00420928, 0.00424671, 0.00423884, 0.00448918, 0.00425625, 0.00413632, 0.0041666 , 0.00428391, 0.00411654, 0.00445509, 0.0042913 , 0.00414419, 0.00438619, 0.00430059, 0.00449133, 0.00418448, 0.00416422, 0.00441861, 0.00426173, 0.00422931, 0.00429821, 0.00419974]), score_times=array([0.00130105, 0.00129461, 0.00126123, 0.00127268, 0.00127721, 0.00126052, 0.00126719, 0.00129294, 0.00127292, 0.00126648, 0.00129676, 0.00127292, 0.0012784 , 0.00126815, 0.00127435, 0.0012753 , 0.0012722 , 0.00127053, 0.00128722, 0.00133848, 0.00131106, 0.00136113, 0.00127411, 0.00126147, 0.0012939 , 0.0012784 , 0.00127459, 0.00127983, 0.00126672, 0.00126457]), estimator=[DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier()], metrics={'f1': [0.47619047619047616, 0.6666666666666666, 0.5882352941176471, 0.5454545454545454, 0.26666666666666666, 0.5882352941176471, 0.6086956521739131, 0.7142857142857143, 0.4, 0.26666666666666666, 0.631578947368421, 0.36363636363636365, 0.625, 0.25, 0.2857142857142857, 0.631578947368421, 0.5882352941176471, 0.5263157894736842, 0.5, 0.4444444444444444, 0.7058823529411765, 0.7368421052631579, 0.6, 0.36363636363636365, 0.5714285714285714, 0.7058823529411765, 0.6666666666666666, 0.5555555555555556, 0.7619047619047619, 0.7777777777777778]}, metrics_group={}), 'NB': CrossValidationResult(test_scores=array([0.88461538, 0.73076923, 0.57692308, 0.73076923, 0.84615385, 0.88461538, 0.69230769, 0.76923077, 0.80769231, 0.65384615, 0.76923077, 0.65384615, 0.80769231, 0.76923077, 0.76923077, 0.76923077, 0.88461538, 0.57692308, 0.88 , 0.6 , 0.64 , 0.84 , 0.64 , 0.84 , 0.72 , 0.72 , 0.76 , 0.76 , 0.64 , 0.88 ]), train_scores=array([0.75876011, 0.76415094, 0.76684636, 0.76280323, 0.75876011, 0.7574124 , 0.76145553, 0.76280323, 0.76280323, 0.76280323, 0.76280323, 0.76280323, 0.76145553, 0.76549865, 0.76010782, 0.76280323, 0.75876011, 0.77088949, 0.75370121, 0.76716016, 0.76581427, 0.7577389 , 0.76716016, 0.75908479, 0.76177658, 0.76446837, 0.76581427, 0.76177658, 0.76716016, 0.756393 ]), fit_times=array([0.00159216, 0.00149226, 0.00149322, 0.00149155, 0.00150824, 0.00149894, 0.00150681, 0.00147939, 0.00149202, 0.00146031, 0.00146794, 0.00146317, 0.00151467, 0.00148416, 0.00147891, 0.00150561, 0.0014565 , 0.00145364, 0.0014596 , 0.0014782 , 0.00150704, 0.00148034, 0.00147963, 0.00148511, 0.00147367, 0.00147033, 0.00147223, 0.00153303, 0.00152063, 0.00148678]), score_times=array([0.00130749, 0.00129128, 0.00122547, 0.00126052, 0.00124955, 0.00123858, 0.00122285, 0.00123143, 0.00122261, 0.00123096, 0.00121713, 0.00125051, 0.001261 , 0.00124884, 0.00122499, 0.00123668, 0.00126624, 0.00123262, 0.00124526, 0.00125742, 0.00124526, 0.00124097, 0.00123429, 0.00123572, 0.00123072, 0.00122857, 0.00123549, 0.00124097, 0.0012424 , 0.00124574]), estimator=[GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB()], metrics={'f1': [0.8421052631578947, 0.6666666666666666, 0.47619047619047616, 0.631578947368421, 0.7142857142857143, 0.8, 0.5555555555555556, 0.5, 0.7058823529411765, 0.5263157894736842, 0.7, 0.47058823529411764, 0.7368421052631579, 0.7, 0.5714285714285714, 0.7692307692307693, 0.8421052631578947, 0.42105263157894735, 0.7692307692307693, 0.375, 0.4, 0.75, 0.18181818181818182, 0.6666666666666666, 0.5882352941176471, 0.46153846153846156, 0.75, 0.5714285714285714, 0.47058823529411764, 0.7692307692307693]}, metrics_group={}), 'SVM': CrossValidationResult(test_scores=array([0.65384615, 0.65384615, 0.73076923, 0.65384615, 0.73076923, 0.80769231, 0.80769231, 0.76923077, 0.80769231, 0.69230769, 0.73076923, 0.92307692, 0.84615385, 0.92307692, 0.73076923, 0.57692308, 0.80769231, 0.88461538, 0.72 , 0.64 , 0.8 , 0.8 , 0.76 , 0.84 , 0.88 , 0.68 , 0.68 , 0.8 , 0.8 , 0.84 ]), train_scores=array([0.82614555, 0.82749326, 0.82749326, 0.82614555, 0.83423181, 0.82345013, 0.82345013, 0.82614555, 0.82479784, 0.83153639, 0.83153639, 0.82614555, 0.82210243, 0.81671159, 0.82345013, 0.82749326, 0.82479784, 0.81940701, 0.83310902, 0.82368775, 0.82637954, 0.82234186, 0.82637954, 0.81830417, 0.81695828, 0.83041723, 0.82637954, 0.82503365, 0.82772544, 0.81965007]), fit_times=array([0.00962901, 0.0095489 , 0.009516 , 0.00968194, 0.0095284 , 0.00953197, 0.0097506 , 0.00934386, 0.00951433, 0.00955391, 0.00937533, 0.00947881, 0.0095849 , 0.00968003, 0.00972986, 0.00953293, 0.0094831 , 0.00957584, 0.00954723, 0.00949812, 0.00965214, 0.00978875, 0.00950575, 0.00960684, 0.00969839, 0.00948334, 0.00967646, 0.00970507, 0.00952339, 0.01302576]), score_times=array([0.00170159, 0.00171876, 0.0016849 , 0.001683 , 0.00168777, 0.00168967, 0.00171828, 0.0016644 , 0.00172305, 0.00168371, 0.00170541, 0.00167155, 0.00167918, 0.00170398, 0.00170851, 0.00169015, 0.00169396, 0.00167537, 0.00167084, 0.00169635, 0.00166512, 0.00167847, 0.00167823, 0.00167227, 0.001683 , 0.00169349, 0.00167465, 0.00167823, 0.00175214, 0.00173664]), estimator=[SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC()], metrics={'f1': [0.47058823529411764, 0.4, 0.5882352941176471, 0.6086956521739131, 0.5882352941176471, 0.5454545454545454, 0.6666666666666666, 0.7272727272727273, 0.6153846153846154, 0.5555555555555556, 0.5882352941176471, 0.8888888888888888, 0.8, 0.875, 0.46153846153846156, 0.15384615384615385, 0.6666666666666666, 0.6666666666666666, 0.5882352941176471, 0.6086956521739131, 0.6666666666666666, 0.6153846153846154, 0.625, 0.8, 0.7692307692307693, 0.6363636363636364, 0.5, 0.6153846153846154, 0.6666666666666666, 0.6666666666666666]}, metrics_group={})}} .. GENERATED FROM PYTHON SOURCE LINES 88-96 Regressor Model Selection ------------------------- Of course we're not limited to classification problems. We will demonstrate a regression problem, with multiple metrics. We prepare a `group` variable (a pd.Series) in order to calculate the metrics by group for each fold. This cross-validation takes a bit longer, so we set the n_jobs to -2, to fit in parallel, while preserving a core to ensure the system can respond. .. GENERATED FROM PYTHON SOURCE LINES 96-110 .. code-block:: default diabetes = load_diabetes(as_frame=True) x, y = diabetes.data, diabetes.target y.name = "progression" xy: pd.DataFrame = pd.concat([x, y], axis=1) group: pd.Series = pd.Series(x['sex'] > 0, name='grp_sex', index=x.index) pp: Pipeline = make_pipeline(StandardScaler()) models_to_test: Dict = Models().fast_regressors() ms: ModelSelection = ModelSelection(estimators=models_to_test, datasets=xy, target='progression', pre_processor=pp, k_folds=30, scorer='r2', group=group, metrics={'moe': metrics.moe_95, 'me': metrics.mean_error}, n_jobs=-2, verbosity=2) .. GENERATED FROM PYTHON SOURCE LINES 111-112 Next we'll view the plot, but we will not (yet) leverage the group variable. .. GENERATED FROM PYTHON SOURCE LINES 112-117 .. code-block:: default fig = ms.plot(metrics=['moe', 'me']) fig.update_layout(height=700) fig .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 118-122 Now, we will re-plot using group. This is fast, since the fitting metrics were calculated when the first plot was created, and do not need to be calculated again. Plotting by group can (hopefully) provide evidence that metrics are consistent across groups. .. GENERATED FROM PYTHON SOURCE LINES 122-127 .. code-block:: default fig = ms.plot(metrics=['moe', 'me'], show_group=True, col_wrap=2) fig.update_layout(height=700) fig .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 128-129 Clearly, plot real estate will become a problem for more than 2 or 3 classes - here we used col_wrap mitigate that. .. GENERATED FROM PYTHON SOURCE LINES 131-136 Comparing Datasets ------------------ Next we will demonstrate a single Algorithm with multiple datasets. This is useful when exploring features that improve model performance. We modify DS2 by removing a feature and sampling 40% of the data. .. GENERATED FROM PYTHON SOURCE LINES 136-143 .. code-block:: default datasets: Dict = {'DS1': xy, 'DS2': xy.drop(columns=['age']).sample(frac=0.4)} fig = plot_model_selection(estimators=LinearRegression(), datasets=datasets, target='progression', pre_processor=pp, k_folds=30) fig.update_layout(height=600) fig .. raw:: html


.. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 9.120 seconds) .. _sphx_glr_download_auto_examples_model_selection.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: model_selection.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: model_selection.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_