Skip to content

HWM Train and CEM Eval Push-T#210

Open
Hongyang-Du wants to merge 7 commits intogalilai-group:mainfrom
Hongyang-Du:main
Open

HWM Train and CEM Eval Push-T#210
Hongyang-Du wants to merge 7 commits intogalilai-group:mainfrom
Hongyang-Du:main

Conversation

@Hongyang-Du
Copy link
Copy Markdown

Summary

This PR adds Hierarchical World Model support for Push-T using DINOv2 as the frozen visual encoder. It includes HWM training and hierarchical MPC evaluation on top of the existing low-level DINO-WM.

Files

stable_worldmodel/wm/hwm/modules.py

Adds action sequence encoders. The main one, SequenceEncoder, compresses low-level action chunks into latent macro-actions used by the high-level world model.

stable_worldmodel/wm/hwm/training.py

Adds the HWM training forward pass. It samples waypoint pairs, builds action chunks between them, encodes those chunks into latent_action, and trains the high-level PreJEPA model to predict the target latent state.

stable_worldmodel/wm/hwm/__init__.py

Exports the HWM modules and training forward so they can be imported through stable_worldmodel.wm.hwm.

scripts/train/hwm.py

Adds the HWM training entrypoint. It builds the dataset, DINOv2 encoder, PreJEPA predictor, extra encoders, action encoder, optimizer, callbacks, and checkpoint saving.

scripts/train/config/hwm.yaml

Adds the default Push-T HWM training config, including DINOv2 backbone settings, macro-action span, latent action dimension, action encoder hyperparameters, and trainer settings.

stable_worldmodel/policy.py

Adds hierarchical planning components:

  • HWMCostModel: high-level CEM cost model over latent macro-actions.
  • _FixedGoalCostModel: low-level CEM cost model toward a fixed subgoal embedding.
  • HWMPolicy: coordinates L2 subgoal planning, L1 primitive action planning, and action buffering.

scripts/plan/eval_hwm.py

Adds the Push-T HWM evaluation script. It loads the HWM and low-level DINO-WM checkpoints, builds the two CEM solvers, creates HWMPolicy, and evaluates through World.evaluate(...).

scripts/plan/config/pusht_hwm.yaml

Adds the Push-T HWM eval config, including checkpoint names, L2/L1 CEM hyperparameters, low-level frameskip, and dataset eval settings.

stable_worldmodel/wm/__init__.py

Exports the new hwm module from the world-model package.

stable_worldmodel/wm/prejepa/__init__.py

Exports PreJEPA building blocks needed by HWM loading/evaluation.

How To Train

python scripts/train/hwm.py

How To Evaluate

python scripts/plan/eval_hwm.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant