NO PAIRS LEFT BEHIND: IMPROVING METRIC LEARN-ING WITH REGULARIZED TRIPLET OBJECTIVE

Abstract

We propose a novel formulation of the triplet objective function that improves metric learning without additional sample mining or overhead costs. Our approach aims to explicitly regularize the distance between the positive and negative samples in a triplet with respect to the anchor-negative distance. As an initial validation, we show that our method (called No Pairs Left Behind [NPLB]) improves upon the traditional and current state-of-the-art triplet objective formulations on standard benchmark datasets. To show the effectiveness and potentials of NPLB on real-world complex data, we evaluate our approach on a large-scale healthcare dataset (UK Biobank), demonstrating that the embeddings learned by our model significantly outperform all other current representations on tested downstream tasks. Additionally, we provide a new model-agnostic single-time health risk definition that, when used in tandem with the learned representations, achieves the most accurate prediction of subjects' future health complications. Our results indicate that NPLB is a simple, yet effective framework for improving existing deep metric learning models, showcasing the potential implications of metric learning in more complex applications, especially in the biological and healthcare domains. Our code package as well as tutorial notebooks is available on our public repository: <revealed after the double blind reviews>.

1. INTRODUCTION

Metric learning is the task of encoding similarity-based embeddings where similar samples are mapped closer in space and dissimilar ones afar (Xing et al., 2002; Wang et al., 2019; Roth et al., 2020) . Deep metric learning (DML) has shown success in many domains, including computer vision (Hermans et al., 2017; Vinyals et al., 2016; Wang et al., 2018b) and natural language processing (Reimers & Gurevych, 2019; Mueller & Thyagarajan, 2016; Benajiba et al., 2019) . Many DML models utilize paired samples to learn useful embeddings based on distance comparisons. The most common architectures among these techniques are the Siamese (Bromley et al., 1993) and triplet networks (Hoffer & Ailon, 2015) . The main components of these models are the: (1) Strategies for constructing training tuples and (2) objectives that the model must minimize. Though many studies have focused on improving sampling strategies (Wu et al., 2017; Ge, 2018; Shrivastava et al., 2016; Kalantidis et al., 2020; Zhu et al., 2021) , modifying the objective function has attracted less attention. Given that learning representations with triplets very often yield better results than pairs using the same network (Hoffer & Ailon, 2015; Balntas et al., 2016) , our work focuses on improving triplet-based DML through a simple yet effective modification of the traditional objective. Modifying DML loss functions often requires mining additional samples or identifying new quantities (e.g. identifying class centers iteratively throughout training (He et al., 2018) ) or computing quantities with costly overheads (Balntas et al., 2016) , which may limit their applications. In this work, we aim to provide an easy and intuitive modification of the traditional triplet loss that is motivated by metric learning on more complex datasets, and the notion of density and uniformity of each class. Our proposed variation of the triplet loss leverages all pairwise distances between existing pairs in traditional triplets (positive, negative, and anchor) to encourage denser clusters and better separability between classes. This allows for improving already existing triplet-based DML architectures using implementations in standard deep learning (DL) libraries (e.g. TensorFlow), enabling a wider usage of the methods and improvements presented in this work. Many ML algorithms are developed for and tested on datasets such as MNIST (LeCun, 1998) or ImageNet (Deng et al., 2009) , which often lack the intricacies and nuances of data in other fields, such as health-related domains (Lee & Yoon, 2017). Unfortunately, this can have direct consequences when we try to understand how ML can help improve care for patients (e.g. diagnosis or prognosis). In this work, we demonstrate that DML algorithms can be effective in learning embeddings from complex healthcare datasets. We provide a novel DML objective function and show that our model's learned embeddings improve downstream tasks, such as classifying subjects and predicting future health risk using a single-time point. More specifically, we build upon the DML-learned embeddings to formulate a new mathematical definition for patient health-risks using a single time point which, to the best of our knowledge, does not currently exist. To show the effectiveness of our model and health risk definition, we evaluate our methodology on a large-scale complex public dataset, the UK Biobank (UKB) (Bycroft et al., 2018) , demonstrating the implications of our work for both healthcare and the ML community. In summary, our most important contributions can be described as follows. 1) We present a novel triplet objective function that improves model learning without any additional sample mining or overhead computational costs. 2) We demonstrate the effectiveness of our approach on a large-scale complex public dataset (UK Biobank) and on conventional benchmarking datasets (MNIST, Fashion MNIST (Xiao et al., 2017) and CIFAR10 (Krizhevsky, 2010) ). This demonstrates the potential of DML in other domains which traditionally may have been less considered. 3) We provide a novel definition of patient health risk from a single time point, demonstrating the real-world impact of our approach by predicting current healthy subjects' future risks using only a single lab visit, a challenging but crucial task in healthcare.

2. BACKGROUND AND RELATED WORK

Contrastive learning aims to minimize the distance between two samples if they belong to the same class (are similar). As a result, contrastive models require two samples to be inputted before calculating the loss and updating their parameters. This can be thought of as passing two samples to two parallel models with tied weights, hence being called Siamese or Twin networks (Bromley et al., 1993) . Triplet networks (Hoffer & Ailon, 2015) build upon this idea to rank positive and negative samples based on an anchor value, thus requiring the model to produce mappings for all three before the optimization step (hence being called triplets). Modification of Triplet Loss: Due to their success and importance, triplet networks have attracted increasing attention in recent years. Though the majority of proposed improvements focus on the sampling and selection of the triplets, some studies (Balntas et al., 2016; Zhao et al., 2019; Kim & Park, 2021; Nguyen et al., 2022) have proposed modifications of the traditional triplet loss proposed in Hoffer & Ailon (2015) . Similar to our work, Multi-level Distance Regularization (MDR) (Kim & Park, 2021) seeks to regularize the DML loss function. MDR regularizes the pairwise distances between embedding vectors into multiple levels based on their similarity. The goal of MDR is to disturb the optimization of the pairwise distances among examples and to discourage positive pairs from getting too close and the negative pairs from being too distant. A drawback of regularization methods is the choice of hyperparameter that balances the regularization term, though adaptive balancing methods could be used (Chen et al., 2018; Heydari et al., 2019) . Most related to our work, Balntas et al. ( 2016) modified the traditional objective by explicitly accounting for the distance between the positive and negative pairs (which the traditional triplet function does not consider), and applied their model to learn local feature descriptors using shallow convolutional neural networks. They introduce the idea of "in-triplet hard negative", referring to the swap of the anchor and positive sample if the positive sample is closer to the negative sample than the anchor, thus improving on the performance of traditional triplet networks (we refer to this approach as Distance Swap). Though this method uses the distance between the positive and negative samples to choose the anchor, it does not explicitly enforce the model to regularize the distance between the two, which was the main issue with the original formulation. Our work addresses this pitfall by using the notion of local density and uniformity (defined later in §3) to explicitly enforce the regularization of the distance between the positive and negative pairs using the distance between the anchors and the negatives. As a result, our approach ensures better inter-class separability while encouraging denser intra-class embeddings. In addition to MDR and Swap Distance, we benchmark our approach againt three related and widely-used metric learning algorithms, namely LiftedStruct Song et al. (2015) , N-Pair Loss Sohn (2016), and InfoNCE Oord et al. (2018a) . Due to the space constraints, and given the popularity of the methods, we provide an overview of these algorithms in Appendix E

