RECONCILING FEATURE SHARING AND MULTIPLE PRE-DICTIONS WITH MIMO VISION TRANSFORMERS

Abstract

Multi-input multi-output training improves network performance by optimizing multiple subnetworks simultaneously. In this paper, we propose MixViT, the first MIMO framework for vision transformers that takes advantage of ViTs' innate mechanisms to share features between subnetworks. This is in stark contrast to traditional MIMO CNNs that are limited by their inability to mutualize features. Unlike them, MixViT only separates subnetworks in the last layers thanks to a novel source attribution that ties tokens to specific subnetworks. As such, we retain the benefits of multi-output supervision while training strong features useful to both subnetworks. We verify MixViT leads to significant gains across multiple architectures (ConViT, CaiT) and datasets (CIFAR, TinyImageNet, ImageNet-100, and ImageNet-1k) by fitting multiple subnetworks at the end of a base model.

1. INTRODUCTION

Training deep architectures has become commonplace in modern machine learning applications (Krizhevsky et al., 2012; He et al., 2016) over the course of the last decade. Finding ways to better train these models has therefore become a question of paramount importance, and thus led to significant work in the literature (Zhang et al., 2018; Shanmugam et al., 2021) . Multi-input multi-output (MIMO) architectures provide a promising (Havasi et al., 2021; Rame et al., 2021) technique to maximize model performance. These methods train 2 concurrent subnetworks within a base network by mixing 2 inputs into a single one, extracting a shared feature representation with a core network and retrieving 2 predictions, one for each input. At inference, we use the same input twice and get 2 predictions. This strongly benefits model performance both by emulating ensembling (Lakshminarayanan et al., 2017; Hansen & Salamon, 1990) with subnetwork predictions (Havasi et al., 2021) and by implementing a particular form (Rame et al., 2021; Sun et al., 2022) of mixing samples data augmentation (MSDA) (Zhang et al., 2018; Yun et al., 2019) . While one would hope the resulting subnetworks cooperate to learn and share strong generic features, this usually is not the case (Rame et al., 2021) : if a feature is used by one subnetwork, the other does not use it (see the visualization given for reference at the bottom of Fig. 1 ). This in turn causes issues for MIMO networks like forcing the use of large base networks. Although Sun et al. (2022) shows this separation issue stems from a difficulty in reconciling the shared representations with the need for distinct predictions, this is not a trivial problem to overcome in existing MIMO CNNs. Vision Transformers (Dosovitskiy et al., 2021; Touvron et al., 2021a) (ViT) offer a new solution to this problem with their token-based representation and attention mechanism that could address these issues. Remarkably, MIMO networks' success has until now remained confined to Convolutionnal Neural Networks and the paradigm has yet to extend to these emerging vision transformers. This is all the more notable as they have started outperforming CNNs on vision benchmarks (Touvron et al., 2021a; Liu et al., 2021b) and seem well suited to train more efficient MIMO transformers. In this paper, we propose the first ViT based MIMO framework, MixViT, that significantly improves upon the performance of standard single-input single-output models at a minimal cost. MixViT modifies the traditional MIMO structure to take full advantage of ViTs' propensity to mutualize features between subnetworks while still retaining the advantage of training distinct predictions (see Fig. 1 ). It therefore avoids common MIMO issues, and even benefits from strong implicit regularization (due to feature sharing) that proves particularly useful on smaller datasets. At training, Figure 1 : MixViT trains efficient MIMO ViTs. We aim to mutualize features between subnetworks as in the visualization bubble (instead of disjoint features like in classical MIMO) by reconciling the need to predict distinct outputs with the need for features to describe both inputs. MixVit mixes patch tokens from 2 inputs before feeding them to shared Self-Attention blocks. To help extract 2 predictions from these shared features, we then introduce "source attribution" to associate each attended token to the input/subnetwork they pertain to and aggregate features with a class token. Finally, we get 2 predictions from the class token features with 2 different dense layers. At inference, MixViT feeds a unique input to the model and retrieves two predictions thanks to source attribution. Our contributions are therefore as follows: 1) We leverage innate properties of Vision Transformers to train efficient subnetworks that share features as needed. MixViT therefore adds only marginal costs at training and inference while strongly regularizing the model. 

2. MIXVIT

We propose MixViT (Fig. 2 ), a MIMO framework for vision transformers that takes advantage of ViTs' innate properties to share features between subnetworks while still yielding diverse predictions. MixViT adapts the classical MIMO workflow to ViTs by completely restructuring the model so that subnetworks share early features and only specialize in the last layers thanks to our novel "source attribution" mechanism which allows the emergence of independent subnetworks in the last layers. This section starts by introducing MixViT (Sec. 2.1) before discussing MixViT's new MIMO structure in Sec. 2.2, our new source attribution (Sec. 2.3) mechanism, and its added overhead in Sec. 2.4. ViTs are seen here as separated into feature extraction blocks that use Self-Attention blocks and classifier blocks that rely on Class-Attention blocks (see Fig. 2 ). A number of modern vision transformers follow this structure like the CaiT (Touvron et al., 2021b) architecture. MixViT is in itself architecture-agnostic and easily generalizes to other ViT architectures with minor adaptations.

2.1. MIXVIT OVERVIEW

At training time, the 2 × N patches extracted from the 2 inputs are mixed into N patches with binary masks {M i } (so one token per position is picked as shown on Fig. 2 ) with a ratio λ ∼ β(α, α) (following a corrected CutMix (Yun et al., 2019) scheme here). The N mixed patches are then encoded into d-dimensional tokens by a linear layer e. Learnable positional embeddings are added to the tokens and the tokens are fed to the L Self-Attention blocks that serve as feature extraction blocks. This yields N attended tokens t but no indication which input/subnetwork each token belongs to. We propose a novel source attribution mechanism s i to specify which input/subnetwork each patch pertains to. The source attribution adds subnetwork-specific information to the tokens as we elaborate



2) We introduce a novel source attribution mechanism to facilitate this late separation of subnetworks. 3) We propose the first working MIMO ViT as MixViT solves issue of traditional MIMO frameworks on transformers. MixViT sets a new state-of-the-art on the TinyImageNet dataset and shows strong benefits across multiple architectures (ConViT (D'Ascoli et al., 2021), CaiT (Touvron et al., 2021b)) and datasets (CIFAR-10/100 (Krizhevsky et al., 2009), TinyImageNet (Chrabaszcz et al., 2017), ImageNet-100 (Tian et al., 2020)), and ImageNet-1k (Deng et al., 2009).

