Mechanics of Next Token Prediction with Self-Attention
arxiv(2024)
摘要
Transformer-based language models are trained on large datasets to predict
the next token given an input sequence. Despite this simple training objective,
they have led to revolutionary advances in natural language processing.
Underlying this success is the self-attention mechanism. In this work, we ask:
What does a single
self-attention layer learn from
next-token prediction? We show that training
self-attention with gradient descent learns an automaton which generates the
next token in two distinct steps: (1) Hard
retrieval: Given input sequence, self-attention precisely selects
the high-priority input tokens associated with
the last input token. (2) Soft composition: It
then creates a convex combination of the high-priority tokens from which the
next token can be sampled. Under suitable conditions, we rigorously
characterize these mechanics through a directed graph over tokens extracted
from the training data. We prove that gradient descent implicitly discovers the
strongly-connected components (SCC) of this graph and self-attention learns to
retrieve the tokens that belong to the highest-priority SCC available in the
context window. Our theory relies on decomposing the model weights into a
directional component and a finite component that correspond to hard retrieval
and soft composition steps respectively. This also formalizes a related
implicit bias formula conjectured in [Tarzanagh et al. 2023]. We hope that
these findings shed light on how self-attention processes sequential data and
pave the path toward demystifying more complex architectures.
更多查看译文
AI 理解论文
溯源树
样例
生成溯源树,研究论文发展脉络
Chat Paper
正在生成论文摘要