Flux.2-Dev and Flux.2-Tiny-Autoencoder latent shape differ

#6
by yuvrajrathore203 - opened

I am attempting to infer and perform a comparative analysis between the FLUX.2-Tiny-Autoencoder (TAEF) Decoder and the FLUX.2-Dev Encoder. However, the FLUX.2-Dev encoder produces a latent tensor with dimensions [1, 32, 128, 128]. Conversely, the FLUX.2-Tiny-Autoencoder expects an input latent of shape [1, 128, 64, 64]. How to get the FLUX.2-Dev encoded latent to a shape of FLUX.2-Tiny-Autoencoder encoders shape?

To make the FAL Flux2TinyAutoEncoder output shapes match the FLUX.2 DEV VAE, I think you need to pixel_shuffle after encoding and pixel_unshuffle before decoding. In my FLUX.2 AE comparisons, I used the following wrapper code:

from flux2_tiny_autoencoder import Flux2TinyAutoEncoder

class DotDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__

class FALTinyFLUX2AEWrapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.vae = Flux2TinyAutoEncoder.from_pretrained(
            "fal/FLUX.2-Tiny-AutoEncoder",
        ).to(device=torch.device("cuda"), dtype=torch.bfloat16)
        self.bn = nn.BatchNorm2d(128, affine=False, eps=0.0) # default bn
        self.config = DotDict(batch_norm_eps=self.bn.eps)
    def encode(self, x):
        x = F.pixel_shuffle(self.vae.encode(x).latent, 2)
        return DotDict(latent_dist=DotDict(sample=lambda:x))
    def decode(self, x):
        return DotDict(sample=self.vae.decode(F.pixel_unshuffle(x, 2)).sample)

fal_tiny_ae = FALTinyFLUX2AEWrapper()

Sign up or log in to comment