Learning Hierarchical Polynomials with Three-Layer Neural Networks
ICLR 2024(2023)
摘要
We study the problem of learning hierarchical polynomials over the standard
Gaussian distribution with three-layer neural networks. We specifically
consider target functions of the form $h = g \circ p$ where $p : \mathbb{R}^d
\rightarrow \mathbb{R}$ is a degree $k$ polynomial and $g: \mathbb{R}
\rightarrow \mathbb{R}$ is a degree $q$ polynomial. This function class
generalizes the single-index model, which corresponds to $k=1$, and is a
natural class of functions possessing an underlying hierarchical structure. Our
main result shows that for a large subclass of degree $k$ polynomials $p$, a
three-layer neural network trained via layerwise gradient descent on the square
loss learns the target $h$ up to vanishing test error in
$\widetilde{\mathcal{O}}(d^k)$ samples and polynomial time. This is a strict
improvement over kernel methods, which require $\widetilde \Theta(d^{kq})$
samples, as well as existing guarantees for two-layer networks, which require
the target function to be low-rank. Our result also generalizes prior works on
three-layer neural networks, which were restricted to the case of $p$ being a
quadratic. When $p$ is indeed a quadratic, we achieve the
information-theoretically optimal sample complexity
$\widetilde{\mathcal{O}}(d^2)$, which is an improvement over prior
work~\citep{nichani2023provable} requiring a sample size of
$\widetilde\Theta(d^4)$. Our proof proceeds by showing that during the initial
stage of training the network performs feature learning to recover the feature
$p$ with $\widetilde{\mathcal{O}}(d^k)$ samples. This work demonstrates the
ability of three-layer neural networks to learn complex features and as a
result, learn a broad class of hierarchical functions.
更多查看译文
关键词
Hierarchical polynomials,feature learning,three-layer networks,sample complexity,gradient descent
AI 理解论文
溯源树
样例
生成溯源树,研究论文发展脉络
Chat Paper
正在生成论文摘要