Spaces:
Running
Running
File size: 8,529 Bytes
ded0e4a 58c61e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import os
import io
import time
import base64
import uuid
import PIL.Image
from flask import Flask, render_template, request, jsonify
from dotenv import load_dotenv
# Google Cloud & GenAI specific imports
from google.cloud import storage
from google.api_core import exceptions as google_exceptions
from google import genai
from google.genai import types
# --- Configuration & Initialization ---
# load_dotenv('.env')
app = Flask(__name__)
LOCAL_IMAGE_DIR = os.path.join('static', 'generated_images')
os.makedirs(LOCAL_IMAGE_DIR, exist_ok=True)
# Gemini Image Generation Client (using your existing setup)
API_KEY = os.environ.get("GOOGLE_API_KEY")
MODEL_ID_IMAGE = 'gemini-2.0-flash-exp-image-generation'
# Veo Video Generation Client (NEW)
PROJECT_ID = os.environ.get("PROJECT_ID")
LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME")
MODEL_ID_VIDEO = "veo-3.0-generate-preview" # Your Veo model ID
if not all([API_KEY, PROJECT_ID, GCS_BUCKET_NAME, LOCATION]):
raise RuntimeError("Missing required environment variables. Check your .env file.")
# Initialize clients
try:
# Client for Gemini Image Generation
gemini_image_client = genai.Client(api_key=API_KEY)
print(f"Gemini Image Client initialized successfully for model: {MODEL_ID_IMAGE}")
# Client for Veo Video Generation (Vertex AI)
veo_video_client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)
print(f"Veo Video Client (Vertex AI) initialized successfully for project: {PROJECT_ID}")
# Client for Google Cloud Storage
gcs_client = storage.Client(project=PROJECT_ID)
print("Google Cloud Storage Client initialized successfully.")
except Exception as e:
print(f"Error during client initialization: {e}")
gemini_image_client = veo_video_client = gcs_client = None
# --- Helper Function to Upload to GCS (NEW) ---
def upload_bytes_to_gcs(image_bytes: bytes, bucket_name: str, destination_blob_name: str) -> str:
"""Uploads image bytes to GCS and returns the GCS URI."""
if not gcs_client:
raise ConnectionError("GCS client is not initialized.")
bucket = gcs_client.bucket(bucket_name)
blob = bucket.blob(destination_blob_name)
blob.upload_from_string(image_bytes, content_type='image/png')
gcs_uri = f"gs://{bucket_name}/{destination_blob_name}"
print(f"Image successfully uploaded to {gcs_uri}")
return gcs_uri
# --- Main Routes ---
@app.route('/')
def index():
"""Renders the main HTML page."""
return render_template('index.html')
@app.route('/generate', methods=['POST'])
def generate_video_from_sketch():
"""Full pipeline: sketch -> image -> video."""
if not all([gemini_image_client]):
# if not all([gemini_image_client, veo_video_client, gcs_client]):
return jsonify({"error": "A server-side client is not initialized. Check server logs."}), 500
if not request.json or 'image_data' not in request.json:
return jsonify({"error": "Missing image_data in request"}), 400
base64_image_data = request.json['image_data']
user_prompt = request.json.get('prompt', '').strip()
# --- Step 1: Generate Image with Gemini ---
try:
print("--- Step 1: Generating image from sketch with Gemini ---")
if ',' in base64_image_data:
base64_data = base64_image_data.split(',', 1)[1]
else:
base64_data = base64_image_data
image_bytes = base64.b64decode(base64_data)
sketch_pil_image = PIL.Image.open(io.BytesIO(image_bytes))
# default_prompt = "Create a photorealistic image based on this sketch. Focus on realistic lighting, textures, and shadows to make it look like a photograph taken with a professional DSLR camera."
default_prompt = "Convert this sketch into a photorealistic image as if it were taken from a real DSLR camera. The elements and objects should look real."
#prompt_text = f"{default_prompt} {user_prompt}" if user_prompt else default_prompt
response = gemini_image_client.models.generate_content(
model=MODEL_ID_IMAGE,
contents=[default_prompt, sketch_pil_image],
config=types.GenerateContentConfig(response_modalities=['TEXT', 'IMAGE'])
)
if not response.candidates:
raise ValueError("Gemini image generation returned no candidates.")
generated_image_bytes = None
for part in response.candidates[0].content.parts:
if part.inline_data and part.inline_data.mime_type.startswith('image/'):
generated_image_bytes = part.inline_data.data
break
if not generated_image_bytes:
raise ValueError("Gemini did not return an image in the response.")
print("Image generated successfully.")
try:
# Use a unique filename to prevent overwrites
local_filename = f"generated-image-{uuid.uuid4()}.png"
local_image_path = os.path.join(LOCAL_IMAGE_DIR, local_filename)
# Write the bytes to a file in binary mode ('wb')
with open(local_image_path, "wb") as f:
f.write(generated_image_bytes)
print(f"Image also saved locally to: {local_image_path}")
except Exception as e:
# This is not a critical error, so we just print a warning and continue.
print(f"[Warning] Could not save image locally: {e}")
except Exception as e:
print(f"Error during Gemini image generation: {e}")
return jsonify({"error": f"Failed to generate image: {e}"}), 500
# --- Step 2 & 3: Upload Image to GCS and Generate Video with Veo ---
try:
print("\n--- Step 2: Uploading generated image to GCS ---")
unique_id = uuid.uuid4()
image_blob_name = f"images/generated-image-{unique_id}.png"
output_gcs_prefix = f"gs://{GCS_BUCKET_NAME}/videos/" # Folder for video outputs
image_gcs_uri = upload_bytes_to_gcs(generated_image_bytes, GCS_BUCKET_NAME, image_blob_name)
print("\n--- Step 3: Calling Veo to generate video ---")
default_video_prompt = "Animate this image. Add subtle, cinematic motion."
video_prompt = f"{user_prompt}" if user_prompt else default_video_prompt
print(video_prompt)
operation = veo_video_client.models.generate_videos(
model=MODEL_ID_VIDEO,
prompt=video_prompt,
image=types.Image(gcs_uri=image_gcs_uri, mime_type="image/png"),
config=types.GenerateVideosConfig(
aspect_ratio="16:9",
output_gcs_uri=output_gcs_prefix,
duration_seconds=8,
person_generation="allow_adult",
enhance_prompt=True,
generate_audio=True, # Keep it simple for now
),
)
# WARNING: This is a synchronous poll, which will block the server thread.
# For production, consider an asynchronous pattern (e.g., websockets or long polling).
timeout_seconds = 300 # 5 minutes
start_time = time.time()
while not operation.done:
if time.time() - start_time > timeout_seconds:
raise TimeoutError("Video generation timed out.")
time.sleep(15)
# You must get the operation object again to refresh its status
operation = veo_video_client.operations.get(operation)
print(operation)
print("Video generation operation complete.")
if not operation.response or not operation.result.generated_videos:
raise ValueError("Veo operation completed but returned no video.")
video_gcs_uri = operation.result.generated_videos[0].video.uri
print(f"Video saved to GCS at: {video_gcs_uri}")
# Convert gs:// URI to public https:// URL
video_blob_name = video_gcs_uri.replace(f"gs://{GCS_BUCKET_NAME}/", "")
public_video_url = f"https://storage.googleapis.com/{GCS_BUCKET_NAME}/{video_blob_name}"
print(f"Video generated successfully. Public URL: {public_video_url}")
return jsonify({"generated_video_url": public_video_url})
except Exception as e:
print(f"An error occurred during video generation: {e}")
return jsonify({"error": f"Failed to generate video: {e}"}), 500
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=5000) |