Scaling Offline Model-Based RL via Jointly-Optimized World-Action Model Pretraining

Jie Cheng1,2, Ruixi Qiao1, Gang Xiong1,2, Qinghai Miao2, Yingwei Ma3, Binhua Li3, Yongbin Li3*, Yisheng Lv1,2*
1State Key Laboratory of Multimodal Artificial Intelligence Systems, CASIA
2Artificial Intelligence, University of Chinese Academy of Sciences
3Alibaba Group

*Corresponding Authors

TL;DR: A single JOWA-150M agent masters 15 Atari games at 84.7% human-level
and 119.5% DQN-level, and can adapt to novel games with ~4 expert demos.

Abstract

A significant aspiration of offline reinforcement learning (RL) is to develop a generalist agent with high capabilities from large and heterogeneous datasets. However, prior approaches that scale offline RL either rely heavily on expert trajectories or struggle to generalize to diverse unseen tasks. Inspired by the excellent generalization of world model in conditional video generation, we explore the potential of image observation-based world model for scaling offline RL and enhancing generalization on novel tasks. In this paper, we introduce JOWA: Jointly-Optimized World-Action model, an offline model-based RL agent pretrained on multiple Atari games with 6 billion tokens data to learn general-purpose representation and decision-making ability. Our method jointly optimizes a world-action model through a shared transformer backbone, which stabilize temporal difference learning with large models during pretraining. Moreover, we propose a provably efficient and parallelizable planning algorithm to compensate for the Q-value estimation error and thus search out better policies. Experimental results indicate that our largest agent, with 150 million parameters, achieves 78.9% human-level performance on pretrained games using only 10% subsampled offline data, outperforming existing state-of-the-art large-scale offline RL baselines by 31.6% on averange. Furthermore, JOWA scales favorably with model capacity and can sample-efficiently transfer to novel games using only 5k offline fine-tuning data (approximately 4 trajectories) per game, demonstrating superior generalization.

Method Overview

Architecture of JOWA

Architecture of JOWA. We use a shared transformer backbone for both world modeling and Q-value criticism to enable joint optimization. VQ-VAE tokenizes images into visual tokens. The sum of vocabulary embeddings, position embeddings and task embeddings forms the input embeddings space for the transformer backbone. Training loss is the weighted sum of supervised prediction loss of the world-part module and conservative distributed TD-loss of action-part module.

To compensate for the optimal Q-value estimation error and thus search out better policies, we propose a provably efficient and parallelizable planning algorithm and derive the condition under which the search-based optimal Q-values have a lower upper-bound of error than TD learning-based optimal Q-values. The planning helps optimal inference during evaluation and sample-efficient transfer to novel games.

We first model the process of finding optimal actions within the imagined Markov Decision Process as a tree search problem, and then extend beam search as the practical and parallelizable planning algorithm. When both the beam width K and horizon H equal to 2, the process of planning is shown in the right figure:

Planning algorithm

Experiments

Scaling trends

scaling trend

We investigate algorithms' ability to leverage higher capacity architectures. The performance of JOWA with planning reliably increases as the model size grows and exhibits the steepest scaling curve among all algorithms. Moreover, the proposed planning algorithm improve the performance by a large margin, highlighting the great scalability potential of offline model-based RL.

Few-shot fine-tuning

fine tune

We fine-tune pretrained agents on 5 held-out games using uniformly subsampled 5k expert-level transitions (from last 20% of DQN-Replay) per game as the benchmark. These tiny amounts of transitions corresponding to approximately 4 trajectories from expert-level DQN-Replay per fine-tuned game on average, which is similar to the settings of few-shot learning and extremely challenging.

The fine-tuned JOWA-150M attains 64.7% IQM DQN-normalized score across 5 held-out games, outperforming baselines by 34.7% on averange. These results underscore JOWA's capacity for rapid and sample-efficient transfer to novel games, highlighting the efficacy of its learned general-purpose representation and decision-making capabilities.

Ablation of training choices

ablation

we conduct a series of controlled ablation studies to evaluate the impact of key design choices in JOWA and wish to offer valuable insights for future research in this domain. For time-saving, we consider a subset of 6 games in the experiments. We train all models for 1M gradient steps and fix the parameter size to 40M. The conclusions are summarized as follows:

(i) Joint optimization and distributional TD-loss are crucial for JOWA;
(ii) CQL regularization, world model for planning, and task embedding significantly improves the performance;
(iii) We have not observed significant improvements of multiple Q-heads over single Q-head, both for equal weight and random weight summation;
(iv) Using synthetic data generated by the world model in the pre-training phase hurts performance.

BibTeX

@article{cheng2024scaling,
    title={Scaling Offline Model-Based RL via Jointly-Optimized World-Action Model Pretraining},
    author={Cheng, Jie and Qiao, Ruixi and Xiong, Gang and Miao, Qinghai and Ma, Yingwei and Li, Binhua and Li, Yongbin and Lv, Yisheng},
    journal={arXiv preprint arXiv:2410.00564},
    year={2024}
}