Flux.2-Dev and Flux.2-Tiny-Autoencoder latent shape differ
#6
by
yuvrajrathore203
- opened
I am attempting to infer and perform a comparative analysis between the FLUX.2-Tiny-Autoencoder (TAEF) Decoder and the FLUX.2-Dev Encoder. However, the FLUX.2-Dev encoder produces a latent tensor with dimensions [1, 32, 128, 128]. Conversely, the FLUX.2-Tiny-Autoencoder expects an input latent of shape [1, 128, 64, 64]. How to get the FLUX.2-Dev encoded latent to a shape of FLUX.2-Tiny-Autoencoder encoders shape?
To make the FAL Flux2TinyAutoEncoder output shapes match the FLUX.2 DEV VAE, I think you need to pixel_shuffle after encoding and pixel_unshuffle before decoding. In my FLUX.2 AE comparisons, I used the following wrapper code:
from flux2_tiny_autoencoder import Flux2TinyAutoEncoder
class DotDict(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
class FALTinyFLUX2AEWrapper(nn.Module):
def __init__(self):
super().__init__()
self.vae = Flux2TinyAutoEncoder.from_pretrained(
"fal/FLUX.2-Tiny-AutoEncoder",
).to(device=torch.device("cuda"), dtype=torch.bfloat16)
self.bn = nn.BatchNorm2d(128, affine=False, eps=0.0) # default bn
self.config = DotDict(batch_norm_eps=self.bn.eps)
def encode(self, x):
x = F.pixel_shuffle(self.vae.encode(x).latent, 2)
return DotDict(latent_dist=DotDict(sample=lambda:x))
def decode(self, x):
return DotDict(sample=self.vae.decode(F.pixel_unshuffle(x, 2)).sample)
fal_tiny_ae = FALTinyFLUX2AEWrapper()