← Back to projects

CUDA / ML systems

Fused Linear Attention

A three-kernel CUDA study of what happens when QKV projection and attention are fused, split, and benchmarked carefully on H100 GPUs.

This project started from a systems question rather than a benchmark alone: what really changes when QKV projection and attention are forced into a tighter CUDA execution path so intermediate data moves less through HBM? The work became more useful once the profiler complicated the story. Reducing memory traffic helped, but it did not automatically produce a faster kernel.

This was built as a group project for NYU's High Performance Machine Learning course. My contribution centered on the tiling strategy, shared-memory layout, HBM traffic model, correctness oracle, and WMMA benchmarking work used to test the Tensor Core direction directly.

Problem

Standard inference paths split projection and attention into separate stages. Each stage writes intermediate tensors back to HBM and then reads them again. On H100 hardware, that round-trip is worth questioning, but the harder question is what happens after you remove it.

Project shape

This was built as a three-kernel family rather than one monolithic implementation. The fully fused path combined projection and attention in one kernel without writing Q, K, or V back to HBM. The hybrid path left projection on cuBLAS and replaced the attention stage with a custom tiled kernel. The final hybrid_warp4 bf16 path used warp-cooperative execution and became the best-performing custom variant.

My part of the project centered on the tiling strategy, shared-memory layout, the HBM traffic model, a NumPy-based oracle for correctness, and WMMA benchmarking to test the Tensor Core direction directly. The shared-memory arrays used [T][d+1] padding to avoid bank conflicts, and the final fused kernel used a T = 64 tile choice on H100.

Architecture diagram

Diagram comparing the unfused baseline HBM round-trip with the custom tiled attention path and the shared-memory layout used by the fused kernel.

The core change was not just “fusion” in the abstract. It was staging Q, K, V, and output tiles in shared memory with T = 64 and d + 1 padding so the kernel could avoid the repeated QKV round-trip through HBM.

Memory path sketch

Standard path QKV projection write intermediate tensors to HBM run attention as a separate stage
Hybrid path projection stays on cuBLAS Tensor Cores custom tiled attention replaces PyTorch SDPA best custom runtime path in the study
Fused path keep projection and attention closer together cut repeated HBM writes and reads measure what bottleneck remains after fusion

What the results actually showed

The important result was not simply that fusion reduced HBM traffic. It did. The more useful result was seeing where that stopped being enough. The initial fully fused scalar kernel removed the QKV round-trip but still ran much slower than the baseline because Tensor Cores stayed idle. The best custom path ended up being the hybrid warp-cooperative bf16 version, which delegated projection back to cuBLAS and then beat the PyTorch baseline at short sequence lengths.

Numbers

Metric Result Why it mattered
Correctness 11 / 11 tests passed, max absolute error ≤ 1.5 × 10-7 Confirmed that the fused kernel matched the PyTorch reference path closely enough to trust profiling.
Short-sequence wins 1.22x speedup at N = 64 and 1.41x at N = 128 Showed that the custom path could turn the memory idea into real runtime wins in the right regime.
HBM reads 11–55% reduction from fusion, with a peak measured cut of 54.6% Supported the original hypothesis that repeated global-memory movement was a real cost worth chasing.
Hybrid bf16 path About 50% lower HBM reads and 85–91% lower peak GPU allocation Made the best-performing custom path clearly better on memory behavior, not just latency in one case.
Environment H100 SXM5, 132 SMs, 3.35 TB/s HBM bandwidth Made it easier to interpret why the baseline was still far from purely HBM-bound at longer sequences.
Main bottleneck Scalar fp32 projection loops in the fully fused path Explained why lower memory traffic alone did not produce a faster kernel without a Tensor Core path.

Baseline vs custom runtime matrix

Sequence length PyTorch baseline runtime Best custom runtime Runtime result Memory note
64 107.2 µs 88.0 µs 1.22x faster HBM movement was low enough that the hybrid path could turn memory savings into a real win.
128 142.1 µs 100.5 µs 1.41x faster This was the clearest short-sequence crossover point for the custom path.
256 127.2 µs 164.2 µs 0.77x Lower HBM traffic still helped, but the compute path was no longer efficient enough to hold the lead.
512 121.6 µs 293.3 µs 0.41x By this point the scalar arithmetic path was clearly the bottleneck, not memory traffic alone.
1024 200.7 µs 556.4 µs 0.36x The profiler story was honest here: the fused scalar path left Tensor Core throughput on the table.

Profiler summary

NSight-guided summary

Hardware

H100

SXM5, 132 SMs, 3.35 TB/s HBM bandwidth

Memory

54.6%

Peak measured HBM read reduction in the fused path

Compute

fp32

Scalar projection loops dominated the fully fused path and kept Tensor Cores idle

This is a compact summary of the profiler evidence surfaced in the report, not a literal timeline screenshot. The useful outcome was identifying the exact handoff where the memory story stopped being enough and a Tensor Core path became the real next step.

Benchmark crossover

Sequence length PyTorch baseline Best custom kernel Result
64 107.2 µs 88.0 µs 1.22x faster
128 142.1 µs 100.5 µs 1.41x faster
256 127.2 µs 164.2 µs 0.77x
512 121.6 µs 293.3 µs 0.41x
1024 200.7 µs 556.4 µs 0.36x

That crossover mattered. It made the next step obvious: the memory story was partly right, but the attention loop and projection arithmetic still needed a more hardware-efficient path. The WMMA prototype made that concrete and pointed toward Tensor Core integration as the highest-priority follow-up.

The useful lesson from this project was not “fusion is always better.” It was learning exactly when lower memory traffic stops being enough and the compute path becomes the real problem.

Why it matters

This is still the clearest example on my site of the kind of work I want to keep doing: systems work that is measured carefully, honest about failure modes, and useful because it sharpens the next engineering question instead of hiding it.