| |
|
| |
|
| | import torch
|
| | from diffusers import FluxPipeline, DPMSolverMultistepScheduler
|
| | from BeamDiffusionModel.models.diffusionModel.configs.config_loader import CONFIG
|
| | from functools import partial
|
| | from BeamDiffusionModel.models.diffusionModel.Latents_Singleton import Latents
|
| |
|
| | class Flux:
|
| | def __init__(self):
|
| | self.device = "cuda" if CONFIG.get("flux", {}).get("use_cuda", True) and torch.cuda.is_available() else "cpu"
|
| | self.torch_dtype = torch.bfloat16 if CONFIG.get("flux", {}).get("precision") == "bfloat16" else torch.float16
|
| |
|
| | print(f"Loading model: {CONFIG['flux']['id']} on {self.device}")
|
| |
|
| | self.pipe = FluxPipeline.from_pretrained(CONFIG["flux"]["id"], torch_dtype=torch.bfloat16)
|
| | self.pipe.enable_sequential_cpu_offload()
|
| | self.pipe.vae.enable_slicing()
|
| | self.pipe.vae.enable_tiling()
|
| | self.pipe.tokenizer.truncation_side = 'left'
|
| |
|
| | print("Model loaded successfully!")
|
| |
|
| |
|
| | def capture_latents(self, latents_store: Latents, pipe, step, timestep, callback_kwargs):
|
| | latents = callback_kwargs["latents"]
|
| | latents_store.add_latents(latents)
|
| | return callback_kwargs
|
| |
|
| | def generate_image(self, prompt: str, latent=None, generator=None):
|
| | latents = Latents()
|
| | callback = partial(self.capture_latents, latents)
|
| | img = self.pipe(prompt, latents=latent, callback_on_step_end=callback,
|
| | generator=generator, callback_on_step_end_tensor_inputs=["latents"],
|
| | height=768,
|
| | width=768,
|
| | guidance_scale=3.5,
|
| | max_sequence_length=512,
|
| | num_inference_steps=CONFIG["flux"]["diffusion_settings"]["steps"]).images[0]
|
| |
|
| | return img, latents.dump_and_clear() |