functiongemma-270m-it-demo
#8
by
aifeifei798
- opened
import warnings
import re # Import the regular expression library for parsing
import json # Import the json library for more reliable parameter parsing
from transformers import AutoTokenizer, AutoModelForCausalLM
# --- Model and Schema Definition Section ---
warnings.filterwarnings(
"ignore",
message="The tokenizer you are loading from.*with an incorrect regex pattern.*"
)
tokenizer = AutoTokenizer.from_pretrained("./functiongemma-270m-it", device_map="auto")
model = AutoModelForCausalLM.from_pretrained("./functiongemma-270m-it", dtype="auto", device_map="auto")
play_music_schema = {
"type": "function",
"function": {
"name": "play_music",
"description": "When the user provides a clear artist or song title, use this function to play music.",
"parameters": {
"type": "object",
"properties": {
"artist": {"type": "string", "description": "The artist's name"},
"song_title": {"type": "string", "description": "The specific title of the song"},
},
"required": ["artist", "song_title"],
},
}
}
recommend_music_schema = {
"type": "function",
"function": {
"name": "recommend_music",
"description": "When the user's request is vague, for example, only providing a music genre or saying 'play anything', use this function to recommend and play a song.",
"parameters": {
"type": "object",
"properties": {"genre": {"type": "string", "description": "The user-mentioned music genre, e.g., R&B, Pop, Rock"}},
"required": [],
},
}
}
stop_music_schema = {
"type": "function",
"function": {
"name": "stop_music",
"description": "Stops the currently playing music.",
"parameters": {"type": "object", "properties": {}, "required": []},
}
}
# The user's prompt, now in English
message = [
{"role": "developer", "content": "You are a model that can do function calling with the following functions"},
{"role": "user", "content": "Play an R&B song"}
]
inputs = tokenizer.apply_chat_template(
message,
tools=[play_music_schema, recommend_music_schema, stop_music_schema],
add_generation_prompt=True, return_dict=True, return_tensors="pt"
)
out = model.generate(**inputs.to(device=model.device), pad_token_id=tokenizer.eos_token_id, max_new_tokens=128)
model_output = tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
print(f"Raw Model Output: {model_output}\n")
# ==============================================================================
# --- Steps 1, 2 & 3: Parse, Route, and Execute ---
# ==============================================================================
# --- First, define our actual Python functions ---
def play_music(artist: str, song_title: str):
"""(Simulated) The actual function to play music."""
print(f"--> [Executing] Now playing '{song_title}' by {artist}...")
# Here, you would write the code to call a music API
return f"Okay, now playing {song_title} by {artist}"
def recommend_music(genre: str = "any"):
"""(Simulated) The actual function to recommend music."""
print(f"--> [Executing] Received recommendation request for genre: {genre}")
# Here, you would write code to query a database or a recommendation engine
recommended_song = "Slow Dancing in a Burning Room"
recommended_artist = "John Mayer"
print(f"--> [Action] Recommended song for you: {recommended_artist} - {recommended_song}")
play_music(artist=recommended_artist, song_title=recommended_song)
return f"Okay, I've recommended a {genre} song for you."
def stop_music():
"""(Simulated) The actual function to stop the music."""
print("--> [Executing] Music has been stopped.")
return "Music stopped."
# --- Create a "tool registry" to map function name strings to actual functions (this is the core of "routing") ---
available_tools = {
"play_music": play_music,
"recommend_music": recommend_music,
"stop_music": stop_music,
}
# --- Logic to parse the model's output and execute the function ---
def parse_and_execute(model_output: str):
# Use a regular expression to parse the function call string
match = re.search(r"<start_function_call>call:(.+?){(.+?)}<end_function_call>", model_output)
if not match:
print("No function call detected, printing the model's direct response:")
print(model_output)
return
function_name = match.group(1).strip()
params_str = match.group(2).strip()
# Convert Gemma's output format (key:<escape>value<escape>) to JSON format ("key":"value")
# 1. Replace <escape> with a quote
params_str = params_str.replace("<escape>", '"')
# 2. Add quotes around the key
params_str = re.sub(r'([a-zA-Z0-9_]+):', r'"\1":', params_str)
# 3. Ensure the entire string is wrapped in {} to be valid JSON
json_str = f"{{{params_str}}}"
try:
# Parse into a Python dictionary
arguments = json.loads(json_str)
print(f"Successfully parsed -> Function: '{function_name}', Arguments: {arguments}")
# Route: Look up the function in our tool registry
function_to_call = available_tools.get(function_name)
if function_to_call:
# Execute: Call the function with the parsed arguments
print("\n--- Starting Execution ---")
function_to_call(**arguments)
print("--- Execution Finished ---\n")
else:
print(f"Error: Function named '{function_name}' not found.")
except json.JSONDecodeError as e:
print(f"Error parsing arguments: {e}")
except Exception as e:
print(f"Error executing function: {e}")
# --- Main execution entry point ---
parse_and_execute(model_output)
# --- Output ---
# Raw Model Output: <start_function_call>call:recommend_music{genre:<escape>R&B<escape>}<end_function_call>
# Successfully parsed -> Function: 'recommend_music', Arguments: {'genre': 'R&B'}
# --- Starting Execution ---
# --> [Executing] Received recommendation request for genre: R&B
# --> [Action] Recommended song for you: John Mayer - Slow Dancing in a Burning Room
# --> [Executing] Now playing 'Slow Dancing in a Burning Room' by John Mayer...
# --- Execution Finished ---