NEAREST NEIGHBOR MACHINE TRANSLATION

Abstract

We introduce k-nearest-neighbor machine translation (kNN-MT), which predicts tokens with a nearest neighbor classifier over a large datastore of cached examples, using representations from a neural translation model for similarity search. This approach requires no additional training and scales to give the decoder direct access to billions of examples at test time, resulting in a highly expressive model that consistently improves performance across many settings. Simply adding nearest neighbor search improves a state-of-the-art German-English translation model by 1.5 BLEU. kNN-MT allows a single model to be adapted to diverse domains by using a domain-specific datastore, improving results by an average of 9.2 BLEU over zero-shot transfer, and achieving new state-of-the-art results-without training on these domains. A massively multilingual model can also be specialized for particular language pairs, with improvements of 3 BLEU for translating from English into German and Chinese. Qualitatively, kNN-MT is easily interpretable; it combines source and target context to retrieve highly relevant examples.

1. INTRODUCTION

Non-parametric methods have recently been successfully applied to tasks such as language modeling (Khandelwal et al., 2020) and question answering (Guu et al., 2020; Lewis et al., 2020) . They allow models that are (1) expressive, because they can use an arbitrary amount of data at test time; (2) adaptable, because predictions can be controlled by changing the datastore, and (3) interpretable, because the data used to make the prediction can be directly inspected. We introduce kNN-MT, a simple non-parametric method for machine translation (MT) using nearest neighbor retrieval. kNN-MT can be added to any pre-trained neural translation model without further training, and significantly improves performance for in-domain, out-of-domain, and multi-lingual evaluations. More specifically, kNN-MT interpolates the target-token softmax distribution from a neural MT model with a multinomial generated using nearest neighbor search over examples cached in a data store. The cache is over translation contexts (i.e. the complete source and prefix of the target), and is indexed by hidden states computed from the base MT model. We hypothesize that contexts which are close in representation space are more likely to be followed by the same target word. We show this is not only true for the original training data, thereby improving base model performance, but across a range of different bi-text corpora, allowing for simple and effective model adaptation. Our work builds upon recent results showing the effectiveness of nearest neighbor methods in unconditional language models (Khandelwal et al., 2020) . We generalize to conditional language models, by using both source and target context, and show nearest neighbour models can be effective for generation in addition to density estimation. Compared to prior work on non-parametric methods for MT, our approach is arguably simpler (in that it requires no training, as compared to Gu et al. (2018) ) and more expressive (in that it provides access to billions of key-value pairs during inference, as compared to Zhang et al. (2018); Gu et al. (2018) ). Extensive experiments show that kNN-MT scales to datastores containing billions of tokens, improving results across a range of settings. For example, it improves a state-of-the-art German-English translation model by 1.5 BLEU. kNN-MT can also be used to adapt a single model to  (s (n) , t (n) i 1 ) < l a t e x i t s h a 1 _ b a s e 6 4 = " N 4 4 / 1 H z V G L P q D f D O h K x 9 W Z H V f S k = " > A A A C A X i c b Z D L S g M x F I Y z 9 V b r b d S N 4 C Z Y h B a 0 z F R B l 0 U 3 L i v Y C 7 R j y a R p G 5 r J D M k Z o Q z j x l d x 4 0 I R t 7 6 F O 9 / G 9 L L Q 1 h 8 C H / 8 5 h 5 P z + 5 H g G h z n 2 8 o s L a + s r m X X c x u b W 9 s 7 9 u 5 e X Y e x o q x G Q x G q p k 8 0 E 1 y y G n A Q r B k p R g J f s I Y / v B 7 X G w 9 M a R 7 K O x h F z A t I X / I e p w S M 1 b E P C v o + K c h i e o J h C p 2 E n 7 p p s W P n n Z I z E V 4 E d w Z 5 N F O 1 Y 3 + 1 u y G N A y a B C q J 1 y 3 U i 8 B K i g F P B 0 l w 7 1 i w i d E j 6 r G V Q k o B p L 5 l c k O J j 4 3 R x L 1 T m S c A T 9 / d E Q g K t R 4 F v O g M C A z 1 f G 5 v / 1 V o x 9 C 6 9 h M s o B i b p d F E v F h h C P I 4 D d 7 l i F M T I A K G K m 7 9 i O i C K U D C h 5 U w I 7 v z J i 1 A v l 9 y z U v n 2 P F + 5 m s W R R Y f o C B W Q i y 5 Q B d 2 g K q o h i h 7 R M 3 p F b 9 a T 9 W K 9 W x / T 1 o w 1 m 9 l H f 2 R 9 / g A o x 5 V r < / l a t e x i t > dj = d(kj, q) < l a t e x i t s h a 1 _ b a s e 6 4 = " n 9 h x < l a t e x i t s h a 1 _ b a s e 6 4 = " h L + F a L t O T 9 l u w f L W 3 U t 0 8 x l 3 P c w i n B O y Q c U d y I B 3 g + 2 1 R / M w j p 0 = " > A A A B + H i c b V D L S g N B E O y N r x g f W f X o Z T A I E S T s R k E v Q t C L x w j m A U l Y Z m d n k 0 l m H 8 7 M C n H J l 3 j x o I h X P 8 W b f + M k 2 Y M m F j Q U V d 1 0 d 7 k x Z 1 J Z 1 r e R W 1 l d W 9 / I b x a 2 t n d 2 i + b e f l N G i S C 0 Q S I e i b a L J e U s p A 3 F F K f t W F A c u J y 2 3 N H N 1 G 8 9 U i F Z F N 6 r c U x 7 A e 6 H z G c E K y 0 5 Z t F z h u g K e e W R M z x F D y e O W b I q 1 g x o m d g Z K U G G u m N + d b 2 I J A E N F e F Y y o 5 t x a q X Y q E Y 4 X R S 6 C a S x p i M c J 9 2 N A 1 x Q G U v n R 0 + Q c d a 8 Z A f C V 2 h Q j P 1 9 0 S K A y n H g a s 7 A 6 w G c t G b i v 9 5 n U T 5 l 7 2 U h X G i a E j m i / y E I x W h a Q r I Y 4 I S x c e a Y C K Y v h W R A R a Y K J 1 V Q Y d g L 7 6 8 T J r V i n 1 W q d 6 d l 2 r X W R x 5 O I Q j K I M N F 1 C D W 6 h D A w g k 8 A y v 8 G Y 8 G S / G u / E x b 8 0 Z 2 c w B / I H x + Q N X 6 J G V < / l a t e x i t > d 0 j = dj/T < l a t e x i t s h a 1 _ b a s e 6 4 = " B Q h 5 h 7 v m N 8 Q / Z D 3 c / Z e Y 5 C y H q y w = " > A A A B / X i c b V D L S g M x F M 3 U V 6 2 v 8 b F z E y y C q z p T B d 0 I R T c u K / Q F 7 T h k M m m b N s k M S U a o Q / F X 3 L h Q x K 3 / 4 c 6 / M W 1 n o a 0 H L h z O u Z d 7 7 w l i R p V 2 n G 8 r t 7 S 8 s r q W X y 9 s b G 5 t 7 9 i 7 e w 0 V J R K T O o 5 Y J F s B U o R R Q e q a a k Z a s S S I B 4 w 0 g + H N x G 8 + E K l o J G p 6 F B O P o 5 6 g X Y q R N p J v H 4 T 3 a S e W l J O x P 4 B X M P Q H p z X f L j o l Z w q 4 S N y M F E G G q m 9 / d c I I J 5 w I j R l S q u 0 6 s f Z S J D X F j I w L n U S R G O E h 6 p G 2 o Q J x o r x 0 e v 0 Y H h s l h N 1 I m h I a T t X f E y n i S o 1 4 Y D o 5 0 n 0 1 7 0 3 E / 7 x 2 o r u X X k p F n G g i 8 G x R N 2 F Q R 3 A S B Q y p J F i z k S E I S 2 p u h b i P J M L a B F Y w I b j z L y + S R r n k n p X K d + f F y n U W R x 4 c g i N w A l x w A S r g F l R B H W D w C J 7 B K 3 i z n q w X 6 9 3 6 m L X m r G x m H / y B 9 f k D j N 2 U o A = = < / l a t e x i t > pkNN(yi) = X j 1y i =vj p(kj) < l a t e x i t s h a 1 _ b a s e 6 4 = " 5 E F p d x i e f N C i 9 k y l 2 k T V 4 7 9 + d v k = " > A A A C G H i c b V D J S g N B E O 1 x j X E b 9 e i l M Q j x E m e i o J d A 0 I s n U T A m k A l N T 6 e j n e l Z 6 K 4 J h m E + w 4 u / 4 s W D I l 5 z 8 2 / s L A e 3 B w W P 9 6 q o q u c n U m h w n E 9 r b n 5 h c W m 5 s F J c X V v f 2 L S 3 t m 9 1 n C r G G y y W s W r 5 V H M p I t 4 A A Z K 3 E s V p 6 E v e 9 I P z s d 8 c c K V F H N 3 A M O G d k N 5 F o i c Y B S M R + z A h m Q f 8 A b L g 8 j L P y 0 M i D n A N e z o N S T 9 z S W a E 2 o D 0 8 6 Q c k P 5 B T u y S U 3 E m w H + J O y M l N M M V s U d e N 2 Z p y C N g k m r d d p 0 E O h l V I J j k e d F L N U 8 o C + g d b x s a 0 Z D r T j Z 5 L M f 7 R u n i X q x M R Y A n 6 v e J j I Z a D 0 P f d I Y U 7 v V v b y z + 5 7 V T 6 J 1 2 M h E l K f C I T R f 1 U o k h x u O U c F c o z k A O D a F M C X M r Z v d U U Q Y m y 6 I J w f 3 9 8 l 9 y W 6 2 4 R 5 X q 9 X G p f j a L o 4 B 2 0 R 4 q I x e d o D q 6 Q F e o g R h 6 R M / o F b 1 Z T 9 a L 9 W 5 9 T F v n r N n M D v o B a / Q F r f u g J w = = < / l a t = " > A A A B 6 H i c b V D L T g J B E O z F F + I L 9 e h l I j H x R H b R R I 9 E L x 4 h k U c C G z I 7 9 M L I 7 O x m Z t Z I C F / g x Y P G e P W T v P k 3 D r A H B S v p p F L V n e 6 u I B F c G 9 f 9 d n J r 6 x u b W / n t w s 7 u 3 v 5 B 8 f C o q e N U M W y w W M S q H V C N g k t s G G 4 E t h O F N A o E t o L R 7 c x v P a L S P J b 3 Z p y g H 9 G B 5 C F n 1 F i p / t Q r l t y y O w d Z J V 5 G S p C h 1 i t + d f s x S y O U h g m q d c d z E + N P q D K c C Z w W u q n G h L I R H W D H U k k j 1 P 5 k f u i U n F m l T 8 J Y 2 Z K G z N X f E x M a a T 2 O A t s Z U T P U y 9 5 M / M / r p C a 8 9 i d c J q l B y R a L w l Q Q E 5 P Z 1 6 T P F T I j x p Z Q p r i 9 l b A h V Z Q Z m 0 3 B h u A t v 7 x K m p W y d 1 G u 1 C 9 L 1 Z s s j j y c w C m c g w d X U I U 7 q E E D G C A 8 w y u 8 O Q / O i / P u f C x a c 0 4 2 c w x / 4 H z + A O e H j Q A = < / l a t e x i t > ŷ1:i 1 < l a t e x i t s h a 1 _ b a s e 6 4 = " P O 6 1 X i X 9 N 1 V 9 D y H 3 0 0 K f i y U b z I E = " > A A A B + H i c b V B N S 8 N A E J 3 4 W e t H o x 6 9 L B b B i y W p g u K p 6 M V j B f s B b Q i b 7 a Z d u t m E 3 Y 1 Q Q 3 6 J F w + K e P W n e P P f u G 1 z 0 N Y H A 4 / 3 Z p i Z F y S c K e 0 4 3 9 b K 6 t r 6 x m Z p q 7 y 9 s 7 t X s f c P 2 i p O J a E t E v N Y d g O s K G e C t j T T n H Y T S X E U c N o J x r d T v / N I p W K x e N C T h H o R H g o W M o K 1 k X y 7 0 h 9 h n U 1 y P 3 O v 2 Z m b + 3 b V q T k z o G X i F q Q K B Z q + / d U f x C S N q N C E Y 6 V 6 r p N o L 8 N S M 8 J p X u 6 n i i a Y j P G Q 9 g w V O K L K y 2 a H 5 + j E K A M U x t K U 0 G i m / p 7 I c K T U J A p M Z 4 T 1 S C 1 6 U / E / r 5 f q 8 M r L m E h S T Q W Z L w p T j n S M p i m g A Z O U a D 4 x B B P J z K 2 I j L D E R J u s y i Y E d / H l Z d K u 1 9 z z W v 3 + o t q 4 K e I o w R E c w y m 4 c A k N u I M m t I B A C s / w C m / W k / V i v V s f 8 9 Y V q 5 g 5 h D + w P n 8 A S + i S 2 A = = < / l a t e x i t > q = f (x, ŷ1:i 1) < l a t e x i t s h a 1 _ b a s e 6 4 = " n y c B G o i r j x X g W v 7 6 r 6 2 g z x M R 4 T U = " > A A A C A n i c b V D L S s N A F J 3 U V 6 2 v q C t x M 1 i E C l q S K i i C U H T j s o J 9 Q B v K Z D p p h 0 4 m c W Y i h h D c + C t u X C j i 1 q 9 w 5 9 8 4 b b P Q 1 g M X D u f c y 7 3 3 u C G j U l n W t 5 G b m 1 9 Y X M o v F 1 Z W 1 9 Y 3 z M 2 t h g w i g U k d B y w Q L R d J w i g n d U U V I 6 1 Q E O S 7 j D T d 4 d X I b 9 4 T I W n A b 1 U c E s d H f U 4 9 i p H S U t f c u Y M X 0 C s 9 H M L O A K k k T r u J f U 6 P 7 P S g a x a t s j U G n C V 2 R o o g Q 6 1 r f n V 6 A Y 5 8 w h V m S M q 2 b Y X K S Z B Q F D O S F j q R J C H C Q 9 Q n b U 0 5 8 o l 0 k v E L K d z X S g 9 6 g d D F F R y r v y c S 5 E s Z + 6 7 u 9 J E a y G l v J P 7 n t S P l n T k J 5 W G k C M e T R V 7 E o A r g K A / Y o 4 J g x W J N E B Z U 3 w r x A A m E l U 6 t o E O w p 1 + e J Y 1 K 2 T 4 u V 2 5 O i t X L L I 4 8 2 A V 7 o A R s c A q q 4 B r U Q B 1 g 8 A i e w S t 4 M 5 6 M F + P d + J i 0 5 o x s Z h v 8 g f H 5 A 5 o c l a U = < / l a t e x i t > yi < l a t e x i t s h a 1 _ b a s e 6 4 = " K r 9 z K i A S c f d 9 h 9 A H I + C + F 2 n C G 1 0 = " > A A A B 6 n i c b V B N S 8 N A E J 3 U r 1 q / q h 6 9 L B b B U 0 m q o M e i F 4 8 V 7 Q e 0 o W y 2 k 3 b p Z h N 2 N 0 I o / Q l e P C j i 1 V / k z X / j t s 1 B W x 8 M P N 6 b Y W Z e k A i u j e t + O 4 W 1 9 Y 3 N r e J 2 a W d 3 b / + g f H j U 0 n G q G D Z Z L G L V C a h G w S U 2 D T c C O 4 l C G g U C 2 8 H 4 d u a 3 n 1 B p H s t H k y X o R 3 Q o e c g Z N V Z 6 y P q 8 X 6 6 4 V X c O s k q 8 n F Q g R 6 N f / u o N Y p Z G K A 0 T V O u u 5 y b G n 1 B l O B M 4 L f V S j Q l l Y z r E r q W S R q j 9 y f z U K T m z y o C E s b I l D Z m r v y c m N N I 6 i w L b G V E z 0 s v e T P z P 6 6 Y m v P Y n X C a p Q c k W i 8 J U E B O T 2 d 9 k w B U y I z J L K F P c 3 k r Y i C r K j E 2 n Z E P w l l 9 e J a 1 a 1 b u o 1 u 4 v K / W b P I 4 i n M A p n I M H V 1 C H O 2 h A E x g M 4 R l e 4 c 0 R z o v z 7 n w s W g t O P n M M f + B 8 / g B j Z I 3 d < / l a t e x i t > kj = f (s (n) , t (n) i 1 ) < l a t e x i t s h a 1 _ b a s e 6 4 = " i 5 4 T y R 1 b I L x v s / A Z 7 t v M B a S E e A I = " > A A A C C H i c b Z D L S s N A F I Y n 9 V b r L e r S h Y N F a E F L U g X d C E U 3 L i v Y C 7 Q x T K a T d u x k E m Y m Q g l Z u v F V 3 L h Q x K 2 P 4 M 6 3 c d p m o a 0 / D H z 8 5 x z O n N + L G J X K s r 6 N 3 M L i 0 v J K f r W w t r 6 x u W V u 7 z R l G A t M G j h k o W h 7 S B J G O W k o q h h p R 4 K g w G O k 5 Q 2 v x v X W A x G S h v x W j S L i B K j P q U 8 x U t p y z f 2 h e w 8 v o F + S d 0 m J l 9 M j q K b g J v T Y T s u u W b Q q 1 k R w H u w M i i B T 3 T W / u r 0 Q x w H h C j M k Z c e 2 I u U k S C i K G U k L 3 V i S C O E h 6 p O O R o 4 C I p 1 k c k g K D 7 X T g 3 4 o 9 O M K T t z f E w k K p B w F n u 4 M k B r I 2 d r Y / K / W i Z V / 7 i S U R 7 E i H E 8 X + T G D K o T j V G C P C o I V G 2 l A W F D 9 V 4 g H S C C s d H Y F H Y I 9 e / I 8 N K s V + 6 R S v T k t 1 i 6 z O P J g D x y A E r D B G a i B a 1 A H D Y D B I 3 g G r + D N e D J e j H f j Y 9 q a M 7 K Z X f B H x u c P k Y e X y A = = < / l a t e x i t > vj = t (n) i < l a t e x i t s h a 1 _ b a s e 6 4 = " K a Q S b V G p l d T + 7 i T 8 e l 9 x H M O 2 T 9 k = " > A A A B + H i c b V B N S 8 N A E N 3 4 W e t H o x 6 9 L B a h X k p S B b 0 I R S 8 e K 9 g P a G P Y b L f t 2 s 0 m 7 E 4 K N f S X e P G g i F d / i j f / j d s 2 B 2 1 9 M P B 4 b 4 a Z e U E s u A b H + b Z W V t f W N z Z z W / n t n d 2 9 g r 1 / 0 N B R o i i r 0 0 h E q h U Q z Q S X r A 4 c B G v F i p E w E K w Z D G + m f n P E l O a R v I d x z L y Q 9 C X v c U r A S L 5 d G P m P + A r D Q 1 q S p x O f + 3 b R K T s z 4 G X i Z q S I M t R 8 + 6 v T j W g S M g l U E K 3 b r h O D l x I F n A o 2 y X c S z W J C h 6 T P 2 o Z K E j L t p b P D J / j E K F 3 c i 5 Q p C X i m / p 5 I S a j 1 O A x M Z 0 h g o B e 9 q f i f 1 0 6 g d + m l X M Y J M E n n i 3 q J w B D h a Q q 4 y x W j I M a G E K q 4 u R X T A V G E g s k q b 0 J w F 1 9 e J o 1 K 2 T 0 r V + 7 O i 9 X r L I 4 c O k L H q I R c d I G q 6 B b V U B 1 R l K B n 9 I r e r C f r x X q 3 P u a t K 1 Y 2 c 4 j + w P r 8 A X l d k l I = < / l a t e x i t > p(kj) / exp( d 0 j ) < l a t e x i t s h a 1 _ b a s e 6 4 = " y y f i h f Y q x x 1 e A V 9 b V k 2 c u Y r T R z M = " > A A A C A n i c b V D L S s N A F J 3 U V 6 2 v q C t x M 1 j E d m F J q q D L o h u X F e w D 2 h A m k 0 k 7 7 S Q Z Z i Z i C c W N v + L G h S J u / Q p 3 / o 3 T N g t t P T B w O O d e 7 p z j c U a l s q x v I 7 e 0 v L K 6 l l 8 v b G x u b e + Y u 3 t N G S c C k w a O W S z a H p K E 0 Y g 0 F F W M t L k g K P Q Y a X n D 6 4 n f u i d C 0 j i 6 U y N O n B D 1 I h p Q j J S W X P O A l 4 b u o A y 7 X M R c x Z A 8 8 N K p 7 w 5 O y q 5 Z t C r W F H C R 2 B k p g g x 1 1 / z q + j F O Q h I p z J C U H d v i y k m R U B Q z M i 5 0 E 0 k 4 w k P U I x 1 N I x Q S 6 a T T C G N 4 r B U f B r H Q L 1 J w q v 7 e S F E o 5 S j 0 9 G S I V F / O e x P x P 6 + T q O D S S W n E E 0 U i P D s U J A z q r J M + o E 8 F w Y q N N E F Y U P 1 X i P t I I K x 0 a w V d g j 0 f e Z E 0 q x X 7 r F K 9 P S / W r r I 6 8 u A Q H I E S s M E F q I E b U A c N g M E j e A a v 4 M 1 4 M l 6 M d + N j N p o z s p 1 9 8 A f G 5 w 8 p 6 J X 9 < / l a t e x i t > Figure 1 : An illustration of how the kNN distribution is computed. The datastore, which is constructed offline, consists of representations of training set translation contexts and corresponding target tokens for every example in the parallel data. During generation, the query representation, conditioned on the test input as well as previously generated tokens, is used to retrieve the k nearest neighbors from the datastore, along with the corresponding target tokens. The distance from the query is used to compute a distribution over the retrieved targets after applying a softmax temperature. This distribution is the final kNN distribution. diverse domains by simply adding a domain-specific datastore-improving results by an average of 9.2 BLEU over the base model out-of-domain, and even outperforming existing models that train on these domains. Finally, language-pair-specific datastores are used to adapt a multilingual model to particular language pairs, with improvements of 3 BLEU for translating English into German and Chinese. We find that retrievals from kNN-MT are typically highly contextually relevant.

2. NEAREST NEIGHBOR MACHINE TRANSLATION

kNN-MT involves augmenting the decoder of a pre-trained machine translation model with a nearest neighbor retrieval mechanism, allowing the model direct access to a datastore of cached examples. The translation is generated word-by-word; at each time step, we find the most similar contexts in the datastore, and compute a distribution over the corresponding target tokens, as shown in Figure 1 . This distribution is then interpolated with the output distribution from the pre-trained MT model. More specifically, given an input sequence of tokens in a source language s = (s 1 , . . . , s M1 ), a neural MT model outputs a sequence of tokens t = (t 1 , . . . , t M2 ) in the target language. When using autoregressive decoders, the output distribution for each token t i in the target sequence is conditioned on the entire source sequence as well as the previous target tokens, p(t i |s, t 1:i-1 ). Let (s, t 1:i-1 ) be the translation context and t i be the target token. Datastore creation Our datastore is constructed offline and consists of a set of key-value pairs. The key is a high-dimensional representation of the entire translation context computed by the MT decoder, f (s, t 1:i-1 ), where f represents a mapping from input to an intermediate representation of the decoder. The value is the corresponding ground truth target token t i . For a parallel text collection (S, T ), the representations are generated by a single forward pass over each example and the complete datastore is defined as follows: (K, V) = {(f (s, t 1:i-1 ), t i ), ∀t i ∈ t | (s, t) ∈ (S, T )} (1) Tokens from the source language are not stored directly as values in the datastore. Conditioning on the source is implicit via the keys, and the values are only target language tokens. Generation At test time, given a source x, the model outputs a distribution over the vocabulary p M T (y i |x, ŷ1:i-1 ) for the target y i at every step of generation, where ŷ represents the generated tokens. The model also outputs the representation f (x, ŷ1:i-1 ), which is used to query the datastore for the k nearest neighbors N according to squared-L 2 distance, d. In practice, the search over billions of key-value pairs is carried out using FAISS (Johnson et al., 2017) , a library for fast nearest neighbor search in high-dimensional spaces. The retrieved set is converted into a distribution over the vocabulary by applying a softmax with temperature T to the negative distances and aggregating over multiple occurrences of the same vo-

