TABULAR DEEP LEARNING WHEN d ≫ n BY USING AN AUXILIARY KNOWLEDGE GRAPH Anonymous authors Paper under double-blind review

Abstract

Machine learning models exhibit strong performance on datasets with abundant labeled samples. However, for tabular datasets with extremely high d-dimensional features but limited n samples (i.e. d ≫ n), machine learning models struggle to achieve strong performance. Here, our key insight is that even in tabular datasets with limited labeled data, input features often represent real-world entities about which there is abundant prior information which can be structured as an auxiliary knowledge graph (KG). For example, in a tabular medical dataset where every input feature is the amount of a gene in a patient's tumor and the label is the patient's survival, there is an auxiliary knowledge graph connecting gene names with drug, disease, and human anatomy nodes. We therefore propose PLATO, a machine learning model for tabular data with d ≫ n and an auxiliary KG with input features as nodes. PLATO uses a multilayer perceptron (MLP) to predict the output labels from the tabular data and the auxiliary KG with two methodological components. First, PLATO predicts the parameters in the first layer of the MLP from the auxiliary KG. PLATO thereby reduces the number of trainable parameters in the MLP and integrates auxiliary information about the input features. Second, PLATO predicts different parameters in the first layer of the MLP for every input sample, thereby increasing the MLP's representational capacity by allowing it to use different prior information for every input sample. Across 10 state-of-the-art baselines and 6 d ≫ n datasets, PLATO exceeds or matches the prior state-of-the-art, achieving performance improvements of up to 10.19%. Overall, PLATO uses an auxiliary KG about input features to enable tabular deep learning prediction when d ≫ n.

1. INTRODUCTION

Machine learning models have reached state-of-the-art performance in domains with abundant labeled data like computer vision (Wortsman et al., 2022; Deng et al., 2009) and natural language processing (Wang et al., 2019; Devlin et al., 2019; Ramesh et al., 2022) . However, for tabular datasets in which the number d of features vastly exceeds the number n of samples, machine learning models struggle to achieve strong performance (Hastie et al., 2009; Liu et al., 2017) . Unfortunately, many high impact domains like chemistry (Guyon et al., 2004) , biology (Iorio et al., 2016; Yang et al., 2012; Garnett et al., 2012; Gao et al., 2015) , and physics (Kasieczka et al., 2021) produce datasets with high-dimensional features but limited labeled samples due to the high time and labor costs associated with experiments. In chemistry, for example, mass spectrometry datasets can have tens of thousands of features but only tens or hundreds of samples (Guyon et al., 2004) . For these and other tabular datasets with d ≫ n, the performance of machine learning systems is currently limited. To date, deep learning approaches for tabular data have focused on data regimes with far more samples than features (n ≫ d) (Grinsztajn et al., 2022; Gorishniy et al., 2021; Shwartz-Ziv & Armon, 2022) . In the low-data regime with far more features than samples (d ≫ n), the dominant approaches are classical statistical methods (Hastie et al., 2009) . These statistical methods reduce the dimensionality of the input space (Abdi & Williams, 2010; Liu et al., 2017; Van der Maaten & Hinton, 2008; Van Der Maaten et al., 2009) , select features (Tibshirani, 1996; Climente-González et al., 2019; Freidling et al., 2021; Meier et al., 2008) , impose regularization penalties on parameter magnitudes (Marquardt & Snee, 1975) , or use ensembles of weak tree-based models (Friedman, 2001; Chen & Guestrin, 2016; Ke et al., 2017; Lou & Obukhov, 2017; Prokhorenkova et al., 2018) . Here, we present a novel problem setting and framework for tabular deep learning when d ≫ n (Figure 1 ). Our key insight is that even in tabular settings with limited labeled data, input features often represent real-world entities about which there is abundant prior information which can be structured as an auxiliary knowledge graph (KG). We propose a novel problem setting in which every input feature of a tabular dataset corresponds to a node in an auxiliary KG (Figure 1a ). For example, consider a tabular medical dataset in which every row is a cancer patient, every column is a gene, every value is the amount of that gene in the patient's tumor, and the task is to predict the patient's survival. For this tabular dataset, there exists an auxiliary KG which consists of each gene's function, the relationships between genes, how a gene affects a part of human anatomy, and how human anatomy itself is structured. Note that the KG does not capture the relationships between input data samples but instead captures the relationships between input features. Within our novel problem setting, we propose PLATO, a deep learning method for tabular data with d ≫ n and an auxiliary KG with input features as nodes (Figure 1(b)-(e) ). PLATO uses a modified multilayer perceptron (MLP) to predict the output labels from the input samples and the auxiliary KG with two methodological components. First, the parameters in the first layer of the MLP are predicted from the auxiliary KG and the input sample rather than learned from just the tabular data. PLATO thereby integrates prior information about the input features from the auxiliary KG and drastically reduces the number of trainable parameters in the MLP. Second, the parameters in the first layer of the MLP are predicted differently for every sample by using the auxiliary KG and the sample values. PLATO thereby increases the representational capacity of the MLP and enables effective predictions. We exhibit PLATO's performance on 6 datasets. We choose computational biology as it is a rich domain for d ≫ n in which we can construct a single knowledge graph to serve as a unified backbone for many distinct tabular datasets with distinct input features. We compare PLATO to 10 state-of-the-art baselines spanning dimensionality reduction, feature selection, classic statistical models, deep tabular learning methods, and parameter-prediction methods. Following a rigorous evaluation protocol from the tabular deep learning literature (Grinsztajn et al., 2022; Gorishniy et al., 2021) , PLATO achieves or matches the prior state-of-the-art on all 6 datasets, achieving performance improvements of up to 10.19%. Ablation studies further demonstrate the necessity of each methodological component of PLATO. Ultimately, PLATO uses an auxiliary KG about input features to enable tabular deep learning prediction when d ≫ n.

2. RELATED WORK

Tabular deep learning methods. In contrast to PLATO's setting, tabular deep learning methods have been developed for settings with far more samples than features (i.e. n ≫ d). Recent tabular deep learning benchmarks ignore datasets with high numbers of features and low numbers of samples (Grinsztajn et al., 2022; Gorishniy et al., 2021; Shwartz-Ziv & Armon, 2022) . In the n ≫ d setting, various categories of deep tabular models have been benchmarked. We select the state-of-the-art models to compare against PLATO. First, decision tree models like NODE (Popov et al., 2020) make decision trees differentiable to enable gradient-based optimization (Hazimeh et al., 2020; Kontschieder et al., 2015; Yang et al., 2018) . Second, tabular transformer architectures use an attention mechanism to select and learn interactions among features. These include TabNet (Arik & Pfister, 2021 ), TabTransformer (Huang et al., 2020) , FT-Transformer (Gorishniy et al., 2021), and others (Song et al., 2019; Somepalli et al., 2021; Kossen et al., 2021) . d ≫ n methods. For PLATO's setting in which d ≫ n, various tabular machine learning approaches exist (Hastie et al., 2009) . First, dimensionality reduction techniques like PCA (Abdi & Williams, 2010) aim to reduce the dimensionality of the input data while preserving as much of the the variance in the data as possible (Liu et al., 2017; Van der Maaten & Hinton, 2008; Van Der Maaten et al., 2009) . Second, feature selection approaches select a parsimonious set of features, leading to a smaller feature space. Classical feature selection approaches include LASSO (Tibshirani, 1996) and its variants (Climente-González et al., 2019; Freidling et al., 2021; Meier et al., 2008) . For feature selection with deep learning, Stochastic Gates (Yamada et al., 2020) are among the best performing of many variants (Balın et al., 2019; Lu et al., 2018) . Finally, classical tree-based models like XGBoost learn ensembles of weak decision trees models to make an overall prediction (Friedman, 2001; Chen & Guestrin, 2016; Ke et al., 2017; Prokhorenkova et al., 2018) .

