Note
Click here to download the full example code
Model Selection
This example demonstrates a model selection plot incorporating cross validation and test error.
Code has been adapted from the machinelearningmastery example
import logging
from typing import Dict
import numpy as np
import pandas
import pandas as pd
import plotly
from sklearn.datasets import load_diabetes
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.preprocessing import StandardScaler
from elphick.sklearn_viz.model_selection import ModelSelection, plot_model_selection, metrics
from elphick.sklearn_viz.model_selection.models import Models
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)s %(module)s - %(funcName)s: %(message)s',
datefmt='%Y-%m-%dT%H:%M:%S%z')
Load Data
Once loaded we’ll create the train-test split for a classification problem.
url = "https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv"
names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
dataframe = pandas.read_csv(url, names=names)
array = dataframe.values
x = pd.DataFrame(array[:, 0:8], columns=names[0:8])
y = pd.Series(array[:, 8], name=names[8])
xy: pd.DataFrame = pd.concat([x, y], axis=1)
Instantiate
Create an optional pre-processor as a sklearn Pipeline.
np.random.seed(1234)
pp: Pipeline = make_pipeline(StandardScaler())
models_to_test: Dict = Models().fast_classifiers()
pp
Plot using the function
The box colors are scaled to provide a relative indication of performance based on the score (Kudos to Shah Newaz Khan)
fig = plot_model_selection(estimators=models_to_test, datasets=xy, target='class', pre_processor=pp)
fig.update_layout(height=600)
fig
Plot using the object
The alternative to using the function is to instantiate a ModelSelection object. This has the advantage of persisting the data, which provides greater flexibility and faster re-plotting. If metrics as provided additional subplots are provided - however since metrics have no concept of “greater-is-good” like a scorer, they are not coloured.
ms: ModelSelection = ModelSelection(estimators=models_to_test, datasets=xy, target='class', pre_processor=pp,
k_folds=30, verbosity=0)
fig = ms.plot(title='Model Selection', metrics='f1')
fig.update_layout(height=600)
# noinspection PyTypeChecker
plotly.io.show(fig) # this call to show will set the thumbnail for use in the gallery
View the data
ms.results
{'dataset': {'LR': CrossValidationResult(test_scores=array([0.88461538, 0.76923077, 0.73076923, 0.73076923, 0.76923077,
0.84615385, 0.69230769, 0.80769231, 0.92307692, 0.73076923,
0.76923077, 0.80769231, 0.69230769, 0.92307692, 0.73076923,
0.65384615, 0.80769231, 0.73076923, 0.8 , 0.8 ,
0.56 , 0.68 , 0.84 , 0.72 , 0.84 ,
0.76 , 0.84 , 0.88 , 0.72 , 0.8 ]), train_scores=array([0.77897574, 0.78167116, 0.78301887, 0.78167116, 0.78436658,
0.78167116, 0.78301887, 0.78032345, 0.77628032, 0.78436658,
0.78436658, 0.78167116, 0.78301887, 0.77628032, 0.78167116,
0.7884097 , 0.78167116, 0.77762803, 0.78061911, 0.7833109 ,
0.78734859, 0.7846568 , 0.78061911, 0.7846568 , 0.77792732,
0.78196501, 0.77792732, 0.77792732, 0.78600269, 0.78061911]), fit_times=array([0.00288987, 0.00273037, 0.00272465, 0.00265408, 0.00273561,
0.00267959, 0.00281501, 0.00265908, 0.00266528, 0.00267959,
0.00266981, 0.0027225 , 0.00269055, 0.00268173, 0.00277352,
0.00266933, 0.00280619, 0.00263047, 0.00266218, 0.00267982,
0.00265217, 0.00270677, 0.00266981, 0.00262809, 0.00269747,
0.00266743, 0.00282025, 0.00265098, 0.00268483, 0.00268197]), score_times=array([0.00119638, 0.00114226, 0.0011332 , 0.00116014, 0.00113368,
0.00114918, 0.00116491, 0.00113392, 0.00115108, 0.00113773,
0.00119519, 0.00113153, 0.00114226, 0.00115395, 0.0011549 ,
0.00113964, 0.00114393, 0.00117874, 0.00112891, 0.00114202,
0.00116014, 0.00125027, 0.00112772, 0.00114655, 0.00114727,
0.0011344 , 0.0011518 , 0.00116515, 0.00114036, 0.00113392]), estimator=[LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression(), LogisticRegression()], metrics={'f1': [0.8235294117647058, 0.75, 0.588235294117647, 0.5333333333333333, 0.5714285714285715, 0.6, 0.6363636363636365, 0.761904761904762, 0.8750000000000001, 0.631578947368421, 0.7, 0.8, 0.42857142857142855, 0.8750000000000001, 0.4615384615384615, 0.3076923076923077, 0.7058823529411764, 0.4615384615384615, 0.7058823529411765, 0.5454545454545454, 0.4761904761904762, 0.2, 0.5, 0.5882352941176471, 0.6666666666666666, 0.5, 0.5, 0.8695652173913044, 0.5882352941176471, 0.7058823529411764]}, metrics_group={}), 'LDA': CrossValidationResult(test_scores=array([0.80769231, 0.84615385, 0.80769231, 0.88461538, 0.61538462,
0.80769231, 0.73076923, 0.61538462, 0.84615385, 0.65384615,
0.84615385, 0.73076923, 0.76923077, 0.80769231, 0.76923077,
0.65384615, 0.73076923, 0.76923077, 0.76 , 0.8 ,
0.76 , 0.72 , 0.88 , 0.88 , 0.76 ,
0.84 , 0.8 , 0.72 , 0.8 , 0.8 ]), train_scores=array([0.77762803, 0.78032345, 0.78167116, 0.77628032, 0.78975741,
0.78301887, 0.78167116, 0.78436658, 0.77762803, 0.78571429,
0.78301887, 0.78301887, 0.77493261, 0.78167116, 0.78436658,
0.78167116, 0.78032345, 0.78167116, 0.7833109 , 0.7833109 ,
0.7833109 , 0.77927322, 0.7833109 , 0.77927322, 0.77927322,
0.77927322, 0.77927322, 0.78196501, 0.78061911, 0.77927322]), fit_times=array([0.00196075, 0.00183129, 0.00185728, 0.00187707, 0.00184894,
0.00186849, 0.00189543, 0.00187898, 0.00185585, 0.0019474 ,
0.00182939, 0.00185013, 0.00184226, 0.00182295, 0.00182986,
0.00181532, 0.00182581, 0.0018208 , 0.00187159, 0.00184226,
0.00183773, 0.00186396, 0.00185394, 0.00186348, 0.00186324,
0.00188732, 0.00183773, 0.00184321, 0.00181603, 0.00181913]), score_times=array([0.00118876, 0.00117207, 0.0011518 , 0.00112724, 0.00114727,
0.00112963, 0.00113606, 0.00114059, 0.00112605, 0.00112891,
0.00113511, 0.00114107, 0.00111794, 0.0011363 , 0.0011363 ,
0.00114536, 0.00114894, 0.00114703, 0.00113273, 0.00112677,
0.00114417, 0.00112343, 0.00111103, 0.00111222, 0.00112486,
0.00112963, 0.00112486, 0.00113606, 0.00112271, 0.00116634]), estimator=[LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis(), LinearDiscriminantAnalysis()], metrics={'f1': [0.6153846153846154, 0.75, 0.6666666666666666, 0.7999999999999999, 0.16666666666666666, 0.7058823529411764, 0.6956521739130435, 0.2857142857142857, 0.8181818181818181, 0.5263157894736842, 0.7142857142857143, 0.631578947368421, 0.7272727272727274, 0.7058823529411765, 0.6666666666666665, 0.47058823529411764, 0.631578947368421, 0.4, 0.4, 0.7058823529411764, 0.25, 0.6956521739130435, 0.6666666666666665, 0.7692307692307692, 0.7272727272727274, 0.7499999999999999, 0.7368421052631579, 0.588235294117647, 0.5454545454545454, 0.5454545454545454]}, metrics_group={}), 'KNN': CrossValidationResult(test_scores=array([0.73076923, 0.76923077, 0.88461538, 0.65384615, 0.80769231,
0.61538462, 0.84615385, 0.65384615, 0.96153846, 0.73076923,
0.57692308, 0.69230769, 0.80769231, 0.69230769, 0.84615385,
0.73076923, 0.73076923, 0.80769231, 0.56 , 0.64 ,
0.68 , 0.72 , 0.64 , 0.72 , 0.76 ,
0.72 , 0.76 , 0.84 , 0.72 , 0.6 ]), train_scores=array([0.83153639, 0.83153639, 0.8328841 , 0.82884097, 0.82614555,
0.83153639, 0.82210243, 0.83018868, 0.82884097, 0.83692722,
0.83692722, 0.83018868, 0.82614555, 0.83423181, 0.82614555,
0.82884097, 0.82884097, 0.82614555, 0.82907133, 0.83176312,
0.82368775, 0.83580081, 0.82907133, 0.82772544, 0.83041723,
0.83580081, 0.82772544, 0.82772544, 0.83310902, 0.8371467 ]), fit_times=array([0.00178933, 0.00174022, 0.00171304, 0.00213313, 0.00195336,
0.0016911 , 0.00169492, 0.00170159, 0.00169897, 0.00167179,
0.00166631, 0.00167513, 0.00166655, 0.00167441, 0.00197434,
0.00178623, 0.00172162, 0.00171494, 0.00173044, 0.00171208,
0.0017705 , 0.00172067, 0.00178027, 0.00173545, 0.00176334,
0.00167632, 0.00167942, 0.00169587, 0.00167918, 0.00175929]), score_times=array([0.00221109, 0.00219941, 0.00200772, 0.00231314, 0.00220656,
0.00203824, 0.00205851, 0.00206113, 0.00210524, 0.00203753,
0.00246358, 0.00203824, 0.00201201, 0.00202656, 0.00283217,
0.00208712, 0.00201631, 0.00200129, 0.00200844, 0.00211596,
0.0020051 , 0.00202155, 0.00207305, 0.00205398, 0.00200486,
0.00201297, 0.00203729, 0.00200319, 0.00215507, 0.00209975]), estimator=[KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier(), KNeighborsClassifier()], metrics={'f1': [0.36363636363636365, 0.625, 0.7272727272727272, 0.4, 0.761904761904762, 0.5454545454545455, 0.7777777777777777, 0.4, 0.9473684210526316, 0.588235294117647, 0.47619047619047616, 0.5555555555555556, 0.6153846153846154, 0.5555555555555556, 0.7777777777777778, 0.5882352941176471, 0.5333333333333333, 0.6666666666666666, 0.4761904761904762, 0.39999999999999997, 0.5, 0.4615384615384615, 0.18181818181818182, 0.631578947368421, 0.5, 0.4615384615384615, 0.625, 0.8, 0.631578947368421, 0.4444444444444444]}, metrics_group={}), 'CART': CrossValidationResult(test_scores=array([0.57692308, 0.73076923, 0.69230769, 0.65384615, 0.57692308,
0.69230769, 0.69230769, 0.80769231, 0.65384615, 0.57692308,
0.73076923, 0.42307692, 0.76923077, 0.80769231, 0.65384615,
0.76923077, 0.73076923, 0.61538462, 0.68 , 0.6 ,
0.76 , 0.76 , 0.68 , 0.68 , 0.6 ,
0.72 , 0.76 , 0.64 , 0.8 , 0.8 ]), train_scores=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), fit_times=array([0.00441813, 0.00399709, 0.00403643, 0.00407481, 0.00410199,
0.00396419, 0.00396347, 0.00402927, 0.00399613, 0.0040195 ,
0.00399184, 0.00425124, 0.00403762, 0.00398684, 0.00396228,
0.00411391, 0.0039835 , 0.00423288, 0.00410914, 0.00395465,
0.00405264, 0.00403929, 0.00426602, 0.00398064, 0.00393915,
0.00473619, 0.00471878, 0.00405717, 0.00411105, 0.00398374]), score_times=array([0.00120997, 0.00116444, 0.00120974, 0.00115657, 0.00115395,
0.00119948, 0.00118065, 0.00115466, 0.00117445, 0.00115323,
0.00115204, 0.00115633, 0.00117946, 0.00118184, 0.00116444,
0.00116897, 0.00120068, 0.00116992, 0.00117493, 0.00116038,
0.00116706, 0.0011487 , 0.00117469, 0.00115657, 0.00117016,
0.00159001, 0.00163865, 0.00118852, 0.00115705, 0.00117183]), estimator=[DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier(), DecisionTreeClassifier()], metrics={'f1': [0.47619047619047616, 0.631578947368421, 0.5555555555555556, 0.5714285714285713, 0.26666666666666666, 0.5555555555555556, 0.6666666666666667, 0.6666666666666666, 0.47058823529411764, 0.26666666666666666, 0.631578947368421, 0.34782608695652173, 0.625, 0.4444444444444444, 0.30769230769230765, 0.7000000000000001, 0.588235294117647, 0.5, 0.5, 0.4444444444444444, 0.625, 0.6666666666666665, 0.6, 0.3333333333333333, 0.5833333333333334, 0.5882352941176471, 0.6666666666666665, 0.47058823529411764, 0.761904761904762, 0.7368421052631579]}, metrics_group={}), 'NB': CrossValidationResult(test_scores=array([0.88461538, 0.73076923, 0.57692308, 0.73076923, 0.84615385,
0.88461538, 0.69230769, 0.76923077, 0.80769231, 0.65384615,
0.76923077, 0.65384615, 0.80769231, 0.76923077, 0.76923077,
0.76923077, 0.88461538, 0.57692308, 0.88 , 0.6 ,
0.64 , 0.84 , 0.64 , 0.84 , 0.72 ,
0.72 , 0.76 , 0.76 , 0.64 , 0.88 ]), train_scores=array([0.75876011, 0.76415094, 0.76684636, 0.76280323, 0.75876011,
0.7574124 , 0.76145553, 0.76280323, 0.76280323, 0.76280323,
0.76280323, 0.76280323, 0.76145553, 0.76549865, 0.76010782,
0.76280323, 0.75876011, 0.77088949, 0.75370121, 0.76716016,
0.76581427, 0.7577389 , 0.76716016, 0.75908479, 0.76177658,
0.76446837, 0.76581427, 0.76177658, 0.76716016, 0.756393 ]), fit_times=array([0.00178385, 0.0015738 , 0.00152946, 0.00158191, 0.00155091,
0.0015409 , 0.00154948, 0.00155616, 0.00153613, 0.00153136,
0.00153422, 0.00156784, 0.00157237, 0.00154543, 0.00153446,
0.00155282, 0.00155568, 0.00153184, 0.00156522, 0.00152779,
0.00152659, 0.00156522, 0.00152469, 0.00152421, 0.00153923,
0.00152087, 0.00154853, 0.00157332, 0.00151706, 0.00151658]), score_times=array([0.00120425, 0.00114584, 0.00114202, 0.00114799, 0.00115514,
0.0011363 , 0.00115585, 0.00115824, 0.00115013, 0.00114322,
0.00113606, 0.00113964, 0.00113893, 0.0011518 , 0.0011549 ,
0.00112987, 0.00114727, 0.0011375 , 0.00114989, 0.00114965,
0.00112486, 0.00113821, 0.00113964, 0.00113082, 0.00113463,
0.00112915, 0.00113726, 0.00112677, 0.00113249, 0.00113273]), estimator=[GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB(), GaussianNB()], metrics={'f1': [0.8421052631578948, 0.6666666666666667, 0.47619047619047616, 0.631578947368421, 0.7142857142857143, 0.7999999999999999, 0.5555555555555556, 0.5, 0.7058823529411764, 0.5263157894736842, 0.7, 0.47058823529411764, 0.7368421052631577, 0.7, 0.5714285714285714, 0.7692307692307693, 0.8421052631578948, 0.4210526315789473, 0.7692307692307693, 0.37499999999999994, 0.39999999999999997, 0.75, 0.18181818181818182, 0.6666666666666666, 0.588235294117647, 0.4615384615384615, 0.75, 0.5714285714285715, 0.47058823529411764, 0.7692307692307692]}, metrics_group={}), 'SVM': CrossValidationResult(test_scores=array([0.65384615, 0.65384615, 0.73076923, 0.65384615, 0.73076923,
0.80769231, 0.80769231, 0.76923077, 0.80769231, 0.69230769,
0.73076923, 0.92307692, 0.84615385, 0.92307692, 0.73076923,
0.57692308, 0.80769231, 0.88461538, 0.72 , 0.64 ,
0.8 , 0.8 , 0.76 , 0.84 , 0.88 ,
0.68 , 0.68 , 0.8 , 0.8 , 0.84 ]), train_scores=array([0.82614555, 0.82749326, 0.82749326, 0.82614555, 0.83423181,
0.82345013, 0.82345013, 0.82614555, 0.82479784, 0.83153639,
0.83153639, 0.82614555, 0.82210243, 0.81671159, 0.82345013,
0.82749326, 0.82479784, 0.81940701, 0.83310902, 0.82368775,
0.82637954, 0.82234186, 0.82637954, 0.81830417, 0.81695828,
0.83041723, 0.82637954, 0.82503365, 0.82772544, 0.81965007]), fit_times=array([0.01058912, 0.01053905, 0.01046419, 0.01059031, 0.01051593,
0.01052547, 0.01087952, 0.01040816, 0.01073122, 0.01215649,
0.01337957, 0.01146436, 0.0108254 , 0.01063704, 0.01059008,
0.01044226, 0.01053882, 0.01063919, 0.01053476, 0.01036549,
0.01067448, 0.01087451, 0.01044321, 0.01072121, 0.0107317 ,
0.0105238 , 0.01067805, 0.01065302, 0.01051974, 0.01071835]), score_times=array([0.00163293, 0.00171232, 0.00158191, 0.00157595, 0.0015862 ,
0.0016005 , 0.00159645, 0.00170302, 0.00162745, 0.00240469,
0.00212526, 0.00184631, 0.0016253 , 0.00160146, 0.00161529,
0.00162363, 0.00157905, 0.0016222 , 0.00159883, 0.00156951,
0.0015769 , 0.00157881, 0.00157785, 0.00158381, 0.00159788,
0.00160217, 0.00158501, 0.00168943, 0.00158262, 0.00156283]), estimator=[SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC(), SVC()], metrics={'f1': [0.47058823529411764, 0.39999999999999997, 0.5882352941176471, 0.6086956521739131, 0.5882352941176471, 0.5454545454545454, 0.6666666666666666, 0.7272727272727272, 0.6153846153846153, 0.5555555555555556, 0.588235294117647, 0.888888888888889, 0.8, 0.8750000000000001, 0.4615384615384615, 0.15384615384615385, 0.6666666666666666, 0.6666666666666665, 0.588235294117647, 0.608695652173913, 0.6666666666666666, 0.6153846153846153, 0.6250000000000001, 0.8000000000000002, 0.7692307692307693, 0.6363636363636364, 0.5, 0.6153846153846153, 0.6666666666666667, 0.6666666666666666]}, metrics_group={})}}
Regressor Model Selection
Of course we’re not limited to classification problems. We will demonstrate a regression problem, with multiple metrics. We prepare a group variable (a pd.Series) in order to calculate the metrics by group for each fold.
This cross-validation takes a bit longer, so we set the n_jobs to -2, to fit in parallel, while preserving a core to ensure the system can respond.
diabetes = load_diabetes(as_frame=True)
x, y = diabetes.data, diabetes.target
y.name = "progression"
xy: pd.DataFrame = pd.concat([x, y], axis=1)
group: pd.Series = pd.Series(x['sex'] > 0, name='grp_sex', index=x.index)
pp: Pipeline = make_pipeline(StandardScaler())
models_to_test: Dict = Models().fast_regressors()
ms: ModelSelection = ModelSelection(estimators=models_to_test, datasets=xy, target='progression', pre_processor=pp,
k_folds=30, scorer='r2', group=group,
metrics={'moe': metrics.moe_95, 'me': metrics.mean_error},
n_jobs=-2, verbosity=2)
Next we’ll view the plot, but we will not (yet) leverage the group variable.
fig = ms.plot(metrics=['moe', 'me'])
fig.update_layout(height=700)
fig
Now, we will re-plot using group. This is fast, since the fitting metrics were calculated when the first plot was created, and do not need to be calculated again.
Plotting by group can (hopefully) provide evidence that metrics are consistent across groups.
fig = ms.plot(metrics=['moe', 'me'], show_group=True, col_wrap=2)
fig.update_layout(height=700)
fig
Clearly, plot real estate will become a problem for more than 2 or 3 classes - here we used col_wrap mitigate that.
Comparing Datasets
Next we will demonstrate a single Algorithm with multiple datasets. This is useful when exploring features that improve model performance. We modify DS2 by removing a feature and sampling 40% of the data.
datasets: Dict = {'DS1': xy, 'DS2': xy.drop(columns=['age']).sample(frac=0.4)}
fig = plot_model_selection(estimators=LinearRegression(), datasets=datasets, target='progression', pre_processor=pp,
k_folds=30)
fig.update_layout(height=600)
fig
Total running time of the script: ( 0 minutes 8.642 seconds)