Accelerating xIELU: An Optimized CUDA Kernel for LLMs

Nathan Ranchin

The xIELU (Expanded Integral of the Exponential Linear Unit) activation function outperforms both Squared ReLU and SwiGLU in LLM training, but its reference torch.compile() implementation is slow. I wrote an optimized CUDA kernel that achieves a 35% speedup over PyTorch in training and an 11% increase in inference token throughput. The code is available at github.com/nathanrchn/kernels/tree/main/xielu.

What is xIELU?

xIELU is a piecewise activation function with trainable parameters that control curvature:

xIELU(x)={αpx2+βxif x>0αn(ex1)αnx+βxif x0\text{xIELU}(x) = \begin{cases} \alpha_p x^2 + \beta x & \text{if } x > 0 \\ \alpha_n (e^x - 1) - \alpha_n x + \beta x & \text{if } x \leq 0 \end{cases}

The parameters αp\alpha_p and αn\alpha_n are trainable per-layer scalars constrained by the softplus function, and β\beta is fixed at 0.5. The positive branch provides a linearly increasing gradient (like Squared ReLU), while the negative branch adds ELU-like saturation. This combination improves perplexity in LLM pretraining compared to Squared ReLU and SwiGLU.

The backward pass is more complex than typical activations because it requires gradients for both xx and the shared scalar parameters αp\alpha_p, αn\alpha_n, which must be accumulated across the entire tensor via a global reduction.

Optimizations

Vectorized 128-bit Memory Access

We load 8 BFloat16 elements per instruction using 128-bit vectorized loads, maximizing HBM3e memory bandwidth utilization:

uint4 x_data = __ldg(reinterpret_cast<const uint4*>(x + offset));
const __nv_bfloat16* x_local = reinterpret_cast<const __nv_bfloat16*>(&x_data);

#pragma unroll
for (int j = 0; j < 8; j++) {
    float xf= __bfloat162float(x_local[j]);
    // ... compute in float32 for numerical stability ...
    y_local[j]= __float2bfloat16_rn(yf);
}

*reinterpret_cast<uint4*>(y + offset) = y_data;

The only constraint is that the total number of elements must be a multiple of 128, ensuring proper alignment. All arithmetic is done in single-precision float for numerical stability, with explicit casting at the register level.

Texture Cache for Parameter Access

The scalar parameters αp\alpha_p and αn\alpha_n are constant across all threads. We load them using the __ldg() intrinsic, routing requests through the read-only texture cache instead of the standard L1 cache. This reduces cache contention when all threads in a warp request the same address.

Warp-Level Reduction

In the backward pass, gradients for αp\alpha_p and αn\alpha_n must be summed across the thread block. Instead of shared memory with __syncthreads() barriers, we use __shfl_down_sync primitives to reduce directly between registers within a warp:

#pragma unroll
for (int i = 16; i > 0; i >>= 1) {
    galpha_p_local += __shfl_down_sync(0xffffffff, galpha_p_local, i);
    galpha_n_local += __shfl_down_sync(0xffffffff, galpha_n_local, i);
}

if (threadIdx.x % 32 == 0) {
    atomicAdd(galpha_p, galpha_p_local);
    atomicAdd(galpha_n, galpha_n_local);
}

This eliminates shared memory allocation and avoids the latency of block-wide synchronization barriers.

Fused Multiply-Add

We use __fmaf_rn (Fused Multiply-Add) intrinsics to combine multiplication and addition into a single instruction cycle, doubling floating-point throughput for these operations and improving numerical accuracy with a single rounding step:

if (xf > 0.0f) {
    yf = xf * __fmaf_rn(s_a_p, xf, beta);
} else {
    float e = __expf(fminf(xf, eps)) - 1.0f;
    yf = __fmaf_rn(alpha_n_val, e, neg_s_a_n * xf);
}

Bug Fix: 64-bit Indexing

We identified a correctness issue in the previous CUDA baseline: it casts tensor.numel() to a 32-bit integer for index calculations. This causes integer overflow when processing tensors with more than 2322^{32} elements (~4 billion parameters), making it unsafe for large-scale models. Our implementation uses int64_t throughout.

Results

Experiments were conducted on NVIDIA GH200 Grace Hopper Superchips. Microbenchmarks used synthetic tensors of shape (64,1024,21504)(64, 1024, 21504), representing MLP activations for an 8B parameter model.

Microbenchmarks

Microbenchmark speedup comparison

Each optimization layer provides progressive improvements. The gap between the PyTorch baseline and our kernel widens as element count increases, showing that the vectorization strategy amortizes overheads and saturates the memory bus more effectively than compiler-generated code. By keeping register usage under 32 per thread, we achieve 96% theoretical GPU occupancy.

Training

We integrated the kernel into the post-training pipeline of an 8B parameter Llama 2 model, training for 1.5 hours on a multi-node cluster (16 nodes, 64 GPUs) using Fully Sharded Data Parallel (FSDP).

Training metrics comparison
ImplementationMean Time/StepSteps in 1.5h
PyTorch Compiled9.39s559
Baseline CUDA (Scalar)6.16s854
Ours6.05s869

Our kernel achieves 6.05s per step, a 35% improvement over PyTorch and 1% over the baseline. The modest gain over the baseline CUDA kernel is consistent with Amdahl's Law: the activation function consumes less than 2% of total runtime in a full transformer block, and FSDP communication overhead further masks element-wise kernel improvements. Importantly, the training loss curves overlap exactly, confirming numerical correctness (unlike the previous vectorized baseline, which diverged).

Inference Throughput

In latency-sensitive inference with 1024 input/output tokens:

ImplementationRequest Throughput (req/s)Token Throughput (tok/s)
Baseline CUDA17.0117,390
Ours18.8719,291
Improvement10.9%11.1%

Inference, especially decoding, is heavily memory-bandwidth bound. Our vectorized memory loads provide a more substantial benefit here, directly translating to reduced latency and higher serving capacity.

Comparison with Standard Activations

Activation function comparison

Our optimized xIELU matches Squared ReLU performance and slightly outperforms native SiLU, despite involving a conditional branch and exponential calculation. The optimizations effectively hide the additional arithmetic latency.

Takeaways

The dominant speedup (35% over PyTorch) comes from removing Python overhead and optimizing memory access patterns that torch.compile could not resolve. Even with modern compilers, low-level manual optimization remains valuable for maximizing hardware utilization in large-scale LLM training and serving.