TABULAR DEEP LEARNING WHEN d ≫ n BY USING AN AUXILIARY KNOWLEDGE GRAPH Anonymous authors Paper under double-blind review

Abstract

Machine learning models exhibit strong performance on datasets with abundant labeled samples. However, for tabular datasets with extremely high d-dimensional features but limited n samples (i.e. d ≫ n), machine learning models struggle to achieve strong performance. Here, our key insight is that even in tabular datasets with limited labeled data, input features often represent real-world entities about which there is abundant prior information which can be structured as an auxiliary knowledge graph (KG). For example, in a tabular medical dataset where every input feature is the amount of a gene in a patient's tumor and the label is the patient's survival, there is an auxiliary knowledge graph connecting gene names with drug, disease, and human anatomy nodes. We therefore propose PLATO, a machine learning model for tabular data with d ≫ n and an auxiliary KG with input features as nodes. PLATO uses a multilayer perceptron (MLP) to predict the output labels from the tabular data and the auxiliary KG with two methodological components. First, PLATO predicts the parameters in the first layer of the MLP from the auxiliary KG. PLATO thereby reduces the number of trainable parameters in the MLP and integrates auxiliary information about the input features. Second, PLATO predicts different parameters in the first layer of the MLP for every input sample, thereby increasing the MLP's representational capacity by allowing it to use different prior information for every input sample. Across 10 state-of-the-art baselines and 6 d ≫ n datasets, PLATO exceeds or matches the prior state-of-the-art, achieving performance improvements of up to 10.19%. Overall, PLATO uses an auxiliary KG about input features to enable tabular deep learning prediction when d ≫ n.

1. INTRODUCTION

Machine learning models have reached state-of-the-art performance in domains with abundant labeled data like computer vision (Wortsman et al., 2022; Deng et al., 2009) and natural language processing (Wang et al., 2019; Devlin et al., 2019; Ramesh et al., 2022) . However, for tabular datasets in which the number d of features vastly exceeds the number n of samples, machine learning models struggle to achieve strong performance (Hastie et al., 2009; Liu et al., 2017) . Unfortunately, many high impact domains like chemistry (Guyon et al., 2004 ), biology (Iorio et al., 2016; Yang et al., 2012; Garnett et al., 2012; Gao et al., 2015) , and physics (Kasieczka et al., 2021) produce datasets with high-dimensional features but limited labeled samples due to the high time and labor costs associated with experiments. In chemistry, for example, mass spectrometry datasets can have tens of thousands of features but only tens or hundreds of samples (Guyon et al., 2004) . For these and other tabular datasets with d ≫ n, the performance of machine learning systems is currently limited. To date, deep learning approaches for tabular data have focused on data regimes with far more samples than features (n ≫ d) (Grinsztajn et al., 2022; Gorishniy et al., 2021; Shwartz-Ziv & Armon, 2022) . In the low-data regime with far more features than samples (d ≫ n), the dominant approaches are classical statistical methods (Hastie et al., 2009) . These statistical methods reduce the dimensionality of the input space (Abdi & Williams, 2010; Liu et al., 2017; Van der Maaten & Hinton, 2008; Van Der Maaten et al., 2009) , select features (Tibshirani, 1996; Climente-González et al., 2019; Freidling et al., 2021; Meier et al., 2008) , impose regularization penalties on parameter magnitudes (Marquardt & Snee, 1975) , or use ensembles of weak tree-based models (Friedman, 2001; Chen & Guestrin, 2016; Ke et al., 2017; Lou & Obukhov, 2017; Prokhorenkova et al., 2018) . 1

