THALAMUS: A BRAIN-INSPIRED ALGORITHM FOR BIOLOGICALLY-PLAUSIBLE CONTINUAL LEARNING AND DISENTANGLED REPRESENTATIONS

Abstract

Animals thrive in a constantly changing environment and leverage the temporal structure to learn well-factorized causal representations. In contrast, traditional neural networks suffer from forgetting in changing environments and many methods have been proposed to limit forgetting with different trade-offs. Inspired by the brain thalamocortical circuit, we introduce a simple algorithm that uses optimization at inference time to generate internal representations of the current task dynamically. The algorithm alternates between updating the model weights and a latent task embedding, allowing the agent to parse the stream of temporal experience into discrete events and organize learning about them. On a continual learning benchmark, it achieves competitive end average accuracy by mitigating forgetting, but importantly, by requiring the model to adapt through latent updates, it organizes knowledge into flexible structures with a cognitive interface to control them. Tasks later in the sequence can be solved through knowledge transfer as they become reachable within the well-factorized latent space. The algorithm meets many of the desiderata of an ideal continually learning agent in open-ended environments, and its simplicity suggests fundamental computations in circuits with abundant feedback control loops such as the thalamocortical circuits in the brain

1. INTRODUCTION

Animals thrive in a constantly changing environmental demands at many time scales. Biological brains seem capable of using these changes advantageously and leverage the temporal structure to learn causal and well-factorized representations (Collins & Koechlin, 2012; Yu et al., 2021; Herce Castañón et al., 2021) . In contrast, traditional neural networks suffer in such settings with sequential experience and display prominent interference between old and new learning limiting most training paradigms to using shuffled data (McCloskey & Cohen, 1989) 2021)). However, in addition to mitigating forgetting, several desirable properties in a continually learning agent have been recently suggested (Hadsell et al., 2020; Veniat et al., 2021) including: accuracy on many tasks at the end of a learning episode or at least fast adaptation and recovery of accuracy with minimal additional training. The ideal agent would also display knowledge transfer forward, to future tasks and backwards to previously learned tasks, but also transfer to tasks with slightly different computation and or slightly different input or output distributions (Veniat et al., 2021) . The algorithm should scale favorably with the number of tasks and maintain plasticity, or the capacity for further learning, Finally, the agent should ideally able to function unsupervised and not rely on access to task labels and task boundaries (Hadsell et al., 2020; Rao et al., 2019) . We argue for another critical feature: contextual behavioral, where the same inputs may require different responses at different times, a feature that might constrain the solution space to be of more relevance to brain function and to the full complexity of the world. A learning agent might struggle to identify reliable contextual signals in high dimensional input space, if they are knowable at all, and with many contextual modifiers it might not be feasible to experience all combinations sufficiently to develop associative responses. Neuroscience experiments have revealed a thalamic role in cognitive flexibility and switching between behavioral policies (Schmitt et al., 2017; Mukherjee et al., 2021) . The prefrontal cortex (PFC), linked to advanced cognition, shows representations of task variables and input to output transformations (Johnston et al., 2007; Mansouri et al., 2006; Rougier et al., 2005; Rikhye et al., 2018) , while the medio-dorsal thalamus shows representations of the task being performed (Schmitt et al., 2017; Rikhye et al., 2018) , and uncertainty about the current task (Mukherjee et al., 2021) . The mediodorasal thalamus, devoid of recurrent excitatory connections, and with extensive reciprocal connections to PFC is thought to gate its computations by selecting task relevant representations to enable flexible behavioral switching (Wang & Halassa, 2021; Hummos et al., 2022) . The connections from cortex to thalamus engage in error feedback control to contextualize perceptual attention (Briggs, 2020), motor planning (Kao et al., 2021) , and cognitive control (Halassa & Kastner, 2017; Wang & Halassa, 2021) and representations of error feedback can be observed in the thalamus (Ide & Li, 2011; Jakob et al., 2021; Wang et al., 2020) . Moreover, these thalamic representations can be composed to produce complex behaviors (Logiaco et al., 2021) . In this paper, we take inspiration from the thalamocortical circuit and develop a simple algorithm that uses optimization at inference time to produce internally generated contextual signals allowing the agent to parse its temporal experience into discrete events and organize learning about them (Fig 1 ). Our contributions are as follows. We show that a network trained on tasks sequentially using the traditional weight updates, with task identifiers provided, can be used to identify tasks dynamically by taking gradient steps in the latent space (latent updates). We then consider unlabeled tasks and simply alternate weight updates and latent updates to arrive at Thalamus, an algorithm capable of parsing sequential experience into events (tasks) and contextualizing its response through the simple dynamics of gradient descent. The algorithm shows generalization to novel tasks and can discover temporal events at any arbitrary time-scale and does not require a pre-specified number of events or clusters. Additionally, it does not require distinction between a training phase or a testing phase and is accordingly suitable for open-ended learning.

2. MODEL

We consider the setting where tasks arrive sequentially and each task k is described by a dataset D k of inputs, outputs, and task identifiers (x k , y k , i k ). We examine both settings where the task identifier is and is not available to the learning agent as input. As a learning agent, we take a function f θ with parameters θ that takes input x k and latent embedding z:



. Many recent methods advanced the flexibility of neural networks (for recent reviews, see Parisi et al. (2019); Hadsell et al. (2020); Veniat et al. (

Figure 1: Model schematic. Controlled triggering of latent updates and weight updates and alternating them allows for unsupervised task discovery. Latent updates retrieve previously learned tasks or choose a new embedding for new ones (PFC: prefrontal cortex, MD: mediodorsal thalamus)

