MASKVIT: MASKED VISUAL PRE-TRAINING FOR VIDEO PREDICTION

Abstract

The ability to predict future visual observations conditioned on past observations and motor commands can enable embodied agents to plan solutions to a variety of tasks in complex environments. This work shows that we can create good video prediction models by pre-training transformers via masked visual modeling. Our approach, named MaskViT, is based on two simple design decisions. First, for memory and training efficiency, we use two types of window attention: spatial and spatiotemporal. Second, during training, we mask a variable percentage of tokens instead of a fixed mask ratio. For inference, MaskViT generates all tokens via iterative refinement where we incrementally decrease the masking ratio following a mask scheduling function. On several datasets we demonstrate that MaskViT outperforms prior works in video prediction, is parameter efficient, generates high-resolution videos (256 × 256) and can be easily adapted to perform goal-conditioned video prediction. Further, we demonstrate the benefits of inference speedup (up to 512×) due to iterative decoding by using MaskViT for planning on a real robot. Our work suggests that we can endow embodied agents with powerful predictive models by leveraging the general framework of masked visual modeling with minimal domain knowledge.

1. INTRODUCTION

Evidence from neuroscience suggests that human cognitive and perceptual capabilities are supported by a predictive mechanism to anticipate future events and sensory signals (Tanji & Evarts, 1976; Wolpert et al., 1995) . Such a mental model of the world can be used to simulate, evaluate, and select among different possible actions. This process is fast and accurate, even under the computational limitations of biological brains (Wu et al., 2016) . Endowing robots with similar predictive capabilities would allow them to plan solutions to multiple tasks in complex and dynamic environments, e.g., via visual model-predictive control (Finn & Levine, 2017; Ebert et al., 2018) . Predicting visual observations for embodied agents is however challenging and computationally demanding: the model needs to capture the complexity and inherent stochasticity of future events while maintaining an inference speed that supports the robot's actions. Therefore, recent advances in autoregressive generative models, which leverage Transformers (Vaswani et al., 2017) for building neural architectures and learn good representations via self-supervised generative pretraining (Devlin et al., 2019) , have not benefited video prediction or robotic applications. We in particular identify three technical challenges. First, memory requirements for the full attention mechanism in Transformers scale quadratically with the length of the input sequence, leading to prohibitively large costs for videos. Second, there is an inconsistency between the video prediction task and autoregressive masked visual pretraining -while the training process assumes partial knowledge of the ground truth future frames, at test time the model has to predict a complete sequence of future frames from scratch, leading to poor video prediction quality (Yan et al., 2021; Feichtenhofer et al., 2022) . Third, the common autoregressive paradigm effective in other domains would be too slow for robotic applications. To address these challenges, we present Masked Video Transformers (MaskViT): a simple, effective and scalable method for video prediction based on masked visual modeling. Since using pixels

VQGAN Iteration 1

Iteration 4 Iteration 16 (context) directly as frame tokens would require an inordinate amount of memory, we use a discrete variational autoencoder (dVAE) (Van Den Oord & Vinyals, 2017; Esser et al., 2021) that compresses frames into a smaller grid of visual tokens. We opt for compression in the spatial (image) domain instead of the spatiotemporal domain (videos), as preserving the correspondence between each original and tokenized video frame allows for flexible conditioning on any subset of frames -initial (past), final (goal), and possibly equally spaced intermediate frames. However, despite operating on tokens, representing 16 frames at 256 tokens per frame still requires 4, 096 tokens, incurring prohibitive memory requirements for full attention. Hence, to further reduce memory, MaskViT is composed of alternating transformer layers with non-overlapping window-restricted (Vaswani et al., 2017) spatial and spatiotemporal attention. … … t = T -1 t = 1 t = 0 t = T -1 t = 1 t = 0 t = T -1 t = 1 t = 0 t = T -1 t = 1 t = 0 (context) t = T -1 t = 1 t = 0 (context) t = T -1 t = 1 t = 0 (context) (context) (context) Bidirectional Window Transformer To reduce the inconsistency between the masked pretraining and the video prediction task and to speed up inference, we take inspiration from non-autoregressive, iterative decoding methods in generative algorithms from other domains (Sohl-Dickstein et al., 2015; Ho et al., 2020; Nichol & Dhariwal, 2021; Ghazvininejad et al., 2019; Chang et al., 2022) . We propose a novel iterative decoding scheme for videos based on a mask scheduling function that specifies, during inference, the number of tokens to be decoded and kept at each iteration. In contrast to autoregressive decoding, which involves predicting tokens one by one, our iterative decoding scheme is faster as the number of decoding iterations is significantly less than the number of tokens. A few initial tokens are predicted over multiple initial iterations, and then the majority of the remaining tokens can be predicted rapidly over the final few iterations. This brings us closer to the ultimate video prediction task, where only the first frame is known and all tokens for other frames must be inferred. To further close the training-test gap, during training we mask a variable percentage of tokens, instead of using a fixed masking ratio. This simulates the different masking ratios MaskViT will encounter during iterative decoding in the actual video prediction task. Through experiments on several publicly available real-world video prediction datasets (Ebert et al., 2017; Geiger et al., 2013; Dasari et al., 2019) , we demonstrate that MaskViT achieves competitive or state-of-the-art results in a variety of metrics. Moreover, MaskViT can predict considerably higher resolution videos (256 × 256) than previous methods. We also show the flexibility of MaskViT by adapting it to predict goal-conditioned video frames. In addition, thanks to iterative decoding, MaskViT is up to 512× faster than autoregressive methods, enabling its application for planning on a real robot ( § 4.5). These results indicate that we can endow embodied agents with powerful



h' // 4) x (w' // 4) Window 1 x h' x w'

Figure 1: MaskViT. (a) Training: We encode the video frames into latent codes via VQ-GAN. A variable number of tokens in future frames are masked, and the network is trained to predict masked tokens. A block in MaskViT consists of two layers with window-restricted attention: spatial and spatiotemporal. (b) Inference: Videos are generated via iterative refinement where we incrementally decrease the masking ratio following a mask scheduling function. Videos available at this project page.

