A LEARNING BASED HYPOTHESIS TEST FOR HARMFUL COVARIATE SHIFT

Abstract

The ability to quickly and accurately identify covariate shift at test time is a critical and often overlooked component of safe machine learning systems deployed in high-risk domains. While methods exist for detecting when predictions should not be made on out-of-distribution test examples, identifying distributional level differences between training and test time can help determine when a model should be removed from the deployment setting and retrained. In this work, we define harmful covariate shift (HCS) as a change in distribution that may weaken the generalization of a predictive model. To detect HCS, we use the discordance between an ensemble of classifiers trained to agree on training data and disagree on test data. We derive a loss function for training this ensemble and show that the disagreement rate and entropy represent powerful discriminative statistics for HCS. Empirically, we demonstrate the ability of our method to detect harmful covariate shift with statistical certainty on a variety of high-dimensional datasets. Across numerous domains and modalities, we show state-of-the-art performance compared to existing methods, particularly when the number of observed test samples is small 1 .

1. INTRODUCTION

Machine learning models operate on the assumption, albeit incorrectly that they will be deployed on data distributed identically to what they were trained on. The violation of this assumption is known as distribution shift and can often result in significant degradation of performance [Bickel et al., 2009; Rabanser et al., 2019; Otles et al., 2021; Ovadia et al., 2019] . There are several cases where a mismatch between training and deployment data results in very real consequences on human beings. In healthcare, machine learning models have been deployed for predicting the likelihood of sepsis. Yet, as [Habib et al., 2021] show, such models can be miscalibrated for large groups of individuals, directly affecting the quality of care they experience. The deployment of classifiers in the criminal justice system [Hao, 2019] , hiring and recruitment pipelines [Dastin, 2018] and self-driving cars [Smiley, 2022] have all seen humans affected by the failures of learning models. The need for methods that quickly detect, characterize and respond to distribution shift is, therefore, a fundamental problem in trustworthy machine learning. We study a special case of distribution shift, commonly known as covariate shift, which considers shifts only in the distribution of input data P(X) while the relation between the inputs and outputs P(Y |X) remains fixed. In a standard deployment setting where ground truth labels are not available, covariate shift is the only type of distribution shift that can be identified. For practitioners, regulatory agencies and individuals to have faith in deployed predictive models without the need for laborious manual audits, we need methods for the identification of covariate shift that are sample-efficient (identifying shifts from a small number of samples), informed (identifying shifts relevant to the domain and learning algorithm), model-agnostic (identifying shifts regardless of the functional class of the predictive model) and statistically sound (identifying true shifts while avoiding false positives with high-confidence). We build off recent progress in understanding model performance under covariate shift using the PQ-learning framework [Goldwasser et al., 2020] , a framework for selective classifiers that may either predict on or reject a given sample, that provides strong performance guarantees on arbitrary test distributions. Our work uses and extends PQ-learning to develop a practical, model-based hypothesis test, named the Detectron, to identify potentially harmful covariate shifts given any existing classification model already in deployment. Our work makes the following key contributions: • We show how to construct an ensemble of classifiers that maximize out-of-domain disagreement while behaving consistently in the training domain. We propose the disagreement cross entropy for models learned via continuous gradient-based methods (e.g., neural networks), as well as a generalization for those learned via discrete optimization (e.g., random forest). • We show that the rejection rate and the entropy of the learning ensemble can be used to define a model-aware hypothesis test for covariate shift, the Detectron, that in idealized settings can provably detect covariate shift. • On high-dimensional image and tabular data, using both neural networks and gradient boosted decision trees, our method outperforms state-of-the-art techniques for detecting covariate shift, particularly when given access to as few as ten test examples.  → → - → Pretrain model f with data (xi, yi) ∼ P P Q - → Finetune f → g Q to disagree on observed data {x1, . . . , xm} ∼ iid QX Finetune f → g P to disagree on i.i.d data {x1, . . . , xm} ∼ iid PX Compare disagreement rates φP = PX [g P (x) = f (x)] φQ = QX [g Q (x) = f (x)] as a test for covariate shift P = Q - → -→ Disagreement Rates Hypothesis Test Covariate Shift 0.5 0.01 < l a t e x i t s h a 1 _ b a s e 6 4 = " + c / Z b q r X T O c W i i D e 3 G v / r Z C d + f 0 = " > A A A B + 3 i c b V D L S s N A F L 2 p r 1 p f s S 7 d D B b B V U l E 1 G X R j c s K 9 g F N C J P p t B 0 6 m Y S Z i V h C f s W N C 0 X c + i P u / B s n b R b a e m D g c M 6 9 3 D M n T D h T 2 n G + r c r a + s b m V n W 7 t r O 7 t 3 9 g H 9 a 7 K k 4 l o R 0 S 8 1 j 2 Q 6 w o Z 4 J 2 N N O c 9 h N J c R R y 2 g u n t 4 X f e 6 R S s V g 8 6 F l C / Q i P B R s x g r W R A r v u J R M W Z F 6 E 9 Y R g n r X z P L A b T t O Z A 6 0 S t y Q N K N E O 7 C 9 v G J M 0 o k I T j p U a u E 6 i / Q x L z Q i n e c 1 L F U 0 w m e I x H R g q c E S V n 8 2 z 5 + j U K E M 0 i q V 5 Q q O 5 + n s j w 5 F S s y g 0 k 0 V G t e w V 4 n / e I N W j a z 9 j I k k 1 F W R x a J R y p G N U F I G G T F K i + c w Q T C Q z W R G Z Y I m J N n X V T A n u 8 p d X S f e 8 6 V 4 2 L + 4 v G q 2 b s o 4 q H M M J n I E L V 9 C C O 2 h D B w g 8 w T O 8 w p u V W y / W u / W x G K 1 Y 5 c 4 R / I H 1 + Q O n P J T a < / l a t e x i t > P < l a t e x i t s h a 1 _ b a s e 6 4 = " W N x Z O t q t 9 a F U b Y d l K n A F S H M V t V A = " > A A A B + 3 i c b V D L S s N A F L 2 p r 1 p f s S 7 d B I v g q i R S 1 G X R j c s W 7 A O a E C b T S T t 0 M g k z E 7 G E / I o b F 4 q 4 9 U f c + T d O 2 i y 0 9 c D A 4 Z x 7 u W d O k D A q l W 1 / G 5 W N z a 3 t n e p u b W / / 4 P D I P K 7 3 Z Z w K T H o 4 Z r E Y B k g S R j n p K a o Y G S a C o C h g Z B D M 7 g p / 8 E i E p D F / U P O E e B G a c B p S j J S W f L P u J l P q Z 2 6 E 1 B Q j l n X z 3 D c b d t N e w F o n T k k a U K L j m 1 / u O M Z p R L j C D E k 5 c u x E e R k S i m J G 8 p q b S p I g P E M T M t K U o 4 h I L 1 t k z 6 1 z r Y y t M B b 6 c W U t 1 N 8 b G Y q k n E e B n i w y y l W v E P / z R q k K b 7 y M 8 i R V h O P l o T B l l o q t o g h r T A X B i s 0 1 Q V h Q n d X C U y Q Q V r q u m i 7 B W f 3 y O u l f N p 2 r Z q v b a r R v y z q q c A p n c A E O X E M b 7 q E D P c D w B M / w C m 9 G b r w Y 7 8 b H c r R i l D s n 8 A f G 5 w + o w p T b < / l a t e x i t > Q < l a t e x i t s h a 1 _ b a s e 6 4 = " 2 h T 7 X + 9 A N i n v S E W f d 2 g d S 1 H 1 Q w g = " > A A A C B H i c b V D L S s N A F L 2 p r 1 p f U Z f d D B b B V U l E 1 G X R j c s W 7 A P a U C b T a T t 0 M o k z E 6 G E L N z 4 K 2 5 c K O L W j 3 D n 3 z h p g 2 j r g Y E z 5 9 z L v f f 4 E W d K O 8 6 X V V h Z X V v f K G 6 W t r Z 3 d v f s / Y O W C m N J a J O E P J Q d H y v K m a B N z T S n n U h S H P i c t v 3 J d e a 3 7 6 l U L B S 3 e h p R L 8 A j w Y a M Y G 2 k v l 3 u B V i P C e Z J P e 0 J e o d + / o 2 0 b 1 e c q j M D W i Z u T i q Q o 9 6 3 P 3 u D k M Q B F Z p w r F T X d S L t J V h q R j h N S 7 1 Y 0 Q i T C R 7 R r q E C B 1 R 5 y e y I F B 0 b Z Y C G o T R P a D R T f 3 c k O F B q G v i m M l t R L X q Z + J / X j f X w 0 k u Y i G J N B Z k P G s Y c 6 R B l i a A B k 5 R o P j U E E 8 n M r o i M s c R E m 9 x K J g R 3 8 e R l 0 j q t u u f V s 8 Z Z p X a V x 1 G E M h z B C b h w A T W 4 g T o 0 g c A D P M E L v F q P 1 r P 1 Z r 3 P S w t W 3 n M I f 2 B 9 f A N B A 5 h / < / l a t e x i t > P 6 = Q < l a t e x i t s h a 1 _ b a s e 6 4 = " N 4 z 1 Y P J I R M g h s f B z H S Z O q 5 F L 9 / Q = " > A A A C W H i c b Z F d S 8 M w F I b T u r m t f t V 5 6 U 1 w C F 6 N V k R Q j G S C n l m 5 A R I z j B i a T v z r C c H O v G M e m v Y y e D L C q V r 1 s 0 y x z E 2 f W i f 7 3 m v z z M b V t N a F t w V d i E a o K i O Z 3 4 6 4 w g n A Q k l Z k i I k W 3 F 0 k 0 R l x Q z k h l O I k i M 8 B x N y U j J E A V E u O k y m A x e K z K G k 4 i r E 0 q 4 p J u O F A V C L A J f T e Z b i u 1 e D v f 1 R o m c P L o p D e N E k h C v L p o k D M o I 5 i n D M e U E S 7 Z Q A m F O 1 a 4 Q z x B H W K q / M F Q I 9 v a T d 0 X / t m n f N + + 6 d 4 3 W a x F H F V y C K 3 A D b P A A W q A N O q A H M P g C v 1 p J K 2 s / O t A r e m 0 1 q m u F 5 w L 8 K 7 3 + B 4 3 J t Q Q = < / l a t e x i t > H0 : P = Q Ha : P < Q < l a t e x i t s h a 1 _ b a s e 6 4 = " g O D Z k 7 v f 3 h c T E Z l 9 / w y Z + 0 x Y V 5 o = " > A A A B 8 n i c b V D L S s N A F L 2 p r 1 p f V Z d u B o v g q i Q i 6 r L o x m U F + 4 A 2 l M l 0 0 g 6 d T M L M j V B C P 8 O N C 0 X c + j X u / B s n b R b a e m D g c M 6 9 z L k n S K Q w 6 L r f T m l t f W N z q 7 x d 2 d n d 2 z + o H h 6 1 T Z x q x l s s l r H u B t R w K R R v o U D J u 4 n m N A o k 7 w S T u 9 z v P H F t R K w e c Z p w P 6 I j J U L B K F q p 1 4 8 o j h m V W X M 2 q N b c u j s H W S V e Q W p Q o D m o f v W H M U s j r p B J a k z P c x P 0 M 6 p R M M l n l X 5 q e E L Z h I 5 4 z 1 J F I 2 7 8 b B 5 5 R s 6 s M i R h r O 1 T S O b q 7 4 2 M R s Z M o 8 B O 5 h H N s p e L / 3 m 9 F M M b P x M q S Z E r t v g o T C X B m O T 3 k 6 H Q n K G c W k K Z F j Y r Y W O q K U P b U s W W 4 C 2 f v E r a F 3 X v q n 7 5 c F l r 3 B Z 1 l O E E T u E c P L i G B t x D E 1 r A I I Z n e I U 3 B 5 0 X 5 9 3 5 W I y W n G L n G P 7 A + f w B i g e R b w = = < / l a t e x i t > P < l a t e x i t s h a 1 _ b a s e 6 4 = " t 2 n X H f Z o P 3 H L B L U e k R g 0 o D H O f U Q = " > A A A B 8 n i c b V D L S g M x F L 1 T X 7 W + q i 7 d B I v g q s x I U Z d F N y 5 b s A + Y D i W T Z t r Q T D I k G a E M / Q w 3 L h R x 6 9 e 4 8 2 / M t L P Q 1 g O B w z n 3 k n N P m H C m j e t + O 6 W N z a 3 t n f J u Z W / / 4 P C o e n z S 1 T J V h H a I 5 F L 1 Q 6 w p Z 4 J 2 D D O c 9 h N F c R x y 2 g u n 9 7 n f e 6 J K M y k e z S y h Q Y z H g k W M Y G M l f x B j M y G Y Z + 3 5 s F p z 6 + 4 C a J 1 4 B a l B g d a w + j U Y S Z L G V B j C s d a + 5 y Y m y L A y j H A 6 r w x S T R N M p n h M f U s F j q k O s k X k O b q w y g h F U t k n D F q o v z c y H G s 9 i 0 M 7 m U f U q 1 4 u / u f 5 q Y l u g 4 y J J D V U k O V H U c q R k S i / H 4 2 Y o s T w m S W Y K G a z I j L B C h N j W 6 r Y E r z V k 9 d J 9 6 r u X d c b 7 U a t e V f U U Y Y z O I d L 8 O A G m v A A L e g A A Q n P 8 A p v j n F e n H f n Y z l a c o q d U / g D 5 / M H i 4 y R c A = = < / l a t e x i t > Q f (x) Constrain g (•) to agree with f on the training set  g Q (x) g P (x)

