Transformer 아키텍처는 자연어 처리와 컴퓨터 비전 분야에서 표준이 되었지만, Attention 연산의 연산량과 메모리 소비는 특히 긴 시퀀스 처리에서 성능 병목의 원인이 됩니다. 2022년 발표된 FlashAttention은 GPU 아키텍처의 특성을 활용하여 정확한(Exact) Attention을 제공하면서도 속도와 메모리 효율성을 극적으로 개선한 알고리즘입니다.
** You can find the English verion of this content at this page (https://markbyun.blogspot.com/2025/05/flashattention-high-speed-memory.html)
1. 기존 Attention의 병목
Transformer의 핵심 연산인 Attention은 다음과 같이 정의됩니다:
$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d}}) V$
이 연산에서 Q, K, V의 곱셈과 softmax는 모든 pairwise interaction을 포함하며, 시퀀스 길이가 늘어날수록 계산량은 $O(n^2)$, 메모리 사용량도 $O(n^2)$가 됩니다. 대부분의 구현에서는 중간 결과(예: $QK^T$)를 GPU HBM에 저장하는데, 이는 대량의 메모리 이동(memory I/O)을 유발합니다.
2. FlashAttention 1의 핵심 아이디어
FlashAttention은 메모리 접근을 최소화하면서도 정확한 attention 결과를 얻는 방법을 제시합니다. 그 핵심은 다음과 같습니다:
- Tile 기반 스트리밍 계산: 시퀀스를 작은 타일로 분할해 register와 SRAM 수준에서 softmax normalization을 수행.
- 중간 결과 미저장: $QK^T$ 행렬 전체를 저장하지 않고, softmax 누적을 streaming 방식으로 구현.
- Numerical Stability 보장: Max-trick을 활용하여 softmax overflow를 방지.
알고리즘 설명
FlashAttention은 Qᵢ와 Kᵢ를 타일 단위로 가져오고, 아래와 같은 방식으로 softmax를 누적 계산합니다:
for each query tile:
initialize sum = 0, max = -inf
for each key tile:
score = Q · Kᵀ
max = max(prev_max, max(score))
score = exp(score - max)
sum += score
acc += score · V
output = acc / sum
이러한 방식은 GPU의 shared memory를 활용하여 대량의 global memory 접근 없이 softmax의 정확한 결과를 구현합니다.
3. FlashAttention-2의 개선점
2023년 발표된 FlashAttention-2는 더욱 빠른 속도와 병렬성을 확보합니다. 주요 개선점은 다음과 같습니다:
- 작업 분할 최적화: Query 축을 따라 병렬 처리하면서 warp-level parallelism을 최대화.
- 더 적은 레지스터 사용: thread간 register spilling 최소화.
- FP16/BF16 효율 개선: Transformer에서 흔히 쓰는 저정밀 연산을 안정적으로 지원.
3.1 Work Partitioning 전략
FlashAttention-2는 다음의 세 가지 병렬화 전략을 사용합니다:
- Block-per-query: 각 block이 query를 담당.
- Warp-per-query: 각 warp가 query를 병렬 처리.
- Thread-per-query: 고속 연산을 위한 fine-grained parallelism 제공.
3.2 커널 구조
FlashAttention-2는 Triton 커널로 구현되어 있으며, register reuse와 shared memory 타일링 전략을 더욱 정교화하였습니다. 특히 softmax normalization을 정확하게 유지하면서도 peak bandwidth 활용률을 대폭 개선했습니다.
4. 성능 비교
모델 | 속도 개선 | 메모리 사용량 | 정확도 |
---|---|---|---|
기존 Attention | Baseline | 높음 | 정확 |
FlashAttention 1 | 1.7x ~ 2.7x | 낮음 | 정확 |
FlashAttention 2 | 2.5x ~ 4.0x | 매우 낮음 | 정확 |
5. 적용 사례
FlashAttention은 HuggingFace Transformers 및 NVIDIA Megatron-LM에서 지원되며, LLaMA, GPT-NeoX, BERT 등의 모델 학습에 적용되어 훈련 시간 단축과 메모리 여유 확보라는 두 가지 장점을 모두 제공합니다.
6. 결론
FlashAttention 시리즈는 단순한 최적화 기법을 넘어서, 하드웨어 친화적인 알고리즘 설계가 얼마나 큰 차이를 만들 수 있는지를 보여주는 대표 사례입니다. 메모리 I/O 병목 해소는 더 큰 모델을 훈련할 수 있는 기반이 되며, 앞으로의 LLM 학습 및 추론에 필수적인 구성요소가 될 것입니다.
참고 문헌
- Tri Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, arXiv:2205.14135
- Tri Dao et al., FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning, arXiv:2307.08691
- https://github.com/HazyResearch/flash-attention
'AI' 카테고리의 다른 글
SentencePiece 완전 정복: AI 엔지니어를 위한 언어 독립형 토크나이저 (2) | 2025.05.28 |
---|---|
ZeRO: 대규모 모델 학습을 위한 메모리 최적화 기법 분석 (1) | 2025.05.27 |
RoFormer와 Rotary Position Embedding: Transformer 위치 인코딩의 혁신 (0) | 2025.05.27 |
대규모 언어 모델의 KV 캐시: 설계, 최적화 및 추론 가속 (0) | 2025.05.26 |
SmolVLM: 허깅페이스의 작고 효율적인 멀티모달 모델 (0) | 2025.05.14 |