Update app.py
Browse files
app.py
CHANGED
|
@@ -127,6 +127,11 @@ class FlowMatchingPipeline:
|
|
| 127 |
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
|
| 128 |
timesteps = self.scheduler.timesteps
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
# Denoising loop
|
| 131 |
for i, t in enumerate(timesteps):
|
| 132 |
if progress_callback:
|
|
@@ -135,15 +140,19 @@ class FlowMatchingPipeline:
|
|
| 135 |
# Expand latents for classifier-free guidance
|
| 136 |
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
|
| 137 |
|
| 138 |
-
#
|
|
|
|
| 139 |
if use_flow_matching and shift > 0:
|
| 140 |
# Compute sigma from timestep with shift
|
| 141 |
sigma = t.float() / 1000.0
|
| 142 |
sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
|
| 143 |
|
| 144 |
-
# Scale latent input
|
| 145 |
scaling = torch.sqrt(1 + sigma_shifted ** 2)
|
| 146 |
latent_model_input = latent_model_input / scaling
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
# Prepare timestep
|
| 149 |
timestep = t.expand(latent_model_input.shape[0])
|
|
|
|
| 127 |
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
|
| 128 |
timesteps = self.scheduler.timesteps
|
| 129 |
|
| 130 |
+
# Scale initial latents by scheduler's init_noise_sigma for standard diffusion
|
| 131 |
+
# Flow matching uses unscaled latents and custom ODE integration
|
| 132 |
+
if not use_flow_matching:
|
| 133 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 134 |
+
|
| 135 |
# Denoising loop
|
| 136 |
for i, t in enumerate(timesteps):
|
| 137 |
if progress_callback:
|
|
|
|
| 140 |
# Expand latents for classifier-free guidance
|
| 141 |
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
|
| 142 |
|
| 143 |
+
# For standard diffusion, let scheduler handle scaling
|
| 144 |
+
# For flow matching, apply custom shift-based scaling
|
| 145 |
if use_flow_matching and shift > 0:
|
| 146 |
# Compute sigma from timestep with shift
|
| 147 |
sigma = t.float() / 1000.0
|
| 148 |
sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
|
| 149 |
|
| 150 |
+
# Scale latent input for flow matching
|
| 151 |
scaling = torch.sqrt(1 + sigma_shifted ** 2)
|
| 152 |
latent_model_input = latent_model_input / scaling
|
| 153 |
+
else:
|
| 154 |
+
# For standard diffusion, scale by scheduler
|
| 155 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 156 |
|
| 157 |
# Prepare timestep
|
| 158 |
timestep = t.expand(latent_model_input.shape[0])
|