File size: 2,991 Bytes
a262a48 a38f7ad a262a48 a38f7ad a262a48 a38f7ad a262a48 2b62ea9 a262a48 a38f7ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
# /// script
# dependencies = [
# "torch",
# "numpy",
# "kernels",
# ]
# ///
import torch
import numpy as np
from kernels import get_kernel
flex = get_kernel("t-tech/flex-sae") #Fast Kernels
@torch.compile(fullgraph=True)
def hierarchical_sae_loss(
indices: torch.Tensor, # [B, K]
weight: torch.Tensor, # [F, D]
vals: torch.Tensor, # [B, K]
bias: torch.Tensor, # [D]
target: torch.Tensor, # [B, D]
) -> torch.Tensor:
emb = weight[indices].to(torch.float32) # [K, D]
recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
loss = diff.pow(2).mean()
return loss
B = 2048
K = 256
F = 1024 * 128
D = 1024
WARMUP = 5
NUM_ITER = 100
dtype = torch.float32
vals = None
decoder = None
bias = None
target = None
indices = None
def init_parameters():
global vals, decoder, bias, target, indices
vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_()
decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True)
bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True)
target = torch.randn(B, D, dtype=dtype, device="cuda")
indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda")
timing_kernel = []
timing_vanilla = []
torch.cuda.reset_peak_memory_stats()
loss_kernel_list = torch.zeros((100,))
loss_vanilla_list = torch.zeros((100,))
def zero_grad():
vals.grad = None
decoder.grad = None
bias.grad = None
torch.cuda.empty_cache()
for i in range(NUM_ITER + WARMUP):
init_parameters()
start_kernel = torch.cuda.Event(enable_timing=True)
end_kernel = torch.cuda.Event(enable_timing=True)
start_vanilla = torch.cuda.Event(enable_timing=True)
end_vanilla = torch.cuda.Event(enable_timing=True)
start_kernel.record()
loss_kernel = flex.triton_hierarchical_sae_loss(indices, decoder, vals, bias, target)
loss_kernel.backward()
end_kernel.record()
zero_grad()
start_vanilla.record()
loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
loss_vanilla.backward()
end_vanilla.record()
if i >= WARMUP:
torch.cuda.synchronize()
timing_kernel.append(start_kernel.elapsed_time(end_kernel))
timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
loss_kernel_list[i-WARMUP] = loss_kernel.detach()
loss_vanilla_list[i-WARMUP] = loss_vanilla.detach()
zero_grad()
if torch.allclose(loss_kernel, loss_vanilla):
print("β
Outputs are close! Everything is good! π")
else:
print("β Outputs mismatch... β οΈπ€")
print(f"π¦ Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} Β± {np.std(timing_kernel):.4f} ms")
print(f"π₯ Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} Β± {np.std(timing_vanilla):.4f} ms")
print(f"π Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x") |