| | import comfy.sd
|
| | import comfy.utils
|
| | import comfy.model_base
|
| | import comfy.model_management
|
| | import comfy.model_sampling
|
| |
|
| | import torch
|
| | import folder_paths
|
| | import json
|
| | import os
|
| |
|
| |
|
| | try:
|
| | from comfy.cli_args import args
|
| | except ImportError:
|
| | class ArgsMock:
|
| | disable_metadata = False
|
| | args = ArgsMock()
|
| |
|
| |
|
| | class ModelMergeSimple:
|
| | @classmethod
|
| | def INPUT_TYPES(s):
|
| | return {"required": { "model1": ("MODEL",),
|
| | "model2": ("MODEL",),
|
| | "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| | }}
|
| | RETURN_TYPES = ("MODEL",)
|
| | FUNCTION = "merge"
|
| |
|
| | CATEGORY = "advanced/model_merging"
|
| |
|
| | def merge(self, model1, model2, ratio):
|
| | m = model1.clone()
|
| | kp = model2.get_key_patches("diffusion_model.")
|
| | for k in kp:
|
| | m.add_patches({k: kp[k]}, ratio, 1.0 - ratio)
|
| | return (m, )
|
| |
|
| | class ModelMergeMultiSimple:
|
| | @classmethod
|
| | def INPUT_TYPES(s):
|
| | inputs = {"required": {}}
|
| | for i in range(1, 6):
|
| | inputs["required"][f"model{i}"] = ("MODEL",)
|
| | inputs["required"][f"ratio{i}"] = ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
| | return inputs
|
| |
|
| | RETURN_TYPES = ("MODEL",)
|
| | FUNCTION = "merge_five"
|
| |
|
| | CATEGORY = "advanced/model_merging"
|
| |
|
| | def merge_five(self, **kwargs):
|
| | models = []
|
| | ratios = []
|
| |
|
| | for i in range(1, 6):
|
| | model = kwargs.get(f"model{i}")
|
| | ratio = kwargs.get(f"ratio{i}")
|
| | if model is not None:
|
| | models.append(model)
|
| | ratios.append(ratio)
|
| | elif ratio > 0:
|
| |
|
| | print(f"Warning: Ratio {ratio} provided for model{i} but model is missing. Ignoring.")
|
| | ratios.append(0.0)
|
| |
|
| | if not models:
|
| | raise ValueError("No models provided for merging.")
|
| |
|
| |
|
| |
|
| | active_models_data = []
|
| | for model, ratio in zip(models, ratios):
|
| | if ratio > 0:
|
| | active_models_data.append({"model": model, "original_ratio": ratio})
|
| |
|
| | if not active_models_data:
|
| | print("Warning: All model ratios are 0. Returning the first provided model without changes.")
|
| | return (models[0].clone(), )
|
| |
|
| |
|
| | total_original_ratio = sum(item["original_ratio"] for item in active_models_data)
|
| |
|
| | if total_original_ratio == 0:
|
| | print("Warning: Sum of active model ratios is 0. Returning the first provided model.")
|
| | return (models[0].clone(), )
|
| |
|
| |
|
| | normalized_ratios = [item["original_ratio"] / total_original_ratio for item in active_models_data]
|
| |
|
| |
|
| | merged_model = active_models_data[0]["model"].clone()
|
| |
|
| | if len(active_models_data) == 1:
|
| |
|
| | return (merged_model,)
|
| |
|
| | current_cumulative_normalized_weight = normalized_ratios[0]
|
| |
|
| |
|
| | for i in range(1, len(active_models_data)):
|
| | next_model_data = active_models_data[i]
|
| | next_model_normalized_weight = normalized_ratios[i]
|
| |
|
| |
|
| |
|
| | if current_cumulative_normalized_weight == 0 and i==0 :
|
| | merged_model = next_model_data["model"].clone()
|
| | current_cumulative_normalized_weight = next_model_normalized_weight
|
| | continue
|
| |
|
| |
|
| |
|
| |
|
| | denominator = current_cumulative_normalized_weight + next_model_normalized_weight
|
| |
|
| | if denominator == 0:
|
| | continue
|
| |
|
| |
|
| | strength_for_next_model = next_model_normalized_weight / denominator
|
| |
|
| | strength_for_merged_model_self = current_cumulative_normalized_weight / denominator
|
| |
|
| | key_patches = next_model_data["model"].get_key_patches("diffusion_model.")
|
| |
|
| |
|
| |
|
| | for k in key_patches:
|
| | merged_model.add_patches({k: key_patches[k]}, strength_for_next_model, strength_for_merged_model_self)
|
| |
|
| | current_cumulative_normalized_weight += next_model_normalized_weight
|
| |
|
| |
|
| |
|
| |
|
| | return (merged_model,)
|
| |
|
| |
|
| |
|
| | class ModelSubtract:
|
| | @classmethod
|
| | def INPUT_TYPES(s):
|
| | return {"required": { "model1": ("MODEL",),
|
| | "model2": ("MODEL",),
|
| | "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
| | }}
|
| | RETURN_TYPES = ("MODEL",)
|
| | FUNCTION = "merge"
|
| | CATEGORY = "advanced/model_merging"
|
| | def merge(self, model1, model2, multiplier):
|
| | m = model1.clone()
|
| | kp = model2.get_key_patches("diffusion_model.")
|
| | for k in kp:
|
| | m.add_patches({k: kp[k]}, multiplier, -multiplier)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | m = model1.clone()
|
| | kp = model2.get_key_patches("diffusion_model.")
|
| | for k in kp:
|
| | m.add_patches({k: kp[k]}, -multiplier, 1.0)
|
| | return (m, )
|
| |
|
| | class ModelAdd:
|
| | @classmethod
|
| | def INPUT_TYPES(s):
|
| | return {"required": { "model1": ("MODEL",),
|
| | "model2": ("MODEL",),
|
| | }}
|
| | RETURN_TYPES = ("MODEL",)
|
| | FUNCTION = "merge"
|
| | CATEGORY = "advanced/model_merging"
|
| | def merge(self, model1, model2):
|
| | m = model1.clone()
|
| | kp = model2.get_key_patches("diffusion_model.")
|
| | for k in kp:
|
| | m.add_patches({k: kp[k]}, 1.0, 1.0)
|
| | return (m, )
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class CLIPMergeSimple:
|
| | @classmethod
|
| | def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }}
|
| | RETURN_TYPES = ("CLIP",)
|
| | FUNCTION = "merge"
|
| | CATEGORY = "advanced/model_merging"
|
| | def merge(self, clip1, clip2, ratio): return (clip1, )
|
| |
|
| | class CLIPSubtract:
|
| | @classmethod
|
| | def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),}}
|
| | RETURN_TYPES = ("CLIP",)
|
| | FUNCTION = "merge"
|
| | CATEGORY = "advanced/model_merging"
|
| | def merge(self, clip1, clip2, multiplier): return (clip1,)
|
| |
|
| | class CLIPAdd:
|
| | @classmethod
|
| | def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),}}
|
| | RETURN_TYPES = ("CLIP",)
|
| | FUNCTION = "merge"
|
| | CATEGORY = "advanced/model_merging"
|
| | def merge(self, clip1, clip2): return (clip1,)
|
| |
|
| | class ModelMergeBlocks:
|
| | @classmethod
|
| | def INPUT_TYPES(s): return {"required": { "model1": ("MODEL",),"model2": ("MODEL",),"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})}}
|
| | RETURN_TYPES = ("MODEL",)
|
| | FUNCTION = "merge"
|
| | CATEGORY = "advanced/model_merging"
|
| | def merge(self, model1, model2, **kwargs): return (model1,)
|
| |
|
| | def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): pass
|
| |
|
| | class CheckpointSave:
|
| | def __init__(self): self.output_dir = folder_paths.get_output_directory()
|
| | @classmethod
|
| | def INPUT_TYPES(s): return {"required": { "model": ("MODEL",),"clip": ("CLIP",),"vae": ("VAE",),"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
| | RETURN_TYPES = ()
|
| | FUNCTION = "save"
|
| | OUTPUT_NODE = True
|
| | CATEGORY = "advanced/model_merging"
|
| | def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): return {}
|
| |
|
| | class CLIPSave:
|
| | def __init__(self): self.output_dir = folder_paths.get_output_directory()
|
| | @classmethod
|
| | def INPUT_TYPES(s): return {"required": { "clip": ("CLIP",),"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
| | RETURN_TYPES = ()
|
| | FUNCTION = "save"
|
| | OUTPUT_NODE = True
|
| | CATEGORY = "advanced/model_merging"
|
| | def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): return {}
|
| |
|
| | class VAESave:
|
| | def __init__(self): self.output_dir = folder_paths.get_output_directory()
|
| | @classmethod
|
| | def INPUT_TYPES(s): return {"required": { "vae": ("VAE",),"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
| | RETURN_TYPES = ()
|
| | FUNCTION = "save"
|
| | OUTPUT_NODE = True
|
| | CATEGORY = "advanced/model_merging"
|
| | def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): return {}
|
| |
|
| | class ModelSave:
|
| | def __init__(self): self.output_dir = folder_paths.get_output_directory()
|
| | @classmethod
|
| | def INPUT_TYPES(s): return {"required": { "model": ("MODEL",),"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
| | RETURN_TYPES = ()
|
| | FUNCTION = "save"
|
| | OUTPUT_NODE = True
|
| | CATEGORY = "advanced/model_merging"
|
| | def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None): return {}
|
| |
|
| |
|
| | NODE_CLASS_MAPPINGS = {
|
| | "ModelMergeSimple": ModelMergeSimple,
|
| | "ModelMergeMultiSimple": ModelMergeMultiSimple,
|
| | "ModelMergeBlocks": ModelMergeBlocks,
|
| | "ModelMergeSubtract": ModelSubtract,
|
| | "ModelMergeAdd": ModelAdd,
|
| | "CheckpointSave": CheckpointSave,
|
| | "CLIPMergeSimple": CLIPMergeSimple,
|
| | "CLIPMergeSubtract": CLIPSubtract,
|
| | "CLIPMergeAdd": CLIPAdd,
|
| | "CLIPSave": CLIPSave,
|
| | "VAESave": VAESave,
|
| | "ModelSave": ModelSave,
|
| | }
|
| |
|
| | NODE_DISPLAY_NAME_MAPPINGS = {
|
| | "ModelMergeSimple": "Model Merge Simple (2 Models)",
|
| | "ModelMergeMultiSimple": "Model Merge Multi Simple (5 Models)",
|
| | "ModelMergeBlocks": "Model Merge Blocks",
|
| | "ModelMergeSubtract": "Model Subtract",
|
| | "ModelMergeAdd": "Model Add",
|
| | "CheckpointSave": "Save Checkpoint",
|
| | "CLIPMergeSimple": "CLIP Merge Simple",
|
| | "CLIPMergeSubtract": "CLIP Subtract",
|
| | "CLIPMergeAdd": "CLIP Add",
|
| | "CLIPSave": "CLIP Save",
|
| | "VAESave": "VAE Save",
|
| | "ModelSave": "Model Save",
|
| | }
|
| |
|
| | print("Custom model merging nodes loaded.") |