REM: ROUTING ENTROPY MINIMIZATION FOR CAPSULE NETWORKS

Abstract

Capsule Networks are biologically-inspired neural network models, but their interpretability still need to be further investigated. One of their main innovations relies on the routing mechanism which extracts a parse tree: its main purpose is to explicitly build relationships between capsules. However, their true potential has not surfaced yet: these relationships are extremely heterogeneous and difficult to understand, as the intra-class extracted parse trees are very different from each other. A school of thoughts, giving-up on this side, propose less interpretable versions of Capsule Networks without routing. This paper proposes REM, a technique which minimizes the entropy of the parse tree-like structure. We accomplish this by driving the model parameters distribution towards low entropy configurations, using a pruning mechanism as a proxy. Thanks to REM, we generate a significantly lower number of parse trees, with essentially no performance loss, showing also that Capsule Networks build stronger and more stable relationships between capsules.

1. INTRODUCTION

Capsule Networks (CapsNets) (Sabour et al., 2017; Hinton et al., 2018; Kosiorek et al., 2019) were recently introduced to overcome the shortcomings of Convolutional Neural Networks (CNNs). CNNs loose the spatial relationships between its parts because of max pooling layers, which progressively drop spatial information (Sabour et al., 2017) . Furthermore, CNNs are also commonly known as "black-box" models: most of the techniques providing interpretation over the model are post-hoc: they produce localized maps that highlight important regions in the image for predicting objects (Selvaraju et al., 2017) . CapsNets attempt to preserve and leverage an image representation as a hierarchy of parts, carving-out a parse tree from the networks. This is possible thanks to the iterative routing mechanism (Sabour et al., 2017) which models the connections between capsules. This can be seen as a parallel attention mechanism, where each active capsule can choose a capsule in the layer above to be its parent in the tree (Sabour et al., 2017) . Therefore, CapsNets can produce interpretable representations encoded in the architecture itself (Sabour et al., 2017) yet can be still successfully applied to a number of applicative tasks (Zhao et al., 2019; Paoletti et al., 2018; Afshar et al., 2018) . However, understanding what really happens inside a CapsNet is still an open challenge. For a given input image, there are too many active co-coupled capsules, making the routing algorithm connections still difficult to understand, as the coupling coefficients typically have similar values, not exploiting the routing algorithm potential (Gu & Tresp, 2020) . On the other hand, we would like for a given image to activate stronger and fewer connections between capsules, so that understanding and interpreting the parts-wholes relationships is a more straightforward process. To encourage this, we impose sparsity and entropy constraints. Furthermore, backward and forward passes of a CapsNet come at an enormous computational cost, since the number of trainable parameters is very high. For example, the CapsNet model deployed on the MNIST dataset by Sabour et al. ( 2017) is composed by an encoder and a decoder part. The full architecture has 8.2M of parameters. Do we really need such an amount of trainable parameters to achieve competitive results on such a task? Recently, many pruning methods were applied to CNNs in order to reduce the complexity of the networks, enforcing sparse topologies (Tartaglione et al., 2018; Molchanov et al., 2017; Louizos et al., 2018) : is it possible to tailor one of these approaches with not only the purpose of lowering the parameters, but aiding the model's interpretability? This work introduces REM (Routing Entropy Minimization) for CapsNets, which moves some steps towards the interpretability of the routing algorithm of CapsNets. Pruning can effectively reduce the overall entropy of the connections of the parse tree-like structure encoded in a CapsNet, because in low pruning regimes it removes noisy couplings which cause the entropy to increase considerably. We collect the coupling coefficients studying their frequency and cardinality, observing lower intra-class conditional entropy: the pruned version adds a missing explicit prior in the routing mechanism, grounding the coupling of the unused primary capsules disallowing fluctuations under the same baseline performance on the validation/test set. This implies that the parse trees are significantly less, hence more stable for the pruned models. The rest of the paper is organized as follows: in Section 2 we introduce some of the basic concepts of CapsNets and their related works, in Section 3 we describe our technique called REM, in Section 4 we investigate the effectiveness of our method by testing it on many datasets and finally we discuss the conclusion of our work.

2. BACKGROUND AND RELATED WORK

This section first describes the fundamental aspects of CapsNets and their routing algorithm introduced by Sabour et al. ( 2017). Then, we review the literature especially related to sparsity in CapsNets. Capsule Networks Fundamentals. CapsNets group neurons into capsules, namely activity vectors, where each capsule accounts for an object of one of its parts. Each element of these vectors accounts for different properties of the object such as its pose and other properties like color, deformation, etc. The magnitude of a capsule stands for the probability of existence of that object in the image. Typically, a CapsNet is composed by at least two capsule layers, called PrimaryCaps and DigitCaps (also called OutputCaps), with a total of I and J capsules respectively. The poses of L-th capsules u i , called primary capsules, are built upon convolutional layers. In order to compute the poses of the capsules of the next layer L + 1, an iterative routing mechanism is performed. Each capsule u i makes a prediction ûj|i , thanks to a transformation matrix W ij , for the pose of an upper layer capsule j ûj|i = W ij u i . (1) Then, the total input s j of capsule j of the DigitCaps layer is computed as the weighted average of votes ûj|i s j = i c ij ûj|i , where c ij are the coupling coefficients between a primary capsule i and an output capsule j. The pose v j of an output capsule j is then defined as the normalized "squashed" s j v j = squash(s j ) = ∥s j ∥ 2 1 + ∥s j ∥ 2 s j ∥s j ∥ . So the routing algorithm computes the poses of output capsules and the connections between capsules of consecutive layers. The coupling coefficients are computed dynamically by the routing algorithm and they are dependent on the input. The coupling coefficients are determined by a "routing softmax" activation function, whose initial logits b ij are the log prior probabilities the i-th capsule should be coupled to the j-th one c ij = softmax(b ij ) = e bij k .e b ik At the first step of the routing algorithm they are equals and then they are refined by measuring the agreement between the output v j of the j-th capsule and the prediction ûj|i for a given input. The agreement is defined as the scalar product a ij = v j • ûj|i . At each iteration, the update rule for the logits is b ij ← b ij + a ij . The steps defined in equation 2, equation 3, equation 4, equation 5 are repeated for the t iterations of the routing algorithm. The cross entropy loss is replaced with the margin loss.

