A PROBABILISTIC MODEL FOR DISCRIMINATIVE AND NEURO-SYMBOLIC SEMI-SUPERVISED LEARNING Anonymous

Abstract

Strong progress has been achieved in semi-supervised learning (SSL) by combining several underlying methods, some that pertain to properties of the data distribution p(x), others to the model outputs p(y|x), e.g. minimising the entropy of unlabelled predictions. Focusing on the latter, we fill a gap in the standard text by introducing a probabilistic model for discriminative semi-supervised learning, mirroring the classical generative model. Several SSL methods are theoretically explained by our model as inducing (approximate) strong priors over parameters of p(y|x). Applying this same probabilistic model to tasks in which labels represent binary attributes, we also theoretically justify a family of neuro-symbolic SSL approaches, taking a step towards bridging the divide between statistical learning and logical reasoning.

1. INTRODUCTION

In semi-supervised learning (SSL), a mapping is learned that predicts labels y for data points x from a dataset of labelled pairs (x l , y l ) and unlabelled x u . SSL is of practical importance since unlabelled data are often cheaper to acquire and/or more abundant than labelled data. For unlabelled data to help predict labels, the distribution of x must contain information relevant to the prediction (Chapelle et al., 2006; Zhu & Goldberg, 2009) . State-of-the-art SSL algorithms (e.g. Berthelot et al., 2019b; a) combine underlying methods, including some that leverage properties of the distribution p(x), and others that rely on the label distribution p(y|x). The latter include entropy minimisation (Grandvalet & Bengio, 2005) , mutual exclusivity (Sajjadi et al., 2016a; Xu et al., 2018) and pseudo-labelling (Lee, 2013), which add functions of unlabelled data predictions to a typical discriminative supervised loss function. Whilst these methods each have their own rationale, we propose a formal probabilistic model that unifies them as a family of discriminative semi-supervised learning (DSSL) methods. Neuro-symbolic learning (NSL) is a broad field that looks to combine logical reasoning and statistical machine learning, e.g. neural networks. Approaches often introduce neural networks into a logical framework (Manhaeve et al., 2018) , or logic into statistical learning models (Rocktäschel et al., 2015) . Several works combine NSL with semi-supervised learning (Xu et al., 2018; van Krieken et al., 2019) but lack rigorous justification. We show that our probabilistic model for discriminative SSL extends to the case where label components obey logical rules, theoretically justifying neuro-symbolic SSL approaches that augment a supervised loss function with a function based on logical constraints. Central to this work are ground truth parameters {θ x } x∈X of the distributions p(y|x), as predicted by models such as neural networks. For example, θ x may be a multinomial parameter vector specifying the distribution over all labels associated with a given x. Since each data point x has a specific label distribution defined by θ x , sampling from p(x) induces an implicit distribution over parameters, p(θ). If known, the distribution p(θ) serves as a prior over all model predictions, θx : for labelled samples it may provide little additional information, but for unlabelled data may allow predictions to be evaluated and the model improved. As such, p(θ) provides a potential basis for semi-supervised learning. We show that, in practice, p(θ) can avoid much of the complexity of p(x) and have a concise analytical form known a priori. In principle, p(θ) can also be estimated from the parameters learned for labelled data (fitting the intuition that predictions for unlabelled data should be consistent with those of labelled data). We refer to SSL methods that rely on p(θ) as discriminative and formalise them with a hierarchical probabilistic model, analogous to that for generative approaches. Recent results (Berthelot et al., 2019b; a) demonstrate that discriminative SSL is orthogonal and complementary to methods that rely on p(x), such as data augmentation and consistency regularisation (Sajjadi et al., 2016b; Laine & Aila, 2017; Tarvainen & Valpola, 2017; Miyato et al., 2018) .

