File size: 6,428 Bytes
8c4216e
0c9dd7c
 
 
 
8c4216e
0c9dd7c
 
8c4216e
0c9dd7c
 
 
8c4216e
0c9dd7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c4216e
0c9dd7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c4216e
0c9dd7c
470e4ba
0c9dd7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470e4ba
0c9dd7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c4216e
 
0c9dd7c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import gradio as gr
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import os

# Model path configuration - can be loaded from environment variable or default path
MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1")

# Global variables to store model and processor
model = None
processor = None

def load_model():
    """Load model and processor"""
    global model, processor
    
    if model is None or processor is None:
        print(f"Loading model: {MODEL_PATH}")
        
        # Load model
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            MODEL_PATH,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        
        # Load processor
        processor = AutoProcessor.from_pretrained(MODEL_PATH)
        
        print("Model loaded successfully!")
    
    return model, processor

def inference(image, question, max_new_tokens=1024, temperature=0.7):
    """Perform inference"""
    try:
        # Ensure model is loaded
        model, processor = load_model()
        
        # Validate multimodal inputs
        if image is None:
            return "⚠️ Error: Please upload an image. This is a multimodal model that requires both an image and text input."
        
        if not question or question.strip() == "":
            return "⚠️ Error: Please enter your question. This is a multimodal model that requires both an image and text input."
        
        # Build multimodal messages (image + text)
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image,  # Image input
                    },
                    {"type": "text", "text": question},  # Text input
                ],
            }
        ]
        
        # Prepare inputs
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        
        # Move inputs to the device where the model is located
        device = next(model.parameters()).device
        inputs = inputs.to(device)
        
        # Generate response
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=True if temperature > 0 else False,
        )
        
        # Decode output
        generated_ids_trimmed = [
            out_ids[len(in_ids):] 
            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )
        
        return output_text[0]
        
    except Exception as e:
        return f"An error occurred: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Robust-R1", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        ## Citation
        The following is a BibTeX reference:
        
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(
                type="pil",
                label="📸 Upload Image (Required)",
                height=400,
                info="Upload an image that you want to ask questions about"
            )
            question_input = gr.Textbox(
                label="💬 Your Question (Required)",
                placeholder="e.g., Describe the content of this image",
                lines=3,
                info="Enter your question about the uploaded image"
            )
            
            with gr.Row():
                max_tokens = gr.Slider(
                    minimum=64,
                    maximum=2048,
                    value=512,
                    step=64,
                    label="Max Generation Length"
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.7,
                    step=0.1,
                    label="Temperature"
                )
            
            submit_btn = gr.Button("Submit", variant="primary", size="lg")
            clear_btn = gr.Button("Clear", variant="secondary")
        
        with gr.Column(scale=1):
            output = gr.Textbox(
                label="Model Response",
                lines=15,
                interactive=False
            )
    
    # Examples
    gr.Examples(
        examples=[
            ["What is the name of the Garage?\n0. polo\n1. imam\n2. leke\n3. akd\nFirst output the the types of degradations in image briefly in <TYPE> <TYPE_END> tags, and thenoutput what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags, then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END>tags, and then sunmmarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END>tags,provides the user with the answer briefly in<ANSWER> <ANSWER_END>.i.e., <TYPE> degradation type here <TYPE_END>\n<INFLUENCE> influence here<INFLUENCE_END>\n<REASONING> reasoning process here<REASONING_END>\n<CONCLUSION>summary here<CONCLUSION_END>\n<ANSWER>final answer<ANSWER_END>."],
        ],
        inputs=[question_input],
        label="Example Questions"
    )
    
    # Bind events
    submit_btn.click(
        fn=inference,
        inputs=[image_input, question_input, max_tokens, temperature],
        outputs=output
    )
    
    clear_btn.click(
        fn=lambda: (None, "", 512, 0.7, ""),
        outputs=[image_input, question_input, max_tokens, temperature, output]
    )
    
    # Show message when page loads
    demo.load(
        fn=lambda: "Model is loading, please wait...",
        outputs=output
    )

if __name__ == "__main__":
    # When running in Space, Gradio will automatically handle the port
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False)