Source code for omfpandas.utils.pandera_utils
from pathlib import Path
import yaml
from pandera import DataFrameSchema, Column, Check
from pandera.engines import pandas_engine
import pandas as pd
from pandera.io import _deserialize_check_stats
[docs]
class DataFrameMetaProcessor:
[docs]
def __init__(self, schema: DataFrameSchema):
self.schema: DataFrameSchema = schema
self.supported_column_meta_keys = ['alias', 'calculation', 'decimals']
@property
def alias_map(self):
return {col.metadata['alias']: col_name for col_name, col in self.schema.columns.items() if
col.metadata and 'alias' in col.metadata}
@property
def calculation_map(self):
return {col_name: col.metadata['calculation'] for col_name, col in self.schema.columns.items() if
col.metadata and 'calculation' in col.metadata}
@property
def decimals_map(self):
return {col_name: col.metadata['decimals'] for col_name, col in self.schema.columns.items() if
col.metadata and 'decimals' in col.metadata}
def rename_from_meta_alias(self, df: pd.DataFrame) -> pd.DataFrame:
return df.rename(columns=self.alias_map)
def calculate_from_meta_calculation(self, df: pd.DataFrame) -> pd.DataFrame:
for col_name, calculation in self.calculation_map.items():
df[col_name] = eval(calculation, {}, df.to_dict('series'))
return df
def round_to_decimals(self, df: pd.DataFrame, columns: list = None) -> pd.DataFrame:
if columns is None:
columns = self.decimals_map.keys()
for col_name in columns:
if col_name in self.decimals_map:
df[col_name] = df[col_name].round(self.decimals_map[col_name])
return df
def preprocess(self, df: pd.DataFrame) -> pd.DataFrame:
df = self.rename_from_meta_alias(df)
df = self.round_to_decimals(df)
df = self.calculate_from_meta_calculation(df)
df = self.round_to_decimals(df, columns=list(self.calculation_map.keys()))
return df
def validate(self, df: pd.DataFrame, return_calculated_columns: bool = True) -> pd.DataFrame:
df = self.schema.validate(df)
if not return_calculated_columns:
return df.drop(columns=list(self.calculation_map.keys()))
return df
[docs]
def load_schema_from_yaml(yaml_path: Path) -> DataFrameSchema:
"""Load a DataFrameSchema from a YAML file."""
with open(yaml_path, "r", encoding="utf-8") as f:
schema_dict = yaml.safe_load(f)
columns = {
col_name: Column(**_deserialize_component_stats(col_stats))
for col_name, col_stats in schema_dict["columns"].items()
}
return DataFrameSchema(
columns=columns,
checks=schema_dict.get("checks"),
index=schema_dict.get("index"),
dtype=schema_dict.get("dtype"),
coerce=schema_dict.get("coerce", False),
strict=schema_dict.get("strict", False),
name=schema_dict.get("name", None),
ordered=schema_dict.get("ordered", False),
unique=schema_dict.get("unique", None),
report_duplicates=schema_dict.get("report_duplicates", "all"),
unique_column_names=schema_dict.get("unique_column_names", False),
add_missing_columns=schema_dict.get("add_missing_columns", False),
title=schema_dict.get("title", None),
description=schema_dict.get("description", None),
)
def _deserialize_component_stats(serialized_component_stats):
dtype = serialized_component_stats.get("dtype")
if dtype:
dtype = pandas_engine.Engine.dtype(dtype)
description = serialized_component_stats.get("description")
title = serialized_component_stats.get("title")
checks = serialized_component_stats.get("checks")
if checks is not None:
checks = [
_deserialize_check_stats(
getattr(Check, check_name), check_stats, dtype
)
for check_name, check_stats in checks.items()
]
return {
"title": title,
"description": description,
"dtype": dtype,
"checks": checks,
**{
key: serialized_component_stats.get(key)
for key in [
"name",
"nullable",
"unique",
"coerce",
"required",
"regex",
"metadata"
]
if key in serialized_component_stats
},
}