SYSTEMATIC GENERALISATION WITH GROUP INVARI-ANT PREDICTIONS

Abstract

We consider situations where the presence of dominant simpler correlations with the target variable in a training set can cause an SGD-trained neural network to be less reliant on more persistently correlating complex features. When the nonpersistent, simpler correlations correspond to non-semantic background factors, a neural network trained on this data can exhibit dramatic failure upon encountering systematic distributional shift, where the correlating background features are recombined with different objects. We perform an empirical study on three synthetic datasets, showing that group invariance methods across inferred partitionings of the training set can lead to significant improvements at such test-time situations. We also suggest a simple invariance penalty, showing with experiments on our setups that it can perform better than alternatives. We find that even without assuming access to any systematically shifted validation sets, one can still find improvements over an ERM-trained reference model.

1. INTRODUCTION

If a training set is biased such that an easier-to-learn feature correlates with the target variable throughout the training set, a modern neural network trained with SGD will use that factor to perform predictions, ignoring co-occurring harder-to-learn complex predictive features (Shah et al., 2020) . Without any other criteria, this is arguably desirable behaviour, reflecting Occam's razor. We consider the situation where although such a simpler correlation is a dominant bias in the training set, a minority group exists within the dataset where the bias does not manifest. In such cases, relying on more complex predictive features which more pervasively explain the data can be preferable to simpler ones that only explain most of it. For example, if all chairs are red, redness ought to be a predictive rule for chairhood (without any other criteria for predictions). However, if some chairs are not red, and all chairs have backs and legs, then one can infer that redness is less relevant. In this paper, we will study object recognition tasks, where the objects correlate strongly with simpler non-semantic background information for a majority of the images, but not for a minority group. There is evidence in the literature that modern CNNs tend to fixate on simpler features such as texture (Geirhos et al., 2019; Brendel & Bethge, 2019) , canonical pose (Alcorn et al., 2019) , or contextual background cues (Beery et al., 2018) . We are assuming that semantic features in a classification context (ones that humans would agree contribute to their labelling of objects) are more likely to persistently correlate with the target variable, while simpler non-semantic background biases are more likely to exhibit non-persistent correlations in real-life data collection processes. Based on this assumption, we will use combinations of objects and backgrounds to compare test-time performances corresponding to particular distributional shifts. Consider coloured MNIST digits such that there is a dominant, but not universal, correlation between colour and digit identity for a majority of the images. In the situation we are considering, if the biasing colours in the majority group are not recombined with different digits in the minority group, then there is no signal for the model to disregard these biasing factors, which are retained as important predictive rules. This can lead to poor performance at systematic generalisation (Lake & Baroni, 2018) , where an object occurs with another object's biasing factor, and at semantic anomaly detection (Ahmed & Courville, 2020) , where a novel object appears with one of the biasing factors. In our example Table 1: For a coloured MNIST dataset with every digit correlated with a colour 80% of the time, we see poor performance at systematically varying tasks. Performance improves if the minority group combines colours from other biased digits -this provides corrective gradients that promote invariance to colour. Non-systematic shifts are when unseen colours are used, and anomaly detection is measured by decreased predictive confidence for an unseen digit (see Section 2 for more details). with coloured MNIST, if we colour the minority group digits with the colours used to bias (different) digits in the majority group, we find a marked improvement at systematically shifted tests over the case when the colours in the minority group are different colours altogether (see Table 1 ). We investigate the role of encouraging robust predictive behaviour across such groups in terms of improved performance at tasks with such distributional shifts. Our experiments suggest that training with cross-group invariance penalties can result in models that have learned to be more reliant on persistent complex correlations without being overwhelmed by simpler, yet less stable features, as indicated by improved performance at systematic generalisation and semantic anomaly detection on our synthetic setups. We find that a recently proposed method (Creager et al., 2020) can be effective at inferring the majority and minority groups along a learned feature-bias, and we use this inferred partition to provide us with groups in the training set in our comparative study. We also suggest a new method for encouraging predictions that rely on persistent correlations across such groups, with the intuition that similar predictive behaviour across the groups should be promoted throughout training. With experiments on three synthetic datasets, we compare the performance of recently proposed invariance penalties and methods, and find that our variant can often perform better at tasks involving such test-time distributional shifts.

2. SYSTEMATIC AND NON-SYSTEMATIC GENERALISATION

If we assume that data x is generated via a composition C of semantic factors h s and non-semantic factors h n , we can use this decomposition, x = C(h s , h n ), to generate test datasets to capture different scenarios. While h n is actually independent of y, we shall have the independence property p D (h n |y) = p D (h n ) to not hold when there is bias in the dataset D due to h n -y correlations. We can evaluate, for a particular target y and our system's prediction of the target ŷ(x), the average accuracy E 1{ŷ(C(h s , h n )) = y} , as a measure of generalisation for the following different cases. 



Figure 1: COLOURED MNIST training and test sets for evaluating generalisation under non-semantic marginal shift and systematic shift, and anomaly detection. (a) Training set; (b) In-distribution generalisation set T g , where the test set is coloured following the same scheme as for T r ; (c) Systematic-shift generalisation set T s , where we colour the test set with the biasing colours, but such that no digit is coloured with its own biasing colour; (d) Non-systematic-shift generalisation set T n , where the test is coloured with random colours that are different from any of the colours seen in the training set; and (e) Semantic anomaly detection set T a , where we colour the held-out digits of the test set randomly with the biasing colours.

