Spaces:
Sleeping
Sleeping
| """ | |
| Autonomous AI Agent with MCP Tool Calling using Granite 4.0 H-1B (Open Source) | |
| This agent uses IBM Granite 4.0 H-1B (1.5B params) loaded locally via transformers | |
| to autonomously decide which MCP tools to call. | |
| Granite 4.0 H-1B is optimized for tool calling and function calling tasks. | |
| Uses ReAct (Reasoning + Acting) prompting pattern for reliable tool calling. | |
| """ | |
| import os | |
| import re | |
| import json | |
| import uuid | |
| import logging | |
| import asyncio | |
| from typing import List, Dict, Any, AsyncGenerator | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| from mcp.tools.definitions import MCP_TOOLS, list_all_tools | |
| from mcp.registry import MCPRegistry | |
| logger = logging.getLogger(__name__) | |
| class AutonomousMCPAgentGranite: | |
| """ | |
| AI Agent that autonomously uses MCP servers as tools using Granite 4. | |
| Uses ReAct (Reasoning + Acting) pattern: | |
| 1. Thought: AI reasons about what to do next | |
| 2. Action: AI decides which tool to call | |
| 3. Observation: AI sees the tool result | |
| 4. Repeat until task complete | |
| """ | |
| def __init__(self, mcp_registry: MCPRegistry, hf_token: str = None): | |
| """ | |
| Initialize the autonomous agent with Granite 4.0 H-1B | |
| Args: | |
| mcp_registry: MCP registry with all servers | |
| hf_token: HuggingFace token (optional, for accessing private models) | |
| """ | |
| self.mcp_registry = mcp_registry | |
| self.hf_token = hf_token or os.getenv("HF_API_TOKEN") or os.getenv("HF_TOKEN") | |
| # Use Granite 4.0 H-1B (1.5B params, optimized for tool calling) | |
| self.model_name = "ibm-granite/granite-4.0-h-1b" | |
| logger.info(f"Loading Granite 4.0 H-1B model locally...") | |
| # Load model with optimizations for CPU/limited memory | |
| try: | |
| logger.info(f"📥 Downloading tokenizer from {self.model_name}...") | |
| # Use bfloat16 for better efficiency, float32 fallback for CPU | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| token=self.hf_token, | |
| trust_remote_code=True | |
| ) | |
| logger.info(f"✓ Tokenizer loaded successfully") | |
| # Check device availability | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| logger.info(f"💻 Device: {device}, dtype: {dtype}") | |
| logger.info(f"📥 Downloading model weights (~1.5GB)...") | |
| # For hybrid models like Granite H-1B, we need explicit device placement | |
| if torch.cuda.is_available(): | |
| # GPU available - use device_map | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| token=self.hf_token, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| else: | |
| # CPU only - load with 8-bit quantization to reduce memory | |
| logger.info(f"⚠️ Loading on CPU (no GPU available)") | |
| logger.info(f"💾 Using 8-bit quantization to reduce memory usage") | |
| try: | |
| # Try loading with 8-bit quantization (requires bitsandbytes) | |
| from transformers import BitsAndBytesConfig | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| llm_int8_threshold=6.0 | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| token=self.hf_token, | |
| quantization_config=quantization_config, | |
| low_cpu_mem_usage=False, | |
| trust_remote_code=True | |
| ) | |
| logger.info(f"✓ Loaded with 8-bit quantization (~50% memory reduction)") | |
| except (ImportError, Exception) as e: | |
| # Fallback to float32 if 8-bit fails | |
| logger.warning(f"⚠️ 8-bit quantization failed: {e}") | |
| logger.info(f"⚠️ Falling back to float32 (may use ~4-6GB RAM)") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| token=self.hf_token, | |
| torch_dtype=torch.float32, # Use float32 for CPU | |
| low_cpu_mem_usage=False, # Disable to avoid meta device | |
| trust_remote_code=True | |
| ) | |
| # Verify all parameters are on CPU, not meta | |
| logger.info(f"🔍 Verifying model is materialized on CPU...") | |
| param_devices = set() | |
| for param in self.model.parameters(): | |
| param_devices.add(str(param.device)) | |
| if 'meta' in param_devices: | |
| logger.error(f"❌ Model still has parameters on meta device!") | |
| raise RuntimeError("Model not properly materialized. Try upgrading transformers: pip install --upgrade transformers") | |
| logger.info(f"✓ All parameters on: {param_devices}") | |
| logger.info(f"✓ Model weights loaded") | |
| # Set model to eval mode | |
| self.model.eval() | |
| logger.info(f"✓ Model set to evaluation mode") | |
| # Get model device and memory info | |
| try: | |
| model_device = next(self.model.parameters()).device | |
| logger.info(f"✓ Model loaded successfully on device: {model_device}") | |
| except StopIteration: | |
| logger.warning(f"⚠️ Could not determine model device (no parameters)") | |
| # Memory info if available | |
| if torch.cuda.is_available(): | |
| memory_allocated = torch.cuda.memory_allocated() / 1024**3 | |
| logger.info(f"📊 GPU Memory allocated: {memory_allocated:.2f} GB") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load model: {e}", exc_info=True) | |
| raise | |
| # Create tool descriptions for the AI | |
| self.tools_description = self._create_tools_description() | |
| logger.info(f"Autonomous MCP Agent initialized with model: {self.model_name}") | |
| def _generate_text(self, prompt: str) -> str: | |
| """ | |
| Generate text using the local Granite model (synchronous, for use in executor) | |
| Args: | |
| prompt: The input prompt | |
| Returns: | |
| Generated text | |
| """ | |
| import time | |
| import gc | |
| start_time = time.time() | |
| # Force garbage collection before inference to free memory | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Tokenize input with aggressive truncation to save memory | |
| logger.info(f"🔤 Tokenizing input (length: {len(prompt)} chars)...") | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=2048 # Reduced from 4096 to save memory | |
| ) | |
| num_input_tokens = inputs["input_ids"].shape[-1] | |
| logger.info(f"✓ Tokenized to {num_input_tokens} tokens") | |
| # Get target device - handle models split across devices | |
| try: | |
| target_device = next(self.model.parameters()).device | |
| except StopIteration: | |
| # Fallback if no parameters found | |
| target_device = torch.device('cpu') | |
| logger.info(f"📍 Moving inputs to device: {target_device}") | |
| # Move to same device as model | |
| inputs = {k: v.to(target_device) for k, v in inputs.items()} | |
| # Generate with memory-efficient settings | |
| logger.info(f"🤖 Generating response (max 400 tokens, temp=0.1)...") | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=400, # Reduced from 800 to save memory | |
| temperature=0.1, # Low temperature for deterministic reasoning | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| use_cache=True, # Use KV cache for efficiency | |
| num_beams=1, # Greedy decoding to save memory | |
| ) | |
| # Decode only the new tokens | |
| response = self.tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[-1]:], | |
| skip_special_tokens=True | |
| ) | |
| elapsed = time.time() - start_time | |
| num_output_tokens = outputs.shape[-1] - num_input_tokens | |
| tokens_per_sec = num_output_tokens / elapsed if elapsed > 0 else 0 | |
| logger.info(f"✓ Generated {num_output_tokens} tokens in {elapsed:.1f}s ({tokens_per_sec:.1f} tokens/sec)") | |
| logger.info(f"📝 Response preview: {response[:100]}...") | |
| # Clean up to free memory | |
| del inputs, outputs | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return response | |
| def _create_tools_description(self) -> str: | |
| """Create a formatted description of all available tools for the AI""" | |
| tools_text = "## Available MCP Tools:\n\n" | |
| for tool in MCP_TOOLS: | |
| tools_text += f"**{tool['name']}**\n" | |
| tools_text += f" Description: {tool['description']}\n" | |
| tools_text += f" Parameters:\n" | |
| for prop_name, prop_data in tool['input_schema']['properties'].items(): | |
| required = prop_name in tool['input_schema'].get('required', []) | |
| tools_text += f" - {prop_name} ({prop_data['type']}){'*' if required else ''}: {prop_data.get('description', '')}\n" | |
| tools_text += "\n" | |
| return tools_text | |
| def _create_system_prompt(self) -> str: | |
| """Create the system prompt for ReAct pattern""" | |
| return f"""You are an autonomous AI agent for B2B sales automation using the ReAct (Reasoning + Acting) framework. | |
| You have access to MCP (Model Context Protocol) tools that let you: | |
| - Search the web for company information and news | |
| - Save prospects, companies, contacts, and facts to a database | |
| - Send emails and manage email threads | |
| - Schedule meetings and generate calendar invites | |
| {self.tools_description} | |
| ## ReAct Format: | |
| You must respond using this EXACT format: | |
| Thought: [Your reasoning about what to do next] | |
| Action: [tool_name] | |
| Action Input: {{"param1": "value1", "param2": "value2"}} | |
| After you see the Observation, you can continue with more Thought/Action/Observation cycles. | |
| When you've completed the task, respond with: | |
| Thought: [Your final reasoning] | |
| Final Answer: [Your complete response to the user] | |
| ## Important Rules: | |
| 1. Always use "Thought:" to reason before acting | |
| 2. Always use "Action:" followed by exact tool name | |
| 3. Always use "Action Input:" with valid JSON | |
| 4. Use tools multiple times if needed | |
| 5. Save important data to the database | |
| 6. When done, give a "Final Answer:" | |
| ## Example: | |
| Thought: I need to research Shopify first | |
| Action: search_web | |
| Action Input: {{"query": "Shopify company information"}} | |
| [You'll see Observation with results] | |
| Thought: Now I should save the company data | |
| Action: save_company | |
| Action Input: {{"company_id": "shopify", "name": "Shopify", "domain": "shopify.com"}} | |
| [Continue until task complete...] | |
| Thought: I've gathered all the information and saved it | |
| Final Answer: I've successfully researched Shopify and created a prospect profile with company information and recent facts. | |
| Now complete your assigned task!""" | |
| async def run( | |
| self, | |
| task: str, | |
| max_iterations: int = 15 | |
| ) -> AsyncGenerator[Dict[str, Any], None]: | |
| """ | |
| Run the agent autonomously on a task using ReAct pattern. | |
| Args: | |
| task: The task to complete | |
| max_iterations: Maximum tool calls to prevent infinite loops | |
| Yields: | |
| Events showing agent's progress and tool calls | |
| """ | |
| yield { | |
| "type": "agent_start", | |
| "message": f"🤖 Autonomous AI Agent (Granite 4) starting task", | |
| "task": task, | |
| "model": self.model | |
| } | |
| # Initialize conversation with system prompt and task | |
| conversation_history = f"""{self._create_system_prompt()} | |
| ## Task: | |
| {task} | |
| Begin! | |
| """ | |
| iteration = 0 | |
| while iteration < max_iterations: | |
| iteration += 1 | |
| yield { | |
| "type": "iteration_start", | |
| "iteration": iteration, | |
| "message": f"🔄 Iteration {iteration}: AI reasoning..." | |
| } | |
| try: | |
| # Get AI response using ReAct pattern | |
| response_text = "" | |
| try: | |
| # Generate using local model | |
| # Run in executor to avoid blocking the event loop | |
| response_text = await asyncio.get_event_loop().run_in_executor( | |
| None, | |
| self._generate_text, | |
| conversation_history | |
| ) | |
| except Exception as gen_error: | |
| logger.error(f"Text generation failed: {gen_error}", exc_info=True) | |
| yield { | |
| "type": "agent_error", | |
| "error": str(gen_error), | |
| "message": f"❌ Model error: {str(gen_error)}" | |
| } | |
| break | |
| # Check if we got a response | |
| if not response_text or not response_text.strip(): | |
| logger.warning("Empty response from model") | |
| yield { | |
| "type": "parse_error", | |
| "message": "⚠️ Model returned empty response. Retrying...", | |
| "response": "" | |
| } | |
| continue | |
| # Log the raw response for debugging | |
| logger.info(f"Model response (iteration {iteration}): {response_text[:200]}...") | |
| # Parse the response for Thought, Action, Action Input | |
| thought_match = re.search(r'Thought:\s*(.+?)(?=\n(?:Action:|Final Answer:)|$)', response_text, re.DOTALL) | |
| action_match = re.search(r'Action:\s*(\w+)', response_text) | |
| action_input_match = re.search(r'Action Input:\s*(\{.+?\})', response_text, re.DOTALL) | |
| final_answer_match = re.search(r'Final Answer:\s*(.+?)$', response_text, re.DOTALL) | |
| # Extract thought | |
| if thought_match: | |
| thought = thought_match.group(1).strip() | |
| yield { | |
| "type": "thought", | |
| "thought": thought, | |
| "message": f"💭 Thought: {thought}" | |
| } | |
| # Check if AI wants to finish | |
| if final_answer_match: | |
| final_answer = final_answer_match.group(1).strip() | |
| yield { | |
| "type": "agent_complete", | |
| "message": "✅ Task complete!", | |
| "final_answer": final_answer, | |
| "iterations": iteration | |
| } | |
| break | |
| # Execute action if present | |
| if action_match and action_input_match: | |
| tool_name = action_match.group(1).strip() | |
| action_input_str = action_input_match.group(1).strip() | |
| # Parse action input JSON | |
| try: | |
| tool_input = json.loads(action_input_str) | |
| except json.JSONDecodeError as e: | |
| error_msg = f"Invalid JSON in Action Input: {e}" | |
| logger.error(error_msg) | |
| # Give feedback to AI | |
| conversation_history += response_text | |
| conversation_history += f"\nObservation: Error - {error_msg}. Please provide valid JSON.\n\n" | |
| continue | |
| yield { | |
| "type": "tool_call", | |
| "tool": tool_name, | |
| "input": tool_input, | |
| "message": f"🔧 Action: {tool_name}" | |
| } | |
| # Execute the MCP tool | |
| try: | |
| result = await self._execute_mcp_tool(tool_name, tool_input) | |
| yield { | |
| "type": "tool_result", | |
| "tool": tool_name, | |
| "result": result, | |
| "message": f"✓ Tool {tool_name} completed" | |
| } | |
| # Add to conversation history | |
| conversation_history += response_text | |
| conversation_history += f"\nObservation: {json.dumps(result, default=str)}\n\n" | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.error(f"Tool execution failed: {tool_name} - {error_msg}") | |
| yield { | |
| "type": "tool_error", | |
| "tool": tool_name, | |
| "error": error_msg, | |
| "message": f"❌ Tool {tool_name} failed: {error_msg}" | |
| } | |
| # Give error feedback to AI | |
| conversation_history += response_text | |
| conversation_history += f"\nObservation: Error - {error_msg}\n\n" | |
| else: | |
| # No action found - AI might be confused | |
| yield { | |
| "type": "parse_error", | |
| "message": "⚠️ Could not parse Action from AI response", | |
| "response": response_text | |
| } | |
| # Give feedback to AI | |
| conversation_history += response_text | |
| conversation_history += "\nObservation: Please follow the format: 'Action: tool_name' and 'Action Input: {...}'\n\n" | |
| except (RuntimeError, StopIteration, StopAsyncIteration) as stop_err: | |
| # Handle StopIteration errors that get wrapped in RuntimeError | |
| error_msg = str(stop_err) | |
| logger.error(f"Stop iteration in agent loop: {error_msg}", exc_info=True) | |
| if "StopIteration" in error_msg or "StopAsyncIteration" in error_msg: | |
| yield { | |
| "type": "agent_error", | |
| "error": "Model inference error - possibly model not available or API issue", | |
| "message": f"❌ Model inference failed. Please check:\n" | |
| f" 1. HF_API_TOKEN is valid\n" | |
| f" 2. Model '{self.model}' is accessible\n" | |
| f" 3. HuggingFace Inference API is operational" | |
| } | |
| else: | |
| yield { | |
| "type": "agent_error", | |
| "error": error_msg, | |
| "message": f"❌ Agent error: {error_msg}" | |
| } | |
| break | |
| except Exception as e: | |
| logger.error(f"Agent iteration failed: {e}", exc_info=True) | |
| yield { | |
| "type": "agent_error", | |
| "error": str(e), | |
| "message": f"❌ Agent error: {str(e)}" | |
| } | |
| break | |
| if iteration >= max_iterations: | |
| yield { | |
| "type": "agent_max_iterations", | |
| "message": f"⚠️ Reached maximum iterations ({max_iterations})", | |
| "iterations": iteration | |
| } | |
| async def _execute_mcp_tool(self, tool_name: str, tool_input: Dict[str, Any]) -> Any: | |
| """ | |
| Execute an MCP tool by routing to the appropriate MCP server. | |
| This is where we actually call the MCP servers! | |
| """ | |
| # ============ SEARCH MCP SERVER ============ | |
| if tool_name == "search_web": | |
| query = tool_input["query"] | |
| max_results = tool_input.get("max_results", 5) | |
| results = await self.mcp_registry.search.query(query, max_results=max_results) | |
| return { | |
| "results": results[:max_results], | |
| "count": len(results[:max_results]) | |
| } | |
| elif tool_name == "search_news": | |
| query = tool_input["query"] | |
| max_results = tool_input.get("max_results", 5) | |
| results = await self.mcp_registry.search.query(f"{query} news", max_results=max_results) | |
| return { | |
| "results": results[:max_results], | |
| "count": len(results[:max_results]) | |
| } | |
| # ============ STORE MCP SERVER ============ | |
| elif tool_name == "save_prospect": | |
| prospect_data = { | |
| "id": tool_input.get("prospect_id", str(uuid.uuid4())), | |
| "company": { | |
| "id": tool_input.get("company_id"), | |
| "name": tool_input.get("company_name"), | |
| "domain": tool_input.get("company_domain") | |
| }, | |
| "fit_score": tool_input.get("fit_score", 0), | |
| "status": tool_input.get("status", "new"), | |
| "metadata": tool_input.get("metadata", {}) | |
| } | |
| result = await self.mcp_registry.store.save_prospect(prospect_data) | |
| return {"status": result, "prospect_id": prospect_data["id"]} | |
| elif tool_name == "get_prospect": | |
| prospect_id = tool_input["prospect_id"] | |
| prospect = await self.mcp_registry.store.get_prospect(prospect_id) | |
| return prospect or {"error": "Prospect not found"} | |
| elif tool_name == "list_prospects": | |
| prospects = await self.mcp_registry.store.list_prospects() | |
| status_filter = tool_input.get("status") | |
| if status_filter: | |
| prospects = [p for p in prospects if p.get("status") == status_filter] | |
| return { | |
| "prospects": prospects, | |
| "count": len(prospects) | |
| } | |
| elif tool_name == "save_company": | |
| company_data = { | |
| "id": tool_input.get("company_id", str(uuid.uuid4())), | |
| "name": tool_input["name"], | |
| "domain": tool_input["domain"], | |
| "industry": tool_input.get("industry"), | |
| "description": tool_input.get("description"), | |
| "employee_count": tool_input.get("employee_count") | |
| } | |
| result = await self.mcp_registry.store.save_company(company_data) | |
| return {"status": result, "company_id": company_data["id"]} | |
| elif tool_name == "get_company": | |
| company_id = tool_input["company_id"] | |
| company = await self.mcp_registry.store.get_company(company_id) | |
| return company or {"error": "Company not found"} | |
| elif tool_name == "save_fact": | |
| fact_data = { | |
| "id": tool_input.get("fact_id", str(uuid.uuid4())), | |
| "company_id": tool_input["company_id"], | |
| "fact_type": tool_input["fact_type"], | |
| "content": tool_input["content"], | |
| "source_url": tool_input.get("source_url"), | |
| "confidence_score": tool_input.get("confidence_score", 0.8) | |
| } | |
| result = await self.mcp_registry.store.save_fact(fact_data) | |
| return {"status": result, "fact_id": fact_data["id"]} | |
| elif tool_name == "save_contact": | |
| contact_data = { | |
| "id": tool_input.get("contact_id", str(uuid.uuid4())), | |
| "company_id": tool_input["company_id"], | |
| "email": tool_input["email"], | |
| "first_name": tool_input.get("first_name"), | |
| "last_name": tool_input.get("last_name"), | |
| "title": tool_input.get("title"), | |
| "seniority": tool_input.get("seniority") | |
| } | |
| result = await self.mcp_registry.store.save_contact(contact_data) | |
| return {"status": result, "contact_id": contact_data["id"]} | |
| elif tool_name == "list_contacts_by_domain": | |
| domain = tool_input["domain"] | |
| contacts = await self.mcp_registry.store.list_contacts_by_domain(domain) | |
| return { | |
| "contacts": contacts, | |
| "count": len(contacts) | |
| } | |
| elif tool_name == "check_suppression": | |
| supp_type = tool_input["suppression_type"] | |
| value = tool_input["value"] | |
| is_suppressed = await self.mcp_registry.store.check_suppression(supp_type, value) | |
| return { | |
| "suppressed": is_suppressed, | |
| "value": value, | |
| "type": supp_type | |
| } | |
| # ============ EMAIL MCP SERVER ============ | |
| elif tool_name == "send_email": | |
| to = tool_input["to"] | |
| subject = tool_input["subject"] | |
| body = tool_input["body"] | |
| prospect_id = tool_input["prospect_id"] | |
| thread_id = await self.mcp_registry.email.send(to, subject, body, prospect_id) | |
| return { | |
| "status": "sent", | |
| "thread_id": thread_id, | |
| "to": to | |
| } | |
| elif tool_name == "get_email_thread": | |
| prospect_id = tool_input["prospect_id"] | |
| thread = await self.mcp_registry.email.get_thread(prospect_id) | |
| return thread or {"error": "No email thread found"} | |
| # ============ CALENDAR MCP SERVER ============ | |
| elif tool_name == "suggest_meeting_slots": | |
| num_slots = tool_input.get("num_slots", 3) | |
| slots = await self.mcp_registry.calendar.suggest_slots() | |
| return { | |
| "slots": slots[:num_slots], | |
| "count": len(slots[:num_slots]) | |
| } | |
| elif tool_name == "generate_calendar_invite": | |
| start_time = tool_input["start_time"] | |
| end_time = tool_input["end_time"] | |
| title = tool_input["title"] | |
| slot = { | |
| "start_iso": start_time, | |
| "end_iso": end_time, | |
| "title": title | |
| } | |
| ics = await self.mcp_registry.calendar.generate_ics(slot) | |
| return { | |
| "ics_content": ics, | |
| "meeting": slot | |
| } | |
| else: | |
| raise ValueError(f"Unknown MCP tool: {tool_name}") | |