DeFT: Flash Tree-attention with IO-Awareness for Efficient Tree-search-based LLM Inference
arxiv(2024)
摘要
Decoding using tree search can greatly enhance the inference quality for
transformer-based Large Language Models (LLMs). Depending on the guidance
signal, it searches for the best path from root to leaf in the tree by forming
LLM outputs to improve controllability, reasoning ability, alignment, et
cetera. However, current tree decoding strategies and their inference systems
do not suit each other well due to redundancy in computation, memory
footprints, and memory access, resulting in inefficient inference. To address
this issue, we propose DeFT, an IO-aware tree attention algorithm that
maintains memory-efficient attention calculation with low memory footprints in
two stages: (1) QKV Preparation: we propose a KV-Guided Tree Split strategy to
group QKV wisely for high utilization of GPUs and reduction of memory
reads/writes for the KV cache between GPU global memory and on-chip shared
memory as much as possible; (2) Attention Calculation: we calculate partial
attention of each QKV groups in a fused kernel then apply a Tree-topology-aware
Global Reduction strategy to get final attention. Thanks to a reduction in KV
cache IO by 3.6-4.5×, along with an additional reduction in IO for
𝐐𝐊^⊤ and Softmax equivalent to 25
cache IO, DeFT can achieve a speedup of 1.7-2.4× in end-to-end latency
across two practical reasoning tasks over the SOTA attention algorithms.
更多查看译文
AI 理解论文
溯源树
样例
生成溯源树,研究论文发展脉络
Chat Paper
正在生成论文摘要