HWM Train and CEM Eval Push-T#210
Open
Hongyang-Du wants to merge 7 commits intogalilai-group:mainfrom
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds Hierarchical World Model support for Push-T using DINOv2 as the frozen visual encoder. It includes HWM training and hierarchical MPC evaluation on top of the existing low-level DINO-WM.
Files
stable_worldmodel/wm/hwm/modules.pyAdds action sequence encoders. The main one,
SequenceEncoder, compresses low-level action chunks into latent macro-actions used by the high-level world model.stable_worldmodel/wm/hwm/training.pyAdds the HWM training forward pass. It samples waypoint pairs, builds action chunks between them, encodes those chunks into
latent_action, and trains the high-level PreJEPA model to predict the target latent state.stable_worldmodel/wm/hwm/__init__.pyExports the HWM modules and training forward so they can be imported through
stable_worldmodel.wm.hwm.scripts/train/hwm.pyAdds the HWM training entrypoint. It builds the dataset, DINOv2 encoder, PreJEPA predictor, extra encoders, action encoder, optimizer, callbacks, and checkpoint saving.
scripts/train/config/hwm.yamlAdds the default Push-T HWM training config, including DINOv2 backbone settings, macro-action span, latent action dimension, action encoder hyperparameters, and trainer settings.
stable_worldmodel/policy.pyAdds hierarchical planning components:
HWMCostModel: high-level CEM cost model over latent macro-actions._FixedGoalCostModel: low-level CEM cost model toward a fixed subgoal embedding.HWMPolicy: coordinates L2 subgoal planning, L1 primitive action planning, and action buffering.scripts/plan/eval_hwm.pyAdds the Push-T HWM evaluation script. It loads the HWM and low-level DINO-WM checkpoints, builds the two CEM solvers, creates
HWMPolicy, and evaluates throughWorld.evaluate(...).scripts/plan/config/pusht_hwm.yamlAdds the Push-T HWM eval config, including checkpoint names, L2/L1 CEM hyperparameters, low-level frameskip, and dataset eval settings.
stable_worldmodel/wm/__init__.pyExports the new
hwmmodule from the world-model package.stable_worldmodel/wm/prejepa/__init__.pyExports PreJEPA building blocks needed by HWM loading/evaluation.
How To Train
How To Evaluate