import logging
from pathlib import Path
from typing import List
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.compute as pc
import pyarrow.dataset as ds
from parq_tools.utils import atomic_output_file
from parq_tools.utils.optional_imports import get_tqdm
from parq_tools.utils.progress import get_batch_progress_bar
[docs]
def validate_index_alignment(datasets: List[ds.Dataset],
index_columns: List[str],
batch_size: int = 100_000) -> None:
"""
Validates that the index columns are identical across all datasets.
Args:
datasets (List[ds.Dataset]): List of PyArrow datasets to validate.
index_columns (List[str]): List of index column names to compare.
batch_size (int, optional): Number of rows per batch to process.
Raises:
ValueError: If the index columns are not identical across datasets.
"""
logging.info("Validating index alignment across datasets")
scanners = [dataset.scanner(columns=index_columns, batch_size=batch_size) for dataset in datasets]
iterators = [scanner.to_batches() for scanner in scanners]
pbar = get_batch_progress_bar(datasets, batch_size, desc="Validating index alignment")
while True:
current_batches = []
all_exhausted = True
for iterator in iterators:
try:
batch = next(iterator)
current_batches.append(pa.Table.from_batches([batch]))
all_exhausted = False
except StopIteration:
current_batches.append(None)
if all_exhausted:
break
reference_batch = current_batches[0]
for i, current_batch in enumerate(current_batches[1:], start=1):
if current_batch is not None and not current_batch.equals(reference_batch):
raise ValueError(
f"Index columns are not aligned across datasets. Mismatch found in dataset {i}."
)
pbar.update(1)
pbar.close()
logging.info("Index alignment validated successfully")
[docs]
def sort_parquet_file(
input_path: Path,
output_path: Path,
columns: List[str],
chunk_size: int = 100_000
) -> None:
"""
Globally sort a Parquet file by the specified columns.
Args:
input_path (Path): Path to the input Parquet file.
output_path (Path): Path to save the sorted Parquet file.
columns (List[str]): List of column names to sort by.
chunk_size (int, optional): Number of rows to process per chunk. Defaults to 100_000.
"""
dataset: ds.Dataset = ds.dataset(input_path, format="parquet")
sorted_batches: List[pa.Table] = []
pbar = get_batch_progress_bar([dataset], chunk_size, desc="Sorting parquet file")
# Read and sort each chunk
for batch in dataset.to_batches(batch_size=chunk_size):
table: pa.Table = pa.Table.from_batches([batch])
sort_indices: pa.Array = pc.sort_indices(
table, sort_keys=[(col, "ascending") for col in columns]
)
sorted_table: pa.Table = table.take(sort_indices)
sorted_batches.append(sorted_table)
pbar.update(1)
pbar.close()
# Merge all sorted chunks
merged_table: pa.Table = pa.concat_tables(sorted_batches).combine_chunks()
sort_indices: pa.Array = pc.sort_indices(
merged_table, sort_keys=[(col, "ascending") for col in columns]
)
sorted_table: pa.Table = merged_table.take(sort_indices)
# Write the globally sorted table to a new Parquet file
with atomic_output_file(output_path) as tmp_file:
pq.write_table(sorted_table, tmp_file)
[docs]
def reindex_parquet(sparse_parquet_path: Path, output_path: Path,
new_index: pa.Table, chunk_size: int = 100_000,
sort_after_reindex: bool = True) -> None:
"""
Reindex a sparse Parquet file to align with a new index, processing in chunks.
Args:
sparse_parquet_path (Path): Path to the sparse Parquet file.
output_path (Path): Path to save the re-indexed Parquet file.
new_index (pa.Table): New index as a PyArrow table.
chunk_size (int): Number of rows to process per chunk.
sort_after_reindex (bool): Whether to sort the output after reindexing. Defaults to True.
"""
# Read the sparse Parquet file as a dataset
sparse_dataset = ds.dataset(sparse_parquet_path, format="parquet")
index_columns = [field.name for field in new_index.schema if field.name in sparse_dataset.schema.names]
# Initialize the writer with the schema of the reindexed table
first_batch = next(sparse_dataset.to_batches(batch_size=chunk_size))
sparse_table = pa.Table.from_batches([first_batch])
reindexed_table = new_index.join(sparse_table, keys=index_columns, join_type="left outer")
writer_schema = reindexed_table.schema
with atomic_output_file(output_path) as tmp_file, pq.ParquetWriter(tmp_file, schema=writer_schema) as writer:
pbar = get_batch_progress_bar([sparse_dataset], chunk_size, desc="Reindexing parquet file")
# Process the sparse dataset in chunks
for batch in sparse_dataset.to_batches(batch_size=chunk_size):
sparse_table = pa.Table.from_batches([batch])
# Perform a left join with the new index
reindexed_table = new_index.join(sparse_table, keys=index_columns, join_type="left outer")
# Fill null values dynamically based on column types
columns = []
for field in reindexed_table.schema:
column = reindexed_table[field.name]
if pa.types.is_floating(field.type):
column = pc.if_else(pc.is_null(column), pa.scalar(float('nan'), type=field.type), column)
elif pa.types.is_string(field.type):
column = pc.if_else(pc.is_null(column), pa.scalar(None, type=field.type), column)
elif pa.types.is_dictionary(field.type): # Categorical
column = pc.if_else(pc.is_null(column), pa.scalar(None, type=field.type), column)
elif pa.types.is_integer(field.type):
column = pc.if_else(pc.is_null(column), pa.scalar(None, type=pa.int64()), column)
columns.append(column)
reindexed_table = pa.table(columns, schema=reindexed_table.schema)
writer.write_table(reindexed_table)
logging.info(f"Wrote {len(batch)} rows to {output_path}")
pbar.update(1)
pbar.close()
if sort_after_reindex:
with atomic_output_file(output_path) as tmp_file:
sort_parquet_file(
input_path=output_path,
output_path=tmp_file,
columns=index_columns,
chunk_size=chunk_size
)
[docs]
def dedup_index_parquet(
input_path: Path,
output_path: Path,
index_columns: List[str],
chunk_size: int = 100_000) -> None:
"""
Remove duplicate rows based on index columns from a Parquet file.
Args:
input_path (Path): Path to the input Parquet file.
output_path (Path): Path to save the deduplicated Parquet file.
index_columns (List[str]): Columns to use as the index for deduplication.
chunk_size (int): Number of rows to process per chunk.
"""
dataset = ds.dataset(input_path, format="parquet")
seen = set()
first_batch = next(dataset.to_batches(batch_size=chunk_size))
schema = pa.Table.from_batches([first_batch]).schema
with atomic_output_file(output_path) as tmp_file, pq.ParquetWriter(tmp_file, schema=schema) as writer:
tqdm = get_tqdm()
pbar = tqdm(total=None, desc="Deduplicating index")
for batch in dataset.to_batches(batch_size=chunk_size):
table = pa.Table.from_batches([batch])
mask = []
num_rows = table.num_rows
for i in range(num_rows):
idx = tuple(table[col][i].as_py() for col in index_columns)
if idx not in seen:
seen.add(idx)
mask.append(True)
else:
mask.append(False)
if any(mask):
filtered_table = table.filter(pa.array(mask))
writer.write_table(filtered_table)
pbar.update(1)
pbar.close()