Hierarchical Planning with Latent World Models
Table of Contents
Notes on the paper: Hierarchical Planning with Latent World Models [pdf][arXiv]
Co-written by me and Notebook-LM.
Problem:
Model predictive control (MPC) with learned world models struggle with long-horizon control because of:
- Accumulation of prediction error
- Exponential growing of search space
Solution:
Perform hierarchical planning, i.e. planning across multiple time scales. This requires:
- Low level world model that takes current state and action, and predict state at next timestep
- High level world model that can take current state, and a macro-action and predict state at multiple timestep in the future
- High level planner take the high level world model and generates a seqeuence of subgoals to reach the final goal
- Low level planner takes the low level world model and generates a sequence of primitive actions to reach the immediate subgoal
Representations:
- Both high and low level world models work in same latent space (a compressed representation of the input state). So, subgoal from high level planner can be fed into the low level planner directly.
- Planning with High level world model is done in an abstract "macro-action" space which is a compressed representation of longer sequences of primitive actions.
Benefits:
- In real world Franka robotic arm task, flat RL has 0% success rate. But this approach gets upto 70% success rate.
- Efficient Inference compared to big VLM architectures trained on more data.
- Error accumulation: Low level model gives good prediction upto 1 seconds but for longer timescale, high level model is needed.
Limitations:
- Sensitive to the dimension of Macro action. Space of 4 dimension was perfect. Higher dimension leads to complex plans that the lower level model can't complete. With lower dimension, planner can't form valid plans.
- No active exploration: The world model is trained on offline data collected. For Franka arm: 130 hours of unlabeled video clips of real robots performing various manipulation tasks
1. Training Macro Action Representation
The macro-action space is trained jointly with the high-level world model through a learned action encoder that compresses variable-length sequences of primitive, low-level actions into compact latent representations.
Training process:
- Waypoint Selection: During training, the system extracts sequences from offline trajectory data and selects specific "waypoint" states separated by variable time steps.
- Action Encoding: The raw sequence of primitive actions executed between two consecutive waypoints is passed into the action encoder, which summarizes them into a single latent macro-action.
- Joint Optimization via Next-Latent Prediction: These encoded macro-actions are fed into the high-level world model alongside the latent features of the starting waypoint. The high-level model's objective is to predict the latent representation of the next waypoint.
Both the action encoder and the high-level world model are trained simultaneously using a "teacher-forcing" loss function, which minimizes the error (L1 distance) between the predicted next waypoint and the actual next waypoint from the training data.
Encoder architecture:
Varies based on the environment:
- Franka arm and Push-T: For visual robotic manipulation tasks transformer-based encoder where a CLS1 token is passed through an MLP head to output the latent action.
- For 2D maze navigation: a simple MLP architecture.
2. Training World Model
The world models in this framework are trained entirely using an offline learning paradigm on large, pre-collected datasets of task-agnostic, reward-free offline trajectories, meaning the model does not actually do any active exploration itself.
Data source:
- Real-World Robotic Tasks (Franka Arm): The models are trained on approximately 130 hours of pre-existing, unlabeled video clips of real robots performing various manipulation tasks, taken from the massive open-source DROID and RoboSet datasets.
- Push-T Simulation: The models use a pre-existing offline dataset consisting of 18,500 trajectories of a T-shaped block being pushed.
- Diverse Maze Navigation: The training data consists of 5 million transitions collected beforehand by placing an agent in 25 different maze maps and simply having it execute randomly sampled actions.
Training the Low-Level World Model:
The low-level model learns the fine-grained physics of the environment by predicting the immediate next step based on a primitive action. It is trained using a combination of two loss functions:
- Teacher-Forcing Loss: The model predicts the very next latent state, and the loss is the L1 distance (error) between its prediction and the actual next state from the offline data.
- Multi-Step Autoregressive Rollout Loss: To prevent errors from compounding quickly, the model is also forced to predict multiple steps ahead continuously, minimizing the error of the final predicted state across a longer sequence.
Training the High-Level World Model:
The high-level model learns long-horizon dynamics by skipping timesteps and predicting the latent representation of distant "waypoints".
- Take initial waypoint and a sequence of primitive actions
- Compress the action sequence to a macro action
- Predict the future waypoint
- Use teacher-forcing loss function: minimizing the L1 distance between the predicted and actual next waypoint.
Footnotes:
CLS (Classify Token) is a placeholder token used a the end of sequence