Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| from PIL import Image | |
| import requests | |
| import gradio as gr | |
| import pandas as pd | |
| import subprocess | |
| import os | |
| # try one more time | |
| # Install flash-attn without CUDA build | |
| subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| # Load the model and processor | |
| model_id = "yifeihu/TB-OCR-preview-0.1" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="cuda", | |
| trust_remote_code=True, | |
| torch_dtype="auto", | |
| attn_implementation='flash_attention_2', | |
| load_in_4bit=True | |
| ) | |
| processor = AutoProcessor.from_pretrained(model_id, | |
| trust_remote_code=True, | |
| num_crops=16 | |
| ) | |
| # Define the OCR function | |
| def phi_ocr(image): | |
| question = "Convert the text to markdown format." | |
| prompt_message = [{ | |
| 'role': 'user', | |
| 'content': f'<|image_1|>\n{question}', | |
| }] | |
| prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(prompt, [image], return_tensors="pt").to("cuda") | |
| generation_args = { | |
| "max_new_tokens": 1024, | |
| "temperature": 0.1, | |
| "do_sample": False | |
| } | |
| generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args) | |
| generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] | |
| response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| response = response.split("<image_end>")[0] | |
| return response | |
| # Define the function to process multiple images and save results to a CSV | |
| def process_images(input_images): | |
| results = [] | |
| for index, image in enumerate(input_images): | |
| extracted_text = phi_ocr(image) | |
| results.append({ | |
| 'index': index, | |
| 'extracted_text': extracted_text | |
| }) | |
| # Convert to DataFrame and save to CSV | |
| df = pd.DataFrame(results) | |
| output_csv = "extracted_entities.csv" | |
| df.to_csv(output_csv, index=False) | |
| return f"Processed {len(input_images)} images and saved to {output_csv}", output_csv | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# OCR with TB-OCR-preview-0.1") | |
| gr.Markdown("Upload multiple images to extract and convert text to markdown format.") | |
| gr.Markdown("[Check out ](https://huggingface.co/yifeihu/TB-OCR-preview-0.1)") | |
| with gr.Row(): | |
| input_images = gr.Image(type="pil", label="Upload Images", tool="editor", source="upload", multiple=True) | |
| output_text = gr.Textbox(label="Status") | |
| output_csv_link = gr.File(label="Download CSV") | |
| input_images.change(fn=process_images, inputs=input_images, outputs=[output_text, output_csv_link]) | |
| demo.launch() | |