Source code for elphick.sklearn_viz.residuals.error

import logging
from typing import Optional, Union

import numpy as np
import pandas as pd
from plotly.subplots import make_subplots
from sklearn.pipeline import Pipeline
from sklearn.utils.validation import check_is_fitted
import plotly.graph_objects as go
import plotly.express as px


[docs]class Errors:
[docs] def __init__(self, mdl, x_test: Optional[pd.DataFrame] = None, y_test: Optional[Union[pd.DataFrame, pd.Series]] = None): """ Args: mdl: The scikit-learn model or pipeline. x_test: X values provided to calculate residuals. y_test: y values provided to calculate residuals. """ self._logger = logging.getLogger(name=__class__.__name__) self.mdl = mdl self.X_test: Optional[pd.DataFrame] = x_test self.y_test: Optional[Union[pd.DataFrame, pd.Series]] = y_test self._data: Optional[pd.DataFrame] = None self.is_pipeline: bool = isinstance(mdl, Pipeline) check_is_fitted(mdl[-1]) if self.is_pipeline else check_is_fitted(mdl)
[docs] def plot(self, color: Optional[str] = None, title: Optional[str] = 'Error Plot', marginal: bool = False) -> go.Figure: """ Args: color: The variable name to color (group) by. title: title for the plot marginal: If True, marginal histograms are added. Returns: a plotly GraphObjects.Figure """ y_est = pd.Series(self.mdl.predict(self.X_test), name=f"{self.y_test.name}_est", index=self.X_test.index) y = self.y_test data = pd.concat([y, y_est], axis='columns') index_name: str = 'index' if data.index.name is None else data.index.name # Calculate the range of the data and manage padding. data_range = data.max().max() - data.min().min() padding = data_range * 0.05 lims = [data.min().min() - padding, data.max().max() + padding] if marginal: fig = make_subplots(rows=2, cols=2, column_widths=[0.8, 0.2], row_heights=[0.2, 0.8], subplot_titles=(False, False, False, False), vertical_spacing=0.02, # Adjust as needed horizontal_spacing=0.02) # Adjust as needed else: fig = make_subplots(rows=1, cols=1) fig.add_trace(go.Scatter( x=data[y.name], y=data[y_est.name], mode='markers', marker=dict(color=color), name='Data', hovertemplate=f'{index_name}: ' + '%{text}<br>x: %{x}<br>y: %{y}', text=data.index ), row=2 if marginal else 1, col=1) fig.add_trace(go.Scatter( x=lims, y=lims, mode='lines', line=dict(dash='dash', color='red'), name='y=x line', showlegend=False ), row=2 if marginal else 1, col=1) if marginal: fig.add_trace( go.Histogram(x=data[y.name], name='Top Histogram', marker=dict(color='grey'), nbinsx=20, ), row=1, col=1) fig.add_trace( go.Histogram(y=data[y_est.name], name='Right Histogram', marker=dict(color='grey'), nbinsy=20, ), row=2, col=2) # Set the range of the x-axis for the top histogram and the range of the y-axis for the right histogram fig.update_xaxes(range=lims, showgrid=True, gridcolor='white', title_text='', showticklabels=False, row=1, col=1) fig.update_yaxes(range=[0, 'data max'], showgrid=False, title_text='', showticklabels=False, row=1, col=1) fig.update_xaxes(showgrid=False, title_text='', showticklabels=False, row=2, col=2) fig.update_yaxes(range=lims, showgrid=True, gridcolor='white', title_text='', showticklabels=False, row=2, col=2) # The main plot fig.update_xaxes(range=lims, scaleanchor="y", scaleratio=1, title_text=y.name, row=2 if marginal else 1, col=1) fig.update_yaxes(range=lims, title_text=y_est.name, row=2 if marginal else 1, col=1) fig.update_layout( title=title, showlegend=False, width=800, height=800, ) return fig