Algorytm dzieli macierze Q, K, V na bloki rozmiaru B_r × d i B_c × d mieszczące się w SRAM (typowo 100-200 kB per SM). Dla każdego bloku Q ładuje go raz, następnie iteruje po blokach K i V, obliczając cząstkowe wyniki attention i akumulując je z numerycznie stabilnym online softmax: utrzymuje bieżący max m i sumę l, przy każdej nowej parze (K_j, V_j) aktualizuje O ← rescale(O_prev, m_old, m_new) + exp(S_new - m_new) · V_j. Macierz attention n×n nigdy nie jest materializowana w HBM. Backward pass używa rekomputacji zamiast zapisywanej macierzy uwagi (gradient checkpointing).
Standardowa implementacja attention materializuje macierz n×n w HBM i jest memory-bound — dominującym kosztem nie są FLOPs softmaxu, lecz transfer danych. Ogranicza to maksymalną długość kontekstu i throughput.
Podział macierzy Q, K, V na bloki rozmiaru mieszczącego się w SRAM GPU (zwykle B_r × d ~ 64-128 × 64-128).
Numerycznie stabilna rekurencja utrzymująca bieżący max i sumę wykładniczą — pozwala obliczać softmax blokami bez materializacji pełnej macierzy.
Backward nie zapisuje macierzy attention, rekomputuje ją z zapisanych O, L (logsumexp) — kompromis FLOPs vs pamięć.
FlashAttention-3 wymaga Hopper (H100/H200) — nie działa na Ampere (A100). v2 jest standardem na A100. Wybór złej wersji = utrata 2-4× speedup.
FlashAttention zakłada standardową scaled-dot-product attention z opcjonalnym causal mask. Niestandardowe maski (np. ALiBi, block-sparse) wymagają specjalnych wariantów lub uniemożliwiają jego użycie.
Pierwsza publikacja — tiling + online softmax, 2-4× speedup, O(n) pamięć dla exact attention.
Lepsze work partitioning po warpach GPU, parallelism po wymiarze sekwencji — 2× szybsze niż v1, ~50-70% peak FLOPs na A100.
PyTorch dodaje FlashAttention jako domyślny backend dla F.scaled_dot_product_attention — masowa adopcja w ekosystemie.
Wsparcie dla Hopper (H100): asynchroniczne TMA, warp-specialization, FP8 — do 75% peak FLOPs na H100, 2× szybsze niż v2.
FlashAttention przesuwa attention z memory-bound do bliżej compute-bound przez maksymalizację reużycia danych w SRAM. Pozostaje limit przepustowości HBM dla ładowania bloków Q/K/V.
Exact attention — wszystkie pary token-token są obliczane (modulo causal mask). Optymalizacja czysto algorytmiczna, bez zmiany matematyki modelu.