AbstractPhil commited on
Commit
5731dbc
·
verified ·
1 Parent(s): 1893c89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -127,6 +127,11 @@ class FlowMatchingPipeline:
127
  self.scheduler.set_timesteps(num_inference_steps, device=self.device)
128
  timesteps = self.scheduler.timesteps
129
 
 
 
 
 
 
130
  # Denoising loop
131
  for i, t in enumerate(timesteps):
132
  if progress_callback:
@@ -135,15 +140,19 @@ class FlowMatchingPipeline:
135
  # Expand latents for classifier-free guidance
136
  latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
137
 
138
- # Apply shift for flow matching
 
139
  if use_flow_matching and shift > 0:
140
  # Compute sigma from timestep with shift
141
  sigma = t.float() / 1000.0
142
  sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
143
 
144
- # Scale latent input
145
  scaling = torch.sqrt(1 + sigma_shifted ** 2)
146
  latent_model_input = latent_model_input / scaling
 
 
 
147
 
148
  # Prepare timestep
149
  timestep = t.expand(latent_model_input.shape[0])
 
127
  self.scheduler.set_timesteps(num_inference_steps, device=self.device)
128
  timesteps = self.scheduler.timesteps
129
 
130
+ # Scale initial latents by scheduler's init_noise_sigma for standard diffusion
131
+ # Flow matching uses unscaled latents and custom ODE integration
132
+ if not use_flow_matching:
133
+ latents = latents * self.scheduler.init_noise_sigma
134
+
135
  # Denoising loop
136
  for i, t in enumerate(timesteps):
137
  if progress_callback:
 
140
  # Expand latents for classifier-free guidance
141
  latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
142
 
143
+ # For standard diffusion, let scheduler handle scaling
144
+ # For flow matching, apply custom shift-based scaling
145
  if use_flow_matching and shift > 0:
146
  # Compute sigma from timestep with shift
147
  sigma = t.float() / 1000.0
148
  sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
149
 
150
+ # Scale latent input for flow matching
151
  scaling = torch.sqrt(1 + sigma_shifted ** 2)
152
  latent_model_input = latent_model_input / scaling
153
+ else:
154
+ # For standard diffusion, scale by scheduler
155
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
156
 
157
  # Prepare timestep
158
  timestep = t.expand(latent_model_input.shape[0])