Source code for parq_tools.parq_filter

"""
parq_filter.py

Functions for filtering Parquet files with optional row and column selection, supporting chunked processing and
progress reporting.

Main API:

- filter_parquet_file: Filter a Parquet file using a pandas-like expression and/or column selection, with efficient
  chunked writing and optional progress bar.
"""

import logging
from pathlib import Path

import pyarrow.parquet as pq
import pyarrow.dataset as ds

from parq_tools.utils import atomic_output_file
# noinspection PyProtectedMember
from parq_tools.utils._query_parser import build_filter_expression
import pyarrow as pa
from typing import List, Optional

try:
    # noinspection PyUnresolvedReferences
    from tqdm import tqdm

    HAS_TQDM = True
except ImportError:
    HAS_TQDM = False


[docs] def filter_parquet_file(input_path: Path, output_path: Path, filter_expression: Optional[str] = None, columns: Optional[list[str]] = None, chunk_size: int = 100_000, show_progress: bool = False) -> None: """ Filter a Parquet file based on a pandas-like expression and a sub-selection of columns. Args: input_path (Path): Path to the input Parquet file. output_path (Path): Path to save the filtered Parquet file. filter_expression (Optional[str]): Pandas-like expression to filter rows. If None, no filtering is applied. columns (Optional[List[str]]): List of column names to include in the output. If None, all columns are included. chunk_size (int): Number of rows to process in each batch. show_progress (bool): Whether to show a progress bar during processing. """ dataset = ds.dataset(input_path, format="parquet") filter_expr = build_filter_expression(filter_expression, dataset.schema) if filter_expression else None scanner = dataset.scanner(columns=columns, filter=filter_expr, batch_size=chunk_size) total_rows = dataset.count_rows() progress = tqdm(total=total_rows, desc="Filtering", unit="rows") if HAS_TQDM and show_progress else None # Get schema from the first batch batches = scanner.to_batches() try: first_batch = next(batches) except StopIteration: return # No data to write table = pa.Table.from_batches([first_batch]) writer_schema = table.schema with atomic_output_file(output_path) as tmp_path, pq.ParquetWriter(tmp_path, schema=writer_schema) as writer: writer.write_table(table) if progress: progress.update(len(first_batch)) for batch in batches: table = pa.Table.from_batches([batch]) writer.write_table(table) if progress: progress.update(len(batch)) if progress: progress.close() logging.info(f"Filtered {total_rows} rows")