AbstractPhil commited on
Commit
1893c89
·
verified ·
1 Parent(s): ad329bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -11
app.py CHANGED
@@ -185,8 +185,13 @@ class FlowMatchingPipeline:
185
  noise_pred, t, latents, return_dict=False
186
  )[0]
187
 
188
- # Decode latents
189
- latents = (latents / self.vae_scale_factor) * 5.52
 
 
 
 
 
190
 
191
  with torch.no_grad():
192
  image = self.vae.decode(latents).sample
@@ -248,6 +253,8 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
248
 
249
  print(f"🚀 Initializing {model_choice} pipeline...")
250
 
 
 
251
  # Load base components
252
  print("Loading VAE...")
253
  vae = AutoencoderKL.from_pretrained(
@@ -267,7 +274,7 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
267
  )
268
 
269
  # Load UNet based on model choice
270
- if model_choice == "Flow-Lune (Latest)":
271
  # Load latest checkpoint from repo
272
  repo_id = "AbstractPhil/sd15-flow-lune"
273
  # Find latest checkpoint - for now use a known one
@@ -293,7 +300,7 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
293
 
294
  print("✅ Pipeline initialized!")
295
 
296
- return FlowMatchingPipeline(
297
  vae=vae,
298
  text_encoder=text_encoder,
299
  tokenizer=tokenizer,
@@ -301,6 +308,11 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
301
  scheduler=scheduler,
302
  device=device
303
  )
 
 
 
 
 
304
 
305
 
306
  # ============================================================================
@@ -417,16 +429,17 @@ def create_demo():
417
 
418
  with gr.Row():
419
  with gr.Column(scale=1):
420
- # Prompt
421
  prompt = gr.TextArea(
422
  label="Prompt",
423
- placeholder="A beautiful landscape with mountains and a lake at sunset...",
424
  lines=3
425
  )
426
 
427
  negative_prompt = gr.TextArea(
428
  label="Negative Prompt",
429
  placeholder="blurry, low quality, distorted...",
 
430
  lines=2
431
  )
432
 
@@ -460,7 +473,7 @@ def create_demo():
460
  prediction_type = gr.Radio(
461
  label="Prediction Type",
462
  choices=["epsilon", "v_prediction"],
463
- value="epsilon",
464
  info="Type of model prediction"
465
  )
466
 
@@ -531,7 +544,9 @@ def create_demo():
531
  - **Flow matching** works best with 15-25 steps (vs 50+ for standard diffusion)
532
  - **Shift** controls the flow trajectory (2.0-2.5 recommended for Lune)
533
  - Lower shift = more direct path, higher shift = more exploration
534
- - Try **v_prediction** mode if epsilon gives unstable results
 
 
535
 
536
  ### Model Info:
537
  - **Flow-Lune**: Trained with flow matching on 500k SD1.5 distillation pairs
@@ -553,7 +568,7 @@ def create_demo():
553
  512,
554
  2.5,
555
  True,
556
- "epsilon",
557
  42,
558
  False
559
  ],
@@ -567,7 +582,7 @@ def create_demo():
567
  512,
568
  2.5,
569
  True,
570
- "epsilon",
571
  123,
572
  False
573
  ],
@@ -581,7 +596,7 @@ def create_demo():
581
  512,
582
  2.0,
583
  True,
584
- "epsilon",
585
  456,
586
  False
587
  ]
@@ -597,6 +612,29 @@ def create_demo():
597
  )
598
 
599
  # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  generate_btn.click(
601
  fn=generate_image,
602
  inputs=[
 
185
  noise_pred, t, latents, return_dict=False
186
  )[0]
187
 
188
+ # Decode latents with model-specific scaling
189
+ latents = latents / self.vae_scale_factor
190
+
191
+ # Lune-specific scaling: multiply by 5.52 for Lune's latent space offset
192
+ # This must be applied ONLY for Lune model, not SD1.5 Base
193
+ if hasattr(self, 'is_lune_model') and self.is_lune_model:
194
+ latents = latents * 5.52
195
 
196
  with torch.no_grad():
197
  image = self.vae.decode(latents).sample
 
253
 
254
  print(f"🚀 Initializing {model_choice} pipeline...")
255
 
256
+ is_lune = "Lune" in model_choice
257
+
258
  # Load base components
259
  print("Loading VAE...")
260
  vae = AutoencoderKL.from_pretrained(
 
274
  )
275
 
276
  # Load UNet based on model choice
277
+ if is_lune:
278
  # Load latest checkpoint from repo
279
  repo_id = "AbstractPhil/sd15-flow-lune"
280
  # Find latest checkpoint - for now use a known one
 
300
 
301
  print("✅ Pipeline initialized!")
302
 
303
+ pipeline = FlowMatchingPipeline(
304
  vae=vae,
305
  text_encoder=text_encoder,
306
  tokenizer=tokenizer,
 
308
  scheduler=scheduler,
309
  device=device
310
  )
311
+
312
+ # Set flag for Lune-specific VAE scaling
313
+ pipeline.is_lune_model = is_lune
314
+
315
+ return pipeline
316
 
317
 
318
  # ============================================================================
 
429
 
430
  with gr.Row():
431
  with gr.Column(scale=1):
432
+ # Prompt - default to first example
433
  prompt = gr.TextArea(
434
  label="Prompt",
435
+ value="A serene mountain landscape at golden hour, crystal clear lake reflecting snow-capped peaks, photorealistic, 8k",
436
  lines=3
437
  )
438
 
439
  negative_prompt = gr.TextArea(
440
  label="Negative Prompt",
441
  placeholder="blurry, low quality, distorted...",
442
+ value="blurry, low quality",
443
  lines=2
444
  )
445
 
 
473
  prediction_type = gr.Radio(
474
  label="Prediction Type",
475
  choices=["epsilon", "v_prediction"],
476
+ value="v_prediction", # Default to v_prediction for Lune
477
  info="Type of model prediction"
478
  )
479
 
 
544
  - **Flow matching** works best with 15-25 steps (vs 50+ for standard diffusion)
545
  - **Shift** controls the flow trajectory (2.0-2.5 recommended for Lune)
546
  - Lower shift = more direct path, higher shift = more exploration
547
+ - **Lune** uses v_prediction by default for optimal results
548
+ - **SD1.5 Base** uses epsilon (standard diffusion)
549
+ - Lune operates in a scaled latent space (5.52x) for geometric efficiency
550
 
551
  ### Model Info:
552
  - **Flow-Lune**: Trained with flow matching on 500k SD1.5 distillation pairs
 
568
  512,
569
  2.5,
570
  True,
571
+ "v_prediction",
572
  42,
573
  False
574
  ],
 
582
  512,
583
  2.5,
584
  True,
585
+ "v_prediction",
586
  123,
587
  False
588
  ],
 
596
  512,
597
  2.0,
598
  True,
599
+ "v_prediction",
600
  456,
601
  False
602
  ]
 
612
  )
613
 
614
  # Event handlers
615
+
616
+ # Update settings when model changes
617
+ def on_model_change(model_name):
618
+ """Update default settings based on model selection."""
619
+ if model_name == "SD1.5 Base":
620
+ # SD1.5: disable flow matching, use epsilon
621
+ return {
622
+ use_flow_matching: gr.update(value=False),
623
+ prediction_type: gr.update(value="epsilon")
624
+ }
625
+ else:
626
+ # Lune: enable flow matching, use v_prediction
627
+ return {
628
+ use_flow_matching: gr.update(value=True),
629
+ prediction_type: gr.update(value="v_prediction")
630
+ }
631
+
632
+ model_choice.change(
633
+ fn=on_model_change,
634
+ inputs=[model_choice],
635
+ outputs=[use_flow_matching, prediction_type]
636
+ )
637
+
638
  generate_btn.click(
639
  fn=generate_image,
640
  inputs=[