Model Selection

This example demonstrates a model selection plot incorporating cross validation and test error.

Code has been adapted from the machinelearningmastery example

import logging
from typing import Dict

import numpy as np
import pandas
import pandas as pd
import plotly
from sklearn.datasets import load_diabetes
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.preprocessing import StandardScaler

from elphick.sklearn_viz.model_selection import ModelSelection, plot_model_selection, metrics
from elphick.sklearn_viz.model_selection.models import Models

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)s %(module)s - %(funcName)s: %(message)s',
                    datefmt='%Y-%m-%dT%H:%M:%S%z')

Load Data

Once loaded we’ll create the train-test split for a classification problem.

url = "https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv"
names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
dataframe = pandas.read_csv(url, names=names)
array = dataframe.values
x = pd.DataFrame(array[:, 0:8], columns=names[0:8])
y = pd.Series(array[:, 8], name=names[8])
xy: pd.DataFrame = pd.concat([x, y], axis=1)

Instantiate

Create an optional pre-processor as a sklearn Pipeline.

np.random.seed(1234)
pp: Pipeline = make_pipeline(StandardScaler())
models_to_test: Dict = Models().fast_classifiers()
pp
Pipeline(steps=[('standardscaler', StandardScaler())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Plot using the function

The box colors are scaled to provide a relative indication of performance based on the score (Kudos to Shah Newaz Khan)

fig = plot_model_selection(estimators=models_to_test, datasets=xy, target='class', pre_processor=pp)
fig.update_layout(height=600)
fig


Plot using the object

The alternative to using the function is to instantiate a ModelSelection object. This has the advantage of persisting the data, which provides greater flexibility and faster re-plotting. If metrics as provided additional subplots are provided - however since metrics have no concept of “greater-is-good” like a scorer, they are not coloured.

ms: ModelSelection = ModelSelection(estimators=models_to_test, datasets=xy, target='class', pre_processor=pp,
                                    k_folds=30, verbosity=0)
fig = ms.plot(title='Model Selection', metrics='f1')
fig.update_layout(height=600)
# noinspection PyTypeChecker
plotly.io.show(fig)  # this call to show will set the thumbnail for use in the gallery

View the data

ms.results
{'dataset': {'LR': CrossValidationResult(test_scores=array([0.88461538, 0.76923077, 0.73076923, 0.73076923, 0.76923077,
       0.84615385, 0.69230769, 0.80769231, 0.92307692, 0.73076923,
       0.76923077, 0.80769231, 0.69230769, 0.92307692, 0.73076923,
       0.65384615, 0.80769231, 0.73076923, 0.8       , 0.8       ,
       0.56      , 0.68      , 0.84      , 0.72      , 0.84      ,
       0.76      , 0.84      , 0.88      , 0.72      , 0.8       ]), train_scores=array([0.77897574, 0.78167116, 0.78301887, 0.78167116, 0.78436658,
       0.78167116, 0.78301887, 0.78032345, 0.77628032, 0.78436658,
       0.78436658, 0.78167116, 0.78301887, 0.77628032, 0.78167116,
       0.7884097 , 0.78167116, 0.77762803, 0.78061911, 0.7833109 ,
       0.78734859, 0.7846568 , 0.78061911, 0.7846568 , 0.77792732,
       0.78196501, 0.77792732, 0.77792732, 0.78600269, 0.78061911]), fit_times=array([0.00288987, 0.00273037, 0.00272465, 0.00265408, 0.00273561,
       0.00267959, 0.00281501, 0.00265908, 0.00266528, 0.00267959,
       0.00266981, 0.0027225 , 0.00269055, 0.00268173, 0.00277352,
       0.00266933, 0.00280619, 0.00263047, 0.00266218, 0.00267982,
       0.00265217, 0.00270677, 0.00266981, 0.00262809, 0.00269747,
       0.00266743, 0.00282025, 0.00265098, 0.00268483, 0.00268197]), score_times=array([0.00119638, 0.00114226, 0.0011332 , 0.00116014, 0.00113368,
       0.00114918, 0.00116491, 0.00113392, 0.00115108, 0.00113773,
       0.00119519, 0.00113153, 0.00114226, 0.00115395, 0.0011549 ,
       0.00113964, 0.00114393, 0.00117874, 0.00112891, 0.00114202,
       0.00116014, 0.00125027, 0.00112772, 0.00114655, 0.00114727,
       0.0011344 , 0.0011518 , 0.00116515, 0.00114036, 0.00113392]), estimator=[LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression()], metrics={'f1': [0.8235294117647058, 0.75, 0.588235294117647, 0.5333333333333333, 0.5714285714285715, 0.6, 0.6363636363636365, 0.761904761904762, 0.8750000000000001, 0.631578947368421, 0.7, 0.8, 0.42857142857142855, 0.8750000000000001, 0.4615384615384615, 0.3076923076923077, 0.7058823529411764, 0.4615384615384615, 0.7058823529411765, 0.5454545454545454, 0.4761904761904762, 0.2, 0.5, 0.5882352941176471, 0.6666666666666666, 0.5, 0.5, 0.8695652173913044, 0.5882352941176471, 0.7058823529411764]}, metrics_group={}), 'LDA': CrossValidationResult(test_scores=array([0.80769231, 0.84615385, 0.80769231, 0.88461538, 0.61538462,
       0.80769231, 0.73076923, 0.61538462, 0.84615385, 0.65384615,
       0.84615385, 0.73076923, 0.76923077, 0.80769231, 0.76923077,
       0.65384615, 0.73076923, 0.76923077, 0.76      , 0.8       ,
       0.76      , 0.72      , 0.88      , 0.88      , 0.76      ,
       0.84      , 0.8       , 0.72      , 0.8       , 0.8       ]), train_scores=array([0.77762803, 0.78032345, 0.78167116, 0.77628032, 0.78975741,
       0.78301887, 0.78167116, 0.78436658, 0.77762803, 0.78571429,
       0.78301887, 0.78301887, 0.77493261, 0.78167116, 0.78436658,
       0.78167116, 0.78032345, 0.78167116, 0.7833109 , 0.7833109 ,
       0.7833109 , 0.77927322, 0.7833109 , 0.77927322, 0.77927322,
       0.77927322, 0.77927322, 0.78196501, 0.78061911, 0.77927322]), fit_times=array([0.00196075, 0.00183129, 0.00185728, 0.00187707, 0.00184894,
       0.00186849, 0.00189543, 0.00187898, 0.00185585, 0.0019474 ,
       0.00182939, 0.00185013, 0.00184226, 0.00182295, 0.00182986,
       0.00181532, 0.00182581, 0.0018208 , 0.00187159, 0.00184226,
       0.00183773, 0.00186396, 0.00185394, 0.00186348, 0.00186324,
       0.00188732, 0.00183773, 0.00184321, 0.00181603, 0.00181913]), score_times=array([0.00118876, 0.00117207, 0.0011518 , 0.00112724, 0.00114727,
       0.00112963, 0.00113606, 0.00114059, 0.00112605, 0.00112891,
       0.00113511, 0.00114107, 0.00111794, 0.0011363 , 0.0011363 ,
       0.00114536, 0.00114894, 0.00114703, 0.00113273, 0.00112677,
       0.00114417, 0.00112343, 0.00111103, 0.00111222, 0.00112486,
       0.00112963, 0.00112486, 0.00113606, 0.00112271, 0.00116634]), estimator=[LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis()], metrics={'f1': [0.6153846153846154, 0.75, 0.6666666666666666, 0.7999999999999999, 0.16666666666666666, 0.7058823529411764, 0.6956521739130435, 0.2857142857142857, 0.8181818181818181, 0.5263157894736842, 0.7142857142857143, 0.631578947368421, 0.7272727272727274, 0.7058823529411765, 0.6666666666666665, 0.47058823529411764, 0.631578947368421, 0.4, 0.4, 0.7058823529411764, 0.25, 0.6956521739130435, 0.6666666666666665, 0.7692307692307692, 0.7272727272727274, 0.7499999999999999, 0.7368421052631579, 0.588235294117647, 0.5454545454545454, 0.5454545454545454]}, metrics_group={}), 'KNN': CrossValidationResult(test_scores=array([0.73076923, 0.76923077, 0.88461538, 0.65384615, 0.80769231,
       0.61538462, 0.84615385, 0.65384615, 0.96153846, 0.73076923,
       0.57692308, 0.69230769, 0.80769231, 0.69230769, 0.84615385,
       0.73076923, 0.73076923, 0.80769231, 0.56      , 0.64      ,
       0.68      , 0.72      , 0.64      , 0.72      , 0.76      ,
       0.72      , 0.76      , 0.84      , 0.72      , 0.6       ]), train_scores=array([0.83153639, 0.83153639, 0.8328841 , 0.82884097, 0.82614555,
       0.83153639, 0.82210243, 0.83018868, 0.82884097, 0.83692722,
       0.83692722, 0.83018868, 0.82614555, 0.83423181, 0.82614555,
       0.82884097, 0.82884097, 0.82614555, 0.82907133, 0.83176312,
       0.82368775, 0.83580081, 0.82907133, 0.82772544, 0.83041723,
       0.83580081, 0.82772544, 0.82772544, 0.83310902, 0.8371467 ]), fit_times=array([0.00178933, 0.00174022, 0.00171304, 0.00213313, 0.00195336,
       0.0016911 , 0.00169492, 0.00170159, 0.00169897, 0.00167179,
       0.00166631, 0.00167513, 0.00166655, 0.00167441, 0.00197434,
       0.00178623, 0.00172162, 0.00171494, 0.00173044, 0.00171208,
       0.0017705 , 0.00172067, 0.00178027, 0.00173545, 0.00176334,
       0.00167632, 0.00167942, 0.00169587, 0.00167918, 0.00175929]), score_times=array([0.00221109, 0.00219941, 0.00200772, 0.00231314, 0.00220656,
       0.00203824, 0.00205851, 0.00206113, 0.00210524, 0.00203753,
       0.00246358, 0.00203824, 0.00201201, 0.00202656, 0.00283217,
       0.00208712, 0.00201631, 0.00200129, 0.00200844, 0.00211596,
       0.0020051 , 0.00202155, 0.00207305, 0.00205398, 0.00200486,
       0.00201297, 0.00203729, 0.00200319, 0.00215507, 0.00209975]), estimator=[KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier()], metrics={'f1': [0.36363636363636365, 0.625, 0.7272727272727272, 0.4, 0.761904761904762, 0.5454545454545455, 0.7777777777777777, 0.4, 0.9473684210526316, 0.588235294117647, 0.47619047619047616, 0.5555555555555556, 0.6153846153846154, 0.5555555555555556, 0.7777777777777778, 0.5882352941176471, 0.5333333333333333, 0.6666666666666666, 0.4761904761904762, 0.39999999999999997, 0.5, 0.4615384615384615, 0.18181818181818182, 0.631578947368421, 0.5, 0.4615384615384615, 0.625, 0.8, 0.631578947368421, 0.4444444444444444]}, metrics_group={}), 'CART': CrossValidationResult(test_scores=array([0.57692308, 0.73076923, 0.69230769, 0.65384615, 0.57692308,
       0.69230769, 0.69230769, 0.80769231, 0.65384615, 0.57692308,
       0.73076923, 0.42307692, 0.76923077, 0.80769231, 0.65384615,
       0.76923077, 0.73076923, 0.61538462, 0.68      , 0.6       ,
       0.76      , 0.76      , 0.68      , 0.68      , 0.6       ,
       0.72      , 0.76      , 0.64      , 0.8       , 0.8       ]), train_scores=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), fit_times=array([0.00441813, 0.00399709, 0.00403643, 0.00407481, 0.00410199,
       0.00396419, 0.00396347, 0.00402927, 0.00399613, 0.0040195 ,
       0.00399184, 0.00425124, 0.00403762, 0.00398684, 0.00396228,
       0.00411391, 0.0039835 , 0.00423288, 0.00410914, 0.00395465,
       0.00405264, 0.00403929, 0.00426602, 0.00398064, 0.00393915,
       0.00473619, 0.00471878, 0.00405717, 0.00411105, 0.00398374]), score_times=array([0.00120997, 0.00116444, 0.00120974, 0.00115657, 0.00115395,
       0.00119948, 0.00118065, 0.00115466, 0.00117445, 0.00115323,
       0.00115204, 0.00115633, 0.00117946, 0.00118184, 0.00116444,
       0.00116897, 0.00120068, 0.00116992, 0.00117493, 0.00116038,
       0.00116706, 0.0011487 , 0.00117469, 0.00115657, 0.00117016,
       0.00159001, 0.00163865, 0.00118852, 0.00115705, 0.00117183]), estimator=[DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier()], metrics={'f1': [0.47619047619047616, 0.631578947368421, 0.5555555555555556, 0.5714285714285713, 0.26666666666666666, 0.5555555555555556, 0.6666666666666667, 0.6666666666666666, 0.47058823529411764, 0.26666666666666666, 0.631578947368421, 0.34782608695652173, 0.625, 0.4444444444444444, 0.30769230769230765, 0.7000000000000001, 0.588235294117647, 0.5, 0.5, 0.4444444444444444, 0.625, 0.6666666666666665, 0.6, 0.3333333333333333, 0.5833333333333334, 0.5882352941176471, 0.6666666666666665, 0.47058823529411764, 0.761904761904762, 0.7368421052631579]}, metrics_group={}), 'NB': CrossValidationResult(test_scores=array([0.88461538, 0.73076923, 0.57692308, 0.73076923, 0.84615385,
       0.88461538, 0.69230769, 0.76923077, 0.80769231, 0.65384615,
       0.76923077, 0.65384615, 0.80769231, 0.76923077, 0.76923077,
       0.76923077, 0.88461538, 0.57692308, 0.88      , 0.6       ,
       0.64      , 0.84      , 0.64      , 0.84      , 0.72      ,
       0.72      , 0.76      , 0.76      , 0.64      , 0.88      ]), train_scores=array([0.75876011, 0.76415094, 0.76684636, 0.76280323, 0.75876011,
       0.7574124 , 0.76145553, 0.76280323, 0.76280323, 0.76280323,
       0.76280323, 0.76280323, 0.76145553, 0.76549865, 0.76010782,
       0.76280323, 0.75876011, 0.77088949, 0.75370121, 0.76716016,
       0.76581427, 0.7577389 , 0.76716016, 0.75908479, 0.76177658,
       0.76446837, 0.76581427, 0.76177658, 0.76716016, 0.756393  ]), fit_times=array([0.00178385, 0.0015738 , 0.00152946, 0.00158191, 0.00155091,
       0.0015409 , 0.00154948, 0.00155616, 0.00153613, 0.00153136,
       0.00153422, 0.00156784, 0.00157237, 0.00154543, 0.00153446,
       0.00155282, 0.00155568, 0.00153184, 0.00156522, 0.00152779,
       0.00152659, 0.00156522, 0.00152469, 0.00152421, 0.00153923,
       0.00152087, 0.00154853, 0.00157332, 0.00151706, 0.00151658]), score_times=array([0.00120425, 0.00114584, 0.00114202, 0.00114799, 0.00115514,
       0.0011363 , 0.00115585, 0.00115824, 0.00115013, 0.00114322,
       0.00113606, 0.00113964, 0.00113893, 0.0011518 , 0.0011549 ,
       0.00112987, 0.00114727, 0.0011375 , 0.00114989, 0.00114965,
       0.00112486, 0.00113821, 0.00113964, 0.00113082, 0.00113463,
       0.00112915, 0.00113726, 0.00112677, 0.00113249, 0.00113273]), estimator=[GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB()], metrics={'f1': [0.8421052631578948, 0.6666666666666667, 0.47619047619047616, 0.631578947368421, 0.7142857142857143, 0.7999999999999999, 0.5555555555555556, 0.5, 0.7058823529411764, 0.5263157894736842, 0.7, 0.47058823529411764, 0.7368421052631577, 0.7, 0.5714285714285714, 0.7692307692307693, 0.8421052631578948, 0.4210526315789473, 0.7692307692307693, 0.37499999999999994, 0.39999999999999997, 0.75, 0.18181818181818182, 0.6666666666666666, 0.588235294117647, 0.4615384615384615, 0.75, 0.5714285714285715, 0.47058823529411764, 0.7692307692307692]}, metrics_group={}), 'SVM': CrossValidationResult(test_scores=array([0.65384615, 0.65384615, 0.73076923, 0.65384615, 0.73076923,
       0.80769231, 0.80769231, 0.76923077, 0.80769231, 0.69230769,
       0.73076923, 0.92307692, 0.84615385, 0.92307692, 0.73076923,
       0.57692308, 0.80769231, 0.88461538, 0.72      , 0.64      ,
       0.8       , 0.8       , 0.76      , 0.84      , 0.88      ,
       0.68      , 0.68      , 0.8       , 0.8       , 0.84      ]), train_scores=array([0.82614555, 0.82749326, 0.82749326, 0.82614555, 0.83423181,
       0.82345013, 0.82345013, 0.82614555, 0.82479784, 0.83153639,
       0.83153639, 0.82614555, 0.82210243, 0.81671159, 0.82345013,
       0.82749326, 0.82479784, 0.81940701, 0.83310902, 0.82368775,
       0.82637954, 0.82234186, 0.82637954, 0.81830417, 0.81695828,
       0.83041723, 0.82637954, 0.82503365, 0.82772544, 0.81965007]), fit_times=array([0.01058912, 0.01053905, 0.01046419, 0.01059031, 0.01051593,
       0.01052547, 0.01087952, 0.01040816, 0.01073122, 0.01215649,
       0.01337957, 0.01146436, 0.0108254 , 0.01063704, 0.01059008,
       0.01044226, 0.01053882, 0.01063919, 0.01053476, 0.01036549,
       0.01067448, 0.01087451, 0.01044321, 0.01072121, 0.0107317 ,
       0.0105238 , 0.01067805, 0.01065302, 0.01051974, 0.01071835]), score_times=array([0.00163293, 0.00171232, 0.00158191, 0.00157595, 0.0015862 ,
       0.0016005 , 0.00159645, 0.00170302, 0.00162745, 0.00240469,
       0.00212526, 0.00184631, 0.0016253 , 0.00160146, 0.00161529,
       0.00162363, 0.00157905, 0.0016222 , 0.00159883, 0.00156951,
       0.0015769 , 0.00157881, 0.00157785, 0.00158381, 0.00159788,
       0.00160217, 0.00158501, 0.00168943, 0.00158262, 0.00156283]), estimator=[SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC()], metrics={'f1': [0.47058823529411764, 0.39999999999999997, 0.5882352941176471, 0.6086956521739131, 0.5882352941176471, 0.5454545454545454, 0.6666666666666666, 0.7272727272727272, 0.6153846153846153, 0.5555555555555556, 0.588235294117647, 0.888888888888889, 0.8, 0.8750000000000001, 0.4615384615384615, 0.15384615384615385, 0.6666666666666666, 0.6666666666666665, 0.588235294117647, 0.608695652173913, 0.6666666666666666, 0.6153846153846153, 0.6250000000000001, 0.8000000000000002, 0.7692307692307693, 0.6363636363636364, 0.5, 0.6153846153846153, 0.6666666666666667, 0.6666666666666666]}, metrics_group={})}}

Regressor Model Selection

Of course we’re not limited to classification problems. We will demonstrate a regression problem, with multiple metrics. We prepare a group variable (a pd.Series) in order to calculate the metrics by group for each fold.

This cross-validation takes a bit longer, so we set the n_jobs to -2, to fit in parallel, while preserving a core to ensure the system can respond.

diabetes = load_diabetes(as_frame=True)
x, y = diabetes.data, diabetes.target
y.name = "progression"
xy: pd.DataFrame = pd.concat([x, y], axis=1)
group: pd.Series = pd.Series(x['sex'] > 0, name='grp_sex', index=x.index)

pp: Pipeline = make_pipeline(StandardScaler())
models_to_test: Dict = Models().fast_regressors()

ms: ModelSelection = ModelSelection(estimators=models_to_test, datasets=xy, target='progression', pre_processor=pp,
                                    k_folds=30, scorer='r2', group=group,
                                    metrics={'moe': metrics.moe_95, 'me': metrics.mean_error},
                                    n_jobs=-2, verbosity=2)

Next we’ll view the plot, but we will not (yet) leverage the group variable.

fig = ms.plot(metrics=['moe', 'me'])
fig.update_layout(height=700)
fig


Now, we will re-plot using group. This is fast, since the fitting metrics were calculated when the first plot was created, and do not need to be calculated again.

Plotting by group can (hopefully) provide evidence that metrics are consistent across groups.

fig = ms.plot(metrics=['moe', 'me'], show_group=True, col_wrap=2)
fig.update_layout(height=700)
fig


Clearly, plot real estate will become a problem for more than 2 or 3 classes - here we used col_wrap mitigate that.

Comparing Datasets

Next we will demonstrate a single Algorithm with multiple datasets. This is useful when exploring features that improve model performance. We modify DS2 by removing a feature and sampling 40% of the data.

datasets: Dict = {'DS1': xy, 'DS2': xy.drop(columns=['age']).sample(frac=0.4)}

fig = plot_model_selection(estimators=LinearRegression(), datasets=datasets, target='progression', pre_processor=pp,
                           k_folds=30)
fig.update_layout(height=600)
fig


Total running time of the script: ( 0 minutes 8.642 seconds)

Gallery generated by Sphinx-Gallery