Spaces:
Runtime error
Runtime error
File size: 6,428 Bytes
8c4216e 0c9dd7c 8c4216e 0c9dd7c 8c4216e 0c9dd7c 8c4216e 0c9dd7c 8c4216e 0c9dd7c 8c4216e 0c9dd7c 470e4ba 0c9dd7c 470e4ba 0c9dd7c 8c4216e 0c9dd7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import gradio as gr
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import os
# Model path configuration - can be loaded from environment variable or default path
MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1")
# Global variables to store model and processor
model = None
processor = None
def load_model():
"""Load model and processor"""
global model, processor
if model is None or processor is None:
print(f"Loading model: {MODEL_PATH}")
# Load model
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Load processor
processor = AutoProcessor.from_pretrained(MODEL_PATH)
print("Model loaded successfully!")
return model, processor
def inference(image, question, max_new_tokens=1024, temperature=0.7):
"""Perform inference"""
try:
# Ensure model is loaded
model, processor = load_model()
# Validate multimodal inputs
if image is None:
return "⚠️ Error: Please upload an image. This is a multimodal model that requires both an image and text input."
if not question or question.strip() == "":
return "⚠️ Error: Please enter your question. This is a multimodal model that requires both an image and text input."
# Build multimodal messages (image + text)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image, # Image input
},
{"type": "text", "text": question}, # Text input
],
}
]
# Prepare inputs
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Move inputs to the device where the model is located
device = next(model.parameters()).device
inputs = inputs.to(device)
# Generate response
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True if temperature > 0 else False,
)
# Decode output
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return output_text[0]
except Exception as e:
return f"An error occurred: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Robust-R1", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
## Citation
The following is a BibTeX reference:
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="📸 Upload Image (Required)",
height=400,
info="Upload an image that you want to ask questions about"
)
question_input = gr.Textbox(
label="💬 Your Question (Required)",
placeholder="e.g., Describe the content of this image",
lines=3,
info="Enter your question about the uploaded image"
)
with gr.Row():
max_tokens = gr.Slider(
minimum=64,
maximum=2048,
value=512,
step=64,
label="Max Generation Length"
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature"
)
submit_btn = gr.Button("Submit", variant="primary", size="lg")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Column(scale=1):
output = gr.Textbox(
label="Model Response",
lines=15,
interactive=False
)
# Examples
gr.Examples(
examples=[
["What is the name of the Garage?\n0. polo\n1. imam\n2. leke\n3. akd\nFirst output the the types of degradations in image briefly in <TYPE> <TYPE_END> tags, and thenoutput what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags, then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END>tags, and then sunmmarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END>tags,provides the user with the answer briefly in<ANSWER> <ANSWER_END>.i.e., <TYPE> degradation type here <TYPE_END>\n<INFLUENCE> influence here<INFLUENCE_END>\n<REASONING> reasoning process here<REASONING_END>\n<CONCLUSION>summary here<CONCLUSION_END>\n<ANSWER>final answer<ANSWER_END>."],
],
inputs=[question_input],
label="Example Questions"
)
# Bind events
submit_btn.click(
fn=inference,
inputs=[image_input, question_input, max_tokens, temperature],
outputs=output
)
clear_btn.click(
fn=lambda: (None, "", 512, 0.7, ""),
outputs=[image_input, question_input, max_tokens, temperature, output]
)
# Show message when page loads
demo.load(
fn=lambda: "Model is loading, please wait...",
outputs=output
)
if __name__ == "__main__":
# When running in Space, Gradio will automatically handle the port
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|