AutoML.org

Freiburg-Hannover-Tübingen

Dreaming of Many Worlds: Learning Contextual World Models Aids Zero-Shot Generalization

Authors: Sai Prasanna Raman, Karim Farid, Raghu Rajan, André Biedenkapp

Zero-shot generalization (ZSG) to unseen dynamics is a major challenge for creating generally capable embodied agents. To address the broader challenge, we start with the simpler setting of contextual reinforcement learning (cRL), assuming observability of the context values that parameterize the variation in the system’s dynamics, such as the mass of a robot. To this end, we propose the contextual recurrent state-space model (cRSSM), which augments the world model of Dreamer (v3) to incorporate context for inferring latent Markovian states from the observations and modeling the latent dynamics. Our experiments show that such principled incorporation of the context improves the ZSG of the policies trained on the “dreams” of the world model. We further find qualitatively that our approach allows Dreamer to disentangle the latent state from context, allowing it to extrapolate (factually and counterfactually) its dreams to the many worlds of unseen contexts.

Paper Link

Contextual World Models and the Promise of ZSG

Contextual world models provide a framework that augments world models for covering problems with diverse dynamics, rewards, and tasks, which, if incorporated properly into model-based reinforcement learning (MbRL), could afford a promising avenue for zero-shot generalization (ZSG) without the need to adapt weights. The context refers to a set of parameters within a (partially observable) Markov decision process ((PO)MDP) that remains constant during an episode but can vary across episodes, thus influencing both the dynamics and rewards. For instance, factors such as the height of a robot, the mass of a carried load, or the strength of an actuator can represent such contextual parameters. This work focuses on (PO)MDPs with observed context as an initial step toward studying the ZSG problem in mbRL. Particularly, we explore how the SOTA MbRL agent, Dreamer (v3), generalises when observing a novel context it did not encounter in training.

ZSG Evaluation Scheme

In order to systematically study the ZSG performance of the models we evaluate it is the ability its generalization ability on

  1. Interpolation (I): Evaluation contexts are selected fully within the training range.
  2. Inter+Extrapolation (I+E): Evaluation contexts are selected to be within the training distribution for one context dimension and out-of-distribution (OOD) for another. This evaluation setting only applies to agents trained in the dual context variation setting.
  3. Extrapolation (E): Evaluation contexts are fully OOD, as they are selected outside the training context set limits.

Contextual RSSM

We employ a novel contextual recurrent state space model (cRSSM), building on Dreamer’s RSSM world model, and systematically incorporate context based on the contextual POMDP formulation. To showcase the zero-shot generalization (ZSG) capabilities of our approach, we compare it against two naive methods of context integration: Hidden Context, where the model is trained on multiple contexts but without access to them as observable information, and Concat Context, where the context is provided as an additional observation. To evaluate the impact of training on a context distribution versus a single context for Dreamer’s ZSG performance, we include a Default agent trained solely on a fixed ‘default’ context for comparison.

We evaluate our approach on two environments and different context choices: Cartpole (discrete control), where the context varies through changes in gravity and pole length, and Walker Walk (continuous control), with varying gravity and actuator strength. In both environments, individually and aggregately, we demonstrate that our method outperforms naive context incorporation methods and maintains composure in terms of ZSG.

Beyond task performance, we visually investigate our method’s ability to semantically understand the context and ground it in the image space in comparison to the Concat Context method when asking it to extrapolate, factually or counterfactually, to contexts beyond that observed in training. We show that our method can disentangle both the state and the context and be more faithful in both inference and imagination to the observed (factual or counterfactual) context as a source of ground truth information. 

Conclusion

Our findings show that explicitly conditioning on context significantly improves zero-shot generalization (ZSG), as demonstrated by the superior performance of our cRSSM over naive methods, such as adding context as an observation. This advantage likely stems from the disentanglement of latent states from the context, enabling more robust and generalizable inference. In summary, our approach offers a principled way to leverage contextual information in reinforcement learning tasks for better ZSG.

Check the full paper: https://rlj.cs.umass.edu/2024/papers/RLJ_RLC_2024_167.pdf

Back