DZRobo commited on
Commit
e69f3b7
·
1 Parent(s): 7af46cf

Add Z_image support and Improve latent/channel handling

Browse files

Adds functions to harmonize latent channel counts and condition token lengths to prevent mismatches, especially for models like FLUX/Z_image. Enhances error reporting with debug output and traceback printing. Updates mg_combinode to better validate VAE/CLIP presence for checkpoint and input selection. Fixes hybrid sigma schedule alignment in mg_zesmart_sampler_v1_1.

mod/easy/mg_cade25_easy.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  import os
8
  import numpy as np
9
  import torch.nn.functional as F
 
10
 
11
  import nodes
12
  import comfy.model_management as model_management
@@ -1115,6 +1116,133 @@ def safe_decode(vae, lat, tile=512, ovlp=128, to_fp32: bool = False):
1115
  return out
1116
 
1117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1118
  def safe_encode(vae, img, tile=512, ovlp=64):
1119
  import math, torch.nn.functional as F
1120
  h, w = img.shape[1:3]
@@ -2309,6 +2437,13 @@ class ComfyAdaptiveDetailEnhancer25:
2309
  except Exception:
2310
  pass
2311
 
 
 
 
 
 
 
 
2312
  image = safe_decode(vae, latent, to_fp32=bool(vae_decode_fp32))
2313
  # allow user cancel right after initial decode
2314
  model_management.throw_exception_if_processing_interrupted()
@@ -2830,6 +2965,7 @@ class ComfyAdaptiveDetailEnhancer25:
2830
  )
2831
  # Prepare latent + noise like in MG_ZeSmartSampler
2832
  lat_img = current_latent["samples"]
 
2833
  lat_img = _sample.fix_empty_latent_channels(sampler_model, lat_img)
2834
  batch_inds = current_latent.get("batch_index", None)
2835
  noise = _sample.prepare_noise(lat_img, int(iter_seed), batch_inds)
@@ -2848,6 +2984,16 @@ class ComfyAdaptiveDetailEnhancer25:
2848
  current_latent = {**current_latent}
2849
  current_latent["samples"] = samples
2850
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
2851
  # Before any fallback, propagate user cancel if set
2852
  try:
2853
  model_management.throw_exception_if_processing_interrupted()
 
7
  import os
8
  import numpy as np
9
  import torch.nn.functional as F
10
+ import traceback
11
 
12
  import nodes
13
  import comfy.model_management as model_management
 
1116
  return out
1117
 
1118
 
