Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import VisionEncoderDecoderModel, AutoTokenizer | |
| from datasets import load_dataset, concatenate_datasets | |
| from texteller.api.load import load_model, load_tokenizer | |
| from texteller.api.inference import img2latex | |
| from skimage.metrics import structural_similarity as ssim | |
| from modules.cdm.evaluation import compute_cdm_score | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import io | |
| from io import BytesIO | |
| import base64 | |
| import pandas as pd | |
| import re | |
| import os | |
| import evaluate | |
| import time | |
| from collections import defaultdict | |
| import shutil | |
| # Configure Streamlit layout | |
| st.set_page_config(layout="wide") | |
| st.title("TeXTeller Demo: LaTeX Code Prediction from Math Images") | |
| # Load model and tokenizer | |
| def load_model_and_tokenizer(): | |
| checkpoint = "OleehyO/TexTeller" | |
| model = load_model(checkpoint) | |
| tokenizer = load_tokenizer(checkpoint) | |
| return model, tokenizer | |
| def load_data(): | |
| dataset = load_dataset("linxy/LaTeX_OCR", "small") | |
| dataset = concatenate_datasets([split for split in dataset.values()]) | |
| dataset = dataset.map(lambda sample: { | |
| "complexity": estimate_complexity(sample["text"]), | |
| "latex_length": len(sample["text"]), | |
| "latex_depth": max_brace_depth(sample["text"]), | |
| "text": normalize_latex(sample["text"]) | |
| }) | |
| return dataset | |
| def load_metrics(): | |
| return evaluate.load("bleu") | |
| # Utilities to evaluate LaTeX complexity | |
| def count_occurrences(pattern, text): | |
| return len(re.findall(pattern, text)) | |
| def max_brace_depth(latex): | |
| depth = max_depth = 0 | |
| for char in latex: | |
| if char == '{': | |
| depth += 1 | |
| max_depth = max(max_depth, depth) | |
| elif char == '}': | |
| depth -= 1 | |
| return max_depth | |
| def estimate_complexity(latex): | |
| length = len(latex) | |
| depth = max_brace_depth(latex) | |
| score = 0 | |
| score += count_occurrences(r'\\(frac|sqrt)', latex) | |
| score += count_occurrences(r'\\(sum|prod|int)', latex) * 2 | |
| score += count_occurrences(r'\\(left|right|begin|end)', latex) * 2 | |
| score += count_occurrences(r'\\begin\{(bmatrix|matrix|pmatrix)\}', latex) * 3 | |
| greek_letters = r'\\(alpha|beta|gamma|delta|epsilon|zeta|eta|theta|iota|kappa|lambda|mu|nu|xi|pi|rho|sigma|tau|upsilon|phi|chi|psi|omega|' \ | |
| r'Gamma|Delta|Theta|Lambda|Xi|Pi|Sigma|Upsilon|Phi|Psi|Omega)' | |
| score += count_occurrences(greek_letters, latex) * 0.5 | |
| score += depth | |
| score += length / 20 | |
| if score < 4: | |
| return "very simple" | |
| elif score < 8: | |
| return "simple" | |
| elif score < 12: | |
| return "medium" | |
| elif score < 20: | |
| return "complex" | |
| return "very complex" | |
| def normalize_latex(latex_code): | |
| latex_code = latex_code.replace(" ", "").replace("\\displaystyle", "") | |
| latex_code = re.sub(r"\\begin\{align\**\}", "", latex_code) | |
| latex_code = re.sub(r"\\end\{align\**\}", "", latex_code) | |
| return latex_code | |
| def compute_ssim(image1, image2): | |
| """Calcule le SSIM entre deux images PIL""" | |
| img1 = np.array(image1.convert("L")) # Convertir en niveaux de gris | |
| img2 = np.array(image2.convert("L")) | |
| return ssim(img1, img2) | |
| # Convert LaTeX to image | |
| def latex2image(latex_expression, image_size_in=(3, 0.5), fontsize=16, dpi=200): | |
| fig = plt.figure(figsize=image_size_in, dpi=dpi) | |
| fig.text( | |
| x=0.5, | |
| y=0.5, | |
| s=f"${latex_expression}$", | |
| horizontalalignment="center", | |
| verticalalignment="center", | |
| fontsize=fontsize | |
| ) | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="PNG", bbox_inches="tight", pad_inches=0.1) | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf) | |
| # --- Convert PIL image to base64 --- | |
| def image_to_base64(pil_img: Image.Image) -> str: | |
| img = pil_img.copy() | |
| with BytesIO() as buffer: | |
| img.save(buffer, 'png') | |
| return base64.b64encode(buffer.getvalue()).decode() | |
| # --- Formatter for HTML rendering --- | |
| def image_formatter(pil_img: Image.Image) -> str: | |
| img_b64 = image_to_base64(pil_img) | |
| return f'<img src="data:image/png;base64,{img_b64}">' | |
| # --- Build HTML table from dictionary --- | |
| def build_html_table(metrics_dico): | |
| metrics_df = pd.DataFrame(metrics_dico) | |
| return metrics_df.to_html(escape=False, formatters={"CDM Image": image_formatter}) | |
| model, tokenizer = load_model_and_tokenizer() | |
| dataset = load_data() | |
| bleu_metric = load_metrics() | |
| # Section 1: Dataset Overview | |
| st.markdown("---") | |
| st.markdown("## 📚 Dataset Overview") | |
| st.markdown(""" | |
| This demo uses the [LaTeX_OCR dataset](https://huggingface.co/datasets/linxy/LaTeX_OCR) from Hugging Face 🤗. | |
| Below are 10 examples showing input images and their corresponding LaTeX code. | |
| """) | |
| # Take 10 examples | |
| sample_dataset = dataset.select(range(10)) | |
| # Constrain the width of the "table" to ~50% using centered columns | |
| col_left, col_center, col_right = st.columns([1, 2, 1]) | |
| with col_center: | |
| header1, header2 = st.columns(2, border=True) | |
| with header1: | |
| st.markdown("<p style='text-align: center; font-size: 24px; font-weight: bold;'>Image</p>", unsafe_allow_html=True) | |
| with header2: | |
| st.markdown("<p style='text-align: center; font-size: 24px; font-weight: bold;'>LaTeX Code</p>", unsafe_allow_html=True) | |
| for i in range(10): | |
| col1, col2 = st.columns(2, border=True) | |
| sample = sample_dataset[i] | |
| with col1: | |
| st.image(sample["image"]) | |
| with col2: | |
| st.markdown(f"`{sample['text']}`") | |
| # ---- Section 2: Exploratory Data Analysis ---- | |
| st.markdown("---") | |
| st.header("📊 Exploratory Data Analysis") | |
| st.markdown("We analyze the distribution of LaTeX expressions in terms of complexity, length, and depth.") | |
| df = pd.DataFrame(dataset) | |
| sns.set_theme() | |
| # Layout: 3 plots in a row | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| fig, ax = plt.subplots(figsize=(3, 3)) | |
| plot = sns.countplot(data=df, x="complexity", order=["very simple", "simple", "medium", "complex", "very complex"], palette="flare", ax=ax) | |
| plot.set_xticklabels(plot.get_xticklabels(), rotation=45, horizontalalignment='right', fontsize=8) | |
| ax.set_title("LaTeX Formula Complexity", fontsize=8) | |
| ax.set_xlabel("") | |
| ax.set_ylabel("Count", fontsize=8) | |
| st.pyplot(fig) | |
| with col2: | |
| fig, ax = plt.subplots(figsize=(3, 3)) | |
| sns.histplot(df["latex_length"], bins=20, kde=True, ax=ax) | |
| ax.set_title("Length of LaTeX Code", fontsize=8) | |
| ax.set_xlabel("Characters", fontsize=8) | |
| ax.set_ylabel("Count", fontsize=8) | |
| st.pyplot(fig) | |
| with col3: | |
| fig, ax = plt.subplots(figsize=(3, 3)) | |
| sns.histplot(df["latex_depth"], bins=5, kde=True, color="forestgreen", ax=ax) | |
| ax.set_title("Max Brace Depth of LaTeX Code", fontsize=8) | |
| ax.set_xlabel("Depth", fontsize=8) | |
| ax.set_ylabel("Count", fontsize=8) | |
| st.pyplot(fig) | |
| # ---- Section 3: Prediction ---- | |
| st.markdown("---") | |
| st.header("🔍 TeXTeller Inference") | |
| st.markdown("Upload a math image below to predict the LaTeX code using the TeXTeller model.") | |
| # Radio button to select input source | |
| input_option = st.radio( | |
| "Choose an input method:", | |
| options=["Upload your own image", "Use a sample from the dataset"], | |
| horizontal=True | |
| ) | |
| image = None | |
| selected_index = None | |
| if input_option == "Use a sample from the dataset": | |
| selected_index = None | |
| nb_cols = 5 | |
| for i in range(10): # Affiche 10 images | |
| if i % nb_cols == 0: | |
| cols = st.columns(nb_cols, border=True) | |
| col = cols[i % nb_cols] | |
| with col: | |
| if st.button("Select this sample", key=f"btn_{i}"): | |
| selected_index = i | |
| st.image(dataset[i]["image"], use_container_width=True) | |
| if selected_index is not None: | |
| image = dataset[selected_index]["image"] | |
| elif input_option == "Upload your own image": | |
| uploaded_file = st.file_uploader("Upload a math image (JPG, PNG)...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file) | |
| image = image.convert("RGB") | |
| # Once we have a valid image | |
| if image: | |
| st.divider() | |
| st.markdown("### TeXTeller Prediction Output") | |
| col1, col2, col3 = st.columns(3, border=True) | |
| with col1: | |
| st.image(image, caption="Input Image", use_container_width=True) | |
| with st.spinner("Running TeXTeller..."): | |
| try: | |
| dico_result = defaultdict(list) | |
| start = time.time() | |
| predicted_latex = img2latex(model, tokenizer, [np.array(image)], out_format="katex")[0] | |
| eval_time = time.time() - start | |
| dico_result["Inference Time (s)"].append(f"{eval_time:.2f}") | |
| with col2: | |
| st.markdown("**Predicted LaTeX Code:**") | |
| st.text_area(label="", value=predicted_latex, height=80) | |
| with col3: | |
| rendered_image = latex2image(predicted_latex) | |
| st.image(rendered_image, caption="Rendered from Prediction", use_container_width=True) | |
| if selected_index is not None: | |
| ref_latex = dataset[selected_index]["text"] | |
| predicted_latex = normalize_latex(predicted_latex) | |
| # Compute BLEU score | |
| bleu_results = bleu_metric.compute(predictions=[predicted_latex], references=[[ref_latex]]) | |
| bleu_score = bleu_results['bleu'] | |
| dico_result["BLEU Score"].append(bleu_score) | |
| # Compute SSIM | |
| pred_image = rendered_image.resize(image.size) | |
| ssim_score = compute_ssim(image, pred_image) | |
| dico_result["SSIM Score"].append(ssim_score) | |
| # Compute CDM | |
| cdm_score, cdm_recall, cdm_precision, compare_img = compute_cdm_score(ref_latex, predicted_latex) | |
| dico_result["CDM Image"].append(compare_img) | |
| dico_result["CDM Score"].append(cdm_score) | |
| # Display metrics | |
| html = build_html_table(dico_result) | |
| st.markdown("### TeXTeller Metrics") | |
| # CSS pour forcer le tableau à occuper toute la largeur | |
| st.markdown(""" | |
| <style> | |
| table { | |
| width: 100% !important; | |
| } | |
| th, td { | |
| text-align: center !important; | |
| vertical-align: middle !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.markdown(html, unsafe_allow_html=True) | |
| except Exception as e: | |
| st.error(f"Error during prediction: {e}") | |