LEARNING EXPLANATIONS THAT ARE HARD TO VARY

Abstract

In this paper, we investigate the principle that good explanations are hard to vary in the context of deep learning. We show that averaging gradients across examples -akin to a logical OR (_) of patterns -can favor memorization and 'patchwork' solutions that sew together different strategies, instead of identifying invariances. To inspect this, we first formalize a notion of consistency for minima of the loss surface, which measures to what extent a minimum appears only when examples are pooled. We then propose and experimentally validate a simple alternative algorithm based on a logical AND (^), that focuses on invariances and prevents memorization in a set of real-world tasks. Finally, using a synthetic dataset with a clear distinction between invariant and spurious mechanisms, we dissect learning signals and compare this approach to well-established regularizers.

✓1

< l a t e x i t s h a 1 _ b a s e 6 4 = " S r H i + A l 9 v a u G K X s H g y l l Q X Q r r v k = " > A A A B 7 3 i c b V D L S g N B E O y N r x h f U Y 9 e B o P g K e y K o M e g F 4 8 R z A O S J c x O e p M h s w 9 n e o U Q 8 h N e P C j i 1 d / x 5 t 8 4 S f a g i Q U N R V U 3 3 V 1 B q q Q h 1 / 1 2 C m v r G 5 t b x e 3 S z u 7 e / k H 5 8 K h p k k w L b I h E J b o d c I N K x t g g S Q r b q U Y e B Q p b w e h 2 5 r e e U B u Z x A 8 0 T t G P + C C W o R S c r N T u 0 h C J 9 7 x e u e J W 3 T n Y K v F y U o E c 9 V 7 5 q 9 t P R B Z h T E J x Y z q e m 5 I / 4 Z q k U D g t d T O D K R c j P s C O p T G P 0 P i T + b 1 T d m a V P g s T b S s m N l d / T 0 x 4 Z M w 4 C m x n x G l o l r 2 Z + J / X y S i 8 9 i c y T j P C W C w W h Z l i l L D Z 8 6 w v N Q p S Y 0 u 4 0 N L e y s S Q a y 7 I R l S y I X j L L 6 + S 5 k X V c 6 v e / W W l d p P H U Y Q T O I V z 8 O A K a n A H d W i A A A X P 8 A p v z q P z 4 r w 7 H 4 v W g p P P H M M f O J 8 / z W e P z A = = < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " S r H Consider the top of Figure 1 , which shows a view from above of the loss surface obtained as we vary a two dimensional parameter vector θ " pθ 1 , θ 2 q, for a fictional dataset containing two observations x A and x B . Note the two global minima on the top-right and bottom-left. Depending on the initial values of θ -marked as white circles -gradient descent converges to one of the two minima. Judging solely by the value of the loss function, which is zero in both cases, the two minima look equally good. i + A l 9 v a u G K X s H g y l l Q X Q r r v k = " > A A A B 7 3 i c b V D L S g N B E O y N r x h f U Y 9 e B o P g K e y K o M e g F 4 8 R z A O S J c x O e p M h s w 9 n e o U Q 8 h N e P C j i 1 d / x 5 t 8 4 S f a g i Q U N R V U 3 3 V 1 B q q Q h 1 / 1 2 C m v g i Q U N R V U 3 3 V 1 B q q Q h 1 / 1 2 C m v O I V z 8 O A K a n A H d W i A A A X P 8 A p v z q P z 4 r w 7 H 4 v W g p P P H M M f O J 8 / z W e P z A = = < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " S r H i + A l 9 v a u G K X s H g y l l Q X Q r r v k = " > A A A B 7 3 i c b V D L S g N B E O y N r x h f U Y 9 e B o P g K e y K o M e g F 4 8 R z A O S J c x O e p M h s w 9 n e o U Q 8 h N e P C j i 1 d / x 5 t 8 4 S f a g i Q U N R V U 3 3 V 1 B q q Q h 1 / 1 2 C m v r G W 9 A 9 E k o s = " > A A A B 7 3 i c b V D L S g N B E O z 1 G e M r 6 t H L Y B A 8 h d 0 g 6 D H o x W M E 8 4 B k C b O T T j J k 9 u F M r x C W / I Q X D 4 p 4 9 X e 8 + T d O k j 1 o Y k F D U d V N d 1 e Q K G n I d b + d t f W N z a 3 t w k 5 x d 2 / / 4 L B 0 d N w 0 c a o F N k S s Y t 0 O u E E l I 2 y Q J I X t R C M P A 4 W t Y H w 7 8 1 t P q I 2 M o w e a J O i H f B j J g R S c r N T u 0 g i J 9 6 q 9 U t m t u H O w V e L l p A w 5 6 r 3 S V 7 c f i z T E i I T i x n Q 8 N y E / 4 5 q k U D g t d l O D C R d j P s S O p R E P 0 f j Z / N 4 p O 7 d K n w 1 i b S s i N l d / T 2 Q 8 N G Y S B However, looking at the loss surfaces for x A and x B separately, as shown below, a crucial difference between those two minima appears: Starting from the same initial parameter configurations and following the gradient of the loss, ∇ θ Lpθ, x i q, the probability of finding the same minimum on the top-right in either case is zero. In contrast, the minimum in the lower-left corner has a significant overlap across the two loss surfaces, so gradient descent can converge to it even if training on x A (or x B ) only. Note that after averaging there is no way to tell what the two loss surfaces looked like: Are we destroying information that is potentially important? In this paper, we argue that the answer is yes. In particular, we hypothesize that if the goal is to find invariant mechanisms in the data, these can be identified by finding explanations (e.g. model parameters) that are hard to vary across examples. A notion of invariance implies something that stays the same, as something else changes. We assume that data comes from different environments:  An



r G 5 t b x e 3 S z u 7 e / k H 5 8 K h p k k w L b I h E J b o d c I N K x t g g S Q r b q U Y e B Q p b w e h 2 5 r e e U B u Z x A 8 0 T t G P + C C W o R S c r N T u 0 h C J 9 7 x e u e J W 3 T n Y K v F y U o E c 9 V 7 5 q 9 t P R B Z h T E J x Y z q e m 5 I / 4 Z q k U D g t d T O D K R c j P s C O p T G P 0 P i T + b 1 T d m a V P g s T b S s m N l d / T 0 x 4 Z M w 4 C m x n x G l o l r 2 Z + J / X y S i 8 9 i c y T j P C W C w W h Z l i l L D Z 8 6 w v N Q p S Y 0 u 4 0 N L e y s S Q a y 7 I R l S y I X j L L 6 + S 5 k X V c 6 v e / W W l d p P H U Y Q T O I V z 8 O A K a n A H d W i A A A X P 8 A p v z q P z 4 r w 7 H 4 v W g p P P H M M f O J 8 / z W e P z A = = < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " S r H i + A l 9 v a u G K X s H g y l l Q X Q r r v k = " > A A A B 7 3 i c b V D L S g N B E O y N r x h f U Y 9 e B o P g K e y K o M e g F 4 8 R z A O S J c x O e p M h s w 9 n e o U Q 8 h N e P C j i 1 d / x 5 t 8 4 S f a

r G 5 t b x e 3 S z u 7 e / k H 5 8 K h p k k w L b I h E J b o d c I N K x t g g S Q r b q U Y e B Q p b w e h 2 5 r e e U B u Z x A 8 0 T t G P + C C W o R S c r N T u 0 h C J 9 7 x e u e J W 3 T n Y K v F y U o E c 9 V 7 5 q 9 t P R B Z h T E J x Y z q e m 5 I / 4 Z q k U D g t d T O D K R c j P s C O p T G P 0 P i T + b 1 T d m a V P g s T b S s m N l d / T 0 x 4 Z M w 4 C m x n x G l o l r 2 Z + J / X y S i 8 9 i c y T j P C W C w W h Z l i l L D Z 8 6 w v N Q p S Y 0 u 4 0 N L e y s S Q a y 7 I R l S y I X j L L 6 + S 5 k X V c 6 v e / W W l d p P H U Y Q T

5 t b x e 3 S z u 7 e / k H 5 8 K h p k k w L b I h E J b o d c I N K x t g g S Q r b q U Y e B Q p b w e h 2 5 r e e U B u Z x A 8 0 T t G P + C C W o R S c r N T u 0 h C J 9 7 x e u e J W 3 T n Y K v F y U o E c 9 V 7 5 q 9 t P R B Z h T E J x Y z q e m 5 I / 4 Z q k U D g t d T O D K R c j P s C O p T G P 0 P i T + b 1 T d m a V P g s T b S s m N l d / T 0 x 4 Z M w 4 C m x n x G l o l r 2 Z + J / X y S i 8 9 i c y T j P C W C w W h Z l i l L D Z 8 6 w v N Q p S Y 0 u 4 0 N L e y s S Q a y 7 I R l S y I X j L L 6 + S 5 k X V c 6 v e / W W l d p P H U Y Q T O I V z 8 O A K a n A H d W i A A A X P 8 A p v z q P z 4 r w 7 H 4 v W g p P P H M M f O J 8 / z W e P z A = = < / l a t e x i t > ✓2 < l a t e x i t s h a 1 _ b a s e 6 4 = " 7 / p S N J 5 6 + 9 M B Y f O P b W b W 9 A 9 E k o s = " > A A A B 7 3 i c b V D L S g N B E O z 1 G e M r 6 t H L Y B A 8 h d 0 g 6 D H o x W M E 8 4 B k C b O T T j J k 9 u F M r x C W / I Q X D 4 p 4 9 X e 8 + T dO k j 1 o Y k F D U d V N d 1 e Q K G n I d b + d t f W N z a 3 t w k 5 x d 2 / / 4 L B 0 d N w 0 c a o F N k S s Y t 0 O u E E l I 2 y Q J I X t R C M P A 4 W t Y H w 7 8 1 t P q I 2 M o w e a J O i H f B j J g R S c r N T u 0 g i J 9 6 q 9 U t m t u H O w V e L l p A w 5 6 r 3 S V 7 c f i z T E i I T i x n Q 8 N y E / 4 5 q k U D g t d l O D C R d j P s S O p R E P 0 f j Z / N 4 p O 7 d K n w 1 i b S s i N l d / T 2 Q 8 N G Y S Br Y z 5 D Q y y 9 5 M / M / r p D S 4 9 j M Z J S l h J B a L B q l i F L P Z 8 6 w v N Q p S E 0 u 4 0 N L e y s S I a y 7 I R l S 0 I X j L L 6 + S Z r X i u R X v / r J c u 8 n j K M A p n M E F e H A F N b i D O j R A g I J n e I U 3 5 9 F 5 c d 6 d j 0 X r m p P P n M A f O J 8 / z u u P z Q = = < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " 7 / p S N J 5 6 + 9 M B Y f O P b W b W 9 A 9 E k o s = " > A A A B 7 3 i c b V D L S g N B E O z 1 G e M r 6 t H L Y B A 8 h d 0 g 6 D H o x W M E 8 4 B k C b O T T j J k 9 u F M r x C W / I Q X D 4 p 4 9 X e 8 + T d O k j 1 o Y k F D U d V N d 1 e Q K G n I d b + d t f W N z a 3 t w k 5 x d 2 / / 4 L B 0 d N w 0 c a o F N k S s Y t 0 O u E E l I 2 y Q J I X t R C M P A 4 W t Y H w 7 8 1 t P q I 2 M o w e a J O i H f B j J g R S c r N T u 0 g i J 9 6 q 9 U t m t u H O w V e L l p A w 5 6 r 3 S V 7 c f i z T E i I T i x n Q 8 N y E / 4 5 q k U D g t d l O D C R d j P s S O p R E P 0 f j Z / N 4 p O 7 d K n w 1 i b S s i N l d / T 2 Q 8 N G Y S B r Y z 5 D Q y y 9 5 M / M / r p D S 4 9 j M Z J S l h J B a L B q l i F L P Z 8 6 w v N Q p S E 0 u 4 0 N L e y s S I a y 7 I R l S 0 I X j L L 6 + S Z r X i u R X v / r J c u 8 n j K M A p n M E F e H A F N b i D O j R A g I J n e I U 3 5 9 F 5 c d 6 d j 0 X r m p P P n M A f O J 8 / z u u P z Q = = < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " 7 / p S N J 5 6 + 9 M B Y f O P b W b W 9 A 9 E k o s = " > A A A B 7 3 i c b V D L S g N B E O z 1 G e M r 6 t H L Y B A 8 h d 0 g 6 D H o x W M E 8 4 B k C b O T T j J k 9 u F M r x C W / I Q X D 4 p 4 9 X e 8 + T d O k j 1 o Y k F D U d V N d 1 e Q K G n I d b + d t f W N z a 3 t w k 5 x d 2 / / 4 L B 0 d N w 0 c a o F N k S s Y t 0 O u E E l I 2 y Q J I X t R C M P A 4 W t Y H w 7 8 1 t P q I 2 M o w e a J O i H f B j J g R S c r N T u 0 g i J 9 6 q 9 U t m t u H O w V e L l p A w 5 6 r 3 S V 7 c f i z T E i I T i x n Q 8 N y E / 4 5 q k U D g t d l O D C R d j P s S O p R E P 0 f j Z / N 4 p O 7 d K n w 1 i b S s i N l d / T 2 Q 8 N G Y S B r Y z 5 D Q y y 9 5 M / M / r p D S 4 9 j M Z J S l h J B a L B q l i F L P Z 8 6 w v N Q p S E 0 u 4 0 N L e y s S I a y 7 I R l S 0 I X j L L 6 + S Z r X i u R X v / r J c u 8 n j K M A p n M E F e H A F N b i D O j R A g I J n e I U 3 5 9 F 5 c d 6 d j 0 X r m p P P n M A f O J 8 / z u u P z Q = = < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " 7 / p S N J 5 6 + 9 M B Y f O P b W b

Figure 1: Loss landscapes of a two-parameter model. Averaging gradients forgoes information that can identify patterns shared across different environments.

invariant mechanism is shared across all, generalizes out of distribution(o.o.d.), but might be hard to model; each environment also has spurious explanations that are easy to spot ('shortcuts'), but do not generalize o.o.d. From the point of view of causal modeling, such invariant mechanisms can be interpreted as conditional distributions of the targets given causal features of the inputs; invariance of such conditionals is expected if they represent causal mechanisms, that is -stable properties of the physical world (see e.g.Hoover (1990)). Generalizing o.o.d. means therefore that the predictor should perform equally well on data coming from different settings, as long as they share the causal mechanisms.