1119
+ def _match_latent_channels(vae, latent: dict, model=None):
1120
+ """Align latent channel count to model/VAE expectations (e.g., FLUX/Z_image 16ch) with variance preservation."""
1121
+ if not isinstance(latent, dict) or ("samples" not in latent):
1122
+ return latent
1123
+ z = latent.get("samples", None)
1124
+ if z is None:
1125
+ return latent
1126
+ try:
1127
+ target_c = None
1128
+ # Prefer model latent_format if available (more reliable than VAE decoder)
1129
+ if model is not None:
1130
+ try:
1131
+ lf = model.get_model_object("latent_format")
1132
+ target_c = int(getattr(lf, "latent_channels", None) or 0) or None
1133
+ except Exception:
1134
+ target_c = None
1135
+ fs = getattr(vae, "first_stage_model", None)
1136
+ dec = getattr(fs, "decoder", None)
1137
+ if dec is not None and hasattr(dec, "conv_in"):
1138
+ target_c = target_c or int(dec.conv_in.in_channels)
1139
+ if target_c is None and hasattr(fs, "latent_channels"):
1140
+ target_c = int(getattr(fs, "latent_channels"))
1141
+ if target_c is None and hasattr(vae, "latent_channels"):
1142
+ target_c = int(getattr(vae, "latent_channels"))
1143
+ if target_c is None:
1144
+ return latent
1145
+ cur_c = int(z.shape[1])
1146
+ if cur_c == target_c:
1147
+ return latent
1148
+ # Repeat channels when divisible (common case: 4 -> 16)
1149
+ if target_c % cur_c == 0 and cur_c > 0:
1150
+ rep = target_c // cur_c
1151
+ reps = [1, rep] + [1] * (z.ndim - 2)
1152
+ z_fixed = z.repeat(*reps)
1153
+ # Preserve variance after channel replication
1154
+ z_fixed = z_fixed / (rep ** 0.5)
1155
+ else:
1156
+ # Fallback: pad zeros or slice to match
1157
+ if target_c > cur_c:
1158
+ pad = target_c - cur_c
1159
+ pad_tensor = torch.zeros_like(z[:, :1, ...]).repeat(1, pad, *([1] * (z.ndim - 2)))
1160
+ z_fixed = torch.cat([z, pad_tensor], dim=1)
1161
+ else:
1162
+ z_fixed = z[:, :target_c, ...]
1163
+ latent = {**latent, "samples": z_fixed}
1164
+ except Exception:
1165
+ pass
1166
+ return latent
1167
+
1168
+
1169
+ def _harmonize_cond_tokens(cond_list):
1170
+ """Pad/truncate cond tokens + masks to a common length to avoid mismatches (e.g., 499 vs 528 or 981 vs 1286)."""
1171
+ if not isinstance(cond_list, list):
1172
+ return cond_list
1173
+ # pass 1: find max token length across cross_attn
1174
+ max_len = 0
1175
+ for c in cond_list:
1176
+ if isinstance(c, dict):
1177
+ ca = c.get("cross_attn", None)
1178
+ if ca is not None:
1179
+ try:
1180
+ max_len = max(max_len, int(ca.shape[1]))
1181
+ except Exception:
1182
+ pass
1183
+ if max_len <= 0:
1184
+ return cond_list
1185
+ fixed = []
1186
+ for c in cond_list:
1187
+ if not isinstance(c, dict):
1188
+ fixed.append(c)
1189
+ continue
1190
+ d = c.copy()
1191
+ ca = d.get("cross_attn", None)
1192
+ am = d.get("attention_mask", None)
1193
+ # Harmonize cross_attn length
1194
+ if ca is not None:
1195
+ try:
1196
+ ca_len = int(ca.shape[1])
1197
+ if ca_len < max_len:
1198
+ pad_shape = list(ca.shape)
1199
+ pad_shape[1] = max_len - ca_len
1200
+ ca_pad = torch.zeros(pad_shape, device=ca.device, dtype=ca.dtype)
1201
+ ca = torch.cat([ca, ca_pad], dim=1)
1202
+ elif ca_len > max_len:
1203
+ ca = ca[:, :max_len, ...]
1204
+ d["cross_attn"] = ca
1205
+ except Exception:
1206
+ pass
1207
+ # Harmonize mask length to cross_attn length
1208
+ if ca is not None:
1209
+ ca_len = int(ca.shape[1])
1210
+ if am is None:
1211
+ am = torch.ones((ca.shape[0], ca_len), device=ca.device, dtype=ca.dtype)
1212
+ try:
1213
+ am_len = int(am.shape[-1] if am.dim() == 2 else am.shape[1])
1214
+ if am_len < ca_len:
1215
+ pad = ca_len - am_len
1216
+ pad_shape = list(am.shape)
1217
+ pad_shape[-1] = pad
1218
+ pad_tensor = torch.zeros(pad_shape, device=am.device, dtype=am.dtype)
1219
+ am = torch.cat([am, pad_tensor], dim=-1)
1220
+ elif am_len > ca_len:
1221
+ am = am[..., :ca_len]
1222
+ d["attention_mask"] = am
1223
+ try:
1224
+ d["num_tokens"] = int(torch.count_nonzero(am, dim=-1).max().item())
1225
+ except Exception:
1226
+ d["num_tokens"] = ca_len
1227
+ except Exception:
1228
+ pass
1229
+ fixed.append(d)
1230
+ return fixed
1231
+
1232
+
1233
+ def _summarize_conds(label, conds):
1234
+ out = []
1235
+ if isinstance(conds, list):
1236
+ for idx, c in enumerate(conds):
1237
+ try:
1238
+ ca = c.get("cross_attn", None) if isinstance(c, dict) else None
1239
+ am = c.get("attention_mask", None) if isinstance(c, dict) else None
1240
+ out.append(f"{label}[{idx}]: ca={None if ca is None else list(ca.shape)}, am={None if am is None else list(am.shape)}")
1241
+ except Exception:
1242
+ pass
1243
+ return "; ".join(out)
1244
+
1245
+
1246
  def safe_encode(vae, img, tile=512, ovlp=64):
1247
  import math, torch.nn.functional as F
1248
  h, w = img.shape[1:3]
 
2437
  except Exception:
2438
  pass
2439
 
2440
+ # Align latent channels to VAE/model (e.g., Z_image/FLUX use 16ch latents)
2441
+ latent = _match_latent_channels(vae, latent, model)
2442
+
2443
+ # Harmonize cond token lengths to prevent rare MGHybrid size mismatches
2444
+ positive = _harmonize_cond_tokens(positive)
2445
+ negative = _harmonize_cond_tokens(negative)
2446
+
2447
  image = safe_decode(vae, latent, to_fp32=bool(vae_decode_fp32))
2448
  # allow user cancel right after initial decode
2449
  model_management.throw_exception_if_processing_interrupted()
 
2965
  )
2966
  # Prepare latent + noise like in MG_ZeSmartSampler
2967
  lat_img = current_latent["samples"]
2968
+ lat_img = _match_latent_channels(vae, {"samples": lat_img}, sampler_model)["samples"]
2969
  lat_img = _sample.fix_empty_latent_channels(sampler_model, lat_img)
2970
  batch_inds = current_latent.get("batch_index", None)
2971
  noise = _sample.prepare_noise(lat_img, int(iter_seed), batch_inds)
 
2984
  current_latent = {**current_latent}
2985
  current_latent["samples"] = samples
2986
  except Exception as e:
2987
+ try:
2988
+ print(f"[CADE2.5][MGHybrid][debug] sigmas={list(sigmas.shape)} lat={list(current_latent['samples'].shape)}")
2989
+ print(_summarize_conds("pos", positive))
2990
+ print(_summarize_conds("neg", negative))
2991
+ except Exception:
2992
+ pass
2993
+ try:
2994
+ traceback.print_exc()
2995
+ except Exception:
2996
+ pass
2997
  # Before any fallback, propagate user cancel if set
2998
  try:
2999
  model_management.throw_exception_if_processing_interrupted()
mod/hard/mg_cade25.py CHANGED
@@ -11,6 +11,7 @@ import torch
11
  import os
12
  import numpy as np
13
  import torch.nn.functional as F
 
14
 
15
  import nodes
16
  import comfy.model_management as model_management
 
11
  import os
12
  import numpy as np
13
  import torch.nn.functional as F
14
+ import traceback
15
 
16
  import nodes
17
  import comfy.model_management as model_management
mod/hard/mg_zesmart_sampler_v1_1.py CHANGED
@@ -33,7 +33,15 @@ def _build_hybrid_sigmas(model, steps: int, base_sampler: str, mode: str,
33
  sig_k = _samplers.calculate_sigmas(ms, "karras", steps)
34
  sig_b = _samplers.calculate_sigmas(ms, "beta", steps)
35
 
 
 
 
 
 
 
 
36
  mode = str(mode).lower()
 
37
  if mode == "karras":
38
  sig = sig_k
39
  elif mode == "beta":
@@ -54,6 +62,7 @@ def _build_hybrid_sigmas(model, steps: int, base_sampler: str, mode: str,
54
  new_steps = max(1, int(steps / max(1e-6, float(denoise))))
55
  sk = _samplers.calculate_sigmas(ms, "karras", new_steps)
56
  sb = _samplers.calculate_sigmas(ms, "beta", new_steps)
 
57
  if mode == "karras":
58
  sig_full = sk
59
  elif mode == "beta":
 
33
  sig_k = _samplers.calculate_sigmas(ms, "karras", steps)
34
  sig_b = _samplers.calculate_sigmas(ms, "beta", steps)
35
 
36
+ def _align_len(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
37
+ """Align two sigma schedules to the same length (use tail of longer)."""
38
+ if a.shape[0] == b.shape[0]:
39
+ return a, b
40
+ m = min(a.shape[0], b.shape[0])
41
+ return a[-m:], b[-m:]
42
+
43
  mode = str(mode).lower()
44
+ sig_k, sig_b = _align_len(sig_k, sig_b)
45
  if mode == "karras":
46
  sig = sig_k
47
  elif mode == "beta":
 
62
  new_steps = max(1, int(steps / max(1e-6, float(denoise))))
63
  sk = _samplers.calculate_sigmas(ms, "karras", new_steps)
64
  sb = _samplers.calculate_sigmas(ms, "beta", new_steps)
65
+ sk, sb = _align_len(sk, sb)
66
  if mode == "karras":
67
  sig_full = sk
68
  elif mode == "beta":
mod/mg_combinode.py CHANGED
@@ -275,13 +275,30 @@ class MagicNodesCombiNode:
275
  pos_text_expanded = _norm_prompt(_expand_dynamic(positive_prompt, int(dyn_seed), bool(dynamic_break_freeze)) if bool(dynamic_pos) else positive_prompt)
276
  neg_text_expanded = _norm_prompt(_expand_dynamic(negative_prompt, int(dyn_seed), bool(dynamic_break_freeze)) if bool(dynamic_neg) else negative_prompt)
277
 
 
 
 
 
 
 
278
  if use_checkpoint and checkpoint:
279
  checkpoint_path = folder_paths.get_full_path_or_raise("checkpoints", checkpoint)
280
  _unload_old_checkpoint(checkpoint_path)
281
  base_model, base_clip, vae = _load_checkpoint(checkpoint_path)
282
  model = base_model.clone()
283
- clip = base_clip.clone()
284
- clip_clean = base_clip.clone() # keep pristine CLIP for standard pipeline path
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  elif model_in and clip_in:
287
  _unload_old_checkpoint(None)
@@ -289,6 +306,8 @@ class MagicNodesCombiNode:
289
  clip = clip_in.clone()
290
  clip_clean = clip_in.clone()
291
  vae = vae_in
 
 
292
  else:
293
  raise Exception("No model selected!")
294
 
 
275
  pos_text_expanded = _norm_prompt(_expand_dynamic(positive_prompt, int(dyn_seed), bool(dynamic_break_freeze)) if bool(dynamic_pos) else positive_prompt)
276
  neg_text_expanded = _norm_prompt(_expand_dynamic(negative_prompt, int(dyn_seed), bool(dynamic_break_freeze)) if bool(dynamic_neg) else negative_prompt)
277
 
278
+ def _valid_vae(v):
279
+ try:
280
+ return (v is not None) and (getattr(v, "first_stage_model", None) is not None)
281
+ except Exception:
282
+ return False
283
+
284
  if use_checkpoint and checkpoint:
285
  checkpoint_path = folder_paths.get_full_path_or_raise("checkpoints", checkpoint)
286
  _unload_old_checkpoint(checkpoint_path)
287
  base_model, base_clip, vae = _load_checkpoint(checkpoint_path)
288
  model = base_model.clone()
289
+ # Some flow/DiT style checkpoints (e.g., Z_image) ship without CLIP/VAE.
290
+ clip_source = base_clip or clip_in
291
+ if clip_source is None:
292
+ raise Exception("Checkpoint has no CLIP. Connect a CLIP input node or use a checkpoint that bundles CLIP.")
293
+ clip = clip_source.clone()
294
+ clip_clean = clip_source.clone() # keep pristine CLIP for standard pipeline path
295
+ # Prefer external VAE when provided; some FLOW/DiT checkpoints return an invalid stub VAE.
296
+ for candidate in (vae_in, vae):
297
+ if _valid_vae(candidate):
298
+ vae = candidate
299
+ break
300
+ else:
301
+ raise Exception("Checkpoint has no valid VAE. Connect a VAE input node or use a checkpoint that bundles VAE.")
302
 
303
  elif model_in and clip_in:
304
  _unload_old_checkpoint(None)
 
306
  clip = clip_in.clone()
307
  clip_clean = clip_in.clone()
308
  vae = vae_in
309
+ if not _valid_vae(vae):
310
+ raise Exception("VAE input is missing or invalid. Please connect a proper VAE node.")
311
  else:
312
  raise Exception("No model selected!")
313