YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

esmfold2-trimul-kernel

Fused inference Triton kernel for ESMFold2's triangle multiplication — the O(N³) hotspot of the folding trunk (124 sites in the full model). Packaged as a kernels-library Hub kernel so transformers can load it on demand; the pure-PyTorch block in transformers is the fallback.

What it fuses

One Triton kernel for the whole TriangleMultiplicativeBlock: norm_start → gated dual-GEMM (sigmoid(x@Wg)·(x@Wp)) → triangular einsum (bikd,bjkd→bijd) → norm_mix → proj_emit → output gate, with the delta intermediate never written to HBM. bf16 in/out, fp32 accumulation. Forward + backward present; registered inference-only in transformers.

How transformers uses it

TriangleMultiplicativeBlock is decorated @use_kernel_forward_from_hub( "ESMFold2TriangleMultiplication") and mapped to this repo in integrations/hub_kernels.py (cuda, Mode.INFERENCE). The layer here (ESMFold2TriangleMultiplication) reimplements that module's forward(pair_grid, visibility), reading its parameters (norm_start/norm_mix/proj_bundle/ proj_emit/proj_gate). Keep this layer in sync with the in-tree module's attribute names and forward signature — that's the contract.

import torch
from transformers import ESMFold2Model

# use_kernels=True swaps in this kernel for the 124 trimul sites (CUDA + inference).
model = ESMFold2Model.from_pretrained(
    "biohub/ESMFold2", dtype=torch.bfloat16, device_map="cuda", use_kernels=True
).eval()
out = model.infer_protein(seq)

Layout

The package name must match kernels' repo-derived name (repo_id.split("/")[-1].replace("-", "_")), i.e. esmfold2_trimul_kernel for the repo …/esmfold2-trimul-kernel, and build.toml's [general] name must match it too. kernels.get_kernel loads from build/torch-universal/ (not torch-ext/).

build.toml                                       # kernel-builder config (universal/Triton)
flake.nix                                        # kernel-builder entry (verify vs current version)
torch-ext/esmfold2_trimul_kernel/                # source (read by kernel-builder)
  __init__.py                                    # exports the layer + the functional entry
  layers.py                                      # ESMFold2TriangleMultiplication (the Hub layer)
  trimul_with_residual.py                        # kernel entrypoint
  fused_dual_gemm.py                             # helper: gated dual GEMM
  fused_ln_residual.py                           # helper: LN + transpose / residual-link epilogues
  trimul_einsum_triton.py                        # helper: batched triangular einsum
build/torch-universal/esmfold2_trimul_kernel/    # loaded by kernels.get_kernel (same files)

Build & publish

The build/torch-universal/ dir checked in here is a hand-built universal layout (the Triton package copied in — no compile step), which is sufficient for kernels.get_kernel. To regenerate it properly with kernel-builder:

nix build .#bundle      # or: kernel-builder build  (see kernel-builder docs)

Mapped from transformers in integrations/hub_kernels.py under the layer name ESMFold2TriangleMultiplication. Currently repo_id = Rocketknight1/esmfold2-trimul-kernel (testing); move it to a kernels-community org and update repo_id before merging.

Validation

Swapped into all 124 TriangleMultiplicativeBlock instances of the real model (biohub/ESMFold2, bf16, GPU), folds match the pure-PyTorch fallback within the model's own non-determinism: ubiquitin 0.801 vs 0.799 pLDDT (Δ +0.002), GB1 0.849 vs 0.849, pTM identical. Standalone microbench (dim=128, B=1): 5–37× over the chunked fp32 fallback, gap growing with N (torch.compile of the fallback only reaches ~1–7×).

Follow-ups

  • Residual-optional entry. The in-tree boundary is delta-only, so the layer passes residual=zeros_like(pair), costing one [B,N,N,C] alloc+read per call. Adding a residual=None fast path (skip the in-kernel residual add) recovers that.
  • cuequivariance provides the same op (triangle_multiplicative_update) as an alternative backend if a vendored-Triton kernel is undesirable.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support