Spaces:
Sleeping
Sleeping
| from typing import Tuple, Union | |
| from diffusers.models.autoencoders.autoencoder_dc import EfficientViTBlock | |
| from diffusers.models.normalization import RMSNorm, get_normalization | |
| from diffusers.models.activations import get_activation | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| norm_type: str = "batch_norm", | |
| act_fn: str = "relu6", | |
| rank_ratio: int = 16, | |
| ) -> None: | |
| super().__init__() | |
| self.norm_type = norm_type | |
| self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity() | |
| if rank_ratio is None: | |
| self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) | |
| self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False) | |
| else: | |
| assert isinstance(rank_ratio, int) | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(in_channels, in_channels // rank_ratio, 3, 1, 1, bias=False), | |
| nn.Conv2d(in_channels // rank_ratio, in_channels, 3, 1, 1), | |
| ) | |
| self.conv2 = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels // rank_ratio, 3, 1, 1, bias=False), | |
| nn.Conv2d(out_channels // rank_ratio, out_channels, 3, 1, 1, bias=False), | |
| ) | |
| self.norm = get_normalization(norm_type, out_channels) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| residual = hidden_states | |
| hidden_states = self.conv1(hidden_states) | |
| hidden_states = self.nonlinearity(hidden_states) | |
| hidden_states = self.conv2(hidden_states) | |
| if self.norm_type == "rms_norm": | |
| # move channel to the last dimension so we apply RMSnorm across channel dimension | |
| hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1) | |
| else: | |
| hidden_states = self.norm(hidden_states) | |
| return hidden_states + residual | |
| def get_block( | |
| block_type: str, | |
| in_channels: int, | |
| out_channels: int, | |
| attention_head_dim: int, | |
| norm_type: str, | |
| act_fn: str, | |
| qkv_mutliscales: Tuple[int] = (), | |
| ): | |
| if block_type == "ResBlock": | |
| block = ResBlock(in_channels, out_channels, norm_type, act_fn) | |
| elif block_type == "EfficientViTBlock": | |
| block = EfficientViTBlock( | |
| in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales | |
| ) | |
| else: | |
| raise ValueError(f"Block with {block_type=} is not supported.") | |
| return block | |
| class DCDownBlock2d(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None: | |
| super().__init__() | |
| self.downsample = downsample | |
| self.factor = 2 | |
| self.stride = 1 if downsample else 2 | |
| self.group_size = in_channels * self.factor**2 // out_channels | |
| self.shortcut = shortcut | |
| out_ratio = self.factor**2 | |
| if downsample: | |
| assert out_channels % out_ratio == 0 | |
| out_channels = out_channels // out_ratio | |
| self.conv = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=self.stride, | |
| padding=1, | |
| ) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| x = self.conv(hidden_states) | |
| if self.downsample: | |
| x = F.pixel_unshuffle(x, self.factor) | |
| if self.shortcut: | |
| y = F.pixel_unshuffle(hidden_states, self.factor) | |
| y = y.unflatten(1, (-1, self.group_size)) | |
| y = y.mean(dim=2) | |
| hidden_states = x + y | |
| else: | |
| hidden_states = x | |
| return hidden_states | |
| class DCUpBlock2d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| interpolate: bool = False, | |
| shortcut: bool = True, | |
| interpolation_mode: str = "nearest", | |
| ) -> None: | |
| super().__init__() | |
| self.interpolate = interpolate | |
| self.interpolation_mode = interpolation_mode | |
| self.shortcut = shortcut | |
| self.factor = 2 | |
| self.repeats = out_channels * self.factor**2 // in_channels | |
| out_ratio = self.factor**2 | |
| if not interpolate: | |
| out_channels = out_channels * out_ratio | |
| self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| if self.interpolate: | |
| x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode) | |
| x = self.conv(x) | |
| else: | |
| x = self.conv(hidden_states) | |
| x = F.pixel_shuffle(x, self.factor) | |
| if self.shortcut: | |
| y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats) | |
| y = F.pixel_shuffle(y, self.factor) | |
| hidden_states = x + y | |
| else: | |
| hidden_states = x | |
| return hidden_states | |
| class DCAEEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| latent_channels: int, | |
| attention_head_dim: int = 32, | |
| block_type: Union[str, Tuple[str]] = "ResBlock", | |
| block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024), | |
| layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2), | |
| qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)), | |
| downsample_block_type: str = "pixel_unshuffle", | |
| out_shortcut: bool = True, | |
| ): | |
| super().__init__() | |
| num_blocks = len(block_out_channels) | |
| if isinstance(block_type, str): | |
| block_type = (block_type,) * num_blocks | |
| if layers_per_block[0] > 0: | |
| self.conv_in = nn.Conv2d( | |
| in_channels, | |
| block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1], | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| ) | |
| else: | |
| self.conv_in = DCDownBlock2d( | |
| in_channels=in_channels, | |
| out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1], | |
| downsample=downsample_block_type == "pixel_unshuffle", | |
| shortcut=False, | |
| ) | |
| down_blocks = [] | |
| for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)): | |
| down_block_list = [] | |
| for _ in range(num_layers): | |
| block = get_block( | |
| block_type[i], | |
| out_channel, | |
| out_channel, | |
| attention_head_dim=attention_head_dim, | |
| norm_type="rms_norm", | |
| act_fn="silu", | |
| qkv_mutliscales=qkv_multiscales[i], | |
| ) | |
| down_block_list.append(block) | |
| if i < num_blocks - 1 and num_layers > 0: | |
| downsample_block = DCDownBlock2d( | |
| in_channels=out_channel, | |
| out_channels=block_out_channels[i + 1], | |
| downsample=downsample_block_type == "pixel_unshuffle", | |
| shortcut=True, | |
| ) | |
| down_block_list.append(downsample_block) | |
| down_blocks.append(nn.Sequential(*down_block_list)) | |
| self.down_blocks = nn.ModuleList(down_blocks) | |
| self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1) | |
| self.out_shortcut = out_shortcut | |
| if out_shortcut: | |
| self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.conv_in(hidden_states) | |
| for down_block in self.down_blocks: | |
| hidden_states = down_block(hidden_states) | |
| if self.out_shortcut: | |
| x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size)) | |
| x = x.mean(dim=2) | |
| hidden_states = self.conv_out(hidden_states) + x | |
| else: | |
| hidden_states = self.conv_out(hidden_states) | |
| return hidden_states | |
| class DCAEDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| TYPE: str, | |
| in_channels: int, | |
| latent_channels: int, | |
| attention_head_dim: int = 32, | |
| block_type: Union[str, Tuple[str]] = "ResBlock", | |
| block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024), | |
| layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2), | |
| qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)), | |
| norm_type: Union[str, Tuple[str]] = "rms_norm", | |
| act_fn: Union[str, Tuple[str]] = "silu", | |
| upsample_block_type: str = "pixel_shuffle", | |
| in_shortcut: bool = True, | |
| ): | |
| super().__init__() | |
| num_blocks = len(block_out_channels) | |
| if isinstance(block_type, str): | |
| block_type = (block_type,) * num_blocks | |
| if isinstance(norm_type, str): | |
| norm_type = (norm_type,) * num_blocks | |
| if isinstance(act_fn, str): | |
| act_fn = (act_fn,) * num_blocks | |
| self.conv_in = nn.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1) | |
| self.in_shortcut = in_shortcut | |
| if in_shortcut: | |
| self.in_shortcut_repeats = block_out_channels[-1] // latent_channels | |
| up_blocks = [] | |
| for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))): | |
| up_block_list = [] | |
| if i < num_blocks - 1 and num_layers > 0: | |
| upsample_block = DCUpBlock2d( | |
| block_out_channels[i + 1], | |
| out_channel, | |
| interpolate=upsample_block_type == "interpolate", | |
| shortcut=True, | |
| ) | |
| up_block_list.append(upsample_block) | |
| for _ in range(num_layers): | |
| block = get_block( | |
| block_type[i], | |
| out_channel, | |
| out_channel, | |
| attention_head_dim=attention_head_dim, | |
| norm_type=norm_type[i], | |
| act_fn=act_fn[i], | |
| qkv_mutliscales=qkv_multiscales[i], | |
| ) | |
| up_block_list.append(block) | |
| up_blocks.insert(0, nn.Sequential(*up_block_list)) | |
| self.up_blocks = nn.ModuleList(up_blocks) | |
| channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1] | |
| self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True) | |
| self.conv_act = nn.ReLU() | |
| self.conv_out = None | |
| if layers_per_block[0] > 0: | |
| self.conv_out = nn.Conv2d(channels, in_channels, 3, 1, 1) | |
| else: | |
| self.conv_out = DCUpBlock2d( | |
| channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False | |
| ) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| if self.in_shortcut: | |
| x = hidden_states.repeat_interleave( | |
| self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats | |
| ) | |
| hidden_states = self.conv_in(hidden_states) + x | |
| else: | |
| hidden_states = self.conv_in(hidden_states) | |
| for up_block in reversed(self.up_blocks): | |
| hidden_states = up_block(hidden_states) | |
| hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) | |
| hidden_states = self.conv_act(hidden_states) | |
| hidden_states = self.conv_out(hidden_states) | |
| return hidden_states | |
| if __name__ == "__main__": | |
| device = torch.device("cuda:0") | |
| # case 1: from 32 x 8 x 8 to 3 x 256 x 256 | |
| data = torch.randn(1, 32, 8, 8).to(device) | |
| model = DCAEDecoder( | |
| in_channels=3, | |
| latent_channels=32, | |
| attention_head_dim=32, | |
| block_type=["ResBlock", "ResBlock", "ResBlock", | |
| "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock"], | |
| block_out_channels=[128, 256, 512, 512, 1024, 1024], | |
| layers_per_block=[3, 3, 3, 3, 3, 3], | |
| qkv_multiscales=[[], [], [], [5], [5], [5]], | |
| norm_type="rms_norm", | |
| act_fn="silu", | |
| upsample_block_type="interpolate", | |
| ).to(device) | |
| output = model(data) | |
| num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Number of parameters: {num_params / 1e6:.2f}M") | |
| print(f"Output shape: {output.shape}") | |
| # case 2: from 2 x 2 x 2 to 32 x 8 x 8 | |
| data = torch.randn(1, 2, 2, 2).to(device) | |
| model = DCAEDecoder( | |
| in_channels=32, | |
| latent_channels=2, | |
| attention_head_dim=32, | |
| block_type=["EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock"], | |
| block_out_channels=[256, 512, 512], | |
| layers_per_block=[6, 6, 6], | |
| qkv_multiscales=[[5], [5], [5]], | |
| norm_type="rms_norm", | |
| act_fn="silu", | |
| upsample_block_type="interpolate", | |
| ).to(device) | |
| output = model(data) | |
| num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Number of parameters: {num_params / 1e6:.2f}M") | |
| print(f"Output shape: {output.shape}") |