Recurrent World Models
Table of Contents
Notes on the paper "Recurrent World Models Facilitate Policy Evolution" [pdf] [website]
They train a model to solve RL problem by first training a world model on data collected by random agent, and then training an RL agent in the virtual world generated by that world model. The world model is a based on a RNN, and thus the name "Recurrent" World Models.
1. Architecture
This model has the following components:
- VAE Encoder \(V\) : Takes observation as input and gives latent vector \(z_t\)
- Controller \(C\): Takes \(z_t\) and \(h_t\) and gives the action to take \(a_t\)
- Predictive Model \(M\) : (MDN-RNN) takes latent vector of observation \(z_t\), current hidden state \(h_{t}\) and current action \(a_{t}\) as input and gives new hidden state \(h_{t+1}=f(h_t, z_t, a_t)\) and prediction of new latent state vector \(P(z_{t+1} | z_{t}, h_{t}, a_t)\)
During Inference:
- At each step the agent recieves observation from the environment
- Vision Model (\(V\)) encodes it to low dimensional latent vector (\(z_t\))
- Memory RNN (\(M\)) uses \(z_t\) and \(h_{t}\) to get new hidden state \(h_{t+1}\)
- Controller \(C\) uses \(z_t\) and \(h_t\) to perform actions
Figure 1: Model Use during gameplay/inference
During Training:
- First \(V\) is trained to reconstruct the observation.
- Then \(M\) is trained to model the distribution of latent vector \(z\)
- Finally \(C\) is trained to maximize expected cumulative reward of rollout:
- We generate \(z_t\) using the predictive model \(M\) and train the controller on the generated inputs.
- The action of controller is feedback into the model \(M\) to generate next state \(z_{t+1}\), \(h_{t+1}\)
Figure 2: RNN with MDN (Mixture Density Network) output to sample prediction of next latent vector \(z\)
To know when the environment ends, the model \(M\) also predicts the done
state in addition to the next observation. The model would also need to predict the reward but for the problem in the papers the reward is 1 for each step survived.
2. Learning Algorithm
Supervised learning is used for learning encoder and predictive model while the controller is trained using Reinforcement Learning. In this case the controller is a linear layer, with few parameters so they use Evolution Strategy to solve the RL problem, specifically using Covariance-Matrix Adaptaion Evolution Strategy (CMA-ES).
Connection with brain:
The brain predicts future sensor data given our current motor actions. There is evidence for this from experiments in mouse. When the prediction and actual sensor data mismatch, neuron activity increase [See Predictive Coding]. Similarly, the model \(M\) here also predicts the sensory input \(z_{t+1}\) that the model is going to recieve. So, what is the hidden state vector \(h\) in this analogy? We also want to keep track of and predict the underlying next world state (which is not necessary observable). The hidden state vector \(h\) encodes this information. Question: How can we measure how well does it do this?
What we perceive in a moment is influenced by the prediction of the future based on our internal model. For example while reading we sometimes fill in the words even when they are missing, or sometimes read the sentence in one way even though something different is written there. In this method, \(h\) is our internal state, and \(z\) is what we perceive. During inference, \(z\) is not influenced by \(h\). Question: How could we modify this method to include this phenomenon of "perception being guided by expectation" seen in humans, and is it a good idea to do so?
Pitfalls of learning in dreams:
The controller can exploit the imperfections of the generated environment and the information from hidden state. Predicting a distribution \(P(z_{t+1})\) and sampling using a temperature parameter \(\tau\) instead of just predicting a deterministic state \(z_{t+1} = M(z_t, h_{t-1}, a_t)\) helps mitigate this in some way. The choice of temperature parameter \(\tau\) is important. If we train \(C\) with low \(\tau\) then \(M\) is not able to transition to another mode in the mixture of Gaussian model because \(C\) exploits \(M\).
Question: What else can be done to learn in modelled environment without exploiting the imperfections?
- Online training in real world after training in the simulated environment? Do humans do this? It seems like so. When training in reals world \(z\) and \(h\) both would be used but the \(z\) won't be the prediction of model but the encoding of real observation.
Do we also learn in our world model?:
# After all, our agent does not directly observe the reality, but merely sees what the world model lets it see.
In humans, is it always so? We humans also learn in the environment that the world model inside our brain generates. When we do planning, it is us actively generating the predictions, taking actions and updading both our plan (the controller) and the model. But do we also learn learn directly in the real world? Perhaps we learn at both places: inside the world model and in the real world (online training) too.
3. Other details
- Benefits of training VAE independently is that we can use the encoder for other tasks. The demerit is that unsupervised learning of VAE would encode task irrelevant details too. Training \(V\) and \(M\) together would be better.
- Benefits of Evolution Strategy (ES):
- only cumulative reward is required
- is highly parallelizable
- How do you output a probability distribution? In this paper they do it as a mixture of gaussian distribution. Known as Mixture Density Network. Then they sample the next latent state vector \(z_t\) from the at distribution.
- # Although VAE model is just a single diagonal Gaussian distribution. Mixture of Gaussian makes it easier to model logic behind complicated environments with discrete random states.
- The cumulative reward for the environments in this paper is defined to be the number of time steps the agent manages to stay alive. And the environment runs for a maximum of 2100 time steps.
- The model is not used to generate rollouts during inference to do planning. The controller directly acts on current input latent state \(z_t\) and current hidden state \(h_t\).
4. Model yourself too
Muscle Memory & sense of self:
In this problem, training the world model based on observation from random agent was sufficient but when the environment becomes complex we need to iteratively train the model and controller hand in hand. In that case, it might also be good for the model to predict the actions of the controller. For example, after the controller learns to walk, this skill can be subsumed in the world model (because world model has higher capacity) and then the controller can learn the higher level task of navigation.
In humans, say for playing piano, initially we need to put in conscious effort but after that it becomes second nature, the muscle memory takes care of translating notes to finger movements. And the conscious effort can be put into figuring out which notes to play. I believe, in humans the sense of self is the part of world model in the brain that predict the actions of itself i.e. the human.
5. What next
Previous work:
- Learning to think,
- PILCO (uses a Gaussian Process to model system dyanmics, but this only worked for low dimensional observation.),
Follow ups:
- Memory module
- Generative distributed memory
- Generative Temporal Models with Memory
- ES for RL with high dimensional input:
- Autoencoder-augmented Neuroevolution for Visual Doom Playing
- A Neuroevolution Approach to General Atari Game Playing
- Evolving Large-Scale Neural Networks for Vision-Based Reinforcement Learning
- Hierarchical planning, environment generation
- See NVIDIA Cosmos-1 [YouTube]
- Behavioural Replay: Behaviour of teacher net is compressed into a student net to avoid forgetting old prediction and control skills when learning new ones
- PowerPlay: Training an Increasingly General Problem Solver by Continually Searching for the Simplest Still Unsolvable Problem [2011]
- First Experiments with PowerPlay