FlashAttention Explained: The Optimization That Made Modern LLMs Practical

개요

FlashAttention은 Transformer 아키텍처의 Attention 메커니즘에서 발생하는 메모리 병목 현상을 GPU 하드웨어에 최적화된 방식으로 재설계하여 해결함으로써 현대 LLM의 실용성을 높인 최적화 기술입니다.

주요 내용

* Attention의 병목 현상: Transformer 모델의 Attention 메커니즘은 각 토큰이 다른 모든 토큰과 비교하는 과정에서 N x N 크기의 Attention 점수 행렬을 생성하는데, 시퀀스 길이가 길어질수록 메모리 사용량이 기하급수적으로 증가하여 GPU의 메모리 대역폭 병목 현상을 유발합니다.
* GPU의 메모리 제약: 현대 GPU는 연산 능력은 뛰어나지만, 메모리 계층 간 데이터 이동 속도는 상대적으로 느립니다. 기존 Attention 방식은 전체 Attention 행렬을 HBM(GPU 메모리)에 반복적으로 쓰고 읽어 비효율적입니다.
* 핵심 아이디어: Attention 행렬 직접 생성 금지: FlashAttention은 전체 Attention 점수 행렬을 메모리에 직접 저장하지 않고, 데이터를 작은 블록 단위로 나누어 처리합니다. 각 블록은 처리 후 결과를 누적하고 메모리에서 즉시 제거됩니다.
* Online Softmax: 전체 Attention 행렬을 저장하지 않고도 Softmax 연산을 정확하게 수행하기 위해, FlashAttention은 블록 단위로 처리하면서 실행 중인 최대값, 정규화 항, 출력 누적 값 등의 통계 정보를 유지하여 점진적으로 Softmax 결과를 업데이트하는 Online Softmax 기법을 사용합니다.
* Tiling과 IO-Aware 알고리즘: FlashAttention은 데이터를 GPU의 빠른 온칩 메모리(Shared Memory)로 로드하여 연산하는 Tiling 기법을 활용합니다. 이는 메모리 이동을 최소화하고 GPU 컴퓨팅 자원을 최대한 활용하는 IO-Aware 알고리즘 설계를 통해 메모리 트래픽을 대폭 줄여줍니다.
* FlashAttention의 발전 및 활용: FlashAttention, FlashAttention-2, FlashAttention-3는 GPU 활용도, 병렬성, 학습 처리량, 최신 GPU 및 저정밀도 포맷 지원 등을 개선해왔습니다. 현재 PyTorch, Hugging Face Transformers, vLLM, TensorRT-LLM 등 다양한 AI 프레임워크와 LLM에서 활용되어 모델 아키텍처 변경 없이 성능을 향상시킵니다.
* FlashAttention의 한계: FlashAttention은 메모리 접근 병목을 해결하지만, Attention 연산 자체의 O(N^2) 계산 복잡성은 여전히 유지됩니다. 따라서 매우 긴 컨텍스트를 처리하려면 여전히 상당한 연산 자원이 필요하며, 이를 해결하기 위한 Sliding-window attention, Linear attention 등의 연구가 지속되고 있습니다.

시사점

FlashAttention은 알고리즘 자체의 변경보다는 하드웨어와의 상호작용을 이해하여 메모리 접근을 최적화하는 접근 방식이 AI 인프라의 근본적인 돌파구를 마련할 수 있음을 보여주며, 현대 LLM, 특히 장문 컨텍스트 처리에 있어 필수적인 기술로 자리매김했습니다.

원문 읽기 →
원문을 불러오는 중...

댓글

GitHub Discussions