functiongemma-270m-it-demo

#8
by aifeifei798 - opened
import warnings
import re  # Import the regular expression library for parsing
import json  # Import the json library for more reliable parameter parsing

from transformers import AutoTokenizer, AutoModelForCausalLM

# --- Model and Schema Definition Section ---
warnings.filterwarnings(
    "ignore",
    message="The tokenizer you are loading from.*with an incorrect regex pattern.*"
)
tokenizer = AutoTokenizer.from_pretrained("./functiongemma-270m-it", device_map="auto")
model = AutoModelForCausalLM.from_pretrained("./functiongemma-270m-it", dtype="auto", device_map="auto")

play_music_schema = {
    "type": "function",
    "function": {
        "name": "play_music",
        "description": "When the user provides a clear artist or song title, use this function to play music.",
        "parameters": {
            "type": "object",
            "properties": {
                "artist": {"type": "string", "description": "The artist's name"},
                "song_title": {"type": "string", "description": "The specific title of the song"},
            },
            "required": ["artist", "song_title"],
        },
    }
}
recommend_music_schema = {
    "type": "function",
    "function": {
        "name": "recommend_music",
        "description": "When the user's request is vague, for example, only providing a music genre or saying 'play anything', use this function to recommend and play a song.",
        "parameters": {
            "type": "object",
            "properties": {"genre": {"type": "string", "description": "The user-mentioned music genre, e.g., R&B, Pop, Rock"}},
            "required": [],
        },
    }
}
stop_music_schema = {
    "type": "function",
    "function": {
        "name": "stop_music",
        "description": "Stops the currently playing music.",
        "parameters": {"type": "object", "properties": {}, "required": []},
    }
}

# The user's prompt, now in English
message = [
    {"role": "developer", "content": "You are a model that can do function calling with the following functions"},
    {"role": "user", "content": "Play an R&B song"}
]

inputs = tokenizer.apply_chat_template(
    message,
    tools=[play_music_schema, recommend_music_schema, stop_music_schema],
    add_generation_prompt=True, return_dict=True, return_tensors="pt"
)
out = model.generate(**inputs.to(device=model.device), pad_token_id=tokenizer.eos_token_id, max_new_tokens=128)
model_output = tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)

print(f"Raw Model Output: {model_output}\n")


# ==============================================================================
# --- Steps 1, 2 & 3: Parse, Route, and Execute ---
# ==============================================================================

# --- First, define our actual Python functions ---

def play_music(artist: str, song_title: str):
    """(Simulated) The actual function to play music."""
    print(f"--> [Executing] Now playing '{song_title}' by {artist}...")
    # Here, you would write the code to call a music API
    return f"Okay, now playing {song_title} by {artist}"

def recommend_music(genre: str = "any"):
    """(Simulated) The actual function to recommend music."""
    print(f"--> [Executing] Received recommendation request for genre: {genre}")
    # Here, you would write code to query a database or a recommendation engine
    recommended_song = "Slow Dancing in a Burning Room"
    recommended_artist = "John Mayer"
    print(f"--> [Action] Recommended song for you: {recommended_artist} - {recommended_song}")
    play_music(artist=recommended_artist, song_title=recommended_song)
    return f"Okay, I've recommended a {genre} song for you."

def stop_music():
    """(Simulated) The actual function to stop the music."""
    print("--> [Executing] Music has been stopped.")
    return "Music stopped."

# --- Create a "tool registry" to map function name strings to actual functions (this is the core of "routing") ---
available_tools = {
    "play_music": play_music,
    "recommend_music": recommend_music,
    "stop_music": stop_music,
}

# --- Logic to parse the model's output and execute the function ---
def parse_and_execute(model_output: str):
    # Use a regular expression to parse the function call string
    match = re.search(r"<start_function_call>call:(.+?){(.+?)}<end_function_call>", model_output)

    if not match:
        print("No function call detected, printing the model's direct response:")
        print(model_output)
        return

    function_name = match.group(1).strip()
    params_str = match.group(2).strip()
    
    # Convert Gemma's output format (key:<escape>value<escape>) to JSON format ("key":"value")
    # 1. Replace <escape> with a quote
    params_str = params_str.replace("<escape>", '"')
    # 2. Add quotes around the key
    params_str = re.sub(r'([a-zA-Z0-9_]+):', r'"\1":', params_str)
    # 3. Ensure the entire string is wrapped in {} to be valid JSON
    json_str = f"{{{params_str}}}"
    
    try:
        # Parse into a Python dictionary
        arguments = json.loads(json_str)
        print(f"Successfully parsed -> Function: '{function_name}', Arguments: {arguments}")
        
        # Route: Look up the function in our tool registry
        function_to_call = available_tools.get(function_name)
        
        if function_to_call:
            # Execute: Call the function with the parsed arguments
            print("\n--- Starting Execution ---")
            function_to_call(**arguments)
            print("--- Execution Finished ---\n")
        else:
            print(f"Error: Function named '{function_name}' not found.")
            
    except json.JSONDecodeError as e:
        print(f"Error parsing arguments: {e}")
    except Exception as e:
        print(f"Error executing function: {e}")

# --- Main execution entry point ---
parse_and_execute(model_output)

# --- Output ---

# Raw Model Output: <start_function_call>call:recommend_music{genre:<escape>R&B<escape>}<end_function_call>

# Successfully parsed -> Function: 'recommend_music', Arguments: {'genre': 'R&B'}

# --- Starting Execution ---
# --> [Executing] Received recommendation request for genre: R&B
# --> [Action] Recommended song for you: John Mayer - Slow Dancing in a Burning Room
# --> [Executing] Now playing 'Slow Dancing in a Burning Room' by John Mayer...
# --- Execution Finished ---

Sign up or log in to comment