Source code for elphick.pandera_utils.pandera_utils
from collections import OrderedDict
from pathlib import Path
from typing import Optional
import pandas as pd
from pandera import DataFrameSchema
from elphick.pandera_utils.utils.pandera_io_pandas_io import from_yaml
import logging
[docs]
def order_columns_to_match_schema(df: pd.DataFrame, schema: DataFrameSchema) -> pd.DataFrame:
"""Order DataFrame columns to match the schema order if coerce and metadata['order_columns'] are true.
Columns not specified in the schema will be retained at the end of the DataFrame.
Args:
df: The DataFrame to reorder.
schema: The Pandera DataFrameSchema defining the desired column order.
Returns:
A DataFrame with columns reordered to match the schema, retaining extra columns at the end.
"""
# Check if both 'coerce' and 'order_columns' are true
if schema.coerce and schema.metadata and schema.metadata.get('pandera_utils', {}).get("order_columns", False):
schema_columns = list(schema.columns.keys())
# Retain only columns present in the schema, in schema order
reordered_columns = [col for col in schema_columns if col in df.columns]
# Retain extra columns in their original order
extra_columns = [col for col in df.columns if col not in schema_columns]
return df[reordered_columns + extra_columns]
return df
[docs]
class DataFrameMetaProcessor:
"""A class to preprocess and validate DataFrames based on metadata."""
[docs]
def __init__(self, schema: DataFrameSchema):
"""Instantiate the DataFrameMetaProcessor object.
Args:
schema: The DataFrameSchema object to use for preprocessing and validation.
"""
self.schema: DataFrameSchema = schema
self.supported_column_meta_keys = ['unit_of_measure', 'aliases', 'decimals', 'missing_sentinels', 'category',
'calculation']
@property
def unit_of_measure_map(self):
return OrderedDict(
(col_name, col.metadata.get('pandera_utils', {}).get('unit_of_measure'))
for col_name, col in self.schema.columns.items()
if col.metadata and 'pandera_utils' in col.metadata and 'unit_of_measure' in col.metadata['pandera_utils']
)
from collections import OrderedDict
@property
def alias_map(self):
alias_dict = OrderedDict()
for col_name, col in self.schema.columns.items():
if col.metadata and 'pandera_utils' in col.metadata and 'aliases' in col.metadata['pandera_utils']:
alias_dict[col_name] = col.metadata['pandera_utils']['aliases']
return alias_dict
@property
def calculation_map(self):
return OrderedDict(
(col_name, col.metadata.get('pandera_utils', {}).get('calculation'))
for col_name, col in self.schema.columns.items()
if col.metadata and 'pandera_utils' in col.metadata and 'calculation' in col.metadata['pandera_utils']
)
@property
def decimals_map(self):
return OrderedDict(
(col_name, col.metadata.get('pandera_utils', {}).get('decimals'))
for col_name, col in self.schema.columns.items()
if col.metadata and 'pandera_utils' in col.metadata and 'decimals' in col.metadata['pandera_utils']
)
@property
def missing_sentinels_map(self):
return OrderedDict(
(col_name, col.metadata.get('pandera_utils', {}).get('missing_sentinels'))
for col_name, col in self.schema.columns.items()
if col.metadata and 'pandera_utils' in col.metadata and 'missing_sentinels' in col.metadata['pandera_utils']
)
@property
def category_maps(self):
cat_maps = OrderedDict(
(col_name, col.metadata.get('pandera_utils', {}).get('category'))
for col_name, col in self.schema.columns.items()
if col.metadata and 'pandera_utils' in col.metadata and 'category' in col.metadata['pandera_utils']
)
return OrderedDict(
(k, {sub_k: sub_v for sub_k, sub_v in v.items() if isinstance(sub_v, dict)})
for k, v in cat_maps.items() if v
)
@property
def category_ordered_map(self):
return OrderedDict(
(col_name, col.metadata.get('pandera_utils', {}).get('category', {}).get('ordered'))
for col_name, col in self.schema.columns.items()
if col.metadata and 'pandera_utils' in col.metadata and 'category' in col.metadata['pandera_utils']
)
[docs]
def apply_rename_from_alias(self, df: pd.DataFrame) -> pd.DataFrame:
"""Rename columns in the DataFrame based on aliases."""
alias_map = self.alias_map
rename_map = {}
for col_name, aliases in alias_map.items():
for alias in aliases:
if alias in df.columns:
rename_map[alias] = col_name
return df.rename(columns=rename_map)
[docs]
def apply_calculations(self, df: pd.DataFrame) -> pd.DataFrame:
"""Apply calculations based on the calculation metadata."""
for col_name, calculation in self.calculation_map.items():
# Check for input columns
inputs = self.schema.columns[col_name].metadata['pandera_utils'].get('inputs', [])
missing_columns = [dep for dep in inputs if dep not in df.columns]
required = self.schema.columns[col_name].required
if missing_columns:
if required:
raise KeyError(f"Missing columns for calculation '{col_name}': {missing_columns}")
else:
logging.warning(f"Missing columns for optional (non-required) calculation '{col_name}': {missing_columns}")
continue
# Evaluate the calculation
calculated_column = eval(calculation, {}, df.to_dict('series'))
# Determine the position to insert the calculated column
if inputs:
rightmost_input = max(df.columns.get_loc(dep) for dep in inputs)
df.insert(rightmost_input + 1, col_name, calculated_column)
else:
df[col_name] = calculated_column
return df
[docs]
def apply_rounding(self, df: pd.DataFrame, columns: Optional[list[str]] = None) -> pd.DataFrame:
"""Round columns based on the decimals metadata."""
if columns is None:
columns = self.decimals_map.keys()
for col_name in columns:
if col_name in self.decimals_map and col_name in df.columns:
df[col_name] = df[col_name].round(self.decimals_map[col_name])
return df
def _generate_category_columns(self, column: pd.Series, map_dict: dict,
retain_original_column: bool = True) -> dict:
"""Generate new columns based on the category metadata."""
new_columns = {}
schema_column = self.schema.columns[column.name]
if retain_original_column:
# Retrieve allowable categories
allowable_categories = None
if 'checks' in schema_column.__dict__:
for check in schema_column.__dict__['checks']:
if check.name == 'isin':
allowable_categories = list(check._check_kwargs['allowed_values'])
break
new_columns[column.name] = column.astype(pd.CategoricalDtype(categories=allowable_categories,
ordered=self.category_ordered_map.get(
column.name, False)))
for k, v in map_dict.items():
new_columns[f"{column.name}_{k}"] = column.map(v['map']).astype(v['dtype'])
return new_columns
[docs]
def apply_missing_sentinels(self, df: pd.DataFrame) -> pd.DataFrame:
"""Apply missing sentinels based on the missing_sentinels metadata."""
for col_name, sentinels in self.missing_sentinels_map.items():
if col_name in df.columns:
df[col_name] = self.schema.columns[col_name].validate(df[[col_name]])
df[col_name] = df[col_name].replace(sentinels, pd.NA)
return df
[docs]
def apply_category_maps(self, df: pd.DataFrame, maps_to_apply: Optional[list[str]] = None,
retain_orig_cat_col: bool = True) -> pd.DataFrame:
"""Apply category maps to create new columns based on the category metadata."""
# assert the supplied maps are valid
all_map_keys: list[str] = list(self.category_maps[list(self.category_maps.keys())[0]].keys())
if maps_to_apply is not None:
for map_name in maps_to_apply:
assert map_name in all_map_keys, f"Map name '{map_name}' not found in category_map"
else:
maps_to_apply = all_map_keys
# Apply the maps
for col in self.category_maps.keys():
if col not in df.columns and self.schema.columns[col].required == True:
raise KeyError(f"Column '{col}' not found in DataFrame")
original_col_position = df.columns.get_loc(col)
new_columns = self._generate_category_columns(df[col], self.category_maps[col], retain_orig_cat_col)
# Insert columns at the original position, and to the right of the original
for col_name, col_data in new_columns.items():
if col_name == col:
df[col_name] = col_data
else:
original_col_position += 1
df.insert(original_col_position, col_name, col_data)
return df
[docs]
def preprocess(self, df: pd.DataFrame, round_before_calc: bool = False,
cat_maps_to_apply: Optional[list[str]] = None, cat_retain_orig_cat_col: bool = True) -> pd.DataFrame:
"""Preprocess a DataFrame based on the metadata.
Args:
df: The DataFrame to preprocess.
round_before_calc: A boolean indicating whether to round columns before applying calculations,
as well as after.
cat_maps_to_apply: A list of category maps to apply. If None, all maps will be applied.
cat_retain_orig_cat_col: A boolean indicating whether to retain the original category columns.
"""
# Check for DataFrame-level metadata for column ordering
if self.schema.metadata and self.schema.metadata.get('pandera_utils', {}).get("order_columns", False):
df = order_columns_to_match_schema(df, self.schema)
if self.alias_map:
df = self.apply_rename_from_alias(df)
if self.missing_sentinels_map:
df = self.apply_missing_sentinels(df)
# Handle rounding before calculations if specified
if round_before_calc and self.decimals_map:
df = self.apply_rounding(df)
# Determine all inputs for calculated columns
calculation_inputs = {
dep for col, calc in self.calculation_map.items()
for dep in self.schema.columns[col].metadata['pandera_utils'].get('inputs', [])
}
# Process each column
for col in self.schema.columns.keys():
if col in self.calculation_map:
# Perform calculations
df = self.apply_calculations(df)
# Skip rounding for columns that are inputs of calculated columns
if not round_before_calc and col in self.decimals_map and col not in self.calculation_map and col not in calculation_inputs:
df = self.apply_rounding(df, columns=[col])
if col in self.category_maps:
df = self.apply_category_maps(df, maps_to_apply=cat_maps_to_apply,
retain_orig_cat_col=cat_retain_orig_cat_col)
# Apply rounding after calculations if needed
if not round_before_calc and self.decimals_map:
df = self.apply_rounding(df)
# Last chance column ordering
if self.schema.metadata and self.schema.metadata.get('pandera_utils', {}).get("order_columns", False):
df = order_columns_to_match_schema(df, self.schema)
return df
[docs]
def validate(self, df: pd.DataFrame, return_calculated_columns: bool = True) -> pd.DataFrame:
"""Validate a DataFrame based on the schema."""
df = self.schema.validate(df)
if not return_calculated_columns:
return df.drop(columns=list(self.calculation_map.keys()))
return df
[docs]
def check_schema(self):
"""Check if the schema is valid."""
# Check the aliases are all unique
alias_map = self.alias_map
all_aliases = [alias for aliases in alias_map.values() for alias in aliases]
duplicate_aliases = {alias for alias in all_aliases if all_aliases.count(alias) > 1}
if duplicate_aliases:
raise ValueError(f"Duplicate aliases found: {duplicate_aliases}")
# Check that all alias keys are a list of strings
for col_name, aliases in alias_map.items():
for alias in aliases:
if not isinstance(alias, str):
raise TypeError(
f"Alias '{alias}' in column '{col_name}' is not a string. All alias keys must be strings.")
# Check all columns with metadata.category values (maps) have the same keys
category_maps = self.category_maps
if category_maps:
# Get the set of keys from the first column's category map
reference_keys = set(next(iter(category_maps.values())).keys())
for col_name, category_map in category_maps.items():
if set(category_map.keys()) != reference_keys:
raise ValueError(
f"Inconsistent category map keys in column '{col_name}'. "
f"Expected keys: {reference_keys}, but got: {set(category_map.keys())}."
)
[docs]
def load_schema_from_yaml(yaml_path: Path) -> DataFrameSchema:
"""Load a DataFrameSchema from a YAML file."""
return from_yaml(yaml_path)
[docs]
def merge_schemas(list_of_schemas: list[DataFrameSchema]) -> DataFrameSchema:
"""Merge a list of DataFrameSchemas into a single DataFrameSchema.
The merged schema will contain all columns and checks from the input schemas.
The schema for root level properties or the index will be taken from the first schema in the list.
If there are multiple columns defined, the column from the first schema in the list will be used.
Args:
list_of_schemas: The list of DataFrameSchemas to merge.
Returns:
A DataFrameSchema that combines all the input schemas.
"""
if not list_of_schemas:
raise ValueError("The list of schemas is empty")
# Start with the first schema
base_schema = list_of_schemas[0]
# Merge columns
merged_columns = base_schema.columns.copy()
for schema in list_of_schemas[1:]:
for col_name, col in schema.columns.items():
if col_name not in merged_columns:
merged_columns[col_name] = col
# Merge checks
merged_checks = base_schema.checks.copy()
for schema in list_of_schemas[1:]:
for check in schema.checks:
if check not in merged_checks:
merged_checks.append(check)
# Create the merged schema
merged_schema = DataFrameSchema(
columns=merged_columns,
checks=merged_checks,
index=base_schema.index,
dtype=base_schema.dtype,
coerce=base_schema.coerce,
strict=base_schema.strict,
name=base_schema.name,
ordered=base_schema.ordered,
unique=base_schema.unique,
report_duplicates=base_schema.report_duplicates,
unique_column_names=base_schema.unique_column_names,
add_missing_columns=base_schema.add_missing_columns,
title=base_schema.title,
description=base_schema.description,
)
return merged_schema
[docs]
def load_merged_schema_from_yaml(yaml_paths: list[Path]) -> DataFrameSchema:
"""Load and merge DataFrameSchemas from a list of YAML files."""
schemas = [load_schema_from_yaml(yaml_path) for yaml_path in yaml_paths]
return merge_schemas(schemas)