from itertools import chain import torch from torch import nn from diffusers.models.attention_processor import ( Attention, AttentionProcessor, ) from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel import torch.nn.functional as F from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from diffusers.models.attention_processor import Attention import inspect from functools import partial from diffusers.models.normalization import RMSNorm from typing import Any, Dict, List, Optional, Union import torch import torch.nn as nn class IPFluxAttnProcessor2_0(nn.Module): """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, num_heads=0): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False) def __call__( self, attn, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, ip_encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, layer_scale: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ip_hidden_states = ip_encoder_hidden_states # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # handle IP attention FIRST # for ip-adapter if ip_hidden_states != None: ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) # reshaping to match query shape ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_key = self.norm_added_k(ip_key) # Using flux stype attention here ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, dropout_p=0.0, is_causal=False, attn_mask=None, ) # reshaping ip_hidden_states in the same way as hidden_states ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q( encoder_hidden_states_query_proj ) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k( encoder_hidden_states_key_proj ) # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: from diffusers.models.embeddings import apply_rotary_emb query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) if attention_mask is not None: attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = (attention_mask > 0).bool() attention_mask = attention_mask.to( device=hidden_states.device, dtype=query.dtype ) original_hidden_states = hidden_states hidden_states = F.scaled_dot_product_attention( query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask, ) hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) hidden_states = hidden_states.to(query.dtype) layer_scale = layer_scale.view(-1, 1, 1) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = ( hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :], ) # Final injection of ip addapter hidden_states if ip_hidden_states != None: hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states else: # Final injection of ip addapter hidden_states if ip_hidden_states != None: hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states if attn.to_out is not None: hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states class ImageProjModel(nn.Module): def __init__(self, clip_dim=768, cross_attention_dim=4096, num_tokens=16): super().__init__() self.num_tokens = num_tokens self.cross_attention_dim = cross_attention_dim self.clip_dim = clip_dim self.proj = torch.nn.Sequential( torch.nn.Linear(clip_dim,clip_dim*2), torch.nn.GELU(), torch.nn.Linear(clip_dim*2, cross_attention_dim*num_tokens), ) self.norm = torch.nn.LayerNorm(cross_attention_dim) def forward(self,input): raw_proj = self.proj(input) reshaped_proj = raw_proj.reshape(input.shape[0],self.num_tokens,self.cross_attention_dim) reshaped_proj = self.norm( reshaped_proj ) return reshaped_proj class LibreFluxIPAdapter(nn.Module): def __init__(self, transformer, image_proj_model, checkpoint=None): super().__init__() self.transformer = transformer self.image_proj_model = image_proj_model # Using startswith uses only double transformer blocks, and skips the single transformer blocks self.culled_transformer_blocks = {} for name, module in self.transformer.named_modules(): if isinstance(module, Attention): if name.startswith('transformer_blocks') or name.startswith('single_transformer_blocks'): #print (f"Using Transformer: {name}") self.culled_transformer_blocks[name] = module #else: #print (f"Ignoring Transformer: {name}") # Apply the adapter to the culled blocks self.wrap_attention_blocks() if checkpoint: self.load_from_checkpoint(checkpoint) def wrap_attention_blocks(self,scale=1.0, num_tokens=16): """ Inject the IP-Adapter modules into the Transformer model """ sample_attn = self.transformer.transformer_blocks[0].attn hidden_size = sample_attn.inner_dim cross_attention_dim = sample_attn.cross_attention_dim num_heads = sample_attn.heads scale = 1.0 num_tokens = 16 processor_list = [] for name in self.culled_transformer_blocks: module = self.culled_transformer_blocks[name] module.processor = IPFluxAttnProcessor2_0( hidden_size= hidden_size, cross_attention_dim=4096, num_heads=num_heads, scale=1.0, num_tokens=16, ) processor_list.append(module.processor ) lay_count = len(processor_list) print (f"Added Attention IP Wrapper to {lay_count} layers") # Store adapters as a module list for saving/loading self.adapter_modules = torch.nn.ModuleList(processor_list) def parameters(self): """ Easy way to return all params """ # Apply adapter adapter_param_list = [] for name in self.culled_transformer_blocks: module = self.culled_transformer_blocks[name] adapter_param_list.append(module.processor.parameters()) all_params = chain(*adapter_param_list,self.image_proj_model.parameters()) return all_params def forward(self, ref_image, *args, layer_scale= torch.Tensor([1.0]), **kwargs): """ Run projection and run forward """ mod_dtype = next(self.image_proj_model.parameters()).dtype mod_device = next(self.image_proj_model.parameters()).device ip_encoder_hidden_states = None if ref_image != None: ip_encoder_hidden_states = self.image_proj_model(ref_image) # Add ip hidden states to kwargs if 'joint_attention_kwargs' not in kwargs: kwargs['joint_attention_kwargs'] = {} layer_scale = layer_scale.to(dtype=mod_dtype, device=mod_device) kwargs['joint_attention_kwargs']['ip_layer_scale'] = layer_scale kwargs['joint_attention_kwargs']['ip_hidden_states'] = ip_encoder_hidden_states output = self.transformer(*args, **kwargs) return output def save_pretrained(self,ckpt_path): """ Save model weights """ state_dict = {} state_dict["image_proj"] = self.image_proj_model.state_dict() state_dict["ip_adapter"] = self.adapter_modules.state_dict() torch.save(state_dict, ckpt_path) def load_from_checkpoint(self, ckpt_path): """ Loader ripped from tencent repo """ # Calculate original checksums orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) state_dict = torch.load(ckpt_path, map_location="cpu") # Load state dict for image_proj_model and adapter_modules self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) # Calculate new checksums new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) # Verify if the weights have changed assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" print(f"Successfully loaded weights from checkpoint {ckpt_path}") @property def dtype(self): return next(self.image_proj_model.parameters()).dtype