jeanbaptdzd's picture
feat: Clean deployment to HuggingFace Space with model config test endpoint
8c0b652
#!/usr/bin/env python3
"""
Tool usage test suite.
Tests the model's ability to use tools and function calling.
"""
import json
from typing import List, Dict, Any
from ..core.base_tester import BaseTester, TestCase
from ..tools.time_tool import TimeTool
from ..tools.ticker_tool import TickerTool
class ToolUsageTester(BaseTester):
"""Test tool usage and function calling capabilities."""
def __init__(self, endpoint: str, model: str = None):
super().__init__(endpoint, model)
self.time_tool = TimeTool()
self.ticker_tool = TickerTool()
self.available_tools = [self.time_tool, self.ticker_tool]
def load_test_cases(self) -> List[TestCase]:
"""Load tool usage test cases."""
return [
TestCase(
name="time_tool_basic",
prompt="What is the current UTC time?",
expected_keys=["response"],
max_tokens=100,
metadata={"expected_tool": "get_current_time"}
),
TestCase(
name="time_tool_formatted",
prompt="Get the current time in a readable format.",
expected_keys=["response"],
max_tokens=100,
metadata={"expected_tool": "get_current_time", "expected_format": "readable"}
),
TestCase(
name="ticker_tool_basic",
prompt="What is the current price of Apple stock (AAPL)?",
expected_keys=["response"],
max_tokens=100,
metadata={"expected_tool": "get_ticker_info", "expected_symbol": "AAPL"}
),
TestCase(
name="ticker_tool_company_info",
prompt="Get company information for Microsoft (MSFT).",
expected_keys=["response"],
max_tokens=120,
metadata={"expected_tool": "get_ticker_info", "expected_symbol": "MSFT", "expected_info": "company"}
),
TestCase(
name="multiple_tools",
prompt="What is the current time and the price of Google stock (GOOGL)?",
expected_keys=["response"],
max_tokens=150,
metadata={"expected_tools": ["get_current_time", "get_ticker_info"]}
),
TestCase(
name="tool_with_context",
prompt="I need to know the current time in UTC format and the financial information for JPMorgan Chase (JPM) for my risk assessment report.",
expected_keys=["response"],
max_tokens=180,
metadata={"expected_tools": ["get_current_time", "get_ticker_info"], "expected_symbol": "JPM"}
)
]
def validate_response(self, response: Dict[str, Any], test_case: TestCase) -> bool:
"""Validate tool usage response."""
try:
# Check if response exists
if "response" not in response:
return False
response_text = response["response"]
# Basic validation
if not response_text or len(response_text.strip()) < 10:
return False
# Check for tool usage indicators
tool_indicators = [
"calling", "using", "tool", "function", "get_current_time",
"get_ticker_info", "executing", "retrieved", "fetched"
]
has_tool_usage = any(indicator in response_text.lower() for indicator in tool_indicators)
# Validate based on test case expectations
metadata = test_case.metadata or {}
if test_case.name == "time_tool_basic":
# Should contain time-related information
time_indicators = ["utc", "time", "2024", "2025", ":", "am", "pm"]
return has_tool_usage and any(indicator in response_text.lower() for indicator in time_indicators)
elif test_case.name == "time_tool_formatted":
# Should contain formatted time
return has_tool_usage and ("readable" in response_text.lower() or "format" in response_text.lower())
elif test_case.name == "ticker_tool_basic":
# Should contain Apple stock information
return (has_tool_usage and
("aapl" in response_text.lower() or "apple" in response_text.lower()) and
("$" in response_text or "price" in response_text.lower()))
elif test_case.name == "ticker_tool_company_info":
# Should contain Microsoft company information
return (has_tool_usage and
("msft" in response_text.lower() or "microsoft" in response_text.lower()) and
("company" in response_text.lower() or "corporation" in response_text.lower()))
elif test_case.name == "multiple_tools":
# Should contain both time and Google stock info
return (has_tool_usage and
("time" in response_text.lower() or "utc" in response_text.lower()) and
("googl" in response_text.lower() or "google" in response_text.lower()))
elif test_case.name == "tool_with_context":
# Should contain both time and JPMorgan info with context
return (has_tool_usage and
("time" in response_text.lower() or "utc" in response_text.lower()) and
("jpm" in response_text.lower() or "jpmorgan" in response_text.lower()) and
("risk" in response_text.lower() or "assessment" in response_text.lower()))
# Default validation
return has_tool_usage
except Exception as e:
print(f"Validation error: {e}")
return False
def simulate_tool_execution(self, tool_name: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
"""Simulate tool execution for testing purposes."""
if tool_name == "get_current_time":
return self.time_tool.execute(**parameters)
elif tool_name == "get_ticker_info":
return self.ticker_tool.execute(**parameters)
else:
return {
"success": False,
"error": f"Unknown tool: {tool_name}"
}
def get_available_tools(self) -> List[Dict[str, Any]]:
"""Get list of available tools for function calling."""
return [tool.get_tool_definition() for tool in self.available_tools]