RECONCILING FEATURE SHARING AND MULTIPLE PRE-DICTIONS WITH MIMO VISION TRANSFORMERS

Abstract

Multi-input multi-output training improves network performance by optimizing multiple subnetworks simultaneously. In this paper, we propose MixViT, the first MIMO framework for vision transformers that takes advantage of ViTs' innate mechanisms to share features between subnetworks. This is in stark contrast to traditional MIMO CNNs that are limited by their inability to mutualize features. Unlike them, MixViT only separates subnetworks in the last layers thanks to a novel source attribution that ties tokens to specific subnetworks. As such, we retain the benefits of multi-output supervision while training strong features useful to both subnetworks. We verify MixViT leads to significant gains across multiple architectures (ConViT, CaiT) and datasets (CIFAR, TinyImageNet, ImageNet-100, and ImageNet-1k) by fitting multiple subnetworks at the end of a base model.

1. INTRODUCTION

Training deep architectures has become commonplace in modern machine learning applications (Krizhevsky et al., 2012; He et al., 2016) over the course of the last decade. Finding ways to better train these models has therefore become a question of paramount importance, and thus led to significant work in the literature (Zhang et al., 2018; Shanmugam et al., 2021) . Multi-input multi-output (MIMO) architectures provide a promising (Havasi et al., 2021; Rame et al., 2021) technique to maximize model performance. These methods train 2 concurrent subnetworks within a base network by mixing 2 inputs into a single one, extracting a shared feature representation with a core network and retrieving 2 predictions, one for each input. At inference, we use the same input twice and get 2 predictions. This strongly benefits model performance both by emulating ensembling (Lakshminarayanan et al., 2017; Hansen & Salamon, 1990) with subnetwork predictions (Havasi et al., 2021) and by implementing a particular form (Rame et al., 2021; Sun et al., 2022) of mixing samples data augmentation (MSDA) (Zhang et al., 2018; Yun et al., 2019) . While one would hope the resulting subnetworks cooperate to learn and share strong generic features, this usually is not the case (Rame et al., 2021): if a feature is used by one subnetwork, the other does not use it (see the visualization given for reference at the bottom of Fig. 1 ). This in turn causes issues for MIMO networks like forcing the use of large base networks. Although Sun et al. (2022) shows this separation issue stems from a difficulty in reconciling the shared representations with the need for distinct predictions, this is not a trivial problem to overcome in existing MIMO CNNs. Vision Transformers (Dosovitskiy et al., 2021; Touvron et al., 2021a) (ViT) offer a new solution to this problem with their token-based representation and attention mechanism that could address these issues. Remarkably, MIMO networks' success has until now remained confined to Convolutionnal Neural Networks and the paradigm has yet to extend to these emerging vision transformers. This is all the more notable as they have started outperforming CNNs on vision benchmarks (Touvron et al., 2021a; Liu et al., 2021b) and seem well suited to train more efficient MIMO transformers. In this paper, we propose the first ViT based MIMO framework, MixViT, that significantly improves upon the performance of standard single-input single-output models at a minimal cost. MixViT modifies the traditional MIMO structure to take full advantage of ViTs' propensity to mutualize features between subnetworks while still retaining the advantage of training distinct predictions (see Fig. 1 ). It therefore avoids common MIMO issues, and even benefits from strong implicit regularization (due to feature sharing) that proves particularly useful on smaller datasets. At training, 1

