UNCERTAINTY ESTIMATION AND CALIBRATION WITH FINITE-STATE PROBABILISTIC RNNS

Abstract

Uncertainty quantification is crucial for building reliable and trustable machine learning systems. We propose to estimate uncertainty in recurrent neural networks (RNNs) via stochastic discrete state transitions over recurrent timesteps. The uncertainty of the model can be quantified by running a prediction several times, each time sampling from the recurrent state transition distribution, leading to potentially different results if the model is uncertain. Alongside uncertainty quantification, our proposed method offers several advantages in different settings. The proposed method can (1) learn deterministic and probabilistic automata from data, (2) learn well-calibrated models on real-world classification tasks, (3) improve the performance of out-of-distribution detection, and (4) control the explorationexploitation trade-off in reinforcement learning. An implementation is available.

1. INTRODUCTION

Machine learning models are well-calibrated if the probability associated with the predicted class reflects its correctness likelihood relative to the ground truth. The output probabilities of modern neural networks are often poorly calibrated (Guo et al., 2017) . For instance, typical neural networks with a softmax activation tend to assign high probabilities to out-of-distribution samples (Gal & Ghahramani, 2016b) . Providing uncertainty estimates is important for model interpretability as it allows users to assess the extent to which they can trust a given prediction (Jiang et al., 2018) . Moreover, well-calibrated output probabilities are crucial in several use cases. For instance, when monitoring medical time-series data (see Figure 1 (a)), hospital staff should also be alerted when there is a low-confidence prediction concerning a patient's health status. Bayesian neural networks (BNNs), which place a prior distribution on the model's parameters, are a popular approach to modeling uncertainty. BNNs often require more parameters, approximate inference, and depend crucially on the choice of prior (Gal, 2016; Lakshminarayanan et al., 2017) . Applying dropout both during training and inference can be interpreted as a BNN and provides a more efficient method for uncertainty quantification (Gal & Ghahramani, 2016b) . The dropout probability, however, needs to be tuned and, therefore, leads to a trade-off between predictive error and calibration error. Sidestepping the challenges of Bayesian NNs, we propose an orthogonal approach to quantify the uncertainty in recurrent neural networks (RNNs). At each time step, based on the current hidden (and cell) state, the model computes a probability distribution over a finite set of states. The next state of the RNN is then drawn from this distribution. We use the Gumbel softmax trick (Gumbel, 1954; Kendall & Gal, 2017; Jang et al., 2017) to perform Monte-Carlo gradient estimation. Inspired by the effectiveness of temperature scaling (Guo et al., 2017) which is usually applied to trained models, we learn the temperature τ of the Gumbel softmax distribution during training to control the concentration of the state transition distribution. Learning τ as a parameter can be seen as entropy regularization (Szegedy et al., 2016; Pereyra et al., 2017; Jang et al., 2017 ). The resulting model, which we name ST-τ , defines for every input sequence a probability distribution over state- The y-axis presents the model's confidence of the sentence having a negative sentiment. After the first few words, the model leans towards a negative sentiment, but is uncertain about its prediction. After the word "mess," its uncertainty drops and it predicts the sentiment as negative. transition paths similar to a probabilistic state machine. To estimate the model's uncertainty for a prediction, ST-τ is run multiple times to compute mean and variance of the prediction probabilities. We explore the behavior of ST-τ in a variety of tasks and settings. First, we show that ST-τ can learn deterministic and probabilistic automata from data. Second, we demonstrate on real-world classification tasks that ST-τ learns well calibrated models. Third, we show that ST-τ is competitive in out-of-distribution detection tasks. Fourth, in a reinforcement learning task, we find that ST-τ is able to trade off exploration and exploitation behavior better than existing methods. Especially the outof-distribution detection and reinforcement learning tasks are not amenable to post-hoc calibration approaches (Guo et al., 2017) and, therefore, require a method such as ours that is able to calibrate the probabilities during training.

2.1. BACKGROUND

An RNN is a function f defined through a neural network with parameters w that is applied over time steps: at time step t, it reuses the hidden state h t-1 of the previous time step and the current input x t to compute a new state h t , f : (h t-1 , x t ) → h t . Some RNN variants such as LSTMs have memory cells c t and apply the function f : (h t-1 , c t-1 , x t ) → h t at each step. A vanilla RNN maps two identical input sequences to the same state and it is therefore not possible to measure uncertainty of a prediction by running inference multiple times. Furthermore, it is known that passing h t through a softmax transformation leads to overconfident predictions on out-of-distribution samples and poorly calibrated probabilities (Guo et al., 2017) . In a Bayesian RNN the weight matrices w are drawn from a distribution and, therefore, the output is an average of an infinite number of models. Unlike vanilla RNNs, Bayesian RNNs are stochastic and it is possible to compute average and variance for a prediction. Using a prior to integrate out the parameters during training also leads to a regularization effect. However, there are two major and often debated challenges of BNNs: the right choice of prior and the efficient approximation of the posterior. With this paper, we side-step these challenges and model the uncertainty of an RNN through probabilistic state transitions between a finite number of k learnable states s 1 , ..., s k . Given a state h t , we compute a probability distribution over the learnable states. Hence, for the same state and input, the RNN might move to different states in different runs. Instead of integrating over possible weights, as in the case of BNNs, we sum over all possible state sequences and weigh the classification probabilities by the probabilities of these sequences. Figure 2 illustrates the proposed approach and contrasts it with vanilla and Bayesian RNNs. The proposed method combines two building blocks. The first is state-regularization (Wang & Niepert, 2019) as a way to compute a probability distribution over a finite set of states in an RNN. State-regularization, however, is deterministic and therefore we utilize the second building block, the Gumbel softmax trick (Gumbel, 1954; Maddison et al., 2017; Jang et al., 2017) to sample from a categorical distribution. Combining the two blocks allows us to create



Figure1: (a) Prediction uncertainty of ST-τ , our proposed method, for an ECG time-series based on 10 runs. To the left of the red line ST-τ classifies a heart beat as normal. To the right of the red line, ST-τ makes wrong predictions. Due to its drop in certainty, however, can alert medical personnel. (b) Given a sentence with negative sentiment, ST-τ reads the sentence word by word. The y-axis presents the model's confidence of the sentence having a negative sentiment. After the first few words, the model leans towards a negative sentiment, but is uncertain about its prediction. After the word "mess," its uncertainty drops and it predicts the sentiment as negative.

* Equal contribution.

† Work done at NEC Laboratories Europe.

