YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
esmfold2-trimul-kernel
Fused inference Triton kernel for ESMFold2's triangle multiplication — the
O(N³) hotspot of the folding trunk (124 sites in the full model). Packaged as a
kernels-library Hub kernel so
transformers can load it on demand; the pure-PyTorch block in transformers is the
fallback.
What it fuses
One Triton kernel for the whole TriangleMultiplicativeBlock:
norm_start → gated dual-GEMM (sigmoid(x@Wg)·(x@Wp)) → triangular einsum (bikd,bjkd→bijd) → norm_mix → proj_emit → output gate, with the delta intermediate
never written to HBM. bf16 in/out, fp32 accumulation. Forward + backward present;
registered inference-only in transformers.
How transformers uses it
TriangleMultiplicativeBlock is decorated @use_kernel_forward_from_hub( "ESMFold2TriangleMultiplication") and mapped to this repo in
integrations/hub_kernels.py (cuda, Mode.INFERENCE). The layer here
(ESMFold2TriangleMultiplication) reimplements that module's forward(pair_grid, visibility), reading its parameters (norm_start/norm_mix/proj_bundle/
proj_emit/proj_gate). Keep this layer in sync with the in-tree module's
attribute names and forward signature — that's the contract.
import torch
from transformers import ESMFold2Model
# use_kernels=True swaps in this kernel for the 124 trimul sites (CUDA + inference).
model = ESMFold2Model.from_pretrained(
"biohub/ESMFold2", dtype=torch.bfloat16, device_map="cuda", use_kernels=True
).eval()
out = model.infer_protein(seq)
Layout
The package name must match kernels' repo-derived name
(repo_id.split("/")[-1].replace("-", "_")), i.e. esmfold2_trimul_kernel for the
repo …/esmfold2-trimul-kernel, and build.toml's [general] name must match it too.
kernels.get_kernel loads from build/torch-universal/ (not torch-ext/).
build.toml # kernel-builder config (universal/Triton)
flake.nix # kernel-builder entry (verify vs current version)
torch-ext/esmfold2_trimul_kernel/ # source (read by kernel-builder)
__init__.py # exports the layer + the functional entry
layers.py # ESMFold2TriangleMultiplication (the Hub layer)
trimul_with_residual.py # kernel entrypoint
fused_dual_gemm.py # helper: gated dual GEMM
fused_ln_residual.py # helper: LN + transpose / residual-link epilogues
trimul_einsum_triton.py # helper: batched triangular einsum
build/torch-universal/esmfold2_trimul_kernel/ # loaded by kernels.get_kernel (same files)
Build & publish
The build/torch-universal/ dir checked in here is a hand-built universal layout
(the Triton package copied in — no compile step), which is sufficient for
kernels.get_kernel. To regenerate it properly with kernel-builder:
nix build .#bundle # or: kernel-builder build (see kernel-builder docs)
Mapped from transformers in integrations/hub_kernels.py under the layer name
ESMFold2TriangleMultiplication. Currently repo_id = Rocketknight1/esmfold2-trimul-kernel
(testing); move it to a kernels-community org and update repo_id before merging.
Validation
Swapped into all 124 TriangleMultiplicativeBlock instances of the real model
(biohub/ESMFold2, bf16, GPU), folds match the pure-PyTorch fallback within the
model's own non-determinism: ubiquitin 0.801 vs 0.799 pLDDT (Δ +0.002), GB1 0.849 vs
0.849, pTM identical. Standalone microbench (dim=128, B=1): 5–37× over the chunked
fp32 fallback, gap growing with N (torch.compile of the fallback only reaches
~1–7×).
Follow-ups
- Residual-optional entry. The in-tree boundary is delta-only, so the layer
passes
residual=zeros_like(pair), costing one[B,N,N,C]alloc+read per call. Adding aresidual=Nonefast path (skip the in-kernel residual add) recovers that. - cuequivariance provides the same op (
triangle_multiplicative_update) as an alternative backend if a vendored-Triton kernel is undesirable.