Spaces:
Sleeping
Sleeping
| from google import genai | |
| from google.genai import types | |
| from typing import Union, List, Generator, Dict, Optional | |
| from PIL import Image | |
| from io import BytesIO | |
| import base64 | |
| import requests | |
| import asyncio | |
| import os | |
| from dotenv import load_dotenv | |
| from .category_instructions import get_instruction_for_category | |
| from .category_config import CATEGORY_CONFIGS | |
| load_dotenv() | |
| client = genai.Client( | |
| api_key=os.getenv("API_KEY") | |
| ) | |
| def bytes_to_base64(data: bytes, with_prefix: bool = True) -> str: | |
| encoded = base64.b64encode(data).decode("utf-8") | |
| return f"data:image/png;base64,{encoded}" if with_prefix else encoded | |
| def decode_base64_image(base64_str: str) -> Image.Image: | |
| # Remove the prefix if present (e.g., "data:image/png;base64,") | |
| if base64_str.startswith("data:image"): | |
| base64_str = base64_str.split(",")[1] | |
| image_data = base64.b64decode(base64_str) | |
| image = Image.open(BytesIO(image_data)) | |
| return image | |
| async def async_generate_text_and_image(prompt, category: Optional[str] = None): | |
| # Get the appropriate instruction and configuration | |
| instruction = get_instruction_for_category(category) | |
| config = CATEGORY_CONFIGS.get(category.lower() if category else "", {}) | |
| # Enhance the prompt with category-specific guidance if available | |
| if config: | |
| style_guide = config.get("style_guide", "") | |
| conventions = config.get("conventions", []) | |
| common_elements = config.get("common_elements", []) | |
| enhanced_prompt = ( | |
| f"{instruction}\n\n" | |
| f"Style Guide: {style_guide}\n" | |
| f"Drawing Conventions to Follow:\n- " + "\n- ".join(conventions) + "\n" | |
| f"Consider Including These Elements:\n- " + "\n- ".join(common_elements) + "\n\n" | |
| f"User Request: {prompt}" | |
| ) | |
| else: | |
| enhanced_prompt = f"{instruction}\n\nUser Request: {prompt}" | |
| response = await client.aio.models.generate_content( | |
| model=os.getenv("MODEL"), | |
| contents=enhanced_prompt, | |
| config=types.GenerateContentConfig( | |
| response_modalities=['TEXT', 'IMAGE'] | |
| ) | |
| ) | |
| for part in response.candidates[0].content.parts: | |
| if hasattr(part, 'text') and part.text is not None: | |
| # Try to parse the text into sections | |
| try: | |
| text_sections = {} | |
| current_section = "overview" | |
| lines = part.text.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| # Check for section headers | |
| if any(line.lower().startswith(f"{i}.") for i in range(1, 6)): | |
| section_name = line.split('.', 1)[1].split(':', 1)[0].strip().lower() | |
| section_name = section_name.replace(' ', '_') | |
| current_section = section_name | |
| text_sections[current_section] = [] | |
| else: | |
| if current_section not in text_sections: | |
| text_sections[current_section] = [] | |
| text_sections[current_section].append(line) | |
| # Clean up the sections | |
| for section in text_sections: | |
| text_sections[section] = '\n'.join(text_sections[section]).strip() | |
| yield {'type': 'text', 'data': text_sections} | |
| except Exception as e: | |
| # Fallback to raw text if parsing fails | |
| yield {'type': 'text', 'data': {'raw_text': part.text}} | |
| elif hasattr(part, 'inline_data') and part.inline_data is not None: | |
| yield {'type': 'image', 'data': bytes_to_base64(part.inline_data.data)} | |
| async def async_generate_with_image_input(text: Optional[str], image_path: str, category: Optional[str] = None): | |
| # Validate that the image input is a base64 data URI | |
| if not isinstance(image_path, str) or not image_path.startswith("data:image/"): | |
| raise ValueError("Invalid image input: expected a base64 Data URI starting with 'data:image/'") | |
| # Decode the base64 string into a PIL Image | |
| image = decode_base64_image(image_path) | |
| # Get the appropriate instruction for the category | |
| instruction = get_instruction_for_category(category) | |
| contents = [] | |
| if text: | |
| # Combine the instruction with the user's text input | |
| combined_text = f"{instruction}\n\nUser Request: {text}" | |
| contents.append(combined_text) | |
| else: | |
| contents.append(instruction) | |
| contents.append(image) | |
| response = await client.aio.models.generate_content( | |
| model=os.getenv("MODEL"), | |
| contents=contents, | |
| config=types.GenerateContentConfig( | |
| response_modalities=['TEXT', 'IMAGE'] | |
| ) | |
| ) | |
| for part in response.candidates[0].content.parts: | |
| if hasattr(part, 'text') and part.text is not None: | |
| # Try to parse the text into sections | |
| try: | |
| text_sections = {} | |
| current_section = "overview" | |
| lines = part.text.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| # Check for section headers | |
| if any(line.lower().startswith(f"{i}.") for i in range(1, 6)): | |
| section_name = line.split('.', 1)[1].split(':', 1)[0].strip().lower() | |
| section_name = section_name.replace(' ', '_') | |
| current_section = section_name | |
| text_sections[current_section] = [] | |
| else: | |
| if current_section not in text_sections: | |
| text_sections[current_section] = [] | |
| text_sections[current_section].append(line) | |
| # Clean up the sections | |
| for section in text_sections: | |
| text_sections[section] = '\n'.join(text_sections[section]).strip() | |
| yield {'type': 'text', 'data': text_sections} | |
| except Exception as e: | |
| # Fallback to raw text if parsing fails | |
| yield {'type': 'text', 'data': {'raw_text': part.text}} | |
| elif hasattr(part, 'inline_data') and part.inline_data is not None: | |
| yield {'type': 'image', 'data': bytes_to_base64(part.inline_data.data)} |