META-LEARNING ADAPTIVE DEEP KERNEL GAUSSIAN PROCESSES FOR MOLECULAR PROPERTY PREDICTION

Abstract

We propose Adaptive Deep Kernel Fitting with Implicit Function Theorem (ADKF-IFT), a novel framework for learning deep kernel Gaussian processes (GPs) by interpolating between meta-learning and conventional deep kernel learning. Our approach employs a bilevel optimization objective where we meta-learn generally useful feature representations across tasks, in the sense that task-specific GP models estimated on top of such features achieve the lowest possible predictive loss on average. We solve the resulting nested optimization problem using the implicit function theorem (IFT). We show that our ADKF-IFT framework contains previously proposed Deep Kernel Learning (DKL) and Deep Kernel Transfer (DKT) as special cases. Although ADKF-IFT is a completely general method, we argue that it is especially well-suited for drug discovery problems and demonstrate that it significantly outperforms previous state-of-the-art methods on a variety of real-world few-shot molecular property prediction tasks and out-of-domain molecular property prediction and optimization tasks.

1. INTRODUCTION

Many real-world applications require machine learning algorithms to make robust predictions with well-calibrated uncertainty given very limited training data. One important example is drug discovery, where practitioners not only want models to accurately predict biochemical/physicochemical properties of molecules, but also want to use models to guide the search for novel molecules with desirable properties, leveraging techniques such as Bayesian optimization (BO) which heavily rely on accurate uncertainty estimates (Frazier, 2018) . Despite the meteoric rise of neural networks over the past decade, their notoriously overconfident and unreliable uncertainty estimates (Szegedy et al., 2013) make them generally ineffective surrogate models for BO. Instead, most contemporary BO implementations use Gaussian processes (GPs) (Rasmussen & Williams, 2006) as surrogate models due to their analytically-tractable and generally reliable uncertainty estimates, even on small datasets. Traditionally, GPs are fit on hand-engineered features (e.g., molecular fingerprints), which can limit their predictive performance on complex, structured, high-dimensional data where designing informative features is challenging (e.g., molecules). Naturally, a number of works have proposed to improve performance by instead fitting GPs on features learned by a deep neural network: a family of models generally called Deep Kernel GPs. However, there is no clear consensus about how to train these models: maximizing the GP marginal likelihood (Hinton & Salakhutdinov, 2007; Wilson et al., 2016b) has been shown to overfit on small datasets (Ober et al., 2021) , while meta-learning (Patacchiola et al., 2020) and fully-Bayesian approaches (Ober et al., 2021 ) avoid this at the cost of making strong, often unrealistic assumptions. This suggests that there is demand for new, better techniques for training deep kernel GPs. In this work, we present a novel, general framework called Adaptive Deep Kernel Fitting with Implicit Function Theorem (ADKF-IFT) for training deep kernel GPs which we believe is especially well-suited to small datasets. ADKF-IFT essentially trains a subset of the model parameters with a meta-learning loss, and separately adapts the remaining parameters on each task using maximum marginal likelihood. In contrast to previous methods which use a single loss for all parameters, ADKF-IFT is able to utilize the implicit regularization of meta-learning to prevent overfitting while avoiding the strong assumptions of a pure meta-learning approach which may lead to underfitting. The key contributions and outline of the paper are as follows: 1. As our main technical contribution, we present the general ADKF-IFT framework and its natural formulation as a bilevel optimization problem (Section 3.1), then explain how the implicit function theorem (IFT) can be used to efficiently solve it with gradient-based methods in a few-shot learning setting (Section 3.2). 2. We show how ADKF-IFT can be viewed as a generalization and unification of previous approaches based purely on single-task learning (Wilson et al., 2016b) or purely on metalearning (Patacchiola et al., 2020) for training deep kernel GPs (Section 3.3). 3. We propose a specific practical instantiation of ADKF-IFT wherein all feature extractor parameters are meta-learned, which has a clear interpretation and obviates the need for any Hessian approximations. We argue why this particular instantiation is well-suited to retain the best properties of previously proposed methods (Section 3.4). 4. Motivated by the general demand for better GP models in chemistry, we perform an extensive empirical evaluation of ADKF-IFT on several chemical tasks, finding that it significantly improves upon previous state-of-the-art methods (Section 5).

2. BACKGROUND AND NOTATION

Gaussian Processes (GPs) are tools for specifying Bayesian priors over functions (Rasmussen & Williams, 2006) . A GP(m θ (•), c θ (•, •)) is fully specified by a mean function m θ (•) and a symmetric positive-definite covariance function c θ (•, •). The covariance function encodes the inductive bias (e.g., smoothness) of a GP. One advantage of GPs is that it is easy to perform principled model selection for its hyperparameters θ ∈ Θ using the marginal likelihood p(y | X, θ) evaluated on the training data (X, y) and to obtain closed-form probabilistic predictions p(y * | X * , X, y, θ) for the test data (X * , y * ); we refer the readers to Rasmussen & Williams ( 2006) for more details. Deep Kernel Gaussian Processes are GPs whose covariance function is constructed by first using a neural network feature extractor f ϕ with parameters ϕ ∈ Φ to create feature representations h = f ϕ (x), h ′ = f ϕ (x ′ ) of the input points x, x ′ , then feeding these feature representations into a standard base kernel c θ (h, h ′ ) (e.g., an RBF kernel) (Hinton & Salakhutdinov, 2007; Wilson et al., 2016b; a; Bradshaw et al., 2017; Calandra et al., 2016) . The complete covariance function is therefore k ψ (x, x ′ ) = c θ (f ϕ (x), f ϕ (x ′ )) with learnable parameters ψ = (θ, ϕ). Few-shot Learning refers to learning on many related tasks when each task has few labelled examples (Miller et al., 2000; Lake et al., 2011) . In the standard problem setup, one is given a set of training tasks D = {T t } T t=1 (a meta-dataset) and some unseen test tasks D * = {T * }. Each task T = {(x i , y i )} N T i=1 is a set of points in the domain X (e.g., space of molecules) with corresponding labels (continuous, categorical, etc.) , and is partitioned into a support set S T ⊆ T for training and a query set Q T = T \ S T for testing. Typically, the total number of training tasks T = | D | is large, while the size of each support set | S T | is small. Models for few-shot learning are typically trained to accurately predict Q T given S T for T ∈ D during a meta-training phase, then evaluated by their prediction error on Q T * given S T * for unseen test tasks T * ∈ D * during a meta-testing phase.

3.1. THE GENERAL ADKF-IFT FRAMEWORK FOR LEARNING DEEP KERNEL GPS

Let A Θ and A Φ respectively be the sets of base kernel and feature extractor parameters for a deep kernel GP. Denote the set of all parameters by A Ψ = A Θ ∪ A Φ . The key idea of the general ADKF-IFT framework is that only a subset of the parameters A Ψadapt ⊆ A Ψ will be adapted to each individual task by minimizing a train loss L T , with the remaining set of parameters A Ψmeta = A Ψ \ A Ψadapt meta-learned during a meta-training phase to yield the best possible validation loss L V on average over many related training tasks (after A Ψadapt is separately adapted to each of these tasks). This can be naturally formalized as the following bilevel optimization problem: ψ * meta = arg min ψ meta E p(T ) [L V (ψ meta , ψ * adapt (ψ meta , S T ), T )], (1) such that ψ * adapt (ψ meta , S T ) = arg min ψ adapt L T (ψ meta , ψ adapt , S T ). (2)

