MET: MASKED ENCODING FOR TABULAR DATA

Abstract

We propose Masked Encoding for Tabular Data (MET) for learning selfsupervised representations from tabular data. Tabular self-supervised learning (tabular-SSL) -unlike structured domains like images, audio, text -is more challenging since each tabular dataset can have a completely different structure among its features (or coordinates), which is hard to identify a priori. MET attempts to circumvent this problem by assuming the following hypothesis: the observed tabular data features come from a latent graphical model and the downstream tasks are significantly easier to solve in the latent space. Based on this hypothesis, MET uses random masking based encoders to learn a positional embedding for each coordinate, which would in turn capture the latent structure between coordinates. Through experiments on a toy dataset from a linear graphical model, we show that MET is indeed able to capture the latent graphical model. Practically, through extensive experiments on multiple benchmarks for tabular data, we demonstrate that MET significantly outperforms all the baselines. For example, on Criteoa large-scale click prediction dataset -MET achieves as much as 5% improvement over the current state-of-the-art (SOTA) while purely supervised learning based approaches have been able to advance SOTA by at most 2% in the last six years. Furthermore, averaged over nine datasets, MET is around 3.9% more accurate than the next best method of Gradient-boosted decision trees -considered as SOTA for the tabular setting.

1. INTRODUCTION

Self-supervised pre-training (SSL) followed by supervised fine-tuning has emerged as the state of the art approach for multiple domains such as natural language processing (NLP) (Devlin et al., 2019 ), computer vision (Chen et al., 2020b) and speech/audio processing (Baevski et al., 2020) . However, despite presence of an extensive amount of raw and unlabeled data in a variety of critical tabular-heavy domains like finance, marketing, etc., it has been challenging to extend SSL based pre-training approaches to tabular data. Broadly speaking, there are two dominant approaches to SSL: (i) reconstruction of masked inputs, and (ii) invariance to certain augmentations/transformations, also known as contrastive learning. Most of the existing tabular-SSL methods (Verma et al., 2020; Ucar et al., 2021 ) have adopted the second approach of contrastive learning. The underlying structure and semantics of specific domains such as images remain somewhat static, irrespective of the dataset. So, one can design generalizable domain specific augmentations like cropping, rotating, resizing etc. However, tabular data does not have such fixed input vocabulary space (such as pixels in images) and semantic structure, and thus lacks generalizable augmentations across different datasets. Consequently, there are only a limited number of augmentations that have been proposed for the tabular setting such as mix-up, adding random (gaussian) noise and selecting subsets of features (Verma et al., 2020; Ucar et al., 2021) . In this paper, we hypothesize the following: for any tabular dataset, (i) there is a latent (i.e., unknown/unobserved) graphical model that captures the relations between different coordinates/features, and (ii) classification is easier in the latent space. For example, in the CovType dataset -where the task is to predict the type of forest (e.g., deciduous, alpine etc.) given features such as elevation, soil type -extensive research in mountain and forest science has established that there are very specific relations among different features (Catry et al., 2009; Badía et al., 2016) , and leveraging and learning these relations could yield significant improvements in classification accuracy of machine learning models. Figure 1 : MET Framework for tabular-SSL. Given an input x, we mask γ fraction of it's coordinates to get the masked input x S . We then concatenate the coordinate value x j with a learnable positional encoding P E j to obtain its input embedding ψ(x S ). The obtained input embedding is then passed through the encoder T θenc . The encoder output T θenc (ψ(x S )) is then concatenated with the masked token representations ψ S , and given as the input to the decoder T θ dec . Finally, the decoder output is projected back to the input space using a projection head hω, to obtain the reconstructed input x. Reconstruction loss is then optimized end-to-end. Based on this hypothesis, we propose a masking-based reconstruction approach for self-supervised learning for tabular datasets. More concretely, for every unlabeled data point, we randomly choose a fraction of the coordinates, mask their values, and then train a model to predict the values of these masked coordinates using the remaining unmasked coordinates. We use a transformer architecture with learnable (positional) embeddings for each coordinate, which capture the relations between different coordinates. While masked reconstruction task with a transformer architecture has been successfully used for SSL in computer vision (He et al., 2021) and NLP (Devlin et al., 2019) , to the best of our knowledge, this is the first work to successfully apply this paradigm to tabular datasets. In particular, we demonstrate through experiments on a simple toy tabular setting, how the position embeddings in a transformer, learned with masked reconstruction task, can capture the dependency structure across features. Further, on a real-world dataset of forest cover type classification, we show that features with the most correlated positional embeddings indeed have meaningful relations between them, as corroborated by extensive works in the forest science literature. We evaluate the performance of MET through extensive experiments on nine datasets spanning a wide range in number of examples, classes and difficulty. Our experiments show that MET outperforms current SOTA tabular-SSL methods like DACL (Verma et al., 2020 ), SubTab (Ucar et al., 2021) , VIME (Yoon et al., 2020) , as well as SOTA tabular supervised algorithms such as gradient boosted decision trees (GBDT) (Mason et al., 1999 ) on all of these datasets. MET gives an average accuracy improvement of 3.9% over the second best algorithm, which is GBDT. For example, on Criteo -a popular large scale dataset for click through rate prediction with 45 million examples -MET achieves 5% improvement in AUROC over the current SOTA (Wang et al., 2021) . To put this in perspective, the SOTA on Criteo has improved by less than 2% over the last six years Leaderboard (2022). Furthermore, on some datasets, MET trained with about 20% of the labeled train-set is as effective as standard supervised learning methods trained with all the labeled points in the train-set. To summarize, in this paper, we propose MET, which is a masking based reconstruction approach with a transformer architecture, and demonstrate its effectiveness for tabular-SSL. Conceptually, through experiments on a toy setting, we show that this approach can learn the relations between different coordinates in the dataset, which helps in downstream classification. Practically, we show through extensive experiments on several popular tabular datasets that MET significantly outperforms all the current SOTA tabular-SSL baselines as well as SOTA supervised approaches.

2. RELATED WORK

Self-Supervised Learning (SSL) : SSL has shown promising results not only in the regimes where the labelled training data is scarce but has also shown great empirical success in training large-scale models across various domains like Natural Language Processing and Computer Vision. SSL can be broadly classified into two categories: Pretext task based approaches and contrastive learning based approaches. Pretext based SSL approaches solve a"pretext" task like reconstruction from a masked or a noisy input, to learn the underlying distribution of the unlabeled data.

