Spaces:
Running
Running
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| import pytesseract | |
| import plotly.express as px | |
| from torch.utils.data import Dataset, DataLoader, Subset | |
| import os | |
| import io | |
| import pytesseract | |
| import fitz | |
| from typing import List | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TOKENIZER = "microsoft/layoutlmv3-base" | |
| MODEL_NAME = "fsommers/layoutlmv3-autofinance-classification-us-v01" | |
| TESS_OPTIONS = "--psm 3" # Automatic page segmentation for Tesseract | |
| def create_ocr_reader(): | |
| def scale_bounding_box(box: List[int], w_scale: float = 1.0, h_scale: float = 1.0): | |
| return [ | |
| int(box[0] * w_scale), | |
| int(box[1] * h_scale), | |
| int(box[2] * w_scale), | |
| int(box[3] * h_scale) | |
| ] | |
| def ocr_page(image) -> dict: | |
| """ | |
| OCR a given image. Return a dictionary of words and the bounding boxes | |
| for each word. For each word, there is a corresponding bounding box. | |
| """ | |
| ocr_df = pytesseract.image_to_data(image, output_type='data.frame', config=TESS_OPTIONS) | |
| ocr_df = ocr_df.dropna().reset_index(drop=True) | |
| float_cols = ocr_df.select_dtypes('float').columns | |
| ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int) | |
| ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True) | |
| ocr_df = ocr_df.dropna().reset_index(drop=True) | |
| words = list(ocr_df.text) | |
| words = [str(w) for w in words] | |
| coordinates = ocr_df[['left', 'top', 'width', 'height']] | |
| boxes = [] | |
| for i, row in coordinates.iterrows(): | |
| x, y, w, h = tuple(row) | |
| actual_box = [x, y, x + w, y + h] | |
| boxes.append(actual_box) | |
| assert len(words) == len(boxes) | |
| return {"bbox": boxes, "words": words} | |
| def prepare_image(image): | |
| ocr_data = ocr_page(image) | |
| width, height = image.size | |
| width_scale = 1000 / width | |
| height_scale = 1000 / height | |
| words = [] | |
| boxes = [] | |
| for w, b in zip(ocr_data["words"], ocr_data["bbox"]): | |
| words.append(w) | |
| boxes.append(scale_bounding_box(b, width_scale, height_scale)) | |
| assert len(words) == len(boxes) | |
| for bo in boxes: | |
| for z in bo: | |
| if (z > 1000): | |
| raise | |
| return words, boxes | |
| return prepare_image | |
| def create_model(): | |
| model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME) | |
| return model.eval().to(DEVICE) | |
| def create_processor(): | |
| feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False) | |
| tokenizer = LayoutLMv3TokenizerFast.from_pretrained(TOKENIZER) | |
| return LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) | |
| def predict(image, reader, processor: LayoutLMv3Processor, model: LayoutLMv3ForSequenceClassification): | |
| words, boxes = reader(image) | |
| encoding = processor( | |
| image, | |
| words, | |
| boxes=boxes, | |
| max_length=512, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| with torch.inference_mode(): | |
| output = model( | |
| input_ids=encoding["input_ids"].to(DEVICE), | |
| attention_mask=encoding["attention_mask"].to(DEVICE), | |
| bbox=encoding["bbox"].to(DEVICE), | |
| pixel_values=encoding["pixel_values"].to(DEVICE) | |
| ) | |
| logits = output.logits | |
| predicted_class = logits.argmax() | |
| probabilities = F.softmax(logits, dim=-1).flatten().tolist() | |
| return predicted_class.detach().item(), probabilities | |
| reader = create_ocr_reader() | |
| processor = create_processor() | |
| model = create_model() | |
| uploaded_file = st.file_uploader("Choose a JPG file", ["jpg", "png"]) | |
| if uploaded_file is not None: | |
| bytes_data = io.BytesIO(uploaded_file.read()) | |
| image = Image.open(bytes_data) | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| predicted, probabilities = predict(image, reader, processor, model) | |
| predicted_label = model.config.id2label[predicted] | |
| st.markdown(f"Predicted Label: {predicted_label}") | |
| df = pd.DataFrame({ | |
| "Label": list(model.config.id2label.values()), | |
| "Probability": probabilities | |
| }) | |
| fig = px.bar(df, x="Label", y="Probability") | |
| st.plotly_chart(fig, use_container_width=True) | |