Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| from PIL import Image | |
| import requests | |
| import gradio as gr | |
| import spaces | |
| import subprocess | |
| subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| model_id = "yifeihu/TB-OCR-preview-0.1" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="cuda", | |
| trust_remote_code=True, | |
| torch_dtype="auto", | |
| attn_implementation='flash_attention_2', | |
| load_in_4bit=True | |
| ) | |
| processor = AutoProcessor.from_pretrained(model_id, | |
| trust_remote_code=True, | |
| num_crops=16 | |
| ) | |
| def phi_ocr(image): | |
| question = "Convert the text to markdown format." | |
| prompt_message = [{ | |
| 'role': 'user', | |
| 'content': f'<|image_1|>\n{question}', | |
| }] | |
| prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(prompt, [image], return_tensors="pt").to("cuda") | |
| generation_args = { | |
| "max_new_tokens": 1024, | |
| "temperature": 0.1, | |
| "do_sample": False | |
| } | |
| generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args) | |
| generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] | |
| response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| response = response.split("<image_end>")[0] | |
| return response | |
| def process_image(input_image): | |
| return phi_ocr(input_image) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# OCR with TB-OCR-preview-0.1") | |
| gr.Markdown("Upload an image to extract and convert text to markdown format.") | |
| gr.Markdown("[Check out the model here](https://huggingface.co/yifeihu/TB-OCR-preview-0.1)") | |
| with gr.Row(): | |
| input_image = gr.Image(type="pil") | |
| output_text = gr.Textbox() | |
| input_image.change(fn=process_image, inputs=input_image, outputs=output_text) | |
| demo.launch() |