ENERGY-BASED TEST SAMPLE ADAPTATION FOR DOMAIN GENERALIZATION

Abstract

In this paper, we propose energy-based sample adaptation at test time for domain generalization. Where previous works adapt their models to target domains, we adapt the unseen target samples to source-trained models. To this end, we design a discriminative energy-based model, which is trained on source domains to jointly model the conditional distribution for classification and data distribution for sample adaptation. The model is optimized to simultaneously learn a classifier and an energy function. To adapt target samples to source distributions, we iteratively update the samples by energy minimization with stochastic gradient Langevin dynamics. Moreover, to preserve the categorical information in the sample during adaptation, we introduce a categorical latent variable into the energy-based model. The latent variable is learned from the original sample before adaptation by variational inference and fixed as a condition to guide the sample update. Experiments on six benchmarks for classification of images and microblog threads demonstrate the effectiveness of our proposal.

1. INTRODUCTION

Deep neural networks are vulnerable to domain shifts and suffer from lack of generalization on test samples that do not resemble the ones in the training distribution (Recht et al., 2019; Zhou et al., 2021; Krueger et al., 2021; Shen et al., 2022) . To deal with the domain shifts, domain generalization has been proposed (Muandet et al., 2013; Gulrajani & Lopez-Paz, 2020; Cha et al., 2021) . Domain generalization strives to learn a model exclusively on source domains in order to generalize well on unseen target domains. The major challenge stems from the large domain shifts and the unavailability of any target domain data during training. To address the problem, domain invariant learning has been widely studied, e.g., (Motiian et al., 2017; Zhao et al., 2020; Nguyen et al., 2021) , based on the assumption that invariant representations obtained on source domains are also valid for unseen target domains. However, since the target data is inaccessible during training, it is likely an "adaptivity gap" (Dubey et al., 2021) exists between representations from the source and target domains. Therefore, recent works try to adapt the classification model with target samples at test time by further fine-tuning model parameters (Sun et al., 2020; Wang et al., 2021) or by introducing an extra network module for adaptation (Dubey et al., 2021) . Rather than adapting the model to target domains, Xiao et al. (2022) adapt the classifier for each sample at test time. Nevertheless, a single sample would not be able to adjust the whole model due to the large number of model parameters and the limited information contained in the sample. This makes it challenging for their method to handle large domain gaps. Instead, we propose to adapt each target sample to the source distributions, which does not require any fine-tuning or parameter updates of the source model. In this paper, we propose energy-based test sample adaptation for domain generalization. The method is motivated by the fact that energy-based models (Hinton, 2002; LeCun et al., 2006) flexibly model complex data distributions and allow for efficient sampling from the modeled distribution by Langevin dynamics (Du & Mordatch, 2019; Welling & Teh, 2011) . Specifically, we define a new discriminative energy-based model as the composition of a classifier and a neural-network-based energy function in the data space, which are trained simultaneously on the source domains. The trained model iteratively updates the representation of each target sample by gradient descent of energy minimization through Langevin dynamics, which eventually adapts the sample to the source data distribution. The adapted target samples are then predicted by the classifier that is simultaneously trained in the discriminative energy-based model. For both efficient energy minimization and classification, we deploy the energy functions on the input feature space rather than the raw images. Since Langevin dynamics tends to draw samples randomly from the distribution modeled by the energy function, it cannot guarantee category equivalence. To maintain the category information of the target samples during adaptation and promote better classification performance, we further introduce a categorical latent variable in our energy-based model. Our model learns the latent variable to explicitly carry categorical information by variational inference in the classification model. We utilize the latent variable as conditional categorical attributes like in compositional generation (Du et al., 2020a; Nie et al., 2021) to guide the sample adaptation to preserve the categorical information of the original sample. At inference time, we simply ensemble the predictions obtained by adapting the unseen target sample to each source domain as the final domain generalization result. We conduct experiments on six benchmarks for classification of images and microblog threads to demonstrate the promise and effectiveness of our method for domain generalizationfoot_0 .

2. METHODOLOGY

In domain generalization, we are provided source and target domains as non-overlapping distributions on the joint space X × Y, where X and Y denote the input and label space, respectively. Given a dataset with S source domains D s = D i s S i=1 and T target domains D t = D i t T i=1 , a model is trained only on D s and required to generalize well on D t . Following the multi-source domain generalization setting (Li et al., 2017; Zhou et al., 2021) , we assume there are multiple source domains with the same label space to mimic good domain shifts during training. In this work, we propose energy-based test sample adaptation, which adapts target samples to source distributions to tackle the domain gap between target and source data. The rationale behind our model is that adapting the target samples to the source data distributions is able to improve the prediction of the target data with source models by reducing the domain shifts, as shown in Figure 1 (left). Since the target data is never seen during training, we mimic domain shifts during the training stage to learn the sample adaptation procedure. By doing so, the model acquires the ability to adapt each target sample to the source distribution at inference time. In this section, we first provide a preliminary on energy-based models and then present our energy-based test sample adaptation.

2.1. ENERGY-BASED MODEL PRELIMINARY

Energy-based models (LeCun et al., 2006) represent any probability distribution p(x) for x ∈ R D as p θ (x) = exp(-E θ (x)) Z θ , where E θ (x) : R D → R is known as the energy function that maps each input sample to a scalar and Z θ = exp(-E θ (x))dx denotes the partition function. However, Z θ is usually intractable since it computes the integration over the entire input space of x. Thus, we cannot train the parameter θ of the energy-based model by directly maximizing the log-likelihood log p θ (x) = -E θ (x) -logZ θ . Nevertheless, the log-likelihood has the derivative (Du & Mordatch, 2019; Song & Kingma, 2021) : ∂log p θ (x) ∂θ = E p d (x) - ∂E θ (x) ∂θ + E p θ (x) ∂E θ (x) ∂θ , where the first expectation term is taken over the data distribution p d (x) and the second one is over the model distribution p θ (x). The objective function in eq. ( 1) encourages the model to assign low energy to the sample from the real data distribution while assigning high energy to those from the model distribution. To do so, we need to draw samples from p θ (x), which is challenging and usually approximated by MCMC methods (Hinton, 2002) . An effective MCMC method used in recent works (Du & Mordatch, 2019; Nijkamp et al., 2019; Xiao et al., 2021b; Grathwohl et al., 2020) is Stochastic Gradient Langevin



Code available: https://github.com/zzzx1224/EBTSA-ICLR2023.

