Spaces:
Running
Running
| import os | |
| import io | |
| import time | |
| import base64 | |
| import uuid | |
| import PIL.Image | |
| from flask import Flask, render_template, request, jsonify | |
| from dotenv import load_dotenv | |
| # Google Cloud & GenAI specific imports | |
| from google.cloud import storage | |
| from google.api_core import exceptions as google_exceptions | |
| from google import genai | |
| from google.genai import types | |
| # --- Configuration & Initialization --- | |
| # load_dotenv('.env') | |
| app = Flask(__name__) | |
| LOCAL_IMAGE_DIR = os.path.join('static', 'generated_images') | |
| os.makedirs(LOCAL_IMAGE_DIR, exist_ok=True) | |
| # Gemini Image Generation Client (using your existing setup) | |
| API_KEY = os.environ.get("GOOGLE_API_KEY") | |
| MODEL_ID_IMAGE = 'gemini-2.0-flash-exp-image-generation' | |
| # Veo Video Generation Client (NEW) | |
| PROJECT_ID = os.environ.get("PROJECT_ID") | |
| LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1") | |
| GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME") | |
| MODEL_ID_VIDEO = "veo-3.0-generate-preview" # Your Veo model ID | |
| if not all([API_KEY, PROJECT_ID, GCS_BUCKET_NAME, LOCATION]): | |
| raise RuntimeError("Missing required environment variables. Check your .env file.") | |
| # Initialize clients | |
| try: | |
| # Client for Gemini Image Generation | |
| gemini_image_client = genai.Client(api_key=API_KEY) | |
| print(f"Gemini Image Client initialized successfully for model: {MODEL_ID_IMAGE}") | |
| # Client for Veo Video Generation (Vertex AI) | |
| veo_video_client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION) | |
| print(f"Veo Video Client (Vertex AI) initialized successfully for project: {PROJECT_ID}") | |
| # Client for Google Cloud Storage | |
| gcs_client = storage.Client(project=PROJECT_ID) | |
| print("Google Cloud Storage Client initialized successfully.") | |
| except Exception as e: | |
| print(f"Error during client initialization: {e}") | |
| gemini_image_client = veo_video_client = gcs_client = None | |
| # --- Helper Function to Upload to GCS (NEW) --- | |
| def upload_bytes_to_gcs(image_bytes: bytes, bucket_name: str, destination_blob_name: str) -> str: | |
| """Uploads image bytes to GCS and returns the GCS URI.""" | |
| if not gcs_client: | |
| raise ConnectionError("GCS client is not initialized.") | |
| bucket = gcs_client.bucket(bucket_name) | |
| blob = bucket.blob(destination_blob_name) | |
| blob.upload_from_string(image_bytes, content_type='image/png') | |
| gcs_uri = f"gs://{bucket_name}/{destination_blob_name}" | |
| print(f"Image successfully uploaded to {gcs_uri}") | |
| return gcs_uri | |
| # --- Main Routes --- | |
| def index(): | |
| """Renders the main HTML page.""" | |
| return render_template('index.html') | |
| def generate_video_from_sketch(): | |
| """Full pipeline: sketch -> image -> video.""" | |
| if not all([gemini_image_client]): | |
| # if not all([gemini_image_client, veo_video_client, gcs_client]): | |
| return jsonify({"error": "A server-side client is not initialized. Check server logs."}), 500 | |
| if not request.json or 'image_data' not in request.json: | |
| return jsonify({"error": "Missing image_data in request"}), 400 | |
| base64_image_data = request.json['image_data'] | |
| user_prompt = request.json.get('prompt', '').strip() | |
| # --- Step 1: Generate Image with Gemini --- | |
| try: | |
| print("--- Step 1: Generating image from sketch with Gemini ---") | |
| if ',' in base64_image_data: | |
| base64_data = base64_image_data.split(',', 1)[1] | |
| else: | |
| base64_data = base64_image_data | |
| image_bytes = base64.b64decode(base64_data) | |
| sketch_pil_image = PIL.Image.open(io.BytesIO(image_bytes)) | |
| # default_prompt = "Create a photorealistic image based on this sketch. Focus on realistic lighting, textures, and shadows to make it look like a photograph taken with a professional DSLR camera." | |
| default_prompt = "Convert this sketch into a photorealistic image as if it were taken from a real DSLR camera. The elements and objects should look real." | |
| #prompt_text = f"{default_prompt} {user_prompt}" if user_prompt else default_prompt | |
| response = gemini_image_client.models.generate_content( | |
| model=MODEL_ID_IMAGE, | |
| contents=[default_prompt, sketch_pil_image], | |
| config=types.GenerateContentConfig(response_modalities=['TEXT', 'IMAGE']) | |
| ) | |
| if not response.candidates: | |
| raise ValueError("Gemini image generation returned no candidates.") | |
| generated_image_bytes = None | |
| for part in response.candidates[0].content.parts: | |
| if part.inline_data and part.inline_data.mime_type.startswith('image/'): | |
| generated_image_bytes = part.inline_data.data | |
| break | |
| if not generated_image_bytes: | |
| raise ValueError("Gemini did not return an image in the response.") | |
| print("Image generated successfully.") | |
| try: | |
| # Use a unique filename to prevent overwrites | |
| local_filename = f"generated-image-{uuid.uuid4()}.png" | |
| local_image_path = os.path.join(LOCAL_IMAGE_DIR, local_filename) | |
| # Write the bytes to a file in binary mode ('wb') | |
| with open(local_image_path, "wb") as f: | |
| f.write(generated_image_bytes) | |
| print(f"Image also saved locally to: {local_image_path}") | |
| except Exception as e: | |
| # This is not a critical error, so we just print a warning and continue. | |
| print(f"[Warning] Could not save image locally: {e}") | |
| except Exception as e: | |
| print(f"Error during Gemini image generation: {e}") | |
| return jsonify({"error": f"Failed to generate image: {e}"}), 500 | |
| # --- Step 2 & 3: Upload Image to GCS and Generate Video with Veo --- | |
| try: | |
| print("\n--- Step 2: Uploading generated image to GCS ---") | |
| unique_id = uuid.uuid4() | |
| image_blob_name = f"images/generated-image-{unique_id}.png" | |
| output_gcs_prefix = f"gs://{GCS_BUCKET_NAME}/videos/" # Folder for video outputs | |
| image_gcs_uri = upload_bytes_to_gcs(generated_image_bytes, GCS_BUCKET_NAME, image_blob_name) | |
| print("\n--- Step 3: Calling Veo to generate video ---") | |
| default_video_prompt = "Animate this image. Add subtle, cinematic motion." | |
| video_prompt = f"{user_prompt}" if user_prompt else default_video_prompt | |
| print(video_prompt) | |
| operation = veo_video_client.models.generate_videos( | |
| model=MODEL_ID_VIDEO, | |
| prompt=video_prompt, | |
| image=types.Image(gcs_uri=image_gcs_uri, mime_type="image/png"), | |
| config=types.GenerateVideosConfig( | |
| aspect_ratio="16:9", | |
| output_gcs_uri=output_gcs_prefix, | |
| duration_seconds=8, | |
| person_generation="allow_adult", | |
| enhance_prompt=True, | |
| generate_audio=True, # Keep it simple for now | |
| ), | |
| ) | |
| # WARNING: This is a synchronous poll, which will block the server thread. | |
| # For production, consider an asynchronous pattern (e.g., websockets or long polling). | |
| timeout_seconds = 300 # 5 minutes | |
| start_time = time.time() | |
| while not operation.done: | |
| if time.time() - start_time > timeout_seconds: | |
| raise TimeoutError("Video generation timed out.") | |
| time.sleep(15) | |
| # You must get the operation object again to refresh its status | |
| operation = veo_video_client.operations.get(operation) | |
| print(operation) | |
| print("Video generation operation complete.") | |
| if not operation.response or not operation.result.generated_videos: | |
| raise ValueError("Veo operation completed but returned no video.") | |
| video_gcs_uri = operation.result.generated_videos[0].video.uri | |
| print(f"Video saved to GCS at: {video_gcs_uri}") | |
| # Convert gs:// URI to public https:// URL | |
| video_blob_name = video_gcs_uri.replace(f"gs://{GCS_BUCKET_NAME}/", "") | |
| public_video_url = f"https://storage.googleapis.com/{GCS_BUCKET_NAME}/{video_blob_name}" | |
| print(f"Video generated successfully. Public URL: {public_video_url}") | |
| return jsonify({"generated_video_url": public_video_url}) | |
| except Exception as e: | |
| print(f"An error occurred during video generation: {e}") | |
| return jsonify({"error": f"Failed to generate video: {e}"}), 500 | |
| if __name__ == '__main__': | |
| app.run(debug=True, host='0.0.0.0', port=5000) |