Spaces:
Sleeping
Sleeping
| import json | |
| from rich import print as rich_print | |
| from rich.panel import Panel | |
| from rich.console import Console | |
| from rich.pretty import Pretty | |
| from rich.markdown import Markdown | |
| from rich.json import JSON | |
| from typing import TypedDict, Sequence, Annotated | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain.agents import create_tool_calling_agent, AgentExecutor | |
| from openai import RateLimitError | |
| import time | |
| def print_conversation(messages): | |
| console = Console(width=200, soft_wrap=True) | |
| for msg in messages: | |
| role = msg.get("role", "unknown").capitalize() | |
| content = msg.get("content", "") | |
| try: | |
| if isinstance(content, str): | |
| content = json.loads(content) | |
| elif isinstance(content, dict) and 'output' in content.keys(): | |
| if isinstance(content['output'], HumanMessage): | |
| content['output'] = content['output'].content | |
| elif isinstance(content, HumanMessage): | |
| content = content.content | |
| rendered_content = JSON.from_data(content) | |
| except (json.JSONDecodeError, TypeError): | |
| try: | |
| rendered_content = Markdown(content.strip()) | |
| except AttributeError: | |
| # from gemini | |
| try: | |
| rendered_content = { | |
| 'query': content.get('query', 'QueryKeyNotFound').content[0]['text'], | |
| 'output': content.get('output', 'OutputKeyNotFound'), | |
| } | |
| rendered_content = JSON.from_data(rendered_content) | |
| except Exception as e: | |
| print(f"Failed to render content for role: {role}. Content: {content}") | |
| print("Error:", e) | |
| border_style_color = "red" | |
| if "Assistant" in role: | |
| border_style_color = "magenta" | |
| elif "User" in role: | |
| border_style_color = "green" | |
| elif "System" in role: | |
| border_style_color = "blue" | |
| elif "Tool" in role: | |
| border_style_color = "yellow" | |
| elif "Token" in role: | |
| border_style_color = "white" | |
| panel = Panel( | |
| rendered_content, | |
| title=f"[bold blue]{role}[/]", | |
| border_style=border_style_color, | |
| expand=True | |
| ) | |
| console.print(panel) | |
| def generate_final_answer(qa: dict[str, str]) -> str: | |
| """Invokes gpt-4o-mini to extract generate a final answer based on the content query, response, and metadata""" | |
| final_answer_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, max_retries=5) | |
| system_prompt = ( | |
| "You will be given a JSON object containing a user's query, a response from an AI assistant, and optional metadata. " | |
| "Your task is to extract and return a final answer to the query as a plain string, strictly suitable for exact match evaluation. " | |
| "Do NOT answer the query yourself. Use the response as the source of truth. " | |
| "Use the query only as context to interpret the response and extract a final, normalized answer. " | |
| "Your output must be:\n" | |
| "- A **single plain string** with **no prefixes, labels, or explanations**.\n" | |
| "- Suitable for exact string comparison.\n" | |
| "- Clean and deterministic: no variation in formatting, casing, or punctuation." | |
| "Special rules:\n" | |
| "- If the response shows inability to process attached media (images, audio, video), return: **'File not found'**.\n" | |
| "- If the response is a list of search results aggregate the information before constructing an answer" | |
| "- If the query is quantitative (How many...?), **aggregate the results of the tool(s) call(s) and return the numeric answer** only.\n" | |
| "- If the query is unanswerable from the response, return: **'No answer found: <brief reason>'**." | |
| "Examples:\n" | |
| "- Query: 'What’s in the attached image?'\n" | |
| " Response: 'I'm unable to view images directly...'\n" | |
| " Output: 'File not found'\n\n" | |
| "- Query: 'What’s the total population of X'\n" | |
| " Response: '{title: demographics of X, content: 1. City A: 2M, 2. City B: 3M, title: history of X, content: currently there are Y number of inhabitants in X...'\n" | |
| " Output: '5000000'\n" | |
| "Strictly follow these rules. Some final answers will require more analysis if the provided response. " | |
| "You can reason to get to the answer but always consider the response as the base_knowledge (keep coherence)." | |
| "Return only the final string answer. Do not include any other content." | |
| ) | |
| system_message = SystemMessage(content=system_prompt) | |
| if isinstance(qa['response']['query'], HumanMessage): | |
| qa['response'] = qa['response']['output'] | |
| messages = [ | |
| system_message, | |
| HumanMessage(content=f'Generate the final answer for the following query:\n\n{json.dumps(qa)}') | |
| ] | |
| response = final_answer_llm.invoke(messages) | |
| return response.content | |
| class ToolAgent: | |
| """Basic custom class from an agent prompted for tool-use pattern""" | |
| def __init__(self, tools: list, model='gpt-4o', backstory:str="", streaming=False): | |
| self.name = "GAIA Tool-Use Agent" | |
| self.tools = tools | |
| self.llm = ChatOpenAI(model=model, temperature=0, streaming=streaming, max_retries=5) | |
| self.executor = None | |
| self.backstory = backstory if backstory else "You are a helpful assistant that can use tools to answer questions. Your name is Gaia." | |
| def initialize(self, custom_tools_nm="tools"): | |
| """Binds tools, creates and compiles graph""" | |
| chatgpt_with_tools = self.llm.bind_tools(self.tools) | |
| prompt_template = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", self.backstory), | |
| MessagesPlaceholder(variable_name="history", optional=True), | |
| ("human", "{query}"), | |
| MessagesPlaceholder(variable_name="agent_scratchpad"), | |
| ] | |
| ) | |
| agent = create_tool_calling_agent(self.llm, self.tools, prompt_template) | |
| self.executor = AgentExecutor( | |
| agent=agent, | |
| tools=self.tools, | |
| early_stopping_method='force', | |
| max_iterations=10 | |
| ) | |
| def chat(self, query:str, metadata): | |
| """Perform a single step in the conversation with the tool agent executor.""" | |
| if metadata is None: | |
| metadata = {} | |
| with_attachments = False | |
| query_message = HumanMessage(content=query) | |
| if "image_path" in metadata: | |
| # Create a HumanMessage with image content | |
| query_message = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": query}, | |
| {"type": "text", "text": f"image_path: {metadata['image_path']}"}, | |
| ] | |
| ) | |
| with_attachments = True | |
| if "file_path" in metadata: | |
| # Create a HumanMessage with image content | |
| query_message = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": query}, | |
| {"type": "text", "text": f"file_path: {metadata['file_path']}"}, | |
| ] | |
| ) | |
| with_attachments = True | |
| user_message = {'role': 'user', 'content': query if not with_attachments else query_message} | |
| print_conversation([user_message]) | |
| response = self.executor.invoke({ | |
| "query": query if not with_attachments else query_message, | |
| }) | |
| response_message = {'role': 'assistant', 'content': response} | |
| print_conversation([response_message]) | |
| final_answer = generate_final_answer({ | |
| 'query': query, | |
| 'response': response, | |
| }) | |
| final_answer_message = {'role': 'Final Answer', 'content': final_answer} | |
| print_conversation([final_answer_message]) | |
| return final_answer | |
| def invoke(self, q_data): | |
| """Invoke the executor input data""" | |
| query = q_data.get("query", "") | |
| metadata = q_data.get("metadata", None) | |
| try: | |
| response = self.chat(query, metadata) | |
| time.sleep(3) | |
| except RateLimitError: | |
| response = 'Rate limit error encountered. Retrying after a short pause...' | |
| error_message = {'role': 'Rate-limit-hit', 'content': response} | |
| print_conversation([error_message]) | |
| time.sleep(5) | |
| try: | |
| response = self.chat(query, metadata) | |
| except RateLimitError: | |
| response = 'Rate limit error encountered again. Skipping this query.' | |
| error_message = {'role': 'Rate-limit-hit', 'content': response} | |
| print_conversation([error_message]) | |
| print() | |
| return response | |
| def __call__(self, q_data): | |
| """Call the invoke method from the agent executor.""" | |
| return self.invoke(q_data) |