AbstractPhil commited on
Commit
e4fdf48
·
verified ·
1 Parent(s): 5731dbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -16
app.py CHANGED
@@ -20,9 +20,17 @@ from diffusers import (
20
  DPMSolverMultistepScheduler,
21
  EulerDiscreteScheduler
22
  )
23
- from transformers import CLIPTextModel, CLIPTokenizer
24
  from huggingface_hub import hf_hub_download
25
 
 
 
 
 
 
 
 
 
26
 
27
  # ============================================================================
28
  # MODEL LOADING
@@ -38,7 +46,10 @@ class FlowMatchingPipeline:
38
  tokenizer: CLIPTokenizer,
39
  unet: UNet2DConditionModel,
40
  scheduler,
41
- device: str = "cuda"
 
 
 
42
  ):
43
  self.vae = vae
44
  self.text_encoder = text_encoder
@@ -47,6 +58,11 @@ class FlowMatchingPipeline:
47
  self.scheduler = scheduler
48
  self.device = device
49
 
 
 
 
 
 
50
  # VAE scaling factor
51
  self.vae_scale_factor = 0.18215
52
 
@@ -83,6 +99,90 @@ class FlowMatchingPipeline:
83
 
84
  return prompt_embeds, negative_prompt_embeds
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  @torch.no_grad()
87
  def __call__(
88
  self,
@@ -96,6 +196,7 @@ class FlowMatchingPipeline:
96
  use_flow_matching: bool = True,
97
  prediction_type: str = "epsilon",
98
  seed: Optional[int] = None,
 
99
  progress_callback=None
100
  ):
101
  """Generate image using flow matching or standard diffusion."""
@@ -106,10 +207,15 @@ class FlowMatchingPipeline:
106
  else:
107
  generator = None
108
 
109
- # Encode prompts
110
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
111
- prompt, negative_prompt
112
- )
 
 
 
 
 
113
 
114
  # Prepare latents
115
  latent_channels = 4
@@ -257,12 +363,94 @@ def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"):
257
  return unet.to(device)
258
 
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  def initialize_pipeline(model_choice: str, device: str = "cuda"):
261
  """Initialize the complete pipeline."""
262
 
263
  print(f"🚀 Initializing {model_choice} pipeline...")
264
 
265
  is_lune = "Lune" in model_choice
 
266
 
267
  # Load base components
268
  print("Loading VAE...")
@@ -272,7 +460,7 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
272
  torch_dtype=torch.float32
273
  ).to(device)
274
 
275
- print("Loading text encoder...")
276
  text_encoder = CLIPTextModel.from_pretrained(
277
  "openai/clip-vit-large-patch14",
278
  torch_dtype=torch.float32
@@ -282,6 +470,26 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
282
  "openai/clip-vit-large-patch14"
283
  )
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  # Load UNet based on model choice
286
  if is_lune:
287
  # Load latest checkpoint from repo
@@ -290,7 +498,8 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
290
  filename = "sd15_flow_lune_e34_s34000.pt"
291
  unet = load_lune_checkpoint(repo_id, filename, device)
292
 
293
- elif model_choice == "SD1.5 Base":
 
294
  print("Loading SD1.5 base UNet...")
295
  unet = UNet2DConditionModel.from_pretrained(
296
  "runwayml/stable-diffusion-v1-5",
@@ -315,7 +524,10 @@ def initialize_pipeline(model_choice: str, device: str = "cuda"):
315
  tokenizer=tokenizer,
316
  unet=unet,
317
  scheduler=scheduler,
318
- device=device
 
 
 
319
  )
320
 
321
  # Set flag for Lune-specific VAE scaling
@@ -393,6 +605,9 @@ def generate_image(
393
  # Get pipeline
394
  pipeline = get_pipeline(model_choice)
395
 
 
 
 
396
  # Generate
397
  progress(0.05, desc="Starting generation...")
398
 
@@ -407,6 +622,7 @@ def generate_image(
407
  use_flow_matching=use_flow_matching,
408
  prediction_type=prediction_type,
409
  seed=seed,
 
410
  progress_callback=progress_callback
411
  )
412
 
@@ -432,7 +648,11 @@ def create_demo():
432
 
433
  **Geometric crystalline diffusion with flow matching** by [AbstractPhil](https://huggingface.co/AbstractPhil)
434
 
435
- Generate images using SD1.5-based flow matching with pentachoron geometric structures.
 
 
 
 
436
  Achieves high quality with dramatically reduced step counts through geometric efficiency.
437
  """)
438
 
@@ -457,6 +677,7 @@ def create_demo():
457
  label="Model",
458
  choices=[
459
  "Flow-Lune (Latest)",
 
460
  "SD1.5 Base"
461
  ],
462
  value="Flow-Lune (Latest)"
@@ -554,11 +775,13 @@ def create_demo():
554
  - **Shift** controls the flow trajectory (2.0-2.5 recommended for Lune)
555
  - Lower shift = more direct path, higher shift = more exploration
556
  - **Lune** uses v_prediction by default for optimal results
 
557
  - **SD1.5 Base** uses epsilon (standard diffusion)
558
  - Lune operates in a scaled latent space (5.52x) for geometric efficiency
559
 
560
  ### Model Info:
561
  - **Flow-Lune**: Trained with flow matching on 500k SD1.5 distillation pairs
 
562
  - **SD1.5 Base**: Standard Stable Diffusion 1.5 for comparison
563
 
564
  [📚 Learn more about geometric deep learning](https://github.com/AbstractEyes/lattice_vocabulary)
@@ -584,14 +807,14 @@ def create_demo():
584
  [
585
  "A futuristic cyberpunk city at night, neon lights, rain-slicked streets, highly detailed",
586
  "low quality, blurry",
587
- "Flow-Lune (Latest)",
588
- 22,
589
- 8.0,
590
  512,
591
  512,
592
- 2.5,
593
- True,
594
- "v_prediction",
595
  123,
596
  False
597
  ],
@@ -631,6 +854,12 @@ def create_demo():
631
  use_flow_matching: gr.update(value=False),
632
  prediction_type: gr.update(value="epsilon")
633
  }
 
 
 
 
 
 
634
  else:
635
  # Lune: enable flow matching, use v_prediction
636
  return {
 
20
  DPMSolverMultistepScheduler,
21
  EulerDiscreteScheduler
22
  )
23
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
24
  from huggingface_hub import hf_hub_download
25
 
26
+ # Import Lyra VAE from geovocab2
27
+ try:
28
+ from geovocab2.train.model.vae.vae_lyra import MultiModalVAE, MultiModalVAEConfig
29
+ LYRA_AVAILABLE = True
30
+ except ImportError:
31
+ print("⚠️ Lyra VAE not available - install geovocab2")
32
+ LYRA_AVAILABLE = False
33
+
34
 
35
  # ============================================================================
36
  # MODEL LOADING
 
46
  tokenizer: CLIPTokenizer,
47
  unet: UNet2DConditionModel,
48
  scheduler,
49
+ device: str = "cuda",
50
+ t5_encoder: Optional[T5EncoderModel] = None,
51
+ t5_tokenizer: Optional[T5Tokenizer] = None,
52
+ lyra_model: Optional[any] = None
53
  ):
54
  self.vae = vae
55
  self.text_encoder = text_encoder
 
58
  self.scheduler = scheduler
59
  self.device = device
60
 
61
+ # Lyra-specific components
62
+ self.t5_encoder = t5_encoder
63
+ self.t5_tokenizer = t5_tokenizer
64
+ self.lyra_model = lyra_model
65
+
66
  # VAE scaling factor
67
  self.vae_scale_factor = 0.18215
68
 
 
99
 
100
  return prompt_embeds, negative_prompt_embeds
101
 
102
+ def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""):
103
+ """Encode text prompts using Lyra VAE (CLIP + T5 fusion)."""
104
+ if self.lyra_model is None or self.t5_encoder is None:
105
+ raise ValueError("Lyra VAE components not initialized")
106
+
107
+ # Get CLIP embeddings
108
+ text_inputs = self.tokenizer(
109
+ prompt,
110
+ padding="max_length",
111
+ max_length=self.tokenizer.model_max_length,
112
+ truncation=True,
113
+ return_tensors="pt",
114
+ )
115
+ text_input_ids = text_inputs.input_ids.to(self.device)
116
+
117
+ with torch.no_grad():
118
+ clip_embeds = self.text_encoder(text_input_ids)[0]
119
+
120
+ # Get T5 embeddings
121
+ t5_inputs = self.t5_tokenizer(
122
+ prompt,
123
+ max_length=77,
124
+ padding='max_length',
125
+ truncation=True,
126
+ return_tensors='pt'
127
+ ).to(self.device)
128
+
129
+ with torch.no_grad():
130
+ t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state
131
+
132
+ # Fuse through Lyra VAE
133
+ modality_inputs = {
134
+ 'clip': clip_embeds,
135
+ 't5': t5_embeds
136
+ }
137
+
138
+ with torch.no_grad():
139
+ reconstructions, mu, logvar = self.lyra_model(
140
+ modality_inputs,
141
+ target_modalities=['clip']
142
+ )
143
+ prompt_embeds = reconstructions['clip']
144
+
145
+ # Process negative prompt
146
+ if negative_prompt:
147
+ uncond_inputs = self.tokenizer(
148
+ negative_prompt,
149
+ padding="max_length",
150
+ max_length=self.tokenizer.model_max_length,
151
+ truncation=True,
152
+ return_tensors="pt",
153
+ )
154
+ uncond_input_ids = uncond_inputs.input_ids.to(self.device)
155
+
156
+ with torch.no_grad():
157
+ clip_embeds_uncond = self.text_encoder(uncond_input_ids)[0]
158
+
159
+ t5_inputs_uncond = self.t5_tokenizer(
160
+ negative_prompt,
161
+ max_length=77,
162
+ padding='max_length',
163
+ truncation=True,
164
+ return_tensors='pt'
165
+ ).to(self.device)
166
+
167
+ with torch.no_grad():
168
+ t5_embeds_uncond = self.t5_encoder(**t5_inputs_uncond).last_hidden_state
169
+
170
+ modality_inputs_uncond = {
171
+ 'clip': clip_embeds_uncond,
172
+ 't5': t5_embeds_uncond
173
+ }
174
+
175
+ with torch.no_grad():
176
+ reconstructions_uncond, _, _ = self.lyra_model(
177
+ modality_inputs_uncond,
178
+ target_modalities=['clip']
179
+ )
180
+ negative_prompt_embeds = reconstructions_uncond['clip']
181
+ else:
182
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
183
+
184
+ return prompt_embeds, negative_prompt_embeds
185
+
186
  @torch.no_grad()
187
  def __call__(
188
  self,
 
196
  use_flow_matching: bool = True,
197
  prediction_type: str = "epsilon",
198
  seed: Optional[int] = None,
199
+ use_lyra: bool = False,
200
  progress_callback=None
201
  ):
202
  """Generate image using flow matching or standard diffusion."""
 
207
  else:
208
  generator = None
209
 
210
+ # Encode prompts - use Lyra if specified
211
+ if use_lyra and self.lyra_model is not None:
212
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt_lyra(
213
+ prompt, negative_prompt
214
+ )
215
+ else:
216
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
217
+ prompt, negative_prompt
218
+ )
219
 
220
  # Prepare latents
221
  latent_channels = 4
 
363
  return unet.to(device)
364
 
365
 
366
+ def load_lyra_vae(repo_id: str = "AbstractPhil/vae-lyra", device: str = "cuda"):
367
+ """Load Lyra VAE from HuggingFace."""
368
+ if not LYRA_AVAILABLE:
369
+ print("⚠️ Lyra VAE not available - geovocab2 not installed")
370
+ return None
371
+
372
+ print(f"🎵 Loading Lyra VAE from {repo_id}...")
373
+
374
+ try:
375
+ # Download checkpoint
376
+ checkpoint_path = hf_hub_download(
377
+ repo_id=repo_id,
378
+ filename="best_model.pt",
379
+ repo_type="model"
380
+ )
381
+
382
+ print(f"✓ Downloaded checkpoint: {checkpoint_path}")
383
+
384
+ # Load checkpoint
385
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
386
+
387
+ # Extract config
388
+ if 'config' in checkpoint:
389
+ config_dict = checkpoint['config']
390
+ else:
391
+ # Use default config
392
+ config_dict = {
393
+ 'modality_dims': {"clip": 768, "t5": 768},
394
+ 'latent_dim': 768,
395
+ 'seq_len': 77,
396
+ 'encoder_layers': 3,
397
+ 'decoder_layers': 3,
398
+ 'hidden_dim': 1024,
399
+ 'dropout': 0.1,
400
+ 'fusion_strategy': 'cantor',
401
+ 'fusion_heads': 8,
402
+ 'fusion_dropout': 0.1
403
+ }
404
+
405
+ # Create VAE config
406
+ vae_config = MultiModalVAEConfig(
407
+ modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 768}),
408
+ latent_dim=config_dict.get('latent_dim', 768),
409
+ seq_len=config_dict.get('seq_len', 77),
410
+ encoder_layers=config_dict.get('encoder_layers', 3),
411
+ decoder_layers=config_dict.get('decoder_layers', 3),
412
+ hidden_dim=config_dict.get('hidden_dim', 1024),
413
+ dropout=config_dict.get('dropout', 0.1),
414
+ fusion_strategy=config_dict.get('fusion_strategy', 'cantor'),
415
+ fusion_heads=config_dict.get('fusion_heads', 8),
416
+ fusion_dropout=config_dict.get('fusion_dropout', 0.1)
417
+ )
418
+
419
+ # Create model
420
+ lyra_model = MultiModalVAE(vae_config)
421
+
422
+ # Load weights
423
+ if 'model_state_dict' in checkpoint:
424
+ lyra_model.load_state_dict(checkpoint['model_state_dict'])
425
+ else:
426
+ lyra_model.load_state_dict(checkpoint)
427
+
428
+ lyra_model.to(device)
429
+ lyra_model.eval()
430
+
431
+ # Print info
432
+ print(f"✅ Lyra VAE loaded successfully")
433
+ if 'global_step' in checkpoint:
434
+ print(f" Training step: {checkpoint['global_step']:,}")
435
+ if 'best_loss' in checkpoint:
436
+ print(f" Best loss: {checkpoint['best_loss']:.4f}")
437
+ print(f" Fusion strategy: {vae_config.fusion_strategy}")
438
+ print(f" Latent dim: {vae_config.latent_dim}")
439
+
440
+ return lyra_model
441
+
442
+ except Exception as e:
443
+ print(f"❌ Failed to load Lyra VAE: {e}")
444
+ return None
445
+
446
+
447
  def initialize_pipeline(model_choice: str, device: str = "cuda"):
448
  """Initialize the complete pipeline."""
449
 
450
  print(f"🚀 Initializing {model_choice} pipeline...")
451
 
452
  is_lune = "Lune" in model_choice
453
+ is_lyra = "Lyra" in model_choice
454
 
455
  # Load base components
456
  print("Loading VAE...")
 
460
  torch_dtype=torch.float32
461
  ).to(device)
462
 
463
+ print("Loading CLIP text encoder...")
464
  text_encoder = CLIPTextModel.from_pretrained(
465
  "openai/clip-vit-large-patch14",
466
  torch_dtype=torch.float32
 
470
  "openai/clip-vit-large-patch14"
471
  )
472
 
473
+ # Load T5 and Lyra if needed
474
+ t5_encoder = None
475
+ t5_tokenizer = None
476
+ lyra_model = None
477
+
478
+ if is_lyra:
479
+ print("Loading T5-base encoder...")
480
+ t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
481
+ t5_encoder = T5EncoderModel.from_pretrained(
482
+ "t5-base",
483
+ torch_dtype=torch.float32
484
+ ).to(device)
485
+ t5_encoder.eval()
486
+ print("✓ T5 loaded")
487
+
488
+ print("Loading Lyra VAE...")
489
+ lyra_model = load_lyra_vae(device=device)
490
+ if lyra_model is None:
491
+ raise ValueError("Failed to load Lyra VAE")
492
+
493
  # Load UNet based on model choice
494
  if is_lune:
495
  # Load latest checkpoint from repo
 
498
  filename = "sd15_flow_lune_e34_s34000.pt"
499
  unet = load_lune_checkpoint(repo_id, filename, device)
500
 
501
+ elif is_lyra or model_choice == "SD1.5 Base":
502
+ # Use standard SD1.5 UNet for both Lyra and base
503
  print("Loading SD1.5 base UNet...")
504
  unet = UNet2DConditionModel.from_pretrained(
505
  "runwayml/stable-diffusion-v1-5",
 
524
  tokenizer=tokenizer,
525
  unet=unet,
526
  scheduler=scheduler,
527
+ device=device,
528
+ t5_encoder=t5_encoder,
529
+ t5_tokenizer=t5_tokenizer,
530
+ lyra_model=lyra_model
531
  )
532
 
533
  # Set flag for Lune-specific VAE scaling
 
605
  # Get pipeline
606
  pipeline = get_pipeline(model_choice)
607
 
608
+ # Determine if we should use Lyra encoding
609
+ use_lyra = "Lyra" in model_choice
610
+
611
  # Generate
612
  progress(0.05, desc="Starting generation...")
613
 
 
622
  use_flow_matching=use_flow_matching,
623
  prediction_type=prediction_type,
624
  seed=seed,
625
+ use_lyra=use_lyra,
626
  progress_callback=progress_callback
627
  )
628
 
 
648
 
649
  **Geometric crystalline diffusion with flow matching** by [AbstractPhil](https://huggingface.co/AbstractPhil)
650
 
651
+ Generate images using SD1.5-based models with geometric deep learning approaches:
652
+ - **Flow-Lune**: Flow matching with pentachoron geometric structures
653
+ - **Lyra-VAE**: Multi-modal fusion (CLIP+T5) via geometric attention
654
+ - **SD1.5 Base**: Standard baseline for comparison
655
+
656
  Achieves high quality with dramatically reduced step counts through geometric efficiency.
657
  """)
658
 
 
677
  label="Model",
678
  choices=[
679
  "Flow-Lune (Latest)",
680
+ "Lyra-VAE (Geometric Fusion)",
681
  "SD1.5 Base"
682
  ],
683
  value="Flow-Lune (Latest)"
 
775
  - **Shift** controls the flow trajectory (2.0-2.5 recommended for Lune)
776
  - Lower shift = more direct path, higher shift = more exploration
777
  - **Lune** uses v_prediction by default for optimal results
778
+ - **Lyra** fuses CLIP+T5 encoders through geometric VAE for richer embeddings
779
  - **SD1.5 Base** uses epsilon (standard diffusion)
780
  - Lune operates in a scaled latent space (5.52x) for geometric efficiency
781
 
782
  ### Model Info:
783
  - **Flow-Lune**: Trained with flow matching on 500k SD1.5 distillation pairs
784
+ - **Lyra-VAE**: Multi-modal fusion (CLIP+T5) via Cantor geometric attention
785
  - **SD1.5 Base**: Standard Stable Diffusion 1.5 for comparison
786
 
787
  [📚 Learn more about geometric deep learning](https://github.com/AbstractEyes/lattice_vocabulary)
 
807
  [
808
  "A futuristic cyberpunk city at night, neon lights, rain-slicked streets, highly detailed",
809
  "low quality, blurry",
810
+ "Lyra-VAE (Geometric Fusion)",
811
+ 30,
812
+ 7.5,
813
  512,
814
  512,
815
+ 0.0,
816
+ False,
817
+ "epsilon",
818
  123,
819
  False
820
  ],
 
854
  use_flow_matching: gr.update(value=False),
855
  prediction_type: gr.update(value="epsilon")
856
  }
857
+ elif model_name == "Lyra-VAE (Geometric Fusion)":
858
+ # Lyra: disable flow matching (uses standard diffusion), use epsilon
859
+ return {
860
+ use_flow_matching: gr.update(value=False),
861
+ prediction_type: gr.update(value="epsilon")
862
+ }
863
  else:
864
  # Lune: enable flow matching, use v_prediction
865
  return {