OVER-PARAMETERIZED MODEL OPTIMIZATION WITH POLYAK-ŁOJASIEWICZ CONDITION

Abstract

This work pursues the optimization of over-parameterized deep models for superior training efficiency and test performance. We first theoretically emphasize the importance of two properties of over-parameterized models, i.e., the convergence gap and the generalization gap. Subsequent analyses unveil that these two gaps can be upper-bounded by the ratio of the Lipschitz constant and the Polyak-Łojasiewicz (PL) constant, a crucial term abbreviated as the condition number. Such discoveries have led to a structured pruning method with a novel pruning criterion. That is, we devise a gating network that dynamically detects and masks out those poorly-behaved nodes of a deep model during the training session. To this end, this gating network is learned via minimizing the condition number of the target model, and this process can be implemented as an extra regularization loss term. Experimental studies demonstrate that the proposed method outperforms the baselines in terms of both training efficiency and test performance, exhibiting the potential of generalizing to a variety of deep network architectures and tasks.

1. INTRODUCTION

Most practical deep models are over-parameterized with the model size exceeding the training sample size and can perfectly fit all training points (Du et al., 2018; Vaswani et al., 2019) . Recent empirical and theoretical studies demonstrate that over-parameterization plays an essential role in model optimization and generalization (Liu et al., 2021b; Allen-Zhu et al., 2019) . Indeed, a plethora of state-of-the-art models that are prevalent in the community are over-parameterized, such as Transformer-based models for natural language modeling tasks (Brown et al., 2020; Devlin et al., 2018; Liu et al., 2019) and wide residual networks for computer vision tasks (Zagoruyko & Komodakis, 2016) . However, training over-parameterized models is usually time-consuming and can take anywhere from hours to weeks to complete. Notwithstanding some prior works (Liu et al., 2022; Belkin, 2021) on theoretical analyses of the over-parameterized models, those findings remain siloed from the common practices of training those networks. The work seeks to optimize over-parameterized models, in pursuit of superior training efficiency and generalization capability. We first analyze two key theoretical properties of over-parameterized models, namely the convergence gap and the generalization gap, which can be quantified by the convergence rate and the sample complexity, respectively. Theoretical analysis of over-parameterized models is intrinsically challenging as the over-parameterized optimization landscape is often nonconvex, limiting convexity-based analysis. Inspired by recent research on the convergence analysis of neural networks and other non-linear systems (Bassily et al., 2018; Gupta et al., 2021; Oymak & Soltanolkotabi, 2019) , we propose to use the Polyak-Łojasiewicz (PL) condition (Polyak, 1963; Karimi et al., 2016; Liu et al., 2022) as the primary mathematical tool to analyze convergence rate and sample complexity for over-parameterized models, along with the widely used Lipschitz con- dition (Allen-Zhu et al., 2019) . Our theoretical analysis shows that the aforementioned properties can be controlled by the ratio of the Lipschitz constant to the PL constant, which is referred to as the condition number (Gupta et al., 2021) . A small condition number indicates a large decrease in training loss after parameter updates and high algorithmic stability relative to data perturbation, i.e., fast convergence and good generalization ability. More promisingly, such pattern can be observed in empirical studies. As shown in Figure 1(a-c ), where BERT models were applied to WikiText-2 for language modeling, the training loss of the model with a small condition number (T 2 ) decreases much faster than the model with a large condition number (T 1 ), especially when the differences in condition numbers are pronounced (between 40 and 80 epochs); its test performance also improves much faster and is ultimately better. Such theoretical and empirical findings motivate us to formulate a novel regularized optimization problem which adds the minimization of condition number to the objective function; we call this new additional term PL regularization. In this way, we can directly regularize the condition number while training over-parameterized models, thereby improving their convergence speed and generalization performance. Our empirical analysis further reveals that, given an over-parameterized model, different model components exhibit distinct contributions to model optimization. Figure 1 (d) plots the heatmap of the condition number for all model heads at epoch-10 and it shows that the condition number varies considerably between model heads. Given the fact that over-parameterized models contain a large number of redundant parameters, we argue that it is possible to reduce the condition number of an over-parameterized network during training by identifying and masking out poorly-behaved sub-networks with large condition numbers. Figure 1(a-c ) illustrates the potential efficacy of this approach. After disabling 75% heads of the BERT model T 2 according to the condition number ranking at epoch-10, the masked BERT (Masked T) possesses a smaller condition number and achieves faster convergence and better test performance. This phenomenon motivates us to impose PL regularization, and hence improve model optimization, by adopting a pruning approach. More specifically, we introduce a binary mask for periodically sparsifying parameters, and the mask is learned via a gating network whose input summarizes the optimization dynamics of sub-networks in terms of PL regularization. An overview of the proposed method is provided in Appendix E.1. The proposed pruning approach to enforcing PL regularization is related to the structured pruning works, which focuses on compressing model size while maintaining model accuracy. The significant difference lies in that we utilize the condition number to identify the important components, which thus can simultaneously guarantee the convergence and generalization of model training. More importantly, as a consequence of this difference, compared with most pruning works which obtain a sparse model at a slight cost of degraded accuracy, our method is found to achieve even better test performance than the dense model when no more than 75% of the parameters were pruned. Experimental results demonstrate that our method outperforms state-of-the-art pruning methods in terms of training efficiency and test performance. The contributions of this work are threefold: • We are the first to propose using PL regularization in the training objective function of over-parameterized models. Such proposal is founded on the theoretical analysis of optimization and generalization properties of over-parameterized models, which shows that a small condition number implies fast convergence and good generalization. • We describe a PL regularization-driven structured pruning method for over-parameterized model optimization. Specifically, we introduce a learnable mask, guided by the PL-based condition number, to dynamically sparsify poorly-behaved sub-networks during model training to optimize training efficiency and test performance.



China and Shanghai Key Laboratory of Data Science, School of Computer Science, Fudan University, Shanghai, China. School of Mathematics Statistics, The University of Glasgow, Glasgow, UK. MicrosoftResearch Asia, Shanghai, China. Department of Engineering Science, University of Oxford, Oxford, England. Department of Electrical Engineering and Computer Science, University of Michigan, Michigan, United States. Department of Computer Science, University of Colorado Boulder, Boulder, Colorado, United States. School of Microelectronics, Fudan University, Shanghai, China. * The corresponding author.



Figure 1: Benefits of PL regularization for BERT optimization. Figure (a-c) shows that models with a smaller condition number (e.g., T 2 < T 1 in general) achieve faster training convergence and better test performance. In addition, pruning heads with a large condition number, i.e., Masked T , reduces the condition number, leading to more rapid and accurate training.Figure (d) shows T 2 heads have different condition numbers. The largest ones are pruned to produce Masked T in Figure (a-c).

Figure 1: Benefits of PL regularization for BERT optimization. Figure (a-c) shows that models with a smaller condition number (e.g., T 2 < T 1 in general) achieve faster training convergence and better test performance. In addition, pruning heads with a large condition number, i.e., Masked T , reduces the condition number, leading to more rapid and accurate training.Figure (d) shows T 2 heads have different condition numbers. The largest ones are pruned to produce Masked T in Figure (a-c).

