File size: 16,702 Bytes
10e234c
25c3a8b
 
 
 
 
 
cd46aca
 
25c3a8b
7fab6d4
1922dbd
4e2ccbf
2f8ae1f
25c3a8b
 
f5747b1
53c4c46
25c3a8b
 
4d87419
 
 
53c4c46
7fab6d4
25c3a8b
 
 
 
 
b2929fc
53c4c46
25c3a8b
 
7fab6d4
25c3a8b
36983ae
 
72b2667
36983ae
 
 
25c3a8b
 
2f8ae1f
36983ae
25c3a8b
 
7fab6d4
 
 
 
 
25c3a8b
 
7fab6d4
 
53c4c46
 
 
 
 
 
 
 
 
 
7fab6d4
cd46aca
7fab6d4
4d87419
25c3a8b
7fab6d4
 
 
 
 
 
25c3a8b
 
 
1922dbd
25c3a8b
 
7fab6d4
 
25c3a8b
53c4c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25c3a8b
 
 
1922dbd
53c4c46
 
25c3a8b
 
 
 
 
 
b2929fc
53c4c46
25c3a8b
 
53c4c46
25c3a8b
 
53c4c46
 
 
 
25c3a8b
 
53c4c46
 
4d87419
7fab6d4
53c4c46
1922dbd
53c4c46
 
 
 
4d87419
53c4c46
 
 
25c3a8b
53c4c46
25c3a8b
7fab6d4
53c4c46
7fab6d4
 
53c4c46
 
4d87419
7fab6d4
53c4c46
 
 
 
25c3a8b
53c4c46
 
 
25c3a8b
 
53c4c46
 
 
 
 
25c3a8b
 
53c4c46
25c3a8b
53c4c46
25c3a8b
 
53c4c46
25c3a8b
53c4c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcc601a
53c4c46
 
 
 
 
 
 
 
 
b2929fc
53c4c46
 
25c3a8b
 
 
 
 
10e234c
25c3a8b
 
 
 
 
1812a2a
420d8ba
25c3a8b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
"""Gradio UI for DeepCritical agent with MCP server support."""

import os
from collections.abc import AsyncGenerator
from typing import Any

import gradio as gr
from pydantic_ai.models.huggingface import HuggingFaceModel
from pydantic_ai.providers.huggingface import HuggingFaceProvider

from src.agent_factory.judges import HFInferenceJudgeHandler, JudgeHandler, MockJudgeHandler
from src.orchestrator_factory import create_orchestrator
from src.tools.clinicaltrials import ClinicalTrialsTool
from src.tools.europepmc import EuropePMCTool
from src.tools.pubmed import PubMedTool
from src.tools.search_handler import SearchHandler
from src.utils.config import settings
from src.utils.models import AgentEvent, OrchestratorConfig


def configure_orchestrator(
    use_mock: bool = False,
    mode: str = "simple",
    oauth_token: str | None = None,
) -> tuple[Any, str]:
    """
    Create an orchestrator instance.

    Args:
        use_mock: If True, use MockJudgeHandler (no API key needed)
        mode: Orchestrator mode ("simple" or "advanced")
        oauth_token: Optional OAuth token from HuggingFace login

    Returns:
        Tuple of (Orchestrator instance, backend_name)
    """
    # Create orchestrator config
    config = OrchestratorConfig(
        max_iterations=10,
        max_results_per_tool=10,
    )

    # Create search tools
    search_handler = SearchHandler(
        tools=[PubMedTool(), ClinicalTrialsTool(), EuropePMCTool()],
        timeout=config.search_timeout,
    )

    # Create judge (mock, real, or free tier)
    judge_handler: JudgeHandler | MockJudgeHandler | HFInferenceJudgeHandler
    backend_info = "Unknown"

    # 1. Forced Mock (Unit Testing)
    if use_mock:
        judge_handler = MockJudgeHandler()
        backend_info = "Mock (Testing)"

    # 2. API Key (OAuth or Env) - HuggingFace only (OAuth provides HF token)
    # Priority: oauth_token > env vars
    effective_api_key = oauth_token
    if effective_api_key or (os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY")):
        model: HuggingFaceModel | None = None
        if effective_api_key:
            model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct"
            hf_provider = HuggingFaceProvider(api_key=effective_api_key)
            model = HuggingFaceModel(model_name, provider=hf_provider)
            backend_info = "API (HuggingFace OAuth)"
        else:
            backend_info = "API (Env Config)"

        judge_handler = JudgeHandler(model=model)

    # 3. Free Tier (HuggingFace Inference)
    else:
        judge_handler = HFInferenceJudgeHandler()
        backend_info = "Free Tier (Llama 3.1 / Mistral)"

    orchestrator = create_orchestrator(
        search_handler=search_handler,
        judge_handler=judge_handler,
        config=config,
        mode=mode,  # type: ignore
    )

    return orchestrator, backend_info


def event_to_chat_message(event: AgentEvent) -> gr.ChatMessage:
    """
    Convert AgentEvent to gr.ChatMessage with metadata for accordion display.

    Args:
        event: The AgentEvent to convert

    Returns:
        ChatMessage with metadata for collapsible accordion
    """
    # Map event types to accordion titles and determine if pending
    event_configs: dict[str, dict[str, Any]] = {
        "started": {"title": "πŸš€ Starting Research", "status": "done", "icon": "πŸš€"},
        "searching": {"title": "πŸ” Searching Literature", "status": "pending", "icon": "πŸ”"},
        "search_complete": {"title": "πŸ“š Search Results", "status": "done", "icon": "πŸ“š"},
        "judging": {"title": "🧠 Evaluating Evidence", "status": "pending", "icon": "🧠"},
        "judge_complete": {"title": "βœ… Evidence Assessment", "status": "done", "icon": "βœ…"},
        "looping": {"title": "πŸ”„ Research Iteration", "status": "pending", "icon": "πŸ”„"},
        "synthesizing": {"title": "πŸ“ Synthesizing Report", "status": "pending", "icon": "πŸ“"},
        "hypothesizing": {"title": "πŸ”¬ Generating Hypothesis", "status": "pending", "icon": "πŸ”¬"},
        "analyzing": {"title": "πŸ“Š Statistical Analysis", "status": "pending", "icon": "πŸ“Š"},
        "analysis_complete": {"title": "πŸ“ˆ Analysis Results", "status": "done", "icon": "πŸ“ˆ"},
        "streaming": {"title": "πŸ“‘ Processing", "status": "pending", "icon": "πŸ“‘"},
        "complete": {"title": None, "status": "done", "icon": "πŸŽ‰"},  # Main response, no accordion
        "error": {"title": "❌ Error", "status": "done", "icon": "❌"},
    }

    config = event_configs.get(
        event.type, {"title": f"β€’ {event.type}", "status": "done", "icon": "β€’"}
    )

    # For complete events, return main response without accordion
    if event.type == "complete":
        return gr.ChatMessage(
            role="assistant",
            content=event.message,
        )

    # Build metadata for accordion
    metadata: dict[str, Any] = {}
    if config["title"]:
        metadata["title"] = config["title"]

    # Set status (pending shows spinner, done is collapsed)
    if config["status"] == "pending":
        metadata["status"] = "pending"

    # Add duration if available in data
    if event.data and isinstance(event.data, dict) and "duration" in event.data:
        metadata["duration"] = event.data["duration"]

    # Add log info (iteration number, etc.)
    log_parts: list[str] = []
    if event.iteration > 0:
        log_parts.append(f"Iteration {event.iteration}")
    if event.data and isinstance(event.data, dict):
        if "tool" in event.data:
            log_parts.append(f"Tool: {event.data['tool']}")
        if "results_count" in event.data:
            log_parts.append(f"Results: {event.data['results_count']}")
    if log_parts:
        metadata["log"] = " | ".join(log_parts)

    return gr.ChatMessage(
        role="assistant",
        content=event.message,
        metadata=metadata if metadata else None,
    )


def extract_oauth_info(request: gr.Request | None) -> tuple[str | None, str | None]:
    """
    Extract OAuth token and username from Gradio request.

    Args:
        request: Gradio request object containing OAuth information

    Returns:
        Tuple of (oauth_token, oauth_username)
    """
    oauth_token: str | None = None
    oauth_username: str | None = None

    if request is None:
        return oauth_token, oauth_username

    # Try multiple ways to access OAuth token (Gradio API may vary)
    # Pattern 1: request.oauth_token.token
    if hasattr(request, "oauth_token") and request.oauth_token is not None:
        if hasattr(request.oauth_token, "token"):
            oauth_token = request.oauth_token.token
        elif isinstance(request.oauth_token, str):
            oauth_token = request.oauth_token
    # Pattern 2: request.headers (fallback)
    elif hasattr(request, "headers"):
        # OAuth token might be in headers
        auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
        if auth_header and auth_header.startswith("Bearer "):
            oauth_token = auth_header.replace("Bearer ", "")

    # Access username from request
    if hasattr(request, "username") and request.username:
        oauth_username = request.username
    # Also try accessing via oauth_profile if available
    elif hasattr(request, "oauth_profile") and request.oauth_profile is not None:
        if hasattr(request.oauth_profile, "username"):
            oauth_username = request.oauth_profile.username
        elif hasattr(request.oauth_profile, "name"):
            oauth_username = request.oauth_profile.name

    return oauth_token, oauth_username


async def yield_auth_messages(
    oauth_username: str | None,
    oauth_token: str | None,
    has_huggingface: bool,
    mode: str,
) -> AsyncGenerator[gr.ChatMessage, None]:
    """
    Yield authentication and mode status messages.

    Args:
        oauth_username: OAuth username if available
        oauth_token: OAuth token if available
        has_huggingface: Whether HuggingFace credentials are available
        mode: Orchestrator mode

    Yields:
        ChatMessage objects with authentication status
    """
    # Show user greeting if logged in via OAuth
    if oauth_username:
        yield gr.ChatMessage(
            role="assistant",
            content=f"πŸ‘‹ **Welcome, {oauth_username}!** Using your HuggingFace account.\n\n",
        )

    # Advanced mode is not supported without OpenAI (which requires manual setup)
    # For now, we only support simple mode with HuggingFace
    if mode == "advanced":
        yield gr.ChatMessage(
            role="assistant",
            content=(
                "⚠️ **Warning**: Advanced mode requires OpenAI API key configuration. "
                "Falling back to simple mode.\n\n"
            ),
        )

    # Inform user about authentication status
    if oauth_token:
        yield gr.ChatMessage(
            role="assistant",
            content=(
                "πŸ” **Using HuggingFace OAuth token** - "
                "Authenticated via your HuggingFace account.\n\n"
            ),
        )
    elif not has_huggingface:
        # No keys at all - will use FREE HuggingFace Inference (public models)
        yield gr.ChatMessage(
            role="assistant",
            content=(
                "πŸ€— **Free Tier**: Using HuggingFace Inference (Llama 3.1 / Mistral) for AI analysis.\n"
                "For premium models or higher rate limits, sign in with HuggingFace above.\n\n"
            ),
        )


async def handle_orchestrator_events(
    orchestrator: Any,
    message: str,
) -> AsyncGenerator[gr.ChatMessage, None]:
    """
    Handle orchestrator events and yield ChatMessages.

    Args:
        orchestrator: The orchestrator instance
        message: The research question

    Yields:
        ChatMessage objects from orchestrator events
    """
    # Track pending accordions for real-time updates
    pending_accordions: dict[str, str] = {}  # title -> accumulated content

    async for event in orchestrator.run(message):
        # Convert event to ChatMessage with metadata
        chat_msg = event_to_chat_message(event)

        # Handle complete events (main response)
        if event.type == "complete":
            # Close any pending accordions first
            if pending_accordions:
                for title, content in pending_accordions.items():
                    yield gr.ChatMessage(
                        role="assistant",
                        content=content.strip(),
                        metadata={"title": title, "status": "done"},
                    )
                pending_accordions.clear()

            # Yield final response (no accordion for main response)
            yield chat_msg
            continue

        # Handle events with metadata (accordions)
        if chat_msg.metadata:
            title = chat_msg.metadata.get("title")
            status = chat_msg.metadata.get("status")

            if title:
                # For pending operations, accumulate content and show spinner
                if status == "pending":
                    if title not in pending_accordions:
                        pending_accordions[title] = ""
                    pending_accordions[title] += chat_msg.content + "\n"
                    # Yield updated accordion with accumulated content
                    yield gr.ChatMessage(
                        role="assistant",
                        content=pending_accordions[title].strip(),
                        metadata=chat_msg.metadata,
                    )
                elif title in pending_accordions:
                    # Combine pending content with final content
                    final_content = pending_accordions[title] + chat_msg.content
                    del pending_accordions[title]
                    yield gr.ChatMessage(
                        role="assistant",
                        content=final_content.strip(),
                        metadata={"title": title, "status": "done"},
                    )
                else:
                    # New done accordion (no pending state)
                    yield chat_msg
            else:
                # No title, yield as-is
                yield chat_msg
        else:
            # No metadata, yield as plain message
            yield chat_msg


async def research_agent(
    message: str,
    history: list[dict[str, Any]],
    mode: str = "simple",
    request: gr.Request | None = None,
) -> AsyncGenerator[gr.ChatMessage | list[gr.ChatMessage], None]:
    """
    Gradio chat function that runs the research agent.

    Args:
        message: User's research question
        history: Chat history (Gradio format)
        mode: Orchestrator mode ("simple" or "advanced")
        request: Gradio request object containing OAuth information

    Yields:
        ChatMessage objects with metadata for accordion display
    """
    if not message.strip():
        yield gr.ChatMessage(
            role="assistant",
            content="Please enter a research question.",
        )
        return

    # Extract OAuth token from request if available
    oauth_token, oauth_username = extract_oauth_info(request)

    # Check available keys
    has_huggingface = bool(os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY") or oauth_token)

    # Adjust mode if needed
    effective_mode = mode
    if mode == "advanced":
        effective_mode = "simple"

    # Yield authentication and mode status messages
    async for msg in yield_auth_messages(oauth_username, oauth_token, has_huggingface, mode):
        yield msg

    # Run the agent and stream events
    try:
        # use_mock=False - let configure_orchestrator decide based on available keys
        # It will use: OAuth token > Env vars > HF Inference (free tier)
        orchestrator, backend_name = configure_orchestrator(
            use_mock=False,  # Never use mock in production - HF Inference is the free fallback
            mode=effective_mode,
            oauth_token=oauth_token,
        )

        yield gr.ChatMessage(
            role="assistant",
            content=f"🧠 **Backend**: {backend_name}\n\n",
        )

        # Handle orchestrator events
        async for msg in handle_orchestrator_events(orchestrator, message):
            yield msg

    except Exception as e:
        yield gr.ChatMessage(
            role="assistant",
            content=f"❌ **Error**: {e!s}",
            metadata={"title": "❌ Error", "status": "done"},
        )


def create_demo() -> gr.Blocks:
    """
    Create the Gradio demo interface with MCP support and OAuth login.

    Returns:
        Configured Gradio Blocks interface with MCP server and OAuth enabled
    """
    with gr.Blocks(title="🧬 DeepCritical") as demo:
        # Add login button at the top
        with gr.Row():
            gr.LoginButton()

        # Chat interface
        gr.ChatInterface(
            fn=research_agent,
            title="🧬 DeepCritical",
            description=(
                "*AI-Powered Drug Repurposing Agent β€” searches PubMed, "
                "ClinicalTrials.gov & Europe PMC*\n\n"
                "---\n"
                "*Research tool only β€” not for medical advice.*  \n"
                "**MCP Server Active**: Connect Claude Desktop to `/gradio_api/mcp/`\n\n"
                "**Sign in with HuggingFace** above to use your account's API token automatically."
            ),
            examples=[
                ["What drugs could be repurposed for Alzheimer's disease?", "simple"],
                ["Is metformin effective for treating cancer?", "simple"],
                ["What medications show promise for Long COVID treatment?", "simple"],
            ],
            additional_inputs_accordion=gr.Accordion(label="βš™οΈ Settings", open=False),
            additional_inputs=[
                gr.Radio(
                    choices=["simple", "advanced"],
                    value="simple",
                    label="Orchestrator Mode",
                    info=(
                        "Simple: Linear (Free Tier Friendly) | Advanced: Multi-Agent (Requires OpenAI - not available without manual config)"
                    ),
                ),
            ],
        )

    return demo


def main() -> None:
    """Run the Gradio app with MCP server enabled."""
    demo = create_demo()
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        mcp_server=True,
        ssr_mode=False,  # Fix for intermittent loading/hydration issues in HF Spaces
    )


if __name__ == "__main__":
    main()