2. BACKGROUND AND RELATED WORK

Covariate Shift Detection. Covariate shift is the tendency for a distribution at test time p test (x) to differ from that seen during training p train (x) while the underlying prediction concept y remains fixed e.g. p train (y|x) = p test (y|x). Many methods for detecting shift apply dimensionality reduction followed by statistical hypothesis tests for distributional differences in the outputs (from a reference and target) [Rabanser et al., 2019] . Rabanser et al. show that using the softmax outputs of a pretrained classifier as low dimensional representations for performing univariate KS-tests, a method known as black box shift detection (BBSD) [Lipton et al., 2018] , is effective at confidently identifying several synthetic covariate shifts in imaging data (e.g. crops, rotations) given approximately 200 i.i.d samples. However, applying statistical tests to non-invertible representations of data can never guarantee to capture arbitrary covariate shifts, as there may always exist multiple distributions that collapse to the same test statistic [Zhang et al., 2021] . Kifer et al. [2004] ; Ben-David et al. [2006] introduce some of the earliest learning theoretic approaches for identifying and correcting for covariate shift based on discriminative learning with finite samples. More recent approaches for covariate shift detection including classifier two sample tests [Lopez-Paz and Oquab, 2017] , deep kernel MMD [Liu et al., 2020] and H-Divergence [Zhao et al., 2022] rely on analyzing the outputs of unsupervised learning models (see Appendix subsection E.3 for more details). In our work we take a transductive learning approach and construct a method to directly use the structure of a supervised classification problem to improve the statistical power for detecting shifts. 



Code available at https://github.com/rgklab/detectron



F F I b e 7 H I D 9 w F r K W m W b W H p B 0 k q j N I / K X i h f 8 U b 0 6 2 M u e 1 A 4 O U 5 5 y U n b / y Y U S E t 6 1 v T D 0 r l w 0 q 1 Z h w d n 5 y e m e f 1 v o g S j k k P R y z i Q x 8 J w m h I e p J K R o Y x J y j w G R n 4 8 7 e 8 P / g g X N A o f J e L m L g B m o Z 0

Figure 1: Overview of the Detectron: Starting with a base classifier f trained on labeled samples from distribution P we train new Constrained Disagreement Classifiers (CDCs) g and g Q on a small sets of unseen samples from P as well as a unknown distribution Q. CDCs aim to maximize classification disagreement on unseen data while constrained to classify consistently with f on their original training set. The rate ϕ that CDCs disagree is a powerful and sample efficient statistic for identifying covariate shift P ̸ = Q.

Out of Distribution Detection. Out of distribution (OOD) detection focuses on identifying when a specific data point x ′ admits low likelihood under the original training distributions (p train (x ′ ) ≈ 0)a useful tool to have at inference time.Ren et al. [2019];Morningstar et al. [2021]  represent

