File size: 20,814 Bytes
68593e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 |
\section{Transformer-Based Compression and Decompression Architecture}
\label{sec:compression}
The compression-decompression pipeline forms the core bridge between high-dimensional ESM-2 embeddings and the efficient latent space required for flow matching generation. Our architecture employs a symmetric hourglass design with transformer self-attention and learned pooling operations to achieve 16× compression while preserving semantic protein information.
\subsection{Compression Architecture Overview}
The compressor $\mathcal{C}: \mathbb{R}^{L \times 1280} \rightarrow \mathbb{R}^{L/2 \times 80}$ transforms normalized ESM-2 embeddings into a compressed latent representation suitable for flow matching. The architecture follows a hourglass design inspired by ProtFlow, combining spatial pooling with transformer self-attention for optimal information preservation.
\subsubsection{Compressor Network Design}
\label{sec:compressor_design}
The compressor employs a four-stage architecture with symmetric transformer processing before and after spatial pooling:
\begin{align}
\mathbf{H}^{(0)} &= \text{LayerNorm}(\mathbf{H}^{(norm)}) \label{eq:comp_input_norm}\\
\mathbf{H}^{(pre)} &= \text{TransformerEncoder}_{\text{pre}}(\mathbf{H}^{(0)}) \label{eq:comp_pre_transformer}\\
\mathbf{H}^{(pool)} &= \text{HourglassPool}(\mathbf{H}^{(pre)}) \label{eq:comp_hourglass_pool}\\
\mathbf{H}^{(post)} &= \text{TransformerEncoder}_{\text{post}}(\mathbf{H}^{(pool)}) \label{eq:comp_post_transformer}\\
\mathbf{Z}^{(comp)} &= \tanh(\text{LayerNorm}(\mathbf{H}^{(post)}) \mathbf{W}^{(proj)} + \mathbf{b}^{(proj)}) \label{eq:comp_final_projection}
\end{align}
where both $\text{TransformerEncoder}_{\text{pre}}$ and $\text{TransformerEncoder}_{\text{post}}$ consist of 2 transformer layers each, maintaining the full ESM-2 dimensionality (1280) until the final projection.
\subsubsection{Hourglass Pooling Strategy}
\label{sec:hourglass_pooling}
The hourglass pooling operation reduces sequence length by exactly half while preserving local spatial relationships. This operation is crucial for computational efficiency in the flow matching process:
\begin{align}
\text{HourglassPool}(\mathbf{H}) &= \begin{cases}
\text{Pool}(\mathbf{H}[:, :L-1, :]) & \text{if } L \text{ is odd} \\
\text{Pool}(\mathbf{H}) & \text{if } L \text{ is even}
\end{cases} \label{eq:hourglass_length_handling}
\end{align}
The pooling operation groups adjacent residue positions and averages their representations:
\begin{align}
\mathbf{H}^{(grouped)} &= \text{Reshape}(\mathbf{H}, [B, L/2, 2, D]) \label{eq:reshape_for_pooling}\\
\mathbf{H}^{(pool)} &= \frac{1}{2}\sum_{k=1}^{2} \mathbf{H}^{(grouped)}[:, :, k, :] \label{eq:mean_pooling}
\end{align}
This pooling strategy preserves local sequence context while achieving the desired compression in sequence length.
\subsubsection{Final Projection and Activation}
\label{sec:comp_projection}
The final projection layer reduces dimensionality from 1280 to 80 (16× compression) with tanh activation to ensure bounded outputs:
\begin{align}
\mathbf{W}^{(proj)} &\in \mathbb{R}^{1280 \times 80}, \quad \mathbf{b}^{(proj)} \in \mathbb{R}^{80} \label{eq:projection_parameters}\\
\mathbf{Z}^{(comp)} &= \tanh(\mathbf{H}^{(post)} \mathbf{W}^{(proj)} + \mathbf{b}^{(proj)}) \in [-1, 1]^{L/2 \times 80} \label{eq:bounded_compression}
\end{align}
The tanh activation ensures that compressed embeddings remain in a bounded range, facilitating stable flow matching training.
\subsection{Decompression Architecture}
The decompressor $\mathcal{D}: \mathbb{R}^{L/2 \times 80} \rightarrow \mathbb{R}^{L \times 1280}$ reconstructs full-dimensional ESM-2 embeddings from compressed representations. The architecture mirrors the compressor with reverse operations: dimension expansion, spatial unpooling, and transformer refinement.
\subsubsection{Decompressor Network Design}
\label{sec:decompressor_design}
The decompressor employs a three-stage reconstruction process:
\begin{align}
\mathbf{H}^{(expanded)} &= \text{LayerNorm}(\mathbf{Z}^{(comp)}) \mathbf{W}^{(expand)} + \mathbf{b}^{(expand)} \label{eq:decomp_expansion}\\
\mathbf{H}^{(unpool)} &= \text{HourglassUnpool}(\mathbf{H}^{(expanded)}) \label{eq:decomp_unpooling}\\
\mathbf{H}^{(recon)} &= \text{TransformerEncoder}_{\text{decode}}(\mathbf{H}^{(unpool)}) \label{eq:decomp_transformer}
\end{align}
where $\mathbf{W}^{(expand)} \in \mathbb{R}^{80 \times 1280}$ and $\mathbf{b}^{(expand)} \in \mathbb{R}^{1280}$ expand the compressed representation back to ESM-2 dimensionality.
\subsubsection{Hourglass Unpooling Operation}
\label{sec:hourglass_unpooling}
The unpooling operation reverses the compression by duplicating each compressed position to restore the original sequence length:
\begin{align}
\text{HourglassUnpool}(\mathbf{H}^{(expanded)}) &= \text{repeat\_interleave}(\mathbf{H}^{(expanded)}, 2, \text{dim}=1) \label{eq:repeat_interleave}
\end{align}
This operation doubles the sequence length, restoring the spatial resolution lost during compression:
\begin{align}
\mathbf{H}^{(unpool)}[b, 2i, :] &= \mathbf{H}^{(expanded)}[b, i, :] \label{eq:unpool_even}\\
\mathbf{H}^{(unpool)}[b, 2i+1, :] &= \mathbf{H}^{(expanded)}[b, i, :] \label{eq:unpool_odd}
\end{align}
for $i = 0, 1, \ldots, L/2-1$, effectively creating identical copies for adjacent positions.
\subsubsection{Transformer Refinement}
\label{sec:decomp_refinement}
The final transformer encoder (2 layers) refines the unpooled representations to recover fine-grained positional information lost during compression:
\begin{align}
\mathbf{H}^{(recon)} = \text{TransformerEncoder}_{\text{decode}}(\mathbf{H}^{(unpool)}) \label{eq:refinement_transformer}
\end{align}
This refinement stage is crucial for recovering the subtle positional dependencies present in ESM-2 embeddings.
\subsection{Training Methodology and Optimization}
The compressor-decompressor pair is trained jointly using reconstruction loss with advanced optimization techniques for stable convergence.
\subsubsection{Reconstruction Loss Function}
\label{sec:reconstruction_loss}
The training objective minimizes mean squared error between original and reconstructed embeddings:
\begin{align}
\mathcal{L}_{\text{recon}}(\theta_{\mathcal{C}}, \theta_{\mathcal{D}}) &= \mathbb{E}_{\mathbf{H} \sim \mathcal{T}} \left[ \|\mathbf{H} - \mathcal{D}(\mathcal{C}(\mathbf{H}; \theta_{\mathcal{C}}); \theta_{\mathcal{D}})\|_2^2 \right] \label{eq:mse_loss}
\end{align}
where $\mathcal{T}$ represents the training dataset distribution and $\theta_{\mathcal{C}}, \theta_{\mathcal{D}}$ are the compressor and decompressor parameters respectively.
\subsubsection{Advanced Learning Rate Scheduling}
\label{sec:lr_scheduling}
Training employs a sophisticated learning rate schedule combining warmup and cosine annealing:
\begin{align}
\text{lr}_{\text{warmup}}(t) &= \text{lr}_{\max} \cdot \frac{t}{T_{\text{warmup}}} \quad \text{for } t \leq T_{\text{warmup}} \label{eq:warmup_lr}\\
\text{lr}_{\text{cosine}}(t) &= \text{lr}_{\min} + \frac{1}{2}(\text{lr}_{\max} - \text{lr}_{\min})\left(1 + \cos\left(\frac{\pi(t - T_{\text{warmup}})}{T_{\text{total}} - T_{\text{warmup}}}\right)\right) \label{eq:cosine_lr}
\end{align}
with hyperparameters: $\text{lr}_{\max} = 10^{-3}$, $\text{lr}_{\min} = 8 \times 10^{-5}$, $T_{\text{warmup}} = 10,000$ steps.
\subsubsection{Normalization and Regularization}
\label{sec:normalization_reg}
The architecture incorporates several regularization techniques:
\begin{itemize}
\item \textbf{Layer Normalization}: Applied before each major operation for training stability
\item \textbf{Dropout}: 0.1 dropout rate in transformer feedforward layers during training
\item \textbf{Weight Decay}: $10^{-4}$ weight decay in AdamW optimizer
\item \textbf{Gradient Clipping}: Maximum gradient norm of 1.0 to prevent exploding gradients
\end{itemize}
\subsection{Architecture Specifications}
\subsubsection{Transformer Layer Configuration}
\label{sec:transformer_config}
Both compressor and decompressor transformer layers share identical specifications:
\begin{itemize}
\item \textbf{Model Dimension}: $d_{\text{model}} = 1280$ (matching ESM-2)
\item \textbf{Attention Heads}: $n_{\text{heads}} = 8$
\item \textbf{Feedforward Dimension}: $d_{\text{ff}} = 5120$ (4× model dimension)
\item \textbf{Activation Function}: GELU in feedforward sublayers
\item \textbf{Layer Normalization}: Pre-normalization architecture
\item \textbf{Residual Connections}: Around each sublayer
\end{itemize}
\subsubsection{Memory and Computational Efficiency}
\label{sec:efficiency}
The compression architecture is optimized for computational efficiency:
\begin{itemize}
\item \textbf{Parameter Count}:
\begin{itemize}
\item Compressor: $\sim$52M parameters
\item Decompressor: $\sim$26M parameters
\item Total: $\sim$78M parameters
\end{itemize}
\item \textbf{Training Memory}: $\sim$12GB GPU memory for batch size 32
\item \textbf{Inference Speed}: $\sim$1000 sequences/second on A100 GPU
\item \textbf{Compression Ratio}: 16× reduction in embedding dimension
\item \textbf{Storage Savings}: 94% reduction in embedding storage requirements
\end{itemize}
\subsection{Performance Metrics and Validation}
\subsubsection{Reconstruction Quality}
\label{sec:reconstruction_quality}
The trained compressor-decompressor achieves high-fidelity reconstruction:
\begin{itemize}
\item \textbf{MSE Loss}: $< 0.01$ on validation set
\item \textbf{Cosine Similarity}: $> 0.95$ between original and reconstructed embeddings
\item \textbf{Pearson Correlation}: $> 0.98$ across all embedding dimensions
\item \textbf{Max Absolute Error}: $< 0.1$ per embedding component
\end{itemize}
\subsubsection{Downstream Task Preservation}
\label{sec:downstream_preservation}
Compressed embeddings maintain performance on downstream tasks:
\begin{itemize}
\item \textbf{AMP Classification}: $< 2\%$ accuracy drop using compressed embeddings
\item \textbf{Secondary Structure}: $< 3\%$ accuracy drop on DSSP prediction
\item \textbf{Contact Prediction}: $< 5\%$ precision drop on contact maps
\item \textbf{Homology Detection}: $< 1\%$ AUC drop on SCOP fold recognition
\end{itemize}
\begin{algorithm}[h]
\caption{Transformer-Based Compressor}
\label{alg:compressor}
\begin{algorithmic}[1]
\REQUIRE Normalized ESM-2 embeddings $\mathbf{H}^{(norm)} \in \mathbb{R}^{B \times L \times 1280}$
\REQUIRE Trained compressor parameters $\theta_{\mathcal{C}}$
\ENSURE Compressed embeddings $\mathbf{Z}^{(comp)} \in \mathbb{R}^{B \times L/2 \times 80}$
\STATE \textbf{// Stage 1: Input Normalization}
\STATE $\mathbf{H}^{(0)} \leftarrow \text{LayerNorm}(\mathbf{H}^{(norm)})$ \COMMENT{Stabilize input distributions}
\STATE \textbf{// Stage 2: Pre-Pooling Transformer Processing}
\FOR{$\ell = 1$ to $2$} \COMMENT{2 pre-pooling transformer layers}
\STATE $\mathbf{H}^{(\ell)} \leftarrow \text{MultiHeadAttention}(\mathbf{H}^{(\ell-1)}, \mathbf{H}^{(\ell-1)}, \mathbf{H}^{(\ell-1)})$
\STATE $\mathbf{H}^{(\ell)} \leftarrow \mathbf{H}^{(\ell-1)} + \text{Dropout}(\mathbf{H}^{(\ell)})$ \COMMENT{Residual connection}
\STATE $\mathbf{H}^{(\ell)} \leftarrow \text{LayerNorm}(\mathbf{H}^{(\ell)})$ \COMMENT{Post-attention normalization}
\STATE $\mathbf{F}^{(\ell)} \leftarrow \text{GELU}(\mathbf{H}^{(\ell)} \mathbf{W}_1^{(\ell)} + \mathbf{b}_1^{(\ell)}) \mathbf{W}_2^{(\ell)} + \mathbf{b}_2^{(\ell)}$ \COMMENT{FFN}
\STATE $\mathbf{H}^{(\ell)} \leftarrow \mathbf{H}^{(\ell)} + \text{Dropout}(\mathbf{F}^{(\ell)})$ \COMMENT{Residual connection}
\STATE $\mathbf{H}^{(\ell)} \leftarrow \text{LayerNorm}(\mathbf{H}^{(\ell)})$ \COMMENT{Post-FFN normalization}
\ENDFOR
\STATE $\mathbf{H}^{(pre)} \leftarrow \mathbf{H}^{(2)}$
\STATE \textbf{// Stage 3: Hourglass Pooling}
\IF{$L \bmod 2 = 1$} \COMMENT{Handle odd sequence lengths}
\STATE $\mathbf{H}^{(pre)} \leftarrow \mathbf{H}^{(pre)}[:, :L-1, :]$ \COMMENT{Remove last position}
\STATE $L \leftarrow L - 1$
\ENDIF
\STATE $\mathbf{H}^{(grouped)} \leftarrow \text{Reshape}(\mathbf{H}^{(pre)}, [B, L/2, 2, 1280])$
\STATE $\mathbf{H}^{(pool)} \leftarrow \text{Mean}(\mathbf{H}^{(grouped)}, \text{dim}=2)$ \COMMENT{Average adjacent positions}
\STATE \textbf{// Stage 4: Post-Pooling Transformer Processing}
\FOR{$\ell = 3$ to $4$} \COMMENT{2 post-pooling transformer layers}
\STATE \textbf{// Same transformer operations as pre-pooling layers}
\STATE $\mathbf{H}^{(\ell)} \leftarrow \text{TransformerLayer}(\mathbf{H}^{(\ell-1)})$
\ENDFOR
\STATE $\mathbf{H}^{(post)} \leftarrow \mathbf{H}^{(4)}$
\STATE \textbf{// Stage 5: Final Projection and Activation}
\STATE $\mathbf{H}^{(proj\_input)} \leftarrow \text{LayerNorm}(\mathbf{H}^{(post)})$
\STATE $\mathbf{Z}^{(comp)} \leftarrow \tanh(\mathbf{H}^{(proj\_input)} \mathbf{W}^{(proj)} + \mathbf{b}^{(proj)})$
\RETURN $\mathbf{Z}^{(comp)}$
\end{algorithmic}
\end{algorithm}
\begin{algorithm}[h]
\caption{Transformer-Based Decompressor}
\label{alg:decompressor}
\begin{algorithmic}[1]
\REQUIRE Compressed embeddings $\mathbf{Z}^{(comp)} \in \mathbb{R}^{B \times L/2 \times 80}$
\REQUIRE Trained decompressor parameters $\theta_{\mathcal{D}}$
\ENSURE Reconstructed embeddings $\mathbf{H}^{(recon)} \in \mathbb{R}^{B \times L \times 1280}$
\STATE \textbf{// Stage 1: Dimension Expansion}
\STATE $\mathbf{Z}^{(norm)} \leftarrow \text{LayerNorm}(\mathbf{Z}^{(comp)})$ \COMMENT{Normalize compressed input}
\STATE $\mathbf{H}^{(expanded)} \leftarrow \mathbf{Z}^{(norm)} \mathbf{W}^{(expand)} + \mathbf{b}^{(expand)}$ \COMMENT{80 → 1280 dimensions}
\STATE \textbf{// Stage 2: Hourglass Unpooling}
\STATE $\mathbf{H}^{(unpool)} \leftarrow \text{repeat\_interleave}(\mathbf{H}^{(expanded)}, 2, \text{dim}=1)$ \COMMENT{L/2 → L length}
\STATE \textbf{// Verify unpooling operation}
\FOR{$b = 1$ to $B$} \COMMENT{For each batch}
\FOR{$i = 0$ to $L/2-1$} \COMMENT{For each compressed position}
\STATE $\mathbf{H}^{(unpool)}[b, 2i, :] \leftarrow \mathbf{H}^{(expanded)}[b, i, :]$ \COMMENT{Even positions}
\STATE $\mathbf{H}^{(unpool)}[b, 2i+1, :] \leftarrow \mathbf{H}^{(expanded)}[b, i, :]$ \COMMENT{Odd positions}
\ENDFOR
\ENDFOR
\STATE \textbf{// Stage 3: Transformer Refinement}
\FOR{$\ell = 1$ to $2$} \COMMENT{2 refinement transformer layers}
\STATE $\mathbf{A}^{(\ell)} \leftarrow \text{MultiHeadAttention}(\mathbf{H}^{(\ell-1)}, \mathbf{H}^{(\ell-1)}, \mathbf{H}^{(\ell-1)})$
\STATE $\mathbf{H}^{(\ell)} \leftarrow \mathbf{H}^{(\ell-1)} + \text{Dropout}(\mathbf{A}^{(\ell)})$ \COMMENT{Residual connection}
\STATE $\mathbf{H}^{(\ell)} \leftarrow \text{LayerNorm}(\mathbf{H}^{(\ell)})$ \COMMENT{Post-attention normalization}
\STATE $\mathbf{F}^{(\ell)} \leftarrow \text{GELU}(\mathbf{H}^{(\ell)} \mathbf{W}_1^{(\ell)} + \mathbf{b}_1^{(\ell)}) \mathbf{W}_2^{(\ell)} + \mathbf{b}_2^{(\ell)}$
\STATE $\mathbf{H}^{(\ell)} \leftarrow \mathbf{H}^{(\ell)} + \text{Dropout}(\mathbf{F}^{(\ell)})$ \COMMENT{Residual connection}
\STATE $\mathbf{H}^{(\ell)} \leftarrow \text{LayerNorm}(\mathbf{H}^{(\ell)})$ \COMMENT{Post-FFN normalization}
\ENDFOR
\STATE $\mathbf{H}^{(recon)} \leftarrow \mathbf{H}^{(2)}$ \COMMENT{Final reconstructed embeddings}
\RETURN $\mathbf{H}^{(recon)}$
\end{algorithmic}
\end{algorithm}
\begin{algorithm}[h]
\caption{Joint Compressor-Decompressor Training}
\label{alg:joint_training}
\begin{algorithmic}[1]
\REQUIRE Training dataset $\mathcal{D} = \{\mathbf{H}_1^{(norm)}, \ldots, \mathbf{H}_N^{(norm)}\}$
\REQUIRE Hyperparameters: $\text{lr}_{\max}, \text{lr}_{\min}, T_{\text{warmup}}, T_{\text{total}}$
\ENSURE Trained compressor $\mathcal{C}(\cdot; \theta_{\mathcal{C}}^*)$ and decompressor $\mathcal{D}(\cdot; \theta_{\mathcal{D}}^*)$
\STATE \textbf{// Initialize models and optimizer}
\STATE $\theta_{\mathcal{C}}, \theta_{\mathcal{D}} \leftarrow \text{InitializeParameters}()$
\STATE $\text{optimizer} \leftarrow \text{AdamW}(\{\theta_{\mathcal{C}}, \theta_{\mathcal{D}}\}, \text{lr}=\text{lr}_{\max}, \text{weight\_decay}=10^{-4})$
\STATE \textbf{// Setup learning rate schedulers}
\STATE $\text{warmup\_sched} \leftarrow \text{LinearLR}(\text{start\_factor}=10^{-8}, \text{end\_factor}=1.0, \text{total\_iters}=T_{\text{warmup}})$
\STATE $\text{cosine\_sched} \leftarrow \text{CosineAnnealingLR}(T_{\max}=T_{\text{total}}, \eta_{\min}=\text{lr}_{\min})$
\STATE $\text{scheduler} \leftarrow \text{SequentialLR}([\text{warmup\_sched}, \text{cosine\_sched}], [T_{\text{warmup}}])$
\FOR{$\text{epoch} = 1$ to $\text{EPOCHS}$}
\STATE $\text{total\_loss} \leftarrow 0$
\FOR{$\mathbf{H}^{(batch)} \in \text{DataLoader}(\mathcal{D}, \text{batch\_size}=32, \text{shuffle}=\text{True})$}
\STATE \textbf{// Forward pass through compressor-decompressor}
\STATE $\mathbf{Z}^{(comp)} \leftarrow \mathcal{C}(\mathbf{H}^{(batch)}; \theta_{\mathcal{C}})$ \COMMENT{Compress}
\STATE $\mathbf{H}^{(recon)} \leftarrow \mathcal{D}(\mathbf{Z}^{(comp)}; \theta_{\mathcal{D}})$ \COMMENT{Decompress}
\STATE \textbf{// Compute reconstruction loss}
\STATE $\mathcal{L} \leftarrow \|\mathbf{H}^{(batch)} - \mathbf{H}^{(recon)}\|_2^2 / |\mathbf{H}^{(batch)}|$ \COMMENT{MSE loss}
\STATE \textbf{// Backward pass and optimization}
\STATE $\text{optimizer.zero\_grad()}$
\STATE $\mathcal{L}.\text{backward()}$
\STATE $\text{torch.nn.utils.clip\_grad\_norm\_}(\{\theta_{\mathcal{C}}, \theta_{\mathcal{D}}\}, \text{max\_norm}=1.0)$
\STATE $\text{optimizer.step()}$
\STATE $\text{scheduler.step()}$
\STATE $\text{total\_loss} \leftarrow \text{total\_loss} + \mathcal{L}.\text{item()}$
\ENDFOR
\STATE $\text{avg\_loss} \leftarrow \text{total\_loss} / |\text{DataLoader}|$
\STATE \textbf{print} $f$"Epoch \{epoch\}: Average MSE = \{avg\_loss:.6f\}"
\IF{$\text{epoch} \bmod 5 = 0$} \COMMENT{Save checkpoint every 5 epochs}
\STATE $\text{SaveCheckpoint}(\theta_{\mathcal{C}}, \theta_{\mathcal{D}}, \text{optimizer}, \text{avg\_loss}, \text{epoch})$
\ENDIF
\ENDFOR
\STATE \textbf{// Save final trained models}
\STATE $\text{SaveModel}(\theta_{\mathcal{C}}, \text{"final\_compressor\_model.pth"})$
\STATE $\text{SaveModel}(\theta_{\mathcal{D}}, \text{"final\_decompressor\_model.pth"})$
\RETURN $\theta_{\mathcal{C}}^*, \theta_{\mathcal{D}}^*$
\end{algorithmic}
\end{algorithm}
\begin{algorithm}[h]
\caption{Hourglass Pooling and Unpooling Operations}
\label{alg:hourglass_operations}
\begin{algorithmic}[1]
\REQUIRE Input tensor $\mathbf{X} \in \mathbb{R}^{B \times L \times D}$
\ENSURE Pooled tensor $\mathbf{X}^{(pool)} \in \mathbb{R}^{B \times L/2 \times D}$ and unpooled tensor $\mathbf{X}^{(unpool)} \in \mathbb{R}^{B \times L \times D}$
\STATE \textbf{// Hourglass Pooling Operation}
\FUNCTION{HourglassPool}{$\mathbf{X}$}
\STATE $B, L, D \leftarrow \mathbf{X}.\text{shape}$
\IF{$L \bmod 2 = 1$} \COMMENT{Handle odd sequence lengths}
\STATE $\mathbf{X} \leftarrow \mathbf{X}[:, :L-1, :]$ \COMMENT{Remove last position}
\STATE $L \leftarrow L - 1$
\ENDIF
\STATE $\mathbf{X}^{(grouped)} \leftarrow \text{Reshape}(\mathbf{X}, [B, L/2, 2, D])$ \COMMENT{Group adjacent positions}
\STATE $\mathbf{X}^{(pool)} \leftarrow \text{Mean}(\mathbf{X}^{(grouped)}, \text{dim}=2)$ \COMMENT{Average grouped positions}
\RETURN $\mathbf{X}^{(pool)}$
\ENDFUNCTION
\STATE \textbf{// Hourglass Unpooling Operation}
\FUNCTION{HourglassUnpool}{$\mathbf{X}^{(pool)}$}
\STATE $B, L_{pool}, D \leftarrow \mathbf{X}^{(pool)}.\text{shape}$
\STATE $L \leftarrow 2 \times L_{pool}$ \COMMENT{Double the sequence length}
\STATE $\mathbf{X}^{(unpool)} \leftarrow \text{repeat\_interleave}(\mathbf{X}^{(pool)}, 2, \text{dim}=1)$
\STATE \textbf{// Verify unpooling correctness}
\FOR{$b = 1$ to $B$}
\FOR{$i = 0$ to $L_{pool}-1$}
\STATE \textbf{assert} $\mathbf{X}^{(unpool)}[b, 2i, :] = \mathbf{X}^{(pool)}[b, i, :]$
\STATE \textbf{assert} $\mathbf{X}^{(unpool)}[b, 2i+1, :] = \mathbf{X}^{(pool)}[b, i, :]$
\ENDFOR
\ENDFOR
\RETURN $\mathbf{X}^{(unpool)}$
\ENDFUNCTION
\STATE \textbf{// Demonstrate invertibility}
\STATE $\mathbf{X}^{(pool)} \leftarrow \text{HourglassPool}(\mathbf{X})$
\STATE $\mathbf{X}^{(unpool)} \leftarrow \text{HourglassUnpool}(\mathbf{X}^{(pool)})$
\STATE \textbf{// Note: $\mathbf{X}^{(unpool)} \neq \mathbf{X}$ due to information loss in pooling}
\STATE \textbf{// But spatial structure is preserved through duplication}
\RETURN $\mathbf{X}^{(pool)}, \mathbf{X}^{(unpool)}$
\end{algorithmic}
\end{algorithm}
|