Kernels
sae
File size: 2,991 Bytes
a262a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a38f7ad
 
a262a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a38f7ad
a262a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a38f7ad
a262a48
 
 
2b62ea9
 
a262a48
 
 
 
 
 
 
 
 
 
a38f7ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# /// script
# dependencies = [
#   "torch",
#   "numpy",
#   "kernels",
# ]
# ///

import torch
import numpy as np
from kernels import get_kernel

flex = get_kernel("t-tech/flex-sae")  #Fast Kernels

@torch.compile(fullgraph=True)
def hierarchical_sae_loss(
    indices: torch.Tensor,  # [B, K]
    weight: torch.Tensor,  # [F, D]
    vals: torch.Tensor,  # [B, K]
    bias: torch.Tensor,  # [D]
    target: torch.Tensor,  # [B, D]
) -> torch.Tensor:
    emb = weight[indices].to(torch.float32)  # [K, D]
    recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
    diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
    loss = diff.pow(2).mean()
    return loss


B = 2048
K = 256
F = 1024 * 128
D = 1024
WARMUP = 5
NUM_ITER = 100
dtype = torch.float32

vals = None
decoder = None
bias = None
target = None
indices = None


def init_parameters():
    global vals, decoder, bias, target, indices
    vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_()
    decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True)
    bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True)
    target = torch.randn(B, D, dtype=dtype, device="cuda")
    indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda")


timing_kernel = []
timing_vanilla = []
torch.cuda.reset_peak_memory_stats()
loss_kernel_list = torch.zeros((100,))
loss_vanilla_list = torch.zeros((100,))


def zero_grad():
    vals.grad = None
    decoder.grad = None
    bias.grad = None
    torch.cuda.empty_cache()


for i in range(NUM_ITER + WARMUP):
    init_parameters()
    start_kernel = torch.cuda.Event(enable_timing=True)
    end_kernel = torch.cuda.Event(enable_timing=True)
    start_vanilla = torch.cuda.Event(enable_timing=True)
    end_vanilla = torch.cuda.Event(enable_timing=True)

    start_kernel.record()
    loss_kernel = flex.triton_hierarchical_sae_loss(indices, decoder, vals, bias, target)
    loss_kernel.backward()
    end_kernel.record()

    zero_grad()
    start_vanilla.record()
    loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
    loss_vanilla.backward()
    end_vanilla.record()
    if i >= WARMUP:
        torch.cuda.synchronize()
        timing_kernel.append(start_kernel.elapsed_time(end_kernel))
        timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
        loss_kernel_list[i-WARMUP] = loss_kernel.detach()
        loss_vanilla_list[i-WARMUP] = loss_vanilla.detach()
    zero_grad()

if torch.allclose(loss_kernel, loss_vanilla):
    print("βœ… Outputs are close! Everything is good! πŸŽ‰")
else:
    print("❌ Outputs mismatch... βš οΈπŸ€”")


print(f"🦎 Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} ± {np.std(timing_kernel):.4f} ms")
print(f"πŸ”₯ Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} Β± {np.std(timing_vanilla):.4f} ms")
print(f"πŸš€ Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x")