optimizing snake1d activation kernel in triton

optimizing snake1d activation kernel in triton

published

published

mar 13, 2026

mar 13, 2026

i've been working on optimizations of various components of SNAC (a multi scale neural audio codec) in triton. my eventual plan was to publish a worklog about all of the optimizations but i realized that it would make it very long and hence wanted to make the scope much more focused.

i've been working on optimizations of various components of SNAC (a multi scale neural audio codec) in triton. my eventual plan was to publish a worklog about all of the optimizations but i realized that it would make it very long and hence wanted to make the scope much more focused.

so here we'll iteratively optimize snake1d activation kernel on nvidia H100 SXM5 from torch eager to hand-tuned triton, one step at a time.

so here we'll iteratively optimize snake1d activation kernel on nvidia H100 SXM5 from torch eager to hand-tuned triton, one step at a time.

but before that, some background :

but before that, some background :

SNAC (multi-Scale neural audio codec) compresses audio into hierarchical discrete tokens. its decoder reconstructs waveforms from these tokens, and a critical activation function called snake1d is called 29 times per decode pass.

SNAC (multi-Scale neural audio codec) compresses audio into hierarchical discrete tokens. its decoder reconstructs waveforms from these tokens, and a critical activation function called snake1d is called 29 times per decode pass.

SNAC compresses audio into hierarchical discrete tokens using RVQ (right)
coming back to the point,

snake1d is a learned periodic activation:

SNAC compresses audio into hierarchical discrete tokens using RVQ (right)
coming back to the point,

snake1d is a learned periodic activation:

def snake(x, alpha):
    return x + (1 / alpha) * sin(alpha * x)^2
def snake(x, alpha):
    return x + (1 / alpha) * sin(alpha * x)^2

where x has shape [B, C, T] and alpha has shape [1, C, 1] (one learnable parameter per channel, broadcasted across batch and time).

where x has shape [B, C, T] and alpha has shape [1, C, 1] (one learnable parameter per channel, broadcasted across batch and time).

this is a pure elementwise operation, no reductions, no cross-element dependencies, mathematically written as:

this is a pure elementwise operation, no reductions, no cross-element dependencies, mathematically written as:

y = x + \frac{1}{\alpha}\sin^2(\alpha x)

for each element we roughly perform 3 multiplications, 1 addition, 1 division and 1 sine evaluation. using common roofline approximations (add/mul ≈ 1 FLOP, div ≈ 4 FLOPs, sin ≈ 10 FLOPs), this gives roughly ~18 FLOPs per element.

for each element we roughly perform 3 multiplications, 1 addition, 1 division and 1 sine evaluation. using common roofline approximations (add/mul ≈ 1 FLOP, div ≈ 4 FLOPs, sin ≈ 10 FLOPs), this gives roughly ~18 FLOPs per element.

memory traffic per element is:

memory traffic per element is:

  • read x (4 bytes)

  • read alpha (4 bytes)

  • write y (4 bytes)

  • read x (4 bytes)

  • read alpha (4 bytes)

  • write y (4 bytes)

so total ≈ 12 bytes moved per element.

this yields an arithmetic intensity of: 18​/12 ≈ 1.5 FLOPs/byte

this is well below the H100's ~590 FLOPs/byte (fp16) balance point. that means the optimization target is maximizing memory bandwidth utilization (H100 HBM3 peak: 3,350 GB/s).

so total ≈ 12 bytes moved per element.

this yields an arithmetic intensity of: 18​/12 ≈ 1.5 FLOPs/byte

this is well below the H100's ~590 FLOPs/byte (fp16) balance point. that means the optimization target is maximizing memory bandwidth utilization (H100 HBM3 peak: 3,350 GB/s).

v1: naive 1D triton kernel

v1: naive 1D triton kernel

the simplest possible triton kernel is : treat the entire [B, C, T] tensor as a flat 1D array. each triton program processes BLOCK_SIZE consecutive elements. We compute the channel index from the flat offset to look up the correct alpha[c].

the simplest possible triton kernel is : treat the entire [B, C, T] tensor as a flat 1D array. each triton program processes BLOCK_SIZE consecutive elements. We compute the channel index from the flat offset to look up the correct alpha[c].

@triton.jit
def snake_v1_kernel(x_ptr, alpha_ptr, out_ptr, B, C, T,
                    stride_b, stride_c, stride_t,
                    BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    n_elements = B * C * T
    mask = offsets < n_elements

    # compute channel index: offset = b*C*T + c*T + t c = (offset // T) % C
    c_idx = (offsets // T) % C

    x = tl.load(x_ptr + offsets, mask=mask)
    a = tl.load(alpha_ptr + c_idx, mask=mask)

    a_safe = a + 1e-9
    sin_val = tl.sin(a_safe * x)
    result = x + (1.0 / a_safe) * sin_val * sin_val  # pow(2) mul

    tl.store(out_ptr + offsets, result, mask=mask)
@triton.jit
def snake_v1_kernel(x_ptr, alpha_ptr, out_ptr, B, C, T,
                    stride_b, stride_c, stride_t,
                    BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    n_elements = B * C * T
    mask = offsets < n_elements

    # compute channel index: offset = b*C*T + c*T + t c = (offset // T) % C
    c_idx = (offsets // T) % C

    x = tl.load(x_ptr + offsets, mask=mask)
    a = tl.load(alpha_ptr + c_idx, mask=mask)

    a_safe = a + 1e-9
    sin_val = tl.sin(a_safe * x)
    result = x + (1.0 / a_safe) * sin_val * sin_val  # pow(2) mul

    tl.store(out_ptr + offsets, result, mask=mask)
@triton.jit
def snake_v1_kernel(x_ptr, alpha_ptr, out_ptr, B, C, T,
                    stride_b, stride_c, stride_t,
                    BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    n_elements = B * C * T
    mask = offsets < n_elements

    # compute channel index: offset = b*C*T + c*T + t c = (offset // T) % C
    c_idx = (offsets // T) % C

    x = tl.load(x_ptr + offsets, mask=mask)
    a = tl.load(alpha_ptr + c_idx, mask=mask)

    a_safe = a + 1e-9
    sin_val = tl.sin(a_safe * x)
    result = x + (1.0 / a_safe) * sin_val * sin_val  # pow(2) mul

    tl.store(out_ptr + offsets, result, mask=mask)

to reduce unnecessary overhead, sin_val * sin_val is used instead of pow(2) to avoid the expensive power intrinsic. each element therefore loads x and its channel’s alpha, applies the Snake activation, and writes the result back. a default BLOCK_SIZE=1024 is used for simplicity, with an autotuned variant exploring block sizes from 256 to 8192 to find the best configuration for memory throughput.

to reduce unnecessary overhead, sin_val * sin_val is used instead of pow(2) to avoid the expensive power intrinsic. each element therefore loads x and its channel’s alpha, applies the Snake activation, and writes the result back. a default BLOCK_SIZE=1024 is used for simplicity, with an autotuned variant exploring block sizes from 256 to 8192 to find the best configuration for memory throughput.

results fp32 (realistic shapes)

results fp32 (realistic shapes)

shape
eager (ms)
eager bw
triton v1
triton bw
speedup
[1, 1024, 236]0.0066294 gb/sec0.0070278 GB/s0.94x
[1, 512. 1888]0.0102761 gb/sec0.0087887 GB/s 1.17x
[1, 256, 15104]0.02331329 gb/sec0.01731789 GB/s1.35x
[1, 128, 60416]0.04041533 gb/sec0.02742259 GB/s 1.47x
[1, 64, 120832]0.04031534 gb/sec0.02762245 GB/s 1.46x

the naive triton kernel is 1.35-1.47x faster than the JIT-scripted eager on the shapes that matter most (the large ones called 7x each per decode).

the naive triton kernel is 1.35-1.47x faster than the JIT-scripted eager on the shapes that matter most (the large ones called 7x each per decode).

eager torch writes intermediate results to HBM between each operation. the triton kernel reads once, computes everything in registers, and writes once to HBM.

eager torch writes intermediate results to HBM between each operation. the triton kernel reads once, computes everything in registers, and writes once to HBM.

scaling tensor sizes (synthetic Shapes, fp32)

scaling tensor sizes (synthetic Shapes, fp32)

the current kernel's advantage grows with tensor size:

the current kernel's advantage grows with tensor size:

tensor elements
eager bw
triton v1 bw
% of h100 peak
16k311 gb/s308 gb/s9%
1m1329 gb/s1860 gb/s55%
8m1675 gb/s2595 gb/s77%
33m1762 gb/s2816 gb/s84%
67m1798 gb/s2940 gb/s88%

at 67M elements, we're hitting 88% of H100 peak memory bandwidth with a naive kernel!

at 67M elements, we're hitting 88% of H100 peak memory bandwidth with a naive kernel!

half-precision (fp16/bf16)

half-precision (fp16/bf16)

the speedup pattern holds for half-precision, with triton roughly 2x faster than eager:

the speedup pattern holds for half-precision, with triton roughly 2x faster than eager:

shape
fp16 eager
fp16 triton v1
fp16 speedup
[1, 128, 60416]712 gb/s1449 gb/s2.0x
[1, 64, 120832]717 gb/s1444 gb/s2.0x

fp16 bandwidths are lower than fp32 in absolute terms because the tensors are half the size (fewer bytes to transfer), so kernel launch overhead is relatively larger.

fp16 bandwidths are lower than fp32 in absolute terms because the tensors are half the size (fewer bytes to transfer), so kernel launch overhead is relatively larger.

autotuning

autotuning

the autotuned variant (sweeping BLOCK_SIZE from 256 to 8192) performs nearly identically to the fixed BLOCK_SIZE=1024 version. at these tensor sizes, the kernel is so memory-bound that block size barely matters, the memory controller is the bottleneck, not the compute scheduler.

the autotuned variant (sweeping BLOCK_SIZE from 256 to 8192) performs nearly identically to the fixed BLOCK_SIZE=1024 version. at these tensor sizes, the kernel is so memory-bound that block size barely matters, the memory controller is the bottleneck, not the compute scheduler.

v2: 2D grid (channel-aware scheduling)

v2: 2D grid (channel-aware scheduling)

in v1 kernel, the channel index is recovered from a flattened offset using:

in v1 kernel, the channel index is recovered from a flattened offset using:

c = (offset // T) % C

this requires an integer division and modulo per element, which are relatively expensive gpu operations (often ~10–20 cycles compared to ~1 cycle for add/mul). while the kernel is still memory-bound, this unnecessary integer arithmetic adds extra latency on the critical path.

this requires an integer division and modulo per element, which are relatively expensive gpu operations (often ~10–20 cycles compared to ~1 cycle for add/mul). while the kernel is still memory-bound, this unnecessary integer arithmetic adds extra latency on the critical path.

the solution is to change how the kernel is scheduled, rather than how the math is computed.

the solution is to change how the kernel is scheduled, rather than how the math is computed.

instead of flattening [B, C, T] into a single dimension, we launch a 2D grid:

instead of flattening [B, C, T] into a single dimension, we launch a 2D grid:

  • program_id(0) → indexes blocks along time

  • program_id(1) → indexes batch × channel

  • program_id(0) → indexes blocks along time

  • program_id(1) → indexes batch × channel

this means each program is responsible for processing a small time slice of a single (b, c) pair. because the channel is already known from the launch index, the kernel no longer needs to compute it via division.

this means each program is responsible for processing a small time slice of a single (b, c) pair. because the channel is already known from the launch index, the kernel no longer needs to compute it via division.

two useful things happen as a result:

two useful things happen as a result:

  • the expensive offset // T computation disappears entirely.

  • alpha[c] becomes a scalar load reused across the entire program, instead of being gathered per element.

  • the expensive offset // T computation disappears entirely.

  • alpha[c] becomes a scalar load reused across the entire program, instead of being gathered per element.

@triton.jit
def snake_v2_kernel(x_ptr, alpha_ptr, out_ptr, B, C, T,
                    stride_b, stride_c, stride_t,
                    BLOCK_T: tl.constexpr):
    pid_t = tl.program_id(0)   # time block index
    pid_bc = tl.program_id(1)  # batch*channel index

    b = pid_bc // C
    c = pid_bc % C

    t_offsets = pid_t * BLOCK_T + tl.arange(0, BLOCK_T)
    mask = t_offsets < T
    ptrs = b * stride_b + c * stride_c + t_offsets * stride_t

    x = tl.load(x_ptr + ptrs, mask=mask)
    a = tl.load(alpha_ptr + c)  # scalar load same alpha for entire program

    a_safe = a + 1e-9
    sin_val = tl.sin(a_safe * x)
    result = x + (1.0 / a_safe) * sin_val * sin_val

    tl.store(out_ptr + ptrs, result, mask=mask)

@triton.jit
def snake_v2_kernel(x_ptr, alpha_ptr, out_ptr, B, C, T,
                    stride_b, stride_c, stride_t,
                    BLOCK_T: tl.constexpr):
    pid_t = tl.program_id(0)   # time block index
    pid_bc = tl.program_id(1)  # batch*channel index

    b = pid_bc // C
    c = pid_bc % C

    t_offsets = pid_t * BLOCK_T + tl.arange(0, BLOCK_T)
    mask = t_offsets < T
    ptrs = b * stride_b + c * stride_c + t_offsets * stride_t

    x = tl.load(x_ptr + ptrs, mask=mask)
    a = tl.load(alpha_ptr + c)  # scalar load same alpha for entire program

    a_safe = a + 1e-9
    sin_val = tl.sin(a_safe * x)
    result = x + (1.0 / a_safe) * sin_val * sin_val

    tl.store(out_ptr + ptrs, result, mask=mask)

@triton.jit
def snake_v2_kernel(x_ptr, alpha_ptr, out_ptr, B, C, T,
                    stride_b, stride_c, stride_t,
                    BLOCK_T: tl.constexpr):
    pid_t = tl.program_id(0)   # time block index
    pid_bc = tl.program_id(1)  # batch*channel index

    b = pid_bc // C
    c = pid_bc % C

    t_offsets = pid_t * BLOCK_T + tl.arange(0, BLOCK_T)
    mask = t_offsets < T
    ptrs = b * stride_b + c * stride_c + t_offsets * stride_t

    x = tl.load(x_ptr + ptrs, mask=mask)
    a = tl.load(alpha_ptr + c)  # scalar load same alpha for entire program

    a_safe = a + 1e-9
    sin_val = tl.sin(a_safe * x)
    result = x + (1.0 / a_safe) * sin_val * sin_val

    tl.store(out_ptr + ptrs, result, mask=mask)

in practice, this improves efficiency in two ways: 1. it removes unnecessary integer arithmetic and 2. this slightly reduces memory pressure by allowing the same alpha value to be reused for all elements processed by the program.

the computation itself remains unchanged, the optimization purely comes from better mapping the problem structure to the gpu launch grid.

in practice, this improves efficiency in two ways: 1. it removes unnecessary integer arithmetic and 2. this slightly reduces memory pressure by allowing the same alpha value to be reused for all elements processed by the program.

the computation itself remains unchanged, the optimization purely comes from better mapping the problem structure to the gpu launch grid.

v1 maps a flat 1D grid over B×C×T, each program must compute its channel index via integer division. v2 uses a 2D grid where program_id(1) directly indexes the channel, eliminating per-element division.

v1 maps a flat 1D grid over B×C×T, each program must compute its channel index via integer division. v2 uses a 2D grid where program_id(1) directly indexes the channel, eliminating per-element division.

results fp32 (on realistic shapes)

results fp32 (on realistic shapes)

shape
eager
v1 (1d)
v2 (2d)
v2 vs v1
[1, 1024, 236]317 gb/s289 gb/s273 gb/s-5%
[1, 512. 1888]790 gb/s922 gb/s939 gb/s+2%
[1, 256, 15104]1348 gb/s1841 gb/s1834 gb/s~0%
[1, 128, 60416]1548 gb/s2251 gb/s2283 gb/s+1%
[1, 64, 120832]1542 gb/s2289 gb/s 2288 gb/s~0%

for fp32, v2 is essentially the same as v1 on realistic shapes. the integer division overhead is negligible when each element is 4 bytes, the memory bandwidth is the bottleneck, not compute.

for fp32, v2 is essentially the same as v1 on realistic shapes. the integer division overhead is negligible when each element is 4 bytes, the memory bandwidth is the bottleneck, not compute.

results on fp16 (realistic shapes)

results on fp16 (realistic shapes)

shape
eager
v1 (1d)
v2 (2d)
v2 vs v1
[1, 1024, 236]150 gb/s150 gb/s139 gb/s-7%
[1, 512. 1888]373 gb/s495 gb/s520 gb/s+5%
[1, 256, 15104]646 gb/s1137 gb/s1276 gb/s+12%
[1, 128, 60416]725 gb/s1484 gb/s1720 gb/s+16%
[1, 64, 120832]722 gb/s1465 gb/s1731 gb/s +18%

for fp16, v2 is significantly faster, almost up to 18% improvement on the largest shapes. with half-precision, the data is half the size but the integer division cost stays the same, so eliminating it has proportionally more impact.

for fp16, v2 is significantly faster, almost up to 18% improvement on the largest shapes. with half-precision, the data is half the size but the integer division cost stays the same, so eliminating it has proportionally more impact.

scaling tensor shapes (synthetic tensor shapes fp32)

scaling tensor shapes (synthetic tensor shapes fp32)

elements
v1 bw
v2 bw
v2 peak %
16k319 gb/s 310 gb/s9%
1m1900 gb/s1926 gb/s57%
8m2615 gb/s2625 gb/s 78%
67m2923 gb/s2920 gb/s87%
elements
v1 bw
v2 bw
v2 peak %
16k319 gb/s 310 gb/s9%
1m1900 gb/s1926 gb/s57%
8m2615 gb/s2625 gb/s 78%
67m2923 gb/s2920 gb/s87%

at fp32, the two kernels converge at large sizes: both hit ~87% of peak. the gap is in fp16 where v2 has a clear edge.

at fp32, the two kernels converge at large sizes: both hit ~87% of peak. the gap is in fp16 where v2 has a clear edge.

v3: triton autotune (finding the optimal launch configuration)

v3: triton autotune (finding the optimal launch configuration)

int v2 we established the 2D grid structure. but we used a fixed BLOCK_T=1024 with default num_warps=4 and num_stages=2. these parameters control:

in v2 we established the 2D grid structure. but we used a fixed BLOCK_T=1024 with default num_warps=4 and num_stages=2. these parameters control:

  • BLOCK_T: elements per program along the time axis (affects parallelism vs. overhead)

  • num_warps: warps per SM (affects occupancy and register pressure)

  • num_stages: software pipeline depth (prefetches loads while computing; higher = more registers)

  • BLOCK_T: elements per program along the time axis (affects parallelism vs. overhead)

  • num_warps: warps per SM (affects occupancy and register pressure)

  • num_stages: software pipeline depth (prefetches loads while computing; higher = more registers)

for a memory-bound kernel, the conventional wisdom is "fewer warps, fewer stages" because:

for a memory-bound kernel, the conventional wisdom is "fewer warps, fewer stages" because:

  • more warps → more register pressure → lower occupancy → fewer active threads

  • more stages → more registers for pipeline buffers → same problem

  • the gpu memory controller is already saturated i.e. additional compute parallelism doesn't help!

  • more warps → more register pressure → lower occupancy → fewer active threads

  • more stages → more registers for pipeline buffers → same problem

  • the gpu memory controller is already saturated i.e. additional compute parallelism doesn't help!

to explore this space, i did a systematic autotuning sweep across:

to explore this space, i did a systematic autotuning sweep across:

  • 6 block sizes

  • 6 block sizes

  • 4 pipeline depths

  • 6 block sizes

  • 6 block sizes

  • 4 pipeline depths

for a total of 120 launch configurations per tensor shape

for a total of 120 launch configurations per tensor shape

autotune results (selected configurations)

autotune results (selected configurations)

the autotuner reveals a clear pattern:

the autotuner reveals a clear pattern:

shape
elements
best config
why
[1, 1024, 236]242kblock=256, warps=2, stages=3tiny T, small blocks, few warps
[1, 512. 1888]967kblock=2048, warps=16, stages=4medium T, big blocks, max warps
[1, 256, 15104]3.9mblock=1024, warps=2, stages=4large T, moderate blocks, few warps
[1, 128, 60416]7.7mblock=256, warps=2, stages=4large T, few C, small blocks, few warps
[1, 64, 120832]7.7mblock=512, warps=2, stages=4largest T, small-ish blocks, few warps

across the larger tensors (where the kernel clearly operates in a memory-bound regime), the autotuner consistently prefers:

across the larger tensors (where the kernel clearly operates in a memory-bound regime), the autotuner consistently prefers:

  • low warp counts (≈2 warps)

  • deeper pipelines (≈4 stages)

  • low warp counts (≈2 warps)

  • deeper pipelines (≈4 stages)

a lower warp count reduces register pressure per SM, which helps maintain higher occupancy. meanwhile, deeper software pipelines allow triton to issue memory loads earlier and overlap them with computation, improving the GPU’s ability to hide HBM latency.

a lower warp count reduces register pressure per SM, which helps maintain higher occupancy. meanwhile, deeper software pipelines allow triton to issue memory loads earlier and overlap them with computation, improving the GPU’s ability to hide HBM latency.

because the snake1d kernel performs very little computation per byte of data moved, performance is largely determined by how efficiently memory requests are issued and overlapped, rather than by raw compute parallelism.

because the snake1d kernel performs very little computation per byte of data moved, performance is largely determined by how efficiently memory requests are issued and overlapped, rather than by raw compute parallelism.

with num_stages=1, the SM stalls waiting for each memory load. with num_stages=4, loads for future iterations are prefetched while the current iteration computes, keeping the memory controller saturated. fewer warps (2 vs 8) leaves more registers per warp for pipeline buffers.

with num_stages=1, the SM stalls waiting for each memory load. with num_stages=4, loads for future iterations are prefetched while the current iteration computes, keeping the memory controller saturated. fewer warps (2 vs 8) leaves more registers per warp for pipeline buffers.

the exception is [1,512,1888] where 16 warps are chosen, this shape has many channels (512) but short time (1888), so there are 512 programs each handling only 1-2 blocks. more warps help fill the SM when there's less work per program.

the exception is [1,512,1888] where 16 warps are chosen, this shape has many channels (512) but short time (1888), so there are 512 programs each handling only 1-2 blocks. more warps help fill the SM when there's less work per program.

autotune results (selected configurations)

autotune results (selected configurations)

shape
eager
v2 (fixed)
v3 (autotuned)
v3 vs v2
[1, 1024, 236]313 gb/s265 gb/s302 gb/s+14%
[1, 512. 1888]797 gb/s925 gb/s928 gb/s~0%
[1, 256, 15104]1342 gb/s1848 gb/s1834 gb/s~0%
[1, 128, 60416]1539 gb/s2284 gb/s2284 gb/s~0%
[1, 64, 120832]1540 gb/s2285 gb/s 2297 gb/s+0.5%

for fp32, v3 barely moves the needle on large shapes, the kernel was already near-optimal with the default config. the biggest win is on the smallest shape where the autotuner finds better launch parameters.

for fp32, v3 barely moves the needle on large shapes, the kernel was already near-optimal with the default config. the biggest win is on the smallest shape where the autotuner finds better launch parameters.

the verdict here is that autotuning provides marginal gains (~3% on fp16) and confirms our kernel is already well-configured. the 2D grid structure with 2 warps and 4 pipeline stages is the sweet spot for this memory-bound workload.

the verdict here is that autotuning provides marginal gains (~3% on fp16) and confirms our kernel is already well-configured. the 2D grid structure with 2 warps and 4 pipeline stages is the sweet spot for this memory-bound workload.

v4: fast sin approximation (inlined polynomial)

v4: fast sin approximation (inlined polynomial)

snake activation relies on the sin() function, which maps to a transcendental instruction on the gpu. these instructions are relatively expensive and have significantly higher latency than basic arithmetic operations.

snake activation relies on the sin() function, which maps to a transcendental instruction on the gpu. these instructions are relatively expensive and have significantly higher latency than basic arithmetic operations.

instead of calling tl.sin(), we inline a 7th-order polynomial approximation of sine. this replaces the transcendental operation with a sequence of fused multiply-add–friendly arithmetic instructions.

instead of calling tl.sin(), we inline a 7th-order polynomial approximation of sine. this replaces the transcendental operation with a sequence of fused multiply-add–friendly arithmetic instructions.

the approximation uses two steps:

the approximation uses two steps:

  • range reduction to map the input into the interval [−𝜋,𝜋]

  • Horner polynomial evaluation of the Taylor series

  • range reduction to map the input into the interval [−𝜋,𝜋]

  • Horner polynomial evaluation of the Taylor series

inlined in the kernel this becomes:

inlined in the kernel this becomes:

# range reduction: ax_mod = ax - round(ax / 2π) * 2π
ax = a_safe * x
n = tl.floor(ax * 0.15915494309189535 + 0.5)   # round to nearest
ax_mod = ax - n * 6.283185307179586

# Horner's method: sin(x) ≈ x - x³/6 + x⁵/120 - x⁷/5040
ax2 = ax_mod * ax_mod
ax3 = ax_mod * ax2
sin_val = ax_mod - ax3 * (0.16667 - ax2 * (0.00833 - ax2 * 0.00
# range reduction: ax_mod = ax - round(ax / 2π) * 2π
ax = a_safe * x
n = tl.floor(ax * 0.15915494309189535 + 0.5)   # round to nearest
ax_mod = ax - n * 6.283185307179586

# Horner's method: sin(x) ≈ x - x³/6 + x⁵/120 - x⁷/5040
ax2 = ax_mod * ax_mod
ax3 = ax_mod * ax2
sin_val = ax_mod - ax3 * (0.16667 - ax2 * (0.00833 - ax2 * 0.00

this approximation slightly reduces numerical precision compared to the hardware sin() instruction, but in practice the error is small and the Snake activation is tolerant to minor deviations.

this approximation slightly reduces numerical precision compared to the hardware sin() instruction, but in practice the error is small and the Snake activation is tolerant to minor deviations.

this approximation effectively reduces the sine computation to roughly ~6–8 FP operations instead of a tens-of-cycles transcendental instruction.

this approximation effectively reduces the sine computation to roughly ~6–8 FP operations instead of a tens-of-cycles transcendental instruction.

even though the kernel is memory-bound, reducing instruction latency can still help warp scheduling and pipeline progress, which sometimes yields measurable gains.

even though the kernel is memory-bound, reducing instruction latency can still help warp scheduling and pipeline progress, which sometimes yields measurable gains.

the H100 has only 4 SFU units per SM vs 128 FP32 ALU cores. tl.sin() always runs at fp32 speed through the SFU. the polynomial approximation routes through the much wider ALU datapath, which is 2x faster for fp16.

the H100 has only 4 SFU units per SM vs 128 FP32 ALU cores. tl.sin() always runs at fp32 speed through the SFU. the polynomial approximation routes through the much wider ALU datapath, which is 2x faster for fp16.

results fp32:

results fp32:

shape
v3 (tl.sin)
v4c (fast sin)
Δ
Max Error
[1,256,15104]1833 gb/s1917 gb/s+4.6%0.024
[1,128,60416]2275 gb/s2291 gb/s+0.7% 0.034
[1,64,120832]2267 gb/s2301 gb/s+1.5%0.035
[1,1024,65536]2932 gb/s2901 gb/s-1.1%0.035

for fp32, the fast sin gives a marginal improvement, the SFU sin() is already pipelined well.

for fp32, the fast sin gives a marginal improvement, the SFU sin() is already pipelined well.

results: fp16 (the actual win here):

results: fp16 (the actual win here):

shape
v3 (tl.sin)
v4c (fast sin)
Δ
Max Error
[1,256,15104]1267 gb/s1399 gb/s+10.4%0.023
[1,128,60416]1736 gb/s1872 gb/s+7.8%0.031
[1,64,120832]1755 gb/s1888 gb/s+7.5%0.031
[1,1024,65536]2684 gb/s2834 gb/s+5.6%0.035

the polynomial approximation helps fp16 significantly more (+5-10%) because:

the polynomial approximation helps fp16 significantly more (+5-10%) because:

  • SFU sin() is computed in fp32 even for fp16 inputs, the GPU upconverts, computes, and downconverts. our polynomial stays in the native datatype.

  • the polynomial is 4 FMA ops (all running on the FP32/FP16 ALUs at full throughput) vs 1 SFU op (only 4 SFUs per SM). at fp16, the ALUs are 2x faster than SFUs.

  • max error is ~0.035, within acceptable tolerance for audio codec inference, unless your use-case is for something much more precise than this.

  • SFU sin() is computed in fp32 even for fp16 inputs, the GPU upconverts, computes, and downconverts. our polynomial stays in the native datatype.

  • the polynomial is 4 FMA ops (all running on the FP32/FP16 ALUs at full throughput) vs 1 SFU op (only 4 SFUs per SM). at fp16, the ALUs are 2x faster than SFUs.

  • max error is ~0.035, within acceptable tolerance for audio codec inference, unless your use-case is for something much more precise than this.

conclusion

conclusion

in this worklog we incrementally optimized snake1d activation from a sequence of HBM round-trips into a single fused triton kernel that performs all computation in registers.

in this worklog we incrementally optimized snake1d activation from a sequence of HBM round-trips into a single fused triton kernel that performs all computation in registers.

in the end, the kernel reaches ~2.3–2.9 TB/s on H100, roughly 85–90% of peak HBM bandwidth, which is close to the practical ceiling for a purely elementwise memory-bound kernel.

in the end, the kernel reaches ~2.3–2.9 TB/s on H100, roughly 85–90% of peak HBM bandwidth, which is close to the practical ceiling for a purely elementwise memory-bound kernel.

(micro-optimizations i tried that didn't work)

(micro-optimizations i tried that didn't work)

  • utilizing triton’s L2 cache eviction policies

  • precomputing 1/alpha to remove the per-element division.

  • utilizing triton’s L2 cache eviction policies

  • precomputing 1/alpha to remove the per-element division.

while i still think that more work can be done to squeeze out the remaining perf on H100s for this kernel, but i'll leave that to a future worklog. thank you for reading!

while i still think that more work can be done to squeeze out the remaining perf on H100s for this kernel, but i'll leave that to a future worklog. thank you for reading!