optimizing snake1d activation kernel in triton
optimizing snake1d activation kernel in triton

i've been working on optimizations of various components of SNAC (a multi scale neural audio codec) in triton. my eventual plan was to publish a worklog about all of the optimizations but i realized that it would make it very long and hence wanted to make the scope much more focused.
i've been working on optimizations of various components of SNAC (a multi scale neural audio codec) in triton. my eventual plan was to publish a worklog about all of the optimizations but i realized that it would make it very long and hence wanted to make the scope much more focused.
so here we'll iteratively optimize snake1d activation kernel on nvidia H100 SXM5 from torch eager to hand-tuned triton, one step at a time.
so here we'll iteratively optimize snake1d activation kernel on nvidia H100 SXM5 from torch eager to hand-tuned triton, one step at a time.
but before that, some background :
but before that, some background :
SNAC (multi-Scale neural audio codec) compresses audio into hierarchical discrete tokens. its decoder reconstructs waveforms from these tokens, and a critical activation function called snake1d is called 29 times per decode pass.
SNAC (multi-Scale neural audio codec) compresses audio into hierarchical discrete tokens. its decoder reconstructs waveforms from these tokens, and a critical activation function called snake1d is called 29 times per decode pass.

SNAC compresses audio into hierarchical discrete tokens using RVQ (right)
coming back to the point,
snake1d is a learned periodic activation:
SNAC compresses audio into hierarchical discrete tokens using RVQ (right)
coming back to the point,
snake1d is a learned periodic activation:
def snake(x, alpha): return x + (1 / alpha) * sin(alpha * x)^2
def snake(x, alpha): return x + (1 / alpha) * sin(alpha * x)^2
where x has shape [B, C, T] and alpha has shape [1, C, 1] (one learnable parameter per channel, broadcasted across batch and time).
where x has shape [B, C, T] and alpha has shape [1, C, 1] (one learnable parameter per channel, broadcasted across batch and time).
this is a pure elementwise operation, no reductions, no cross-element dependencies, mathematically written as:
this is a pure elementwise operation, no reductions, no cross-element dependencies, mathematically written as:
for each element we roughly perform 3 multiplications, 1 addition, 1 division and 1 sine evaluation. using common roofline approximations (add/mul ≈ 1 FLOP, div ≈ 4 FLOPs, sin ≈ 10 FLOPs), this gives roughly ~18 FLOPs per element.
for each element we roughly perform 3 multiplications, 1 addition, 1 division and 1 sine evaluation. using common roofline approximations (add/mul ≈ 1 FLOP, div ≈ 4 FLOPs, sin ≈ 10 FLOPs), this gives roughly ~18 FLOPs per element.
memory traffic per element is:
memory traffic per element is:
read
x(4 bytes)read
alpha(4 bytes)write
y(4 bytes)
read
x(4 bytes)read
alpha(4 bytes)write
y(4 bytes)
so total ≈ 12 bytes moved per element.
this yields an arithmetic intensity of: 18/12 ≈ 1.5 FLOPs/byte
this is well below the H100's ~590 FLOPs/byte (fp16) balance point. that means the optimization target is maximizing memory bandwidth utilization (H100 HBM3 peak: 3,350 GB/s).
so total ≈ 12 bytes moved per element.
this yields an arithmetic intensity of: 18/12 ≈ 1.5 FLOPs/byte
this is well below the H100's ~590 FLOPs/byte (fp16) balance point. that means the optimization target is maximizing memory bandwidth utilization (H100 HBM3 peak: 3,350 GB/s).

v1: naive 1D triton kernel
v1: naive 1D triton kernel
the simplest possible triton kernel is : treat the entire [B, C, T] tensor as a flat 1D array. each triton program processes BLOCK_SIZE consecutive elements. We compute the channel index from the flat offset to look up the correct alpha[c].
the simplest possible triton kernel is : treat the entire [B, C, T] tensor as a flat 1D array. each triton program processes BLOCK_SIZE consecutive elements. We compute the channel index from the flat offset to look up the correct alpha[c].
@triton.jit def snake_v1_kernel(x_ptr, alpha_ptr, out_ptr, B, C, T, stride_b, stride_c, stride_t, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) n_elements = B * C * T mask = offsets < n_elements # compute channel index: offset = b*C*T + c*T + t → c = (offset // T) % C c_idx = (offsets // T) % C x = tl.load(x_ptr + offsets, mask=mask) a = tl.load(alpha_ptr + c_idx, mask=mask) a_safe = a + 1e-9 sin_val = tl.sin(a_safe * x) result = x + (1.0 / a_safe) * sin_val * sin_val # pow(2) → mul tl.store(out_ptr + offsets, result, mask=mask)
@triton.jit def snake_v1_kernel(x_ptr, alpha_ptr, out_ptr, B, C, T, stride_b, stride_c, stride_t, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) n_elements = B * C * T mask = offsets < n_elements # compute channel index: offset = b*C*T + c*T + t → c = (offset // T) % C c_idx = (offsets // T) % C x = tl.load(x_ptr + offsets, mask=mask) a = tl.load(alpha_ptr + c_idx, mask=mask) a_safe = a + 1e-9 sin_val = tl.sin(a_safe * x) result = x + (1.0 / a_safe) * sin_val * sin_val # pow(2) → mul tl.store(out_ptr + offsets, result, mask=mask)
@triton.jit def snake_v1_kernel(x_ptr, alpha_ptr, out_ptr, B, C, T, stride_b, stride_c, stride_t, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) n_elements = B * C * T mask = offsets < n_elements # compute channel index: offset = b*C*T + c*T + t → c = (offset // T) % C c_idx = (offsets // T) % C x = tl.load(x_ptr + offsets, mask=mask) a = tl.load(alpha_ptr + c_idx, mask=mask) a_safe = a + 1e-9 sin_val = tl.sin(a_safe * x) result = x + (1.0 / a_safe) * sin_val * sin_val # pow(2) → mul tl.store(out_ptr + offsets, result, mask=mask)
to reduce unnecessary overhead, sin_val * sin_val is used instead of pow(2) to avoid the expensive power intrinsic. each element therefore loads x and its channel’s alpha, applies the Snake activation, and writes the result back. a default BLOCK_SIZE=1024 is used for simplicity, with an autotuned variant exploring block sizes from 256 to 8192 to find the best configuration for memory throughput.
to reduce unnecessary overhead, sin_val * sin_val is used instead of pow(2) to avoid the expensive power intrinsic. each element therefore loads x and its channel’s alpha, applies the Snake activation, and writes the result back. a default BLOCK_SIZE=1024 is used for simplicity, with an autotuned variant exploring block sizes from 256 to 8192 to find the best configuration for memory throughput.
results fp32 (realistic shapes)
results fp32 (realistic shapes)
shape | eager (ms) | eager bw | triton v1 | triton bw | speedup |
|---|---|---|---|---|---|
| [1, 1024, 236] | 0.0066 | 294 gb/sec | 0.0070 | 278 GB/s | 0.94x |
| [1, 512. 1888] | 0.0102 | 761 gb/sec | 0.0087 | 887 GB/s | 1.17x |
| [1, 256, 15104] | 0.0233 | 1329 gb/sec | 0.0173 | 1789 GB/s | 1.35x |
| [1, 128, 60416] | 0.0404 | 1533 gb/sec | 0.0274 | 2259 GB/s | 1.47x |
| [1, 64, 120832] | 0.0403 | 1534 gb/sec | 0.0276 | 2245 GB/s | 1.46x |
the naive triton kernel is 1.35-1.47x faster than the JIT-scripted eager on the shapes that matter most (the large ones called 7x each per decode).
the naive triton kernel is 1.35-1.47x faster than the JIT-scripted eager on the shapes that matter most (the large ones called 7x each per decode).

eager torch writes intermediate results to HBM between each operation. the triton kernel reads once, computes everything in registers, and writes once to HBM.
eager torch writes intermediate results to HBM between each operation. the triton kernel reads once, computes everything in registers, and writes once to HBM.
scaling tensor sizes (synthetic Shapes, fp32)
scaling tensor sizes (synthetic Shapes, fp32)
the current kernel's advantage grows with tensor size:
the current kernel's advantage grows with tensor size:
tensor elements | eager bw | triton v1 bw | % of h100 peak |
|---|---|---|---|
| 16k | 311 gb/s | 308 gb/s | 9% |
| 1m | 1329 gb/s | 1860 gb/s | 55% |
| 8m | 1675 gb/s | 2595 gb/s | 77% |
| 33m | 1762 gb/s | 2816 gb/s | 84% |
| 67m | 1798 gb/s | 2940 gb/s | 88% |
at 67M elements, we're hitting 88% of H100 peak memory bandwidth with a naive kernel!
at 67M elements, we're hitting 88% of H100 peak memory bandwidth with a naive kernel!
half-precision (fp16/bf16)
half-precision (fp16/bf16)
the speedup pattern holds for half-precision, with triton roughly 2x faster than eager:
the speedup pattern holds for half-precision, with triton roughly 2x faster than eager:
shape | fp16 eager | fp16 triton v1 | fp16 speedup |
|---|---|---|---|
| [1, 128, 60416] | 712 gb/s | 1449 gb/s | 2.0x |
| [1, 64, 120832] | 717 gb/s | 1444 gb/s | 2.0x |
fp16 bandwidths are lower than fp32 in absolute terms because the tensors are half the size (fewer bytes to transfer), so kernel launch overhead is relatively larger.
fp16 bandwidths are lower than fp32 in absolute terms because the tensors are half the size (fewer bytes to transfer), so kernel launch overhead is relatively larger.
autotuning
autotuning
the autotuned variant (sweeping BLOCK_SIZE from 256 to 8192) performs nearly identically to the fixed BLOCK_SIZE=1024 version. at these tensor sizes, the kernel is so memory-bound that block size barely matters, the memory controller is the bottleneck, not the compute scheduler.
the autotuned variant (sweeping BLOCK_SIZE from 256 to 8192) performs nearly identically to the fixed BLOCK_SIZE=1024 version. at these tensor sizes, the kernel is so memory-bound that block size barely matters, the memory controller is the bottleneck, not the compute scheduler.
v2: 2D grid (channel-aware scheduling)
v2: 2D grid (channel-aware scheduling)
in v1 kernel, the channel index is recovered from a flattened offset using:
in v1 kernel, the channel index is recovered from a flattened offset using:
c = (offset // T) % C
this requires an integer division and modulo per element, which are relatively expensive gpu operations (often ~10–20 cycles compared to ~1 cycle for add/mul). while the kernel is still memory-bound, this unnecessary integer arithmetic adds extra latency on the critical path.
this requires an integer division and modulo per element, which are relatively expensive gpu operations (often ~10–20 cycles compared to ~1 cycle for add/mul). while the kernel is still memory-bound, this unnecessary integer arithmetic adds extra latency on the critical path.
the solution is to change how the kernel is scheduled, rather than how the math is computed.
the solution is to change how the kernel is scheduled, rather than how the math is computed.
instead of flattening [B, C, T] into a single dimension, we launch a 2D grid:
instead of flattening [B, C, T] into a single dimension, we launch a 2D grid:
program_id(0) → indexes blocks along time
program_id(1) → indexes batch × channel
program_id(0) → indexes blocks along time
program_id(1) → indexes batch × channel
this means each program is responsible for processing a small time slice of a single (b, c) pair. because the channel is already known from the launch index, the kernel no longer needs to compute it via division.
this means each program is responsible for processing a small time slice of a single (b, c) pair. because the channel is already known from the launch index, the kernel no longer needs to compute it via division.
two useful things happen as a result:
two useful things happen as a result:
the expensive offset // T computation disappears entirely.
alpha[c] becomes a scalar load reused across the entire program, instead of being gathered per element.
the expensive offset // T computation disappears entirely.
alpha[c] becomes a scalar load reused across the entire program, instead of being gathered per element.
@triton.jit def snake_v2_kernel(x_ptr, alpha_ptr, out_ptr, B, C, T, stride_b, stride_c, stride_t, BLOCK_T: tl.constexpr): pid_t = tl.program_id(0) # time block index pid_bc = tl.program_id(1) # batch*channel index b = pid_bc // C c = pid_bc % C t_offsets = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) mask = t_offsets < T ptrs = b * stride_b + c * stride_c + t_offsets * stride_t x = tl.load(x_ptr + ptrs, mask=mask) a = tl.load(alpha_ptr + c) # scalar load — same alpha for entire program a_safe = a + 1e-9 sin_val = tl.sin(a_safe * x) result = x + (1.0 / a_safe) * sin_val * sin_val tl.store(out_ptr + ptrs, result, mask=mask)
@triton.jit def snake_v2_kernel(x_ptr, alpha_ptr, out_ptr, B, C, T, stride_b, stride_c, stride_t, BLOCK_T: tl.constexpr): pid_t = tl.program_id(0) # time block index pid_bc = tl.program_id(1) # batch*channel index b = pid_bc // C c = pid_bc % C t_offsets = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) mask = t_offsets < T ptrs = b * stride_b + c * stride_c + t_offsets * stride_t x = tl.load(x_ptr + ptrs, mask=mask) a = tl.load(alpha_ptr + c) # scalar load — same alpha for entire program a_safe = a + 1e-9 sin_val = tl.sin(a_safe * x) result = x + (1.0 / a_safe) * sin_val * sin_val tl.store(out_ptr + ptrs, result, mask=mask)
@triton.jit def snake_v2_kernel(x_ptr, alpha_ptr, out_ptr, B, C, T, stride_b, stride_c, stride_t, BLOCK_T: tl.constexpr): pid_t = tl.program_id(0) # time block index pid_bc = tl.program_id(1) # batch*channel index b = pid_bc // C c = pid_bc % C t_offsets = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) mask = t_offsets < T ptrs = b * stride_b + c * stride_c + t_offsets * stride_t x = tl.load(x_ptr + ptrs, mask=mask) a = tl.load(alpha_ptr + c) # scalar load — same alpha for entire program a_safe = a + 1e-9 sin_val = tl.sin(a_safe * x) result = x + (1.0 / a_safe) * sin_val * sin_val tl.store(out_ptr + ptrs, result, mask=mask)
in practice, this improves efficiency in two ways: 1. it removes unnecessary integer arithmetic and 2. this slightly reduces memory pressure by allowing the same alpha value to be reused for all elements processed by the program.
the computation itself remains unchanged, the optimization purely comes from better mapping the problem structure to the gpu launch grid.
in practice, this improves efficiency in two ways: 1. it removes unnecessary integer arithmetic and 2. this slightly reduces memory pressure by allowing the same alpha value to be reused for all elements processed by the program.
the computation itself remains unchanged, the optimization purely comes from better mapping the problem structure to the gpu launch grid.

v1 maps a flat 1D grid over B×C×T, each program must compute its channel index via integer division. v2 uses a 2D grid where program_id(1) directly indexes the channel, eliminating per-element division.
v1 maps a flat 1D grid over B×C×T, each program must compute its channel index via integer division. v2 uses a 2D grid where program_id(1) directly indexes the channel, eliminating per-element division.
results fp32 (on realistic shapes)
results fp32 (on realistic shapes)
shape | eager | v1 (1d) | v2 (2d) | v2 vs v1 |
|---|---|---|---|---|
| [1, 1024, 236] | 317 gb/s | 289 gb/s | 273 gb/s | -5% |
| [1, 512. 1888] | 790 gb/s | 922 gb/s | 939 gb/s | +2% |
| [1, 256, 15104] | 1348 gb/s | 1841 gb/s | 1834 gb/s | ~0% |
| [1, 128, 60416] | 1548 gb/s | 2251 gb/s | 2283 gb/s | +1% |
| [1, 64, 120832] | 1542 gb/s | 2289 gb/s | 2288 gb/s | ~0% |
for fp32, v2 is essentially the same as v1 on realistic shapes. the integer division overhead is negligible when each element is 4 bytes, the memory bandwidth is the bottleneck, not compute.
for fp32, v2 is essentially the same as v1 on realistic shapes. the integer division overhead is negligible when each element is 4 bytes, the memory bandwidth is the bottleneck, not compute.
results on fp16 (realistic shapes)
results on fp16 (realistic shapes)
shape | eager | v1 (1d) | v2 (2d) | v2 vs v1 |
|---|---|---|---|---|
| [1, 1024, 236] | 150 gb/s | 150 gb/s | 139 gb/s | -7% |
| [1, 512. 1888] | 373 gb/s | 495 gb/s | 520 gb/s | +5% |
| [1, 256, 15104] | 646 gb/s | 1137 gb/s | 1276 gb/s | +12% |
| [1, 128, 60416] | 725 gb/s | 1484 gb/s | 1720 gb/s | +16% |
| [1, 64, 120832] | 722 gb/s | 1465 gb/s | 1731 gb/s | +18% |
for fp16, v2 is significantly faster, almost up to 18% improvement on the largest shapes. with half-precision, the data is half the size but the integer division cost stays the same, so eliminating it has proportionally more impact.
for fp16, v2 is significantly faster, almost up to 18% improvement on the largest shapes. with half-precision, the data is half the size but the integer division cost stays the same, so eliminating it has proportionally more impact.
scaling tensor shapes (synthetic tensor shapes fp32)
scaling tensor shapes (synthetic tensor shapes fp32)
elements | v1 bw | v2 bw | v2 peak % |
|---|---|---|---|
| 16k | 319 gb/s | 310 gb/s | 9% |
| 1m | 1900 gb/s | 1926 gb/s | 57% |
| 8m | 2615 gb/s | 2625 gb/s | 78% |
| 67m | 2923 gb/s | 2920 gb/s | 87% |
elements | v1 bw | v2 bw | v2 peak % |
|---|---|---|---|
| 16k | 319 gb/s | 310 gb/s | 9% |
| 1m | 1900 gb/s | 1926 gb/s | 57% |
| 8m | 2615 gb/s | 2625 gb/s | 78% |
| 67m | 2923 gb/s | 2920 gb/s | 87% |
at fp32, the two kernels converge at large sizes: both hit ~87% of peak. the gap is in fp16 where v2 has a clear edge.
at fp32, the two kernels converge at large sizes: both hit ~87% of peak. the gap is in fp16 where v2 has a clear edge.
v3: triton autotune (finding the optimal launch configuration)
v3: triton autotune (finding the optimal launch configuration)
int v2 we established the 2D grid structure. but we used a fixed BLOCK_T=1024 with default num_warps=4 and num_stages=2. these parameters control:
in v2 we established the 2D grid structure. but we used a fixed BLOCK_T=1024 with default num_warps=4 and num_stages=2. these parameters control:
BLOCK_T: elements per program along the time axis (affects parallelism vs. overhead)num_warps: warps per SM (affects occupancy and register pressure)num_stages: software pipeline depth (prefetches loads while computing; higher = more registers)
BLOCK_T: elements per program along the time axis (affects parallelism vs. overhead)num_warps: warps per SM (affects occupancy and register pressure)num_stages: software pipeline depth (prefetches loads while computing; higher = more registers)
for a memory-bound kernel, the conventional wisdom is "fewer warps, fewer stages" because:
for a memory-bound kernel, the conventional wisdom is "fewer warps, fewer stages" because:
more warps → more register pressure → lower occupancy → fewer active threads
more stages → more registers for pipeline buffers → same problem
the gpu memory controller is already saturated i.e. additional compute parallelism doesn't help!
more warps → more register pressure → lower occupancy → fewer active threads
more stages → more registers for pipeline buffers → same problem
the gpu memory controller is already saturated i.e. additional compute parallelism doesn't help!
to explore this space, i did a systematic autotuning sweep across:
to explore this space, i did a systematic autotuning sweep across:
6 block sizes
6 block sizes
4 pipeline depths
6 block sizes
6 block sizes
4 pipeline depths
for a total of 120 launch configurations per tensor shape
for a total of 120 launch configurations per tensor shape
autotune results (selected configurations)
autotune results (selected configurations)
the autotuner reveals a clear pattern:
the autotuner reveals a clear pattern:
shape | elements | best config | why |
|---|---|---|---|
| [1, 1024, 236] | 242k | block=256, warps=2, stages=3 | tiny T, small blocks, few warps |
| [1, 512. 1888] | 967k | block=2048, warps=16, stages=4 | medium T, big blocks, max warps |
| [1, 256, 15104] | 3.9m | block=1024, warps=2, stages=4 | large T, moderate blocks, few warps |
| [1, 128, 60416] | 7.7m | block=256, warps=2, stages=4 | large T, few C, small blocks, few warps |
| [1, 64, 120832] | 7.7m | block=512, warps=2, stages=4 | largest T, small-ish blocks, few warps |
across the larger tensors (where the kernel clearly operates in a memory-bound regime), the autotuner consistently prefers:
across the larger tensors (where the kernel clearly operates in a memory-bound regime), the autotuner consistently prefers:
low warp counts (≈2 warps)
deeper pipelines (≈4 stages)
low warp counts (≈2 warps)
deeper pipelines (≈4 stages)
a lower warp count reduces register pressure per SM, which helps maintain higher occupancy. meanwhile, deeper software pipelines allow triton to issue memory loads earlier and overlap them with computation, improving the GPU’s ability to hide HBM latency.
a lower warp count reduces register pressure per SM, which helps maintain higher occupancy. meanwhile, deeper software pipelines allow triton to issue memory loads earlier and overlap them with computation, improving the GPU’s ability to hide HBM latency.
because the snake1d kernel performs very little computation per byte of data moved, performance is largely determined by how efficiently memory requests are issued and overlapped, rather than by raw compute parallelism.
because the snake1d kernel performs very little computation per byte of data moved, performance is largely determined by how efficiently memory requests are issued and overlapped, rather than by raw compute parallelism.

with num_stages=1, the SM stalls waiting for each memory load. with num_stages=4, loads for future iterations are prefetched while the current iteration computes, keeping the memory controller saturated. fewer warps (2 vs 8) leaves more registers per warp for pipeline buffers.
with num_stages=1, the SM stalls waiting for each memory load. with num_stages=4, loads for future iterations are prefetched while the current iteration computes, keeping the memory controller saturated. fewer warps (2 vs 8) leaves more registers per warp for pipeline buffers.
the exception is [1,512,1888] where 16 warps are chosen, this shape has many channels (512) but short time (1888), so there are 512 programs each handling only 1-2 blocks. more warps help fill the SM when there's less work per program.
the exception is [1,512,1888] where 16 warps are chosen, this shape has many channels (512) but short time (1888), so there are 512 programs each handling only 1-2 blocks. more warps help fill the SM when there's less work per program.
autotune results (selected configurations)
autotune results (selected configurations)
shape | eager | v2 (fixed) | v3 (autotuned) | v3 vs v2 |
|---|---|---|---|---|
| [1, 1024, 236] | 313 gb/s | 265 gb/s | 302 gb/s | +14% |
| [1, 512. 1888] | 797 gb/s | 925 gb/s | 928 gb/s | ~0% |
| [1, 256, 15104] | 1342 gb/s | 1848 gb/s | 1834 gb/s | ~0% |
| [1, 128, 60416] | 1539 gb/s | 2284 gb/s | 2284 gb/s | ~0% |
| [1, 64, 120832] | 1540 gb/s | 2285 gb/s | 2297 gb/s | +0.5% |
for fp32, v3 barely moves the needle on large shapes, the kernel was already near-optimal with the default config. the biggest win is on the smallest shape where the autotuner finds better launch parameters.
for fp32, v3 barely moves the needle on large shapes, the kernel was already near-optimal with the default config. the biggest win is on the smallest shape where the autotuner finds better launch parameters.
the verdict here is that autotuning provides marginal gains (~3% on fp16) and confirms our kernel is already well-configured. the 2D grid structure with 2 warps and 4 pipeline stages is the sweet spot for this memory-bound workload.
the verdict here is that autotuning provides marginal gains (~3% on fp16) and confirms our kernel is already well-configured. the 2D grid structure with 2 warps and 4 pipeline stages is the sweet spot for this memory-bound workload.
v4: fast sin approximation (inlined polynomial)
v4: fast sin approximation (inlined polynomial)
snake activation relies on the sin() function, which maps to a transcendental instruction on the gpu. these instructions are relatively expensive and have significantly higher latency than basic arithmetic operations.
snake activation relies on the sin() function, which maps to a transcendental instruction on the gpu. these instructions are relatively expensive and have significantly higher latency than basic arithmetic operations.
instead of calling tl.sin(), we inline a 7th-order polynomial approximation of sine. this replaces the transcendental operation with a sequence of fused multiply-add–friendly arithmetic instructions.
instead of calling tl.sin(), we inline a 7th-order polynomial approximation of sine. this replaces the transcendental operation with a sequence of fused multiply-add–friendly arithmetic instructions.
the approximation uses two steps:
the approximation uses two steps:
range reduction to map the input into the interval [−𝜋,𝜋]
Horner polynomial evaluation of the Taylor series
range reduction to map the input into the interval [−𝜋,𝜋]
Horner polynomial evaluation of the Taylor series
inlined in the kernel this becomes:
inlined in the kernel this becomes:
# range reduction: ax_mod = ax - round(ax / 2π) * 2π ax = a_safe * x n = tl.floor(ax * 0.15915494309189535 + 0.5) # round to nearest ax_mod = ax - n * 6.283185307179586 # Horner's method: sin(x) ≈ x - x³/6 + x⁵/120 - x⁷/5040 ax2 = ax_mod * ax_mod ax3 = ax_mod * ax2 sin_val = ax_mod - ax3 * (0.16667 - ax2 * (0.00833 - ax2 * 0.00
# range reduction: ax_mod = ax - round(ax / 2π) * 2π ax = a_safe * x n = tl.floor(ax * 0.15915494309189535 + 0.5) # round to nearest ax_mod = ax - n * 6.283185307179586 # Horner's method: sin(x) ≈ x - x³/6 + x⁵/120 - x⁷/5040 ax2 = ax_mod * ax_mod ax3 = ax_mod * ax2 sin_val = ax_mod - ax3 * (0.16667 - ax2 * (0.00833 - ax2 * 0.00
this approximation slightly reduces numerical precision compared to the hardware sin() instruction, but in practice the error is small and the Snake activation is tolerant to minor deviations.
this approximation slightly reduces numerical precision compared to the hardware sin() instruction, but in practice the error is small and the Snake activation is tolerant to minor deviations.
this approximation effectively reduces the sine computation to roughly ~6–8 FP operations instead of a tens-of-cycles transcendental instruction.
this approximation effectively reduces the sine computation to roughly ~6–8 FP operations instead of a tens-of-cycles transcendental instruction.
even though the kernel is memory-bound, reducing instruction latency can still help warp scheduling and pipeline progress, which sometimes yields measurable gains.
even though the kernel is memory-bound, reducing instruction latency can still help warp scheduling and pipeline progress, which sometimes yields measurable gains.

the H100 has only 4 SFU units per SM vs 128 FP32 ALU cores. tl.sin() always runs at fp32 speed through the SFU. the polynomial approximation routes through the much wider ALU datapath, which is 2x faster for fp16.
the H100 has only 4 SFU units per SM vs 128 FP32 ALU cores. tl.sin() always runs at fp32 speed through the SFU. the polynomial approximation routes through the much wider ALU datapath, which is 2x faster for fp16.
results fp32:
results fp32:
shape | v3 (tl.sin) | v4c (fast sin) | Δ | Max Error |
|---|---|---|---|---|
| [1,256,15104] | 1833 gb/s | 1917 gb/s | +4.6% | 0.024 |
| [1,128,60416] | 2275 gb/s | 2291 gb/s | +0.7% | 0.034 |
| [1,64,120832] | 2267 gb/s | 2301 gb/s | +1.5% | 0.035 |
| [1,1024,65536] | 2932 gb/s | 2901 gb/s | -1.1% | 0.035 |
for fp32, the fast sin gives a marginal improvement, the SFU sin() is already pipelined well.
for fp32, the fast sin gives a marginal improvement, the SFU sin() is already pipelined well.
results: fp16 (the actual win here):
results: fp16 (the actual win here):
shape | v3 (tl.sin) | v4c (fast sin) | Δ | Max Error |
|---|---|---|---|---|
| [1,256,15104] | 1267 gb/s | 1399 gb/s | +10.4% | 0.023 |
| [1,128,60416] | 1736 gb/s | 1872 gb/s | +7.8% | 0.031 |
| [1,64,120832] | 1755 gb/s | 1888 gb/s | +7.5% | 0.031 |
| [1,1024,65536] | 2684 gb/s | 2834 gb/s | +5.6% | 0.035 |
the polynomial approximation helps fp16 significantly more (+5-10%) because:
the polynomial approximation helps fp16 significantly more (+5-10%) because:
SFU sin() is computed in fp32 even for fp16 inputs, the GPU upconverts, computes, and downconverts. our polynomial stays in the native datatype.
the polynomial is 4 FMA ops (all running on the FP32/FP16 ALUs at full throughput) vs 1 SFU op (only 4 SFUs per SM). at fp16, the ALUs are 2x faster than SFUs.
max error is ~0.035, within acceptable tolerance for audio codec inference, unless your use-case is for something much more precise than this.
SFU sin() is computed in fp32 even for fp16 inputs, the GPU upconverts, computes, and downconverts. our polynomial stays in the native datatype.
the polynomial is 4 FMA ops (all running on the FP32/FP16 ALUs at full throughput) vs 1 SFU op (only 4 SFUs per SM). at fp16, the ALUs are 2x faster than SFUs.
max error is ~0.035, within acceptable tolerance for audio codec inference, unless your use-case is for something much more precise than this.
conclusion
conclusion
in this worklog we incrementally optimized snake1d activation from a sequence of HBM round-trips into a single fused triton kernel that performs all computation in registers.
in this worklog we incrementally optimized snake1d activation from a sequence of HBM round-trips into a single fused triton kernel that performs all computation in registers.
in the end, the kernel reaches ~2.3–2.9 TB/s on H100, roughly 85–90% of peak HBM bandwidth, which is close to the practical ceiling for a purely elementwise memory-bound kernel.
in the end, the kernel reaches ~2.3–2.9 TB/s on H100, roughly 85–90% of peak HBM bandwidth, which is close to the practical ceiling for a purely elementwise memory-bound kernel.
(micro-optimizations i tried that didn't work)
(micro-optimizations i tried that didn't work)
utilizing triton’s L2 cache eviction policies
precomputing 1/alpha to remove the per-element division.
utilizing triton’s L2 cache eviction policies
precomputing 1/alpha to remove the per-element division.
while i still think that more work can be done to squeeze out the remaining perf on H100s for this kernel, but i'll leave that to a future worklog. thank you for reading!
while i still think that more work can be done to squeeze out the remaining perf on H100s for this kernel, but i'll leave that to a future worklog. thank you for reading!
