WisPaper
WisPaper
Scholar Search
Scholar QA
Pricing
TrueCite
[arXiv 2025] FlashAttention-4: Conquering the Asymmetric Scaling of Blackwell GPUs
Summary
Problem
Method
Results
Takeaways
Abstract

FlashAttention-4 is a next-generation attention algorithm specifically co-designed for the NVIDIA Blackwell architecture. It introduces techniques like asynchronous pipeline redesign, software-emulated exponentials, and 2-CTA MMA modes to achieve up to 1613 TFLOPs/s (71% utilization) and a 1.3x speedup over cuDNN 9.13.

TL;DR

FlashAttention-4 is not just an incremental update; it is a fundamental redesign of the attention mechanism for the NVIDIA Blackwell (B200/GB200) era. As Tensor Core throughput doubles while memory bandwidth lags, FlashAttention-4 introduces kernel-algorithm co-design to mitigate non-matmul bottlenecks. It achieves a staggering 1613 TFLOPs/s (71% utilization) and is implemented in a new Python-based DSL that compiles 30x faster than the old C++ templates.

The "Blackwell Bottleneck": Why FlashAttention-3 Wasn't Enough

In the jump from Hopper (H100) to Blackwell (B200), NVIDIA doubled the peak BF16 throughput to 2.25 PFLOPS. However, the Multi-Function Unit (MUFU)—responsible for the exp operations in softmax—and Shared Memory (SMEM) bandwidth did not scale at the same rate.

The authors' roofline analysis reveals a grim reality for standard kernels: for typical tile sizes on Blackwell, shared memory traffic and exponential operations now exceed MMA compute time by up to 60%. The GPU is essentially waiting for the "math" to finish while it struggles to move data and calculate exponents.

Methodology: High-Precision Engineering for Asymmetric Hardware

1. Software-Emulated Exponentials

To solve the bottleneck of the limited MUFU units, FlashAttention-4 implements software-emulated 2^x using standard Floating-Point (FMA) units.

  • The Insight: Use FMA units (which are abundant) to compute a polynomial approximation of the fractional part of an exponent while using integer ALU shifts for the integer part.
  • The Result: By offloading 10-25% of the exponential work to FMA units, they effectively increase the total exponential throughput of the SM.

Accuracy Table Table: Polynomial approximation is indistinguishable from hardware MUFU once rounded to BF16.

2. The 2-CTA Backward Pass

The backward pass is notoriously memory-bound. Blackwell introduces a 2-CTA MMA mode where a pair of CTAs (thread blocks) work together.

  • Shared Memory Savings: Each CTA stages only half of operand B, reducing redundant traffic.
  • Atomic Reduction: By partitioning the output across the CTA pair, FlashAttention-4 halves the number of global atomic adds required for the dQ gradient, which is a major source of non-determinism and latency.

2-CTA Architecture Figure: The 2-CTA dQ step partitioning the workload to minimize SMEM traffic.

3. Pipeline Pipelining: Tensor Memory (TMEM)

Unlike Hopper, where accumulators lived in registers, Blackwell introduces Tensor Memory (TMEM), a 256KB on-chip buffer. FlashAttention-4 uses a "ping-pong" schedule in TMEM to overlap the next tile's matrix multiplication with the current tile's softmax computation, ensuring the Tensor Cores are never idle.

Performance: Breaking Records

FlashAttention-4 doesn't just beat the baseline; it dominates.

  • Speedup: 1.3x faster than cuDNN 9.13 and 2.7x faster than Triton on B200.
  • Efficiency: Reaches 71% of the B200's theoretical peak performance.
  • Compilation: By moving from C++ templates to CuTe-DSL in Python, the team reduced compile headers and logic from minutes to seconds (32x faster for backward pass).

Forward Pass TFLOPS Figure: Performance comparison across different sequence lengths on B200.

Critical Insight & Future Outlook

The most profound contribution of FlashAttention-4 is the move toward Python-embedded DSLs for high-performance kernels. For years, the barrier to entry for CUDA optimization was the "C++ template hell." By providing a Python-based framework that compiles directly to SASS/PTX without losing control, the authors have democratized low-level GPU programming.

However, the reliance on Blackwell-specific features (TMEM, 2-CTA mode) means that while this is the "gold standard" for B200, the industry still faces a fragmented landscape between Blackwell datacenters and consumer RDNA/GeForce hardware.

Conclusion

FlashAttention-4 proves that as hardware scales asymmetrically, we can no longer rely on hardware-agnostic code. The future of AI efficiency lies in the tight coupling of algorithmic math (skipping rescaling) and hardware-specific plumbing (TMEM/2-CTA).

Find Similar Papers

Try Our Examples

  • Find recent papers addressing the asymmetric scaling of Tensor Cores vs. Memory Bandwidth in Post-Hopper GPU architectures.
  • What are the theoretical foundations of software-emulated transcendental functions (like Cody-Waite range reduction) in high-performance GPU kernels?
  • Explore how the 2-CTA MMA mode of NVIDIA Blackwell is being utilized in non-attention workloads like sparse matrix multiplication or FFT.
Contents
[arXiv 2025] FlashAttention-4: Conquering the Asymmetric Scaling of Blackwell GPUs
1. TL;DR
2. The "Blackwell Bottleneck": Why FlashAttention-3 Wasn't Enough
3. Methodology: High-Precision Engineering for Asymmetric Hardware
3.1. 1. Software-Emulated Exponentials
3.2. 2. The 2-CTA Backward Pass
3.3. 3. Pipeline Pipelining: Tensor Memory (TMEM)
4. Performance: Breaking Records
5. Critical Insight & Future Outlook
6. Conclusion