import spaces import gradio as gr import os import random import numpy as np import torch from transformers import pipeline # Import the pipeline from flux_pipeline_mod import FluxMoDTilingPipeline # 1. Load Translation Models --- # These models are small and run efficiently on CPU. print("Loading translation models...") try: ko_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") zh_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en") print("Translation models loaded successfully.") except Exception as e: print(f"Could not load translation models: {e}") ko_en_translator = None zh_en_translator = None # 2. Conditional MMGP Setup --- USE_MMGP_ENV = os.getenv("USE_MMGP", "true").lower() USE_MMGP = USE_MMGP_ENV not in ("false", "0", "no", "none") offload = None if USE_MMGP: print("INFO: Attempting to use MMGP.") try: from mmgp import offload, profile_type print("Successfully imported MMGP.") except ImportError: print("WARNING: MMGP import failed. Falling back to standard offload.") USE_MMGP = False else: print("INFO: MMGP is disabled.") MAX_SEED = np.iinfo(np.int32).max # 3. Load the Main Pipeline --- print("Loading the FLUX Tiling pipeline...") # Use an environment variable for the model path to make it flexible MODEL_PATH = os.getenv("MODEL_PATH", "black-forest-labs/FLUX.1-schnell") print(f"Loading model from: {MODEL_PATH}") pipe = FluxMoDTilingPipeline.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to( "cuda" ) if USE_MMGP and offload: print("Applying LowRAM_LowVRAM offload profile via MMGP...") offload.profile(pipe, profile_type.LowRAM_LowVRAM) else: print("Attempting to use the standard Diffusers CPU offload...") try: pipe.enable_model_cpu_offload() except Exception as e: print(f"Could not apply standard offload: {e}") print("Pipeline loaded and ready.") # Helper Functions def translate_prompt(text: str, language: str) -> str: """Translates text to English if the selected language is not English.""" if language == "English" or not text.strip(): return text translated_text = text if language == "Korean" and ko_en_translator: if any( "\uac00" <= char <= "\ud7a3" for char in text ): # Check if Korean characters are present print(f"Translating Korean to English: '{text}'") translated_text = ko_en_translator(text)[0]["translation_text"] print(f" -> Translated: '{translated_text}'") elif language == "Chinese" and zh_en_translator: if any( "\u4e00" <= char <= "\u9fff" for char in text ): # Check if Chinese characters are present print(f"Translating Chinese to English: '{text}'") translated_text = zh_en_translator(text)[0]["translation_text"] print(f" -> Translated: '{translated_text}'") return translated_text def create_hdr_effect(image, hdr_strength): if hdr_strength == 0: return image from PIL import ImageEnhance, Image if isinstance(image, Image.Image): image = np.array(image) from scipy.ndimage import gaussian_filter blurred = gaussian_filter(image, sigma=5) sharpened = np.clip(image + hdr_strength * (image - blurred), 0, 255).astype( np.uint8 ) pil_img = Image.fromarray(sharpened) converter = ImageEnhance.Color(pil_img) return converter.enhance(1 + hdr_strength) @spaces.GPU(duration=120) def generate_flux_panorama( left_prompt, center_prompt, right_prompt, left_gs, center_gs, right_gs, overlap_pixels, steps, generation_seed, tile_weighting_method, prompt_language, target_height, target_width, hdr, randomize_seed, progress=gr.Progress(track_tqdm=True), ): """ Generate a panoramic image using the FLUX model with tiling and composition. Args: left_prompt (str): Text prompt for the left section of the panorama. center_prompt (str): Text prompt for the center section of the panorama. right_prompt (str): Text prompt for the right section of the panorama. left_gs (float): Guidance scale for the left tile. center_gs (float): Guidance scale for the center tile. right_gs (float): Guidance scale for the right tile. overlap_pixels (int): Number of pixels to overlap between tiles. steps (int): Number of inference steps for generation. generation_seed (int): Random seed for reproducibility. tile_weighting_method (str): Method for weighting overlapping tile regions. prompt_language (str): Language code for prompt translation. target_height (int): Height of the generated panorama in pixels. target_width (int): Width of the generated panorama in pixels. hdr (float): HDR effect intensity. randomize_seed (boolean): Not used. progress (gr.Progress): Gradio progress tracker. Returns: PIL.Image: The generated panoramic image with optional HDR effect applied. """ if not left_prompt or not center_prompt or not right_prompt: gr.Info("⚡️ Prompts must be provided!") return gr.skip() global pipe generator = torch.Generator("cuda").manual_seed(generation_seed) final_height, final_width = int(target_height), int(target_width) # Translate prompts if necessary translated_left = translate_prompt(left_prompt, prompt_language) translated_center = translate_prompt(center_prompt, prompt_language) translated_right = translate_prompt(right_prompt, prompt_language) print("Starting generation with Tiling Pipeline (Composition Mode)...") image = pipe( prompt=[[translated_left, translated_center, translated_right]], height=final_height, width=final_width, tile_overlap=overlap_pixels, guidance_scale_tiles=[[left_gs, center_gs, right_gs]], tile_weighting_method=tile_weighting_method, generator=generator, num_inference_steps=steps, max_sequence_length=512, ).images[0] return create_hdr_effect(image, hdr) def calculate_tile_size(target_height, target_width, overlap_pixels): """ Calculate tile dimensions for panoramic image generation. Args: target_height (int): The target height of the final panoramic image in pixels. target_width (int): The target width of the final panoramic image in pixels. overlap_pixels (int): The number of overlapping pixels between adjacent tiles. Returns: tuple: A tuple of 4 gr.update objects containing: - final_height: Final panorama height after tiling - final_width: Final panorama width after tiling """ num_cols = 3 num_rows = 1 tile_width = (target_width + (num_cols - 1) * overlap_pixels) // num_cols tile_height = (target_height + (num_rows - 1) * overlap_pixels) // num_rows tile_width -= tile_width % 16 tile_height -= tile_height % 16 final_width = tile_width * num_cols - (num_cols - 1) * overlap_pixels final_height = tile_height * num_rows - (num_rows - 1) * overlap_pixels return ( gr.update(value=final_height), gr.update(value=final_width), ) def clear_result(): return gr.update(value=None) def run_for_examples( left_prompt, center_prompt, right_prompt, left_gs, center_gs, right_gs, overlap_pixels, steps, generation_seed, tile_weighting_method, prompt_language, target_height, target_width, hdr, randomize_seed ): return generate_flux_panorama( left_prompt, center_prompt, right_prompt, left_gs, center_gs, right_gs, overlap_pixels, steps, generation_seed, tile_weighting_method, prompt_language, target_height, target_width, hdr, randomize_seed ) def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int: if randomize_seed: generation_seed = random.randint(0, MAX_SEED) return generation_seed # UI Layout theme = gr.themes.Default( primary_hue="blue", secondary_hue="teal", neutral_hue="neutral" ).set( body_background_fill="*neutral_100", body_background_fill_dark="*neutral_900", body_text_color="*neutral_700", body_text_color_dark="*neutral_200", body_text_weight="400", link_text_color="*primary_500", link_text_color_dark="*primary_400", code_background_fill="*neutral_100", code_background_fill_dark="*neutral_800", shadow_drop="0 1px 3px rgba(0,0,0,0.1)", shadow_inset="inset 0 2px 4px rgba(0,0,0,0.05)", block_background_fill="*neutral_50", block_background_fill_dark="*neutral_700", block_border_color="*neutral_200", block_border_color_dark="*neutral_600", block_border_width="1px", block_border_width_dark="1px", block_label_background_fill="*primary_50", block_label_background_fill_dark="*primary_600", block_label_text_color="*primary_600", block_label_text_color_dark="*primary_50", panel_background_fill="white", panel_background_fill_dark="*neutral_800", panel_border_color="*neutral_200", panel_border_color_dark="*neutral_700", panel_border_width="1px", panel_border_width_dark="1px", input_background_fill="white", input_background_fill_dark="*neutral_800", input_border_color="*neutral_300", input_border_color_dark="*neutral_700", slider_color="*primary_500", slider_color_dark="*primary_400", button_primary_background_fill="*primary_600", button_primary_background_fill_dark="*primary_500", button_primary_background_fill_hover="*primary_700", button_primary_background_fill_hover_dark="*primary_400", button_primary_border_color="transparent", button_primary_border_color_dark="transparent", button_primary_text_color="white", button_primary_text_color_dark="white", button_secondary_background_fill="*neutral_200", button_secondary_background_fill_dark="*neutral_600", button_secondary_background_fill_hover="*neutral_300", button_secondary_background_fill_hover_dark="*neutral_500", button_secondary_border_color="transparent", button_secondary_border_color_dark="transparent", button_secondary_text_color="*neutral_700", button_secondary_text_color_dark="*neutral_200", ) css_code = "" try: with open("./style.css", "r", encoding="utf-8") as f: css_code += f.read() + "\n" except FileNotFoundError: pass title = """