|
|
import os |
|
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
|
|
|
|
|
from pathlib import Path |
|
|
from itertools import repeat |
|
|
from typing import List |
|
|
|
|
|
import time |
|
|
import datasets |
|
|
from tqdm import tqdm |
|
|
import click |
|
|
from tabulate import tabulate |
|
|
import json |
|
|
from concurrent.futures import ProcessPoolExecutor |
|
|
|
|
|
from marker.settings import settings |
|
|
from benchmarks.table.inference import inference_tables |
|
|
|
|
|
from scoring import wrap_table_html, similarity_eval_html |
|
|
|
|
|
def update_teds_score(result, prefix: str = "marker"): |
|
|
prediction, ground_truth = result[f'{prefix}_table'], result['gt_table'] |
|
|
prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth) |
|
|
score = similarity_eval_html(prediction, ground_truth) |
|
|
result.update({f'{prefix}_score':score}) |
|
|
return result |
|
|
|
|
|
|
|
|
@click.command(help="Benchmark Table to HTML Conversion") |
|
|
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "table"), help="Output path for results.") |
|
|
@click.option("--dataset", type=str, default="datalab-to/fintabnet_bench_marker", help="Dataset to use") |
|
|
@click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process") |
|
|
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use") |
|
|
@click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.") |
|
|
@click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.") |
|
|
@click.option("--use_gemini", is_flag=True, help="Evaluate Gemini for table recognition.") |
|
|
def main( |
|
|
result_path: str, |
|
|
dataset: str, |
|
|
max_rows: int, |
|
|
max_workers: int, |
|
|
use_llm: bool, |
|
|
table_rec_batch_size: int | None, |
|
|
use_gemini: bool = False |
|
|
): |
|
|
start = time.time() |
|
|
|
|
|
|
|
|
dataset = datasets.load_dataset(dataset, split='train') |
|
|
dataset = dataset.shuffle(seed=0) |
|
|
|
|
|
results, total_unaligned = inference_tables(dataset, use_llm, table_rec_batch_size, max_rows, use_gemini) |
|
|
|
|
|
print(f"Total time: {time.time() - start}.") |
|
|
print(f"Could not align {total_unaligned} tables from fintabnet.") |
|
|
|
|
|
with ProcessPoolExecutor(max_workers=max_workers) as executor: |
|
|
marker_results = list( |
|
|
tqdm( |
|
|
executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results) |
|
|
) |
|
|
) |
|
|
|
|
|
avg_score = sum([r["marker_score"] for r in marker_results]) / len(marker_results) |
|
|
headers = ["Avg score", "Total tables"] |
|
|
data = [f"{avg_score:.3f}", len(marker_results)] |
|
|
gemini_results = None |
|
|
if use_gemini: |
|
|
with ProcessPoolExecutor(max_workers=max_workers) as executor: |
|
|
gemini_results = list( |
|
|
tqdm( |
|
|
executor.map(update_teds_score, results, repeat("gemini")), desc='Computing Gemini scores', |
|
|
total=len(results) |
|
|
) |
|
|
) |
|
|
avg_gemini_score = sum([r["gemini_score"] for r in gemini_results]) / len(gemini_results) |
|
|
headers.append("Avg Gemini score") |
|
|
data.append(f"{avg_gemini_score:.3f}") |
|
|
|
|
|
table = tabulate([data], headers=headers, tablefmt="github") |
|
|
print(table) |
|
|
print("Avg score computed by comparing marker predicted HTML with original HTML") |
|
|
|
|
|
results = { |
|
|
"marker": marker_results, |
|
|
"gemini": gemini_results |
|
|
} |
|
|
|
|
|
out_path = Path(result_path) |
|
|
out_path.mkdir(parents=True, exist_ok=True) |
|
|
with open(out_path / "table.json", "w+") as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
print(f"Results saved to {out_path}.") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |