EMB-GAM: AN INTERPRETABLE AND EFFICIENT PRE-DICTOR USING PRE-TRAINED LANGUAGE MODELS

Abstract

Deep learning models have achieved impressive prediction performance but often sacrifice interpretability and speed, critical considerations in high-stakes domains and compute-limited settings. In contrast, generalized additive models (GAMs) can maintain interpretability and speed but often suffer from poor prediction performance due to their inability to effectively capture feature interactions. This work aims to bridge this gap by using pre-trained neural language models to extract embeddings from each input before aggregating them and learning a linear model in the embedding space. The final model (which we call Emb-GAM) is a transparent, linear function of its input features and feature interactions. Leveraging the language model allows Emb-GAM to learn far fewer linear coefficients, model larger interactions, dramatically speed up inference, and generalize well to novel inputs (e.g. unseen ngrams in text). Across a variety of natural-languageprocessing datasets, Emb-GAM achieves strong prediction performance without sacrificing interpretability or speed. All code is made available on Github.

1. INTRODUCTION

Large neural language models (LLMs) have demonstrated impressive predictive performance due to their ability to learn complex, non-linear, relationships between variables. However, the inability of humans to understand these relationships has led LLMs to be characterized as black boxes, often limiting their use in high-stakes applications such as science (Angermueller et al., 2016) , medicine (Kornblith et al., 2022) , and policy-making (Brennan & Oliver, 2013) . Moreover, the use of black-box models such as LLMs has come under increasing scrutiny in settings where users require explanations or where models struggle with issues such as fairness (Dwork et al., 2012) and regulatory pressure (Goodman & Flaxman, 2016) . Simultaneously, recent black-box models have grown to massive sizes, making them costly and difficult to deploy, particularly for edge devices such as mobile phones. As an alternative to large black-box models, transparent models, such as generalized additive models (Hastie & Tibshirani, 1986) and rule-based models (Breiman et al., 1984) can maintain interpretability. Additionally, transparent models tend to be faster and more computationally efficient than black-box models. While transparent models can sometimes perform as well as black-box models (e.g. Rudin et al. ( 2021 2022)), in many settings such as natural-language processing (NLP), there is often a large gap in the performance between transparent models and black-box models. This work aims to minimize this gap by leveraging a pre-trained LLM to learn a more effective transparent model. Specifically, we extract LLM embeddings for different feature interactions (e.g. ngrams in text) and then learn a generalized additive model on top of these embeddings. The final learned model (which we call Emb-GAM) is a transparent, linear function of its input features and feature interactions, but the use of the LLM allows Emb-GAM to intelligently reduce its number of learned parameters (see Fig 1) . Rather than learning a linear model over all possible feature interactions (which scales exponentially with the order of the interaction and the feature dimension), Emb-GAM requires learning only a fixed set of linear coefficients (the size of the embedding extracted by the LLM). As a result, Emb-GAM can efficiently model high-order interactions, generalize well to novel interactions, and even vary the number of features used at test-time for prediction. Moreover, inference 1



); Ha et al. (2021); Mignan & Broccardo (2019); Tan et al. (

