본문 바로가기
AI

FlashAttention: Transformer의 메모리 효율적 고속 Attention 메커니즘 분석

by markbyun 2025. 5. 27.

Left: FlashAttention uses tiling to prevent materialization of the large 𝑁 × 𝑁 attention matrix (dotted box) on (relatively) slow GPU HBM. In the outer loop (red arrows), FlashAttention loops through blocks of the K and V matrices and loads them to fast on-chip SRAM. In each block, FlashAttention loops over blocks of Q matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM. Right: Speedup over the PyTorch implementation of attention on GPT-2. FlashAttention does not read and write the large 𝑁 × 𝑁 attention matrix to HBM, resulting in an 7.6× speedup on the attention computation.

Transformer 아키텍처는 자연어 처리와 컴퓨터 비전 분야에서 표준이 되었지만, Attention 연산의 연산량과 메모리 소비는 특히 긴 시퀀스 처리에서 성능 병목의 원인이 됩니다. 2022년 발표된 FlashAttention은 GPU 아키텍처의 특성을 활용하여 정확한(Exact) Attention을 제공하면서도 속도와 메모리 효율성을 극적으로 개선한 알고리즘입니다.

** You can find the English verion of this content at this page (https://markbyun.blogspot.com/2025/05/flashattention-high-speed-memory.html)


1. 기존 Attention의 병목

Transformer의 핵심 연산인 Attention은 다음과 같이 정의됩니다:

$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d}}) V$

이 연산에서 Q, K, V의 곱셈과 softmax는 모든 pairwise interaction을 포함하며, 시퀀스 길이가 늘어날수록 계산량은 $O(n^2)$, 메모리 사용량도 $O(n^2)$가 됩니다. 대부분의 구현에서는 중간 결과(예: $QK^T$)를 GPU HBM에 저장하는데, 이는 대량의 메모리 이동(memory I/O)을 유발합니다.


2. FlashAttention 1의 핵심 아이디어

FlashAttention은 메모리 접근을 최소화하면서도 정확한 attention 결과를 얻는 방법을 제시합니다. 그 핵심은 다음과 같습니다:

  • Tile 기반 스트리밍 계산: 시퀀스를 작은 타일로 분할해 register와 SRAM 수준에서 softmax normalization을 수행.
  • 중간 결과 미저장: $QK^T$ 행렬 전체를 저장하지 않고, softmax 누적을 streaming 방식으로 구현.
  • Numerical Stability 보장: Max-trick을 활용하여 softmax overflow를 방지.

알고리즘 설명

FlashAttention은 Qᵢ와 Kᵢ를 타일 단위로 가져오고, 아래와 같은 방식으로 softmax를 누적 계산합니다:

for each query tile:
  initialize sum = 0, max = -inf
  for each key tile:
    score = Q · Kᵀ
    max = max(prev_max, max(score))
    score = exp(score - max)
    sum += score
    acc += score · V
output = acc / sum

이러한 방식은 GPU의 shared memory를 활용하여 대량의 global memory 접근 없이 softmax의 정확한 결과를 구현합니다.


3. FlashAttention-2의 개선점

2023년 발표된 FlashAttention-2는 더욱 빠른 속도와 병렬성을 확보합니다. 주요 개선점은 다음과 같습니다:

  • 작업 분할 최적화: Query 축을 따라 병렬 처리하면서 warp-level parallelism을 최대화.
  • 더 적은 레지스터 사용: thread간 register spilling 최소화.
  • FP16/BF16 효율 개선: Transformer에서 흔히 쓰는 저정밀 연산을 안정적으로 지원.

3.1 Work Partitioning 전략

FlashAttention-2는 다음의 세 가지 병렬화 전략을 사용합니다:

  1. Block-per-query: 각 block이 query를 담당.
  2. Warp-per-query: 각 warp가 query를 병렬 처리.
  3. Thread-per-query: 고속 연산을 위한 fine-grained parallelism 제공.

3.2 커널 구조

FlashAttention-2는 Triton 커널로 구현되어 있으며, register reuse와 shared memory 타일링 전략을 더욱 정교화하였습니다. 특히 softmax normalization을 정확하게 유지하면서도 peak bandwidth 활용률을 대폭 개선했습니다.


4. 성능 비교

모델 속도 개선 메모리 사용량 정확도
기존 Attention Baseline 높음 정확
FlashAttention 1 1.7x ~ 2.7x 낮음 정확
FlashAttention 2 2.5x ~ 4.0x 매우 낮음 정확

5. 적용 사례

FlashAttention은 HuggingFace Transformers 및 NVIDIA Megatron-LM에서 지원되며, LLaMA, GPT-NeoX, BERT 등의 모델 학습에 적용되어 훈련 시간 단축메모리 여유 확보라는 두 가지 장점을 모두 제공합니다.


6. 결론

FlashAttention 시리즈는 단순한 최적화 기법을 넘어서, 하드웨어 친화적인 알고리즘 설계가 얼마나 큰 차이를 만들 수 있는지를 보여주는 대표 사례입니다. 메모리 I/O 병목 해소는 더 큰 모델을 훈련할 수 있는 기반이 되며, 앞으로의 LLM 학습 및 추론에 필수적인 구성요소가 될 것입니다.

참고 문헌