Pruning is Optimal for Learning Sparse Features in High-Dimensions

Vural, Nuri Mert; Erdogdu, Murat A

Session: 3D-1 (Neural Networks), Tuesday, Jul 02, 16:30-17:45

Abstract: Neural network pruning has attracted a lot of attention in deep learning. While it is commonly observed in practice that pruning networks to a certain level of sparsity can improve the quality of the features, a theoretical explanation of this phenomenon remains elusive. In this work, we investigate this phenomenon by demonstrating that a broad class of statistical models can be optimally learned using pruned neural networks trained with gradient-based methods, in high-dimensions. We consider learning both single-index and multi-index models of the form $y = g(V^T x) + \epsilon$, where $g$ is a degree-$p$ polynomial, and $V \in \R^{d \times r}$, where $r \ll d$, is the matrix containing relevant model directions. We assume that $V$ satisfies an extension of $\ell_q$-sparsity for matrices and show that pruning neural networks proportional to the sparsity level of $V$ improves their sample complexity compared to unpruned networks. Furthermore, we establish Correlational Statistical Query (CSQ) lower bounds for in this setting, which take the sparsity level of $V$ into account. We show that if the sparsity level of $V$ exceeds a certain threshold, training pruned networks with a gradient-based algorithm achieves sample complexity matching the CSQ lower bound, for a class of link functions. In the same scenario, however, our results imply that basis-independent methods such as models trained via standard gradient descent initialized with rotationally invariant random weights can provably achieve only suboptimal sample complexity.