MACHINE LEARNING FROM EXPLANATIONS

Abstract

Machine learning needs a huge amount of (labeled) data, as otherwise it might not learn the right model for different sub-populations, or even worse, they might pick up spurious correlations in the training data leading to brittle models. Also, for small training datasets, there is a huge variability in the learned models on randomly sampled training datasets, which makes the whole process less reliable. But, collection of large amount of useful representative data, and training on large datasets, are very costly. In this paper, we present a technique to train reliable classification models on small datasets, assuming we have access to some simple explanations (e.g., subset of important input features) on labeled data. We also propose a two stage training pipeline that optimizes the model's output and finetunes its attention in an interleaving manner, to help the model to discover the reason behind the provided explanation while learning from the data. We show that our training pipeline enables a faster convergence to better models, especially when there is a severe class imbalance or spurious features in the training data.

1. INTRODUCTION

Machine learning has an excellent performance in many challenging tasks (Dosovitskiy et al., 2020; Liu et al., 2021; Zhang et al., 2020) , reaching or even outperforming humans (Silver et al., 2016; Li et al., 2022) , in controlled experiments. However, their real-life performance is often drastically worse, especially when there are natural spurious correlations in training data (Arjovsky et al., 2019; Ribeiro et al., 2016) , and in so many cases where the training data set is small, heterogeneous, or unbalanced. All this makes it hard for the existing learning algorithms to learn the right set of robust rules that can generalize well from training data. This is a serious issue in many critical applications where machine learning promises to assist humans. For example, the training data used in training human-level medical AI models often contain spurious features, and models overly rely on spurious features to classify medical images, which results in untrustworthy diagnoses in practice (Rieger et al., 2020) . To make it worse, models tend to extract more spurious reasons from the smaller sets, making models more useless for minority groups (Sagawa et al., 2020) . If the model does not learn the true reason, it usually finds reasons to favor the majority in the training set because it is an empirical risk minimizer. Even if the model up-weights the minority group to prevent them from being left out, which is a common practice in long-tailed classification, the model cannot classify minorities well on unseen data because the reasons extracted cannot be generalized. This discrepancy in group performance is also the base of the algorithmic (un)fairness problem in machine learning. Getting more data seems to help to solve some obvious problems, with a significant cost (for obtaining a large amount of high-quality, accurately labeled data, and for training models on large datasets). However, it does not seem to necessarily help to learn stable and generalizable models. We observe that running the same training algorithms on the same datasets can yield models of similar accuracies that have learned drastically different decision functions and have low prediction agreements among themselves (Ross et al., 2017; D'Amour et al., 2020; Watson et al., 2022) . This comes as no surprise, given the high-dimensionality of input space and parameter space, for (classification) tasks with so few labels. To fully eliminate the ambiguity in parameter space, we need datasets of huge sizes, which prohibitively increases the cost of using machine learning for most real-world problems. Collecting more data is also not always possible. For example, for (rare) disease diagnoses, there are not many new patients every year to expand the datasets. In manufacturing industries, defect detection systems for new products also do not have much data available. This shows there is a need to develop better learning algorithms in small data regimes. In this paper, we show that guiding machine learning with simple explanations can significantly improve its performance, reduce its sample complexity, and increase its stability. We assume that for some training data points, in addition to their label, we also have an expert explanation for the assigned labels. The explanation can have different degrees of complexity. However, in its simplest form, we can assume that a small subset of input features is highlighted to contain the high-level reason for the assigned label. We propose an effective algorithm for machine learning with explanations, where we guide the model to identify the correct latent features most consistent with the explanations while optimizing the model's overall accuracy based on labels. Our approach significantly outperforms the baselines in convergence rate, accuracy, robustness to spurious correlations, and stability (with influences generalizability), especially for heterogeneous data. Under these settings, we show that baselines not only struggle to learn complex tasks, but also fail at properly learning the right reasons for classifying data for extremely simple tasks (e.g., detecting geometric shapes). Thus, we argue that it is necessary to incorporate explanations in learning algorithms if we aim at deploying trustworthy ML models that do not latch onto spurious or counter-intuitive signals. Some prior works have explored the idea of using prior knowledge to improve machine learning (Ross et al., 2017; Rieger et al., 2020; Schramowski et al., 2020; Shao et al., 2021) . Although their objective is to be "right for the right reasons", these methods are actually penalizing models when they learn the wrong reasons. As we show in our analysis, this does not necessarily result in learning the right reasons, thus it has a limited advantage to learning only from labels. This is also fundamentally different from our proposed approach. Conceptually, providing the explanation for right reasons is much easier than enumerating all possible ways that the models might make mistakes, and doing so during the data collection phase. If there are known spurious correlations, a more straightforward and more effective solution (compared to penalizing the model on learning the spurious correlations) would be to remove them and train models on clean data (Friedrich et al., 2022) .

2. LEARNING FROM EXPLANATIONS

2.1 PROBLEM STATEMENT Given a labeled dataset D = {(x, y) i }, we also have access to explanations e(x) ⊂ x, which are a subset of the input features, of the label for each input point x. The explanations are informative enough that they can sufficiently explain the labels. We want to train a model to produce outputs similar to the given labels and base their decisions using reasons close to the given explanations.

2.2. WHY DO PREVIOUS METHODS NOT WORK WELL?

In this work, we focus on the image domain, where a wide range of model explanation methods have been studied. There are some prior works (Ross et al., 2017; Rieger et al., 2020; Schramowski et al., 2020; Shao et al., 2021) on a similar problem to ours. Explanations in their settings are bounding boxes of either the main object or the spurious features, which differ from our definition of "informative and sufficient subsets of input features". Technically, they all adopt a loss-based approach, adding an explanation misalignment loss to the label loss: L joint = L label + λL expl . (1) The additional loss is computed by taking the difference between the feature attributions attr(x) computed by certain model explanation methods and the given explanation e(x): L expl = ||attr(x) -e(x)||. ( ) While it is tempting to reuse the existing algorithms for our problem, they do not lead to higher test accuracy than vanilla training in practice. The first reason is that using model explanation methods to generate models' attributions is questionable. Adebayo et al. (2018) have shown that many popular saliency map based explanations do not even pass the sanity check. Using them as a proxy of models' attention is thus unreliable. The second reason is with training. Optimization of the joint loss is often done via gradient descent. However, gradients of the two loss terms may point to different directions, creating a race condition that pulls and pushes the model into a bad local optimum. Imagine the two gradients counteract each other. The weights are then updated with negligible aggregated gradients.

