INTERPOLATING COMPRESSED PARAMETER SUBSPACES

Abstract

Though distribution shifts have caused growing concern for machine learning scalability, solutions tend to specialize towards a specific type of distribution shift. Methods for label shift may not succeed against domain or task shift, and vice versa. We learn that constructing a Compressed Parameter Subspaces (CPS), a geometric structure representing distance-regularized parameters mapped to a set of train-time distributions, can maximize average accuracy over a broad range of distribution shifts concurrently. We show sampling parameters within a CPS can mitigate backdoor, adversarial, permutation, stylization and rotation perturbations. We also show training a hypernetwork representing a CPS can adapt to seen tasks as well as unseen interpolated tasks.

1. INTRODUCTION

Recent work on the geometry of the loss landscape, such as neural subspaces (Wortsman et al., 2021) and mode connectivity (Fort & Jastrzebski, 2019; Draxler et al., 2019; Garipov et al., 2018) discovered properties of robustness between multiple parameters. Departing from constructing subspaces w.r.t. a single/unperturbed input distribution, we investigate the construction of subspces w.r.t. multiple perturbed distributions, and find improved mappability between shifted distributions and low-loss parameters contained in these subspaces. Contributions. We share a method to construct a compressed parameter subspace such that the likelihood of a parameter sampled from this subspace can be mapped to a shifted input distribution is higher. We demonstrate a high average accuracy across distribution shifts in single and multiple test-time settings (Figure 1 ). We show improved robustness across perturbation types, reduced catastrophic forgetting on Split-CIFAR10/100, and strong capacity for multitask solutions and unseen/distant tasks. As such, we are motivated to contribute a method of adaptation/robustness that can multiple types of distribution shifts concurrently.

2.2. PRELIMINARIES

Let X , Y, K be denoted as the input space, coarse label space, and fine label space respectively. Coarse labels are the higher-order labels of fine labels. A base learner function f with parameters θ accepts inputs x to return predicted labels ȳ = f(θ; x). θ is computed such that it minimizes the loss between the ground-truth and predicted labels: L(θ; x, y) = 1 |x| |x| i (f(θ; x) -y) 2 . Definition 1. A distribution shift is the divergence between a train-time distribution x 0 , y 0 and a test-time distribution x, ŷ, where x, ŷ is an interpolated distribution between x 0 , y 0 and a target distribution x ∆ i , y ∆ i such that x = N i α i x ∆ i and ŷ = y i where i = arg max i α i . Definition 1.1 A disjoint distribution shift is a distribution shift of one target distribution |{α i }|= 2 such that x = αx ∆ + (1 -α)x 0 . Definition 1.2 A joint distribution shift is a distribution shift of multiple target distributions |{α i }|> 2 such that x = N i α i x ∆ i . For N distributions {x 0 → y 0 , x ∆ 1 → y ∆ 1 , ..., x ∆ N → y ∆ N } containing (N -1) target distributions, we sample interpolation coefficients α i ∼ [0, 1] s.t. α i → x i and N i α i ≤ N . In CPS training, N is the number of train-time distributions used in training the subspace; in task interpolation, N



Figure 1: Shift-optimality: We plot the tradeoff frontier between methods performing across a range of distribution shift types. For single test-time distributions (Table 1), Clean Accuracy is the accuracy w.r.t. Clean Test Set; Shifted Accuracy is the average accuracy across different shifts (Backdoor/Adversarial Attack, Random Permutations, Stylization, Rotation). For multiple test-time distributions (Table2), Clean Accuracy is the average accuracy of the hypernetwork evaluated after each task; Shifted Accuracy is the average accuracy of the hypernetwork evaluated after the last task.

Figure 2: Change in parameter subspace dynamics: Landscape of unique lowest-loss parameters mapped to interpolated inputs. Refer to Appendix A.3.3 for supplementary visualization.



