Skip to content

[Feature] WorldModel + WorldModelLoss — general model-based RL abstraction#3783

Merged
vmoens merged 4 commits into
pytorch:mainfrom
theap06:feat/World_Model
Jun 10, 2026
Merged

[Feature] WorldModel + WorldModelLoss — general model-based RL abstraction#3783
vmoens merged 4 commits into
pytorch:mainfrom
theap06:feat/World_Model

Conversation

@theap06

@theap06 theap06 commented May 20, 2026

Copy link
Copy Markdown
Contributor

Fixes #3774

  • Adds `WorldModel(TensorDictModuleBase)`: a key-driven, architecture-agnostic composition layer for encoder + dynamics + reward/done/decoder heads, with `encode`, `step`, `decode` shortcuts and a `rollout(start_td, policy, horizon)` method whose `[batch, horizon]` output matches `EnvBase.rollout` — making imagined trajectories drop-in compatible with replay buffers, GAE, and loss modules.
  • Adds `WorldModelLoss(LossModule)`: follows the standard `_AcceptedKeys` / `set_keys()` / `forward() → TensorDict` pattern; supports configurable sub-losses (`reward`, `done`, `reconstruction`, `latent`) with per-loss weights and distance-function choice.
  • Existing Dreamer components (`WorldModelWrapper`, `DreamerEnv`, `RSSMRollout`, `DreamerModelLoss/ActorLoss/ValueLoss`) are unchanged.

Motivation

TorchRL ships Dreamer-specific model-based components but no general abstraction. Users implementing MBPO, TD-MPC, PlaNet, or any custom world model must hand-wire `TensorDictModule` chains and write bespoke multi-step rollout loops. This adds the missing general layer.

Test plan

  • `python -m pytest test/test_world_model.py -v` — 19 new tests covering forward, encode/step/decode shortcuts, nested keys, rollout shape and early termination, replay buffer compatibility, all four loss types, `set_keys`, per-loss weights, and gradient flow
  • `python -c "from torchrl.modules import WorldModel; from torchrl.objectives import WorldModelLoss; print('OK')"` — import smoke test
  • Existing dreamer-related tests (`pytest test/ -k "dreamer" --ignore=test/llm`) — 409 tests, all passing

@pytorch-bot

pytorch-bot Bot commented May 20, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3783

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 6 Pending, 1 Unrelated Failure

As of commit 24f33c9 with merge base 4679d0a (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 20, 2026
@github-actions

Copy link
Copy Markdown
Contributor

⚠️ PR Title Label Error

PR title must start with a label prefix in brackets (e.g., [BugFix]).

Current title: WorldModel + WorldModelLoss — general model-based RL abstraction

Supported Prefixes (case-sensitive)

Your PR title must start with exactly one of these prefixes:

Prefix Label Applied Example
[BugFix] BugFix [BugFix] Fix memory leak in collector
[Feature] Feature [Feature] Add new optimizer
[Doc] or [Docs] Documentation [Doc] Update installation guide
[Refactor] Refactoring [Refactor] Clean up module imports
[CI] CI [CI] Fix workflow permissions
[Test] or [Tests] Tests [Tests] Add unit tests for buffer
[Environment] or [Environments] Environments [Environments] Add Gymnasium support
[Data] Data [Data] Fix replay buffer sampling
[Performance] or [Perf] Performance [Performance] Optimize tensor ops
[BC-Breaking] bc breaking [BC-Breaking] Remove deprecated API
[Deprecation] Deprecation [Deprecation] Mark old function
[Quality] Quality [Quality] Fix typos and add codespell

Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).

@github-actions

Copy link
Copy Markdown
Contributor

⚠️ PR Title Label Error

PR title must start with a label prefix in brackets (e.g., [BugFix]).

Current title: WorldModel + WorldModelLoss — general model-based RL abstraction

Supported Prefixes (case-sensitive)

Your PR title must start with exactly one of these prefixes:

Prefix Label Applied Example
[BugFix] BugFix [BugFix] Fix memory leak in collector
[Feature] Feature [Feature] Add new optimizer
[Doc] or [Docs] Documentation [Doc] Update installation guide
[Refactor] Refactoring [Refactor] Clean up module imports
[CI] CI [CI] Fix workflow permissions
[Test] or [Tests] Tests [Tests] Add unit tests for buffer
[Environment] or [Environments] Environments [Environments] Add Gymnasium support
[Data] Data [Data] Fix replay buffer sampling
[Performance] or [Perf] Performance [Performance] Optimize tensor ops
[BC-Breaking] bc breaking [BC-Breaking] Remove deprecated API
[Deprecation] Deprecation [Deprecation] Mark old function
[Quality] Quality [Quality] Fix typos and add codespell

Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).

2 similar comments
@github-actions

Copy link
Copy Markdown
Contributor

⚠️ PR Title Label Error

PR title must start with a label prefix in brackets (e.g., [BugFix]).

Current title: WorldModel + WorldModelLoss — general model-based RL abstraction

Supported Prefixes (case-sensitive)

Your PR title must start with exactly one of these prefixes:

Prefix Label Applied Example
[BugFix] BugFix [BugFix] Fix memory leak in collector
[Feature] Feature [Feature] Add new optimizer
[Doc] or [Docs] Documentation [Doc] Update installation guide
[Refactor] Refactoring [Refactor] Clean up module imports
[CI] CI [CI] Fix workflow permissions
[Test] or [Tests] Tests [Tests] Add unit tests for buffer
[Environment] or [Environments] Environments [Environments] Add Gymnasium support
[Data] Data [Data] Fix replay buffer sampling
[Performance] or [Perf] Performance [Performance] Optimize tensor ops
[BC-Breaking] bc breaking [BC-Breaking] Remove deprecated API
[Deprecation] Deprecation [Deprecation] Mark old function
[Quality] Quality [Quality] Fix typos and add codespell

Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).

@github-actions

Copy link
Copy Markdown
Contributor

⚠️ PR Title Label Error

PR title must start with a label prefix in brackets (e.g., [BugFix]).

Current title: WorldModel + WorldModelLoss — general model-based RL abstraction

Supported Prefixes (case-sensitive)

Your PR title must start with exactly one of these prefixes:

Prefix Label Applied Example
[BugFix] BugFix [BugFix] Fix memory leak in collector
[Feature] Feature [Feature] Add new optimizer
[Doc] or [Docs] Documentation [Doc] Update installation guide
[Refactor] Refactoring [Refactor] Clean up module imports
[CI] CI [CI] Fix workflow permissions
[Test] or [Tests] Tests [Tests] Add unit tests for buffer
[Environment] or [Environments] Environments [Environments] Add Gymnasium support
[Data] Data [Data] Fix replay buffer sampling
[Performance] or [Perf] Performance [Performance] Optimize tensor ops
[BC-Breaking] bc breaking [BC-Breaking] Remove deprecated API
[Deprecation] Deprecation [Deprecation] Mark old function
[Quality] Quality [Quality] Fix typos and add codespell

Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).

@theap06 theap06 changed the title WorldModel + WorldModelLoss — general model-based RL abstraction [Feature] WorldModel + WorldModelLoss — general model-based RL abstraction May 20, 2026
@github-actions github-actions Bot added the Feature New feature label May 20, 2026
@theap06

theap06 commented May 21, 2026

Copy link
Copy Markdown
Contributor Author

@vmoens @elin-bdai this was my attempt at creating a World Model abstraction. Lmk if you have any feedback!

@theap06

theap06 commented May 22, 2026

Copy link
Copy Markdown
Contributor Author

@vmoens I believe this issue is tied to the flaky tests from earlier. Lmk if the design seems sound to you.

@theap06 theap06 force-pushed the feat/World_Model branch from ddb8e36 to b21a251 Compare May 26, 2026 07:30

@vmoens vmoens left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why WorldModel is a not an env - or a module that's used by a model-based env to generate rollouts. That's the usual way to think about this in torchrl (see Dreamer for example). Can you elaborate on this choice>

Comment thread torchrl/modules/tensordict_module/world_models.py
theap06 added 2 commits June 8, 2026 19:09
Per review feedback (#vmoens): move imagined-rollout responsibility off
the WorldModel module and onto a new WorldModelEnv(ModelBasedEnvBase),
matching the pattern used by DreamerEnv / ImaginedEnv. Eliminates the
divergent rollout semantics WorldModel.rollout had vs EnvBase.rollout.
@theap06 theap06 requested a review from vmoens June 9, 2026 02:09

@vmoens vmoens left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor comments but otherwise LGTM!

Comment thread torchrl/envs/model_based/world_model_env.py Outdated
Comment thread test/test_world_model.py Outdated
Comment thread torchrl/envs/model_based/world_model_env.py Outdated
Comment thread torchrl/envs/model_based/world_model_env.py Outdated
Comment thread torchrl/envs/model_based/world_model_env.py Outdated
Comment thread torchrl/envs/model_based/world_model_env.py Outdated
Comment thread torchrl/objectives/world_model_loss.py Outdated
Per @vmoens review comments on the WorldModelEnv PR:

- WorldModelEnv:
  - Annotate `world_model` parameter as `WorldModel` (via TYPE_CHECKING
    to avoid the torchrl.modules <-> torchrl.envs circular import).
  - Replace `.clone(recurse=False)` with `.copy()`.
  - Drop unnecessary `{}` in `TensorDict(batch_size=..., device=...)`.
  - Iterate spec leaves with `include_nested=True, leaves_only=True`
    so nested observation/reward/done specs survive `_step` end-to-end.
  - Use `full_*_spec.zero()` (and `full_*_spec[key].zero()`) instead of
    hand-rolled `torch.zeros(...)` for default reward/done/terminated
    tensors, keeping shapes / dtypes consistent with the specs.
- WorldModelLoss: drop unused `batch_size=[]` from final TensorDict
  construction.
- Tests: consolidate `test/test_world_model.py` (deleted) into:
  - `test/envs/test_model_based.py` — TestWorldModelForward,
    TestWorldModelEnv (and a `_SpecOnlyEnv` helper).
  - `test/objectives/test_dreamer.py` — TestWorldModelLoss.

No public-API changes. 22/22 affected tests pass locally.
@theap06 theap06 requested a review from vmoens June 10, 2026 07:43

@vmoens vmoens left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
do you think we should consider subclassing the dreamer classes with these for a more homogeneous implementation?

@vmoens vmoens merged commit 78c12fe into pytorch:main Jun 10, 2026
104 of 108 checks passed
@theap06

theap06 commented Jun 10, 2026

Copy link
Copy Markdown
Contributor Author

LGTM do you think we should consider subclassing the dreamer classes with these for a more homogeneous implementation?

I can open an issue with the refactors for the existing dreamer classes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Documentation Improvements or additions to documentation Feature New feature Integrations/torch_geometric Integrations Modules Objectives

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] torchrl.modules.WorldModel — A General TensorDict-Native World Model Abstraction

2 participants