RISK CONTROL FOR ONLINE LEARNING MODELS

Abstract

To provide rigorous uncertainty quantification for online learning models, we develop a framework for constructing uncertainty sets that provably control risksuch as coverage of confidence intervals, false negative rate, or F1 score-in the online setting. This extends conformal prediction to apply to a larger class of online learning problems. Our method guarantees risk control at any userspecified level even when the underlying data distribution shifts drastically, even adversarially, over time in an unknown fashion. The technique we propose is highly flexible as it can be applied with any base online learning algorithm (e.g., a deep neural network trained online), requiring minimal implementation effort and essentially zero additional computational cost. We further extend our approach to control multiple risks simultaneously, so the prediction sets we generate are valid for all given risks. To demonstrate the utility of our method, we conduct experiments on real-world tabular time-series data sets showing that the proposed method rigorously controls various natural risks. Furthermore, we show how to construct valid intervals for an online image-depth estimation problem that previous sequential calibration schemes cannot handle.

1. INTRODUCTION

To confidently deploy learning models in high-stakes applications, we need both high predictive accuracy and reliable safeguards to handle unanticipated changes in the underlying data-generating process. Reasonable accuracy on a fixed validation set is not enough, as raised by Sullivan (2015) ; we must also quantify uncertainty to correctly handle hard input points and take into account shifting distributions. For example, consider the application of autonomous driving, where we have a real-time view of the surroundings of the car. To successfully operate such an autonomous system, we should measure the distance between the car and close-by objects, e.g., via a sensor that outputs a depth image whose pixels represent the distance of the objects in the scene from the camera. Figure 1a displays a colored image of a road and Figure 1b presents its corresponding depth map. Since high-resolution depth measurements often require longer acquisition time compared to capturing a colored image, there were developed online estimation models to predict the depth map from a given RGB image (Patil et al., 2020; Zhang et al., 2020) . The goal of these methods is to artificially speed-up depth sensing acquisition time. However, making decisions solely based on an estimate of the depth map is insufficient as the predictive model may not be accurate enough. Furthermore, the distribution can vary greatly and drastically over time, rendering the online model to output highly inaccurate and unreliable predictions. In these situations, it is necessary to design a predictive system that reflects the range of plausible outcomes, reporting the uncertainty in the prediction. To this end, we encode uncertainty in a rigorous manner via prediction intervals/sets that augment point predictions and have a long-range error control. In the autonomous driving example, the uncertainty in the depth map estimate is represented by depth-valued uncertainty intervals. In this paper, we introduce a novel calibration framework that can wrap any online learning algorithm (e.g., an LSTM model trained online) to construct prediction sets with guaranteed validity. Formally, suppose an online learning setting where we are given data stream {(X t , Y t )} t∈N in a sequential fashion, where X t ∈ X is a feature vector and Y t ∈ Y is a target variable. In single-output regression settings Y = R, while in classification tasks Y is a finite set of all class labels. The input X t is commonly a feature vector, i.e., X = R p , although it may take different forms, as in the depth sensing task, where X t ∈ R M ×N ×3 is an RGB image and Y t ∈ R M ×N is the ground truth depth. Consider a loss function L(Y t , Ĉt (X t )) ∈ R that measures the error of the estimated prediction set Ĉt (X t ) ⊆ Y with respect to the true outcome Y t . Importantly, at each time step t ∈ N, given all samples previously observed {(X i , Y i )} t-1 i=1 along with the test feature vector X t , our goal is to construct a prediction set Ĉt (X t ) guaranteed to attain any user-specified risk level r: R( Ĉ) = lim T →∞ 1 T T t=1 L(Y t , Ĉt (X t )) = r. (1) For instance, a natural choice for the loss L in the depth sensing task is the image miscoverage loss: L image miscoverage (Y t , Ĉ(X t )) = 1 M N (m, n) : Y m,n t / ∈ Ĉm,n (X t ) . In words, L image miscoverage (Y t , C(X t )) is the ratio of pixels that were miscovered by the intervals Ĉm,n (X t ), where (m, n) is the pixel's location. Hence, the resulting risk for the loss in (2) measures the average image miscoverage rate across the prediction sets { Ĉt (X t )} ∞ t=0 , and r = 20% is a possible choice for the desired miscoverage frequency. Another example of a loss function that is attractive in multi-label classification problems is the false negative proportion whose corresponding risk is the false negative rate. In this work, we introduce rolling risk control (Rolling RC): the first calibration procedure to form prediction sets in online settings that achieve any pre-specified risk level in the sense of (1) without making any assumptions on the data distribution, as guaranteed by Theorem 1. We accomplish this by utilizing the mathematical foundations of adaptive conformal inference (ACI) (Gibbs & Candes, 2021) which is a groundbreaking conformal calibration scheme that constructs prediction sets for any arbitrary time-varying data distribution. The uncertainty sets generated by ACI are guaranteed to have valid long-range coverage, being a special case of (1) with the choice of the 0-1 loss (indicator function) defined in Section 2. Importantly, one cannot simply plug an arbitrary loss function into ACI and achieve risk control. The reason is that ACI works with conformity scores-a measure of goodness-of-fit-that are only relevant to the 0-1 loss, but do not exist in the general risk-controlling setting. Therefore, our Rolling RC broadens the set of problems that ACI can tackle, allowing the analyst to control an arbitrary loss. Furthermore, the technique we proposed in Section 3.3 is guaranteed to control multiple risks, and thus constructs sets that are valid for all given risks over long-range windows in time. Additionally, the proposed online calibration scheme is lightweight and can be integrated with any online learning model, with essentially zero added complexity. Lastly, in Section 3.2.1 we carefully investigated design choices of our method to adapt quickly to distributional shifts. Indeed, the experiments conducted on real benchmark data sets, presented in Section 4, demonstrate that sophisticated designed choices lead to improved performance. 



Figure 1: Online depth estimation. The input frame, ground truth depth map, estimated depth image, and interval's size at time step t = 8020. All values are in meter units.

