EM-NETWORK: LEARNING BETTER LATENT VARI-ABLE FOR SEQUENCE-TO-SEQUENCE MODELS

Abstract

In a sequence-to-sequence (seq2seq) framework, the use of an unobserved latent variable, such as latent alignment and representation, is important to address the mismatch problem between the source input and target output sequences. Existing seq2seq literature typically learns the latent space by only consuming the source input, which might produce a sub-optimal latent variable for predicting the target. In this paper, we introduce EM-Network that can yield the promising latent variable by leveraging the target sequence as the model's additional training input. The target input is used as guidance to provide the target-side context and reduce the candidates of the latent variable. The proposed framework is trained in a new self-distillation setup, allowing the original sequence model to benefit from the latent variable of the EM-Network. Specifically, the EM-Network's prediction serves as a soft label for training the inner sequence model, which only takes the source as input. We conduct comprehensive experiments on two types of seq2seq models: connectionist temporal classification (CTC) for speech recognition and attention-based encoder-decoder (AED) for machine translation. Experimental results demonstrate that the EM-Network significantly advances the current stateof-the-art approaches. It improves over the best prior work on speech recognition and establishes state-of-the-art performance on WMT'14 and IWSLT'14 datasets.

1. INTRODUCTION

Throughout the literature on deep learning, sequence-to-sequence (seq2seq) learning has achieved great success in a wide range of applications, especially in speech and natural language processing. Given a source-target pair (x,y), a task of the seq2seq learning is to learn a function for mapping a source sequence x to a target sequence y, which generally suffers from source-target mismatch problems, e.g., unequal length, different domain, and modality mismatch. To deal with this issue, learning the latent variable z and how to improve its quality are deemed critically important in sequence modeling. For example, in automatic speech recognition (ASR), the connectionist temporal classification (CTC) (Graves et al., 2006) model defines the latent alignment z to learn the mapping between the speech feature x and the word sequence y, as shown in Figure 1 . For the counterpart speech synthesis task, some studies (Kim et al., 2020; 2021) have proposed learning an internal alignment z between the text x and the corresponding speech y. In natural language processing (NLP), BERT family models (Devlin et al., 2019; Liu et al., 2019; Lan et al., 2020) employ masked language modeling (MLM) to learn the contextualized representations z, where the randomly masked tokens are predicted given the context of other tokens. From the perspective of z, the self-supervised learning (SSL)-based pre-training enables the model to obtain the robust latent representation z from the input x, which offers the desired performance for predicting the target y on the downstream task. However, it is difficult to learn the optimal latent variable for the learning task. For example, in the case of ASR, CTC models often converge to sub-optimal alignment distributions and produce over-confident predictions (Liu et al., 2018; Yu et al., 2021; Miao et al., 2015) . Since there is an exponential number of possible alignment paths, and the alignment information between source and target sequences is rarely available during training, settling on the optimal alignment is quite challenging. From the feature perspective, powerful representation is important to achieve the desired performance. For the recent NLP studies, including machine translation (MT), a large pre-trained LM is highly required to obtain the contextualized representation (Zhu et al., 2020; Wu et al., 2021; Xu et al., 2021) . In this paper, we propose a novel framework termed EM-Network that can effectively improve the quality of the latent variable z and thus the overall quality of the seq2seq model. In particular, EM-Network encapsulates two key components. First, the proposed framework leverages the target sequence as the model's additional training input, where the target input y is used as guidance to capture the target-side context and reduce the candidates of z. Second, based on the usage of the target input, we present a new self-distillation strategy as collaborative training, where the original sequence model can benefit from the EM-Network's knowledge in a one-stage manner. The prediction of the EM-Network serves as a soft label for training the inner sequence model, which consumes only the source input. The proposed self-distillation acts as a sort of regularization for the seq2seq model, and its performance gain comes from a deep mutual learning scheme (Zhang et al., 2018) , where the students learn collaboratively and teach each other. However, the main difference from the previous mutual learning approaches (Zhang et al., 2018; Liang et al., 2021) is that the proposed method utilizes the target input when training the EM-Network instead of merely considering the ground truth as the sole target in training. Since the target input is used as guidance to provide the target-side information, the prediction of the EM-Network (teacher mode) is more accurate than that of the inner sequence model (student mode). Therefore, the sequence model can effectively benefit from the soft labels of the EM-Network, which will be additionally discussed in Section 6. In addition, we attempt to apply the proposed self-distillation to the CTC framework, an unexplored area in mutual learning research. Modeling the conditional probability of the EM-Network (teacher mode) is determined by whether the latent variable is explicitly defined, as shown in Figure 1 . The CTC computation adopts the alignment z, and it is difficult to settle on the optimal alignment with the conventional framework. The proposed EM-Network computes the posterior P (z|x, y) for the loss, which aims at predicting a better CTC alignment z by leveraging the source and target inputs. Therefore, the CTC model distilled from the EM-Network does not have to consider the exponential number of possible CTC alignments. We theoretically show that the proposed objective function can serves as the proposed Q-function and is justified from the EM-like algorithm perspective. For the attention-based encoder-decoder (AED), where there is no explicit latent alignment, it is challenging to directly apply the same training scheme as the EM-Network for CTC. Simply taking the target y as the additional input may cause an obvious but trivial solution, where the model converges with the conditional probability P (y|x, y) = δ(y). Inspired by the MLM, we present an alternative that employs the masked version of target ỹ as the additional input instead of using the whole target y. The EM-Network for AED computes the posterior P (ỹ|x, y) for loss and provides more robust contextualized representations that can benefit the learning task. We conduct comprehensive experiments on multiple benchmarks, including ASR and MT tasks. The CTC and AED models are considered for ASR and MT tasks, respectively. Experimental results demonstrate that the EM-Network improves over the best prior work on ASR and establishes SOTA performance on WMT'14 and IWSLT'14 datasets.



Figure 1: (Left) Conventional sequence model converts the source x into the target y through the latent variable z. (Middel) EM-Network can estimate the promising latent variable z by using the additional training input y. For the AED model, the masked version of the target ỹ is used instead of the whole target y. (Right) EM-Network is trained with a self-distillation setup, where EM-Network's predcition serves as soft labels for training the original sequence model.

