[Feature] WorldModel + WorldModelLoss — general model-based RL abstraction#3783
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3783
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 6 Pending, 1 Unrelated FailureAs of commit 24f33c9 with merge base 4679d0a ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
| Prefix | Label Applied | Example |
|---|---|---|
[BugFix] |
BugFix | [BugFix] Fix memory leak in collector |
[Feature] |
Feature | [Feature] Add new optimizer |
[Doc] or [Docs] |
Documentation | [Doc] Update installation guide |
[Refactor] |
Refactoring | [Refactor] Clean up module imports |
[CI] |
CI | [CI] Fix workflow permissions |
[Test] or [Tests] |
Tests | [Tests] Add unit tests for buffer |
[Environment] or [Environments] |
Environments | [Environments] Add Gymnasium support |
[Data] |
Data | [Data] Fix replay buffer sampling |
[Performance] or [Perf] |
Performance | [Performance] Optimize tensor ops |
[BC-Breaking] |
bc breaking | [BC-Breaking] Remove deprecated API |
[Deprecation] |
Deprecation | [Deprecation] Mark old function |
[Quality] |
Quality | [Quality] Fix typos and add codespell |
Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).
|
| Prefix | Label Applied | Example |
|---|---|---|
[BugFix] |
BugFix | [BugFix] Fix memory leak in collector |
[Feature] |
Feature | [Feature] Add new optimizer |
[Doc] or [Docs] |
Documentation | [Doc] Update installation guide |
[Refactor] |
Refactoring | [Refactor] Clean up module imports |
[CI] |
CI | [CI] Fix workflow permissions |
[Test] or [Tests] |
Tests | [Tests] Add unit tests for buffer |
[Environment] or [Environments] |
Environments | [Environments] Add Gymnasium support |
[Data] |
Data | [Data] Fix replay buffer sampling |
[Performance] or [Perf] |
Performance | [Performance] Optimize tensor ops |
[BC-Breaking] |
bc breaking | [BC-Breaking] Remove deprecated API |
[Deprecation] |
Deprecation | [Deprecation] Mark old function |
[Quality] |
Quality | [Quality] Fix typos and add codespell |
Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).
2 similar comments
|
| Prefix | Label Applied | Example |
|---|---|---|
[BugFix] |
BugFix | [BugFix] Fix memory leak in collector |
[Feature] |
Feature | [Feature] Add new optimizer |
[Doc] or [Docs] |
Documentation | [Doc] Update installation guide |
[Refactor] |
Refactoring | [Refactor] Clean up module imports |
[CI] |
CI | [CI] Fix workflow permissions |
[Test] or [Tests] |
Tests | [Tests] Add unit tests for buffer |
[Environment] or [Environments] |
Environments | [Environments] Add Gymnasium support |
[Data] |
Data | [Data] Fix replay buffer sampling |
[Performance] or [Perf] |
Performance | [Performance] Optimize tensor ops |
[BC-Breaking] |
bc breaking | [BC-Breaking] Remove deprecated API |
[Deprecation] |
Deprecation | [Deprecation] Mark old function |
[Quality] |
Quality | [Quality] Fix typos and add codespell |
Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).
|
| Prefix | Label Applied | Example |
|---|---|---|
[BugFix] |
BugFix | [BugFix] Fix memory leak in collector |
[Feature] |
Feature | [Feature] Add new optimizer |
[Doc] or [Docs] |
Documentation | [Doc] Update installation guide |
[Refactor] |
Refactoring | [Refactor] Clean up module imports |
[CI] |
CI | [CI] Fix workflow permissions |
[Test] or [Tests] |
Tests | [Tests] Add unit tests for buffer |
[Environment] or [Environments] |
Environments | [Environments] Add Gymnasium support |
[Data] |
Data | [Data] Fix replay buffer sampling |
[Performance] or [Perf] |
Performance | [Performance] Optimize tensor ops |
[BC-Breaking] |
bc breaking | [BC-Breaking] Remove deprecated API |
[Deprecation] |
Deprecation | [Deprecation] Mark old function |
[Quality] |
Quality | [Quality] Fix typos and add codespell |
Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).
|
@vmoens @elin-bdai this was my attempt at creating a World Model abstraction. Lmk if you have any feedback! |
|
@vmoens I believe this issue is tied to the flaky tests from earlier. Lmk if the design seems sound to you. |
vmoens
left a comment
There was a problem hiding this comment.
I'm not sure why WorldModel is a not an env - or a module that's used by a model-based env to generate rollouts. That's the usual way to think about this in torchrl (see Dreamer for example). Can you elaborate on this choice>
Per review feedback (#vmoens): move imagined-rollout responsibility off the WorldModel module and onto a new WorldModelEnv(ModelBasedEnvBase), matching the pattern used by DreamerEnv / ImaginedEnv. Eliminates the divergent rollout semantics WorldModel.rollout had vs EnvBase.rollout.
vmoens
left a comment
There was a problem hiding this comment.
Just some minor comments but otherwise LGTM!
Per @vmoens review comments on the WorldModelEnv PR: - WorldModelEnv: - Annotate `world_model` parameter as `WorldModel` (via TYPE_CHECKING to avoid the torchrl.modules <-> torchrl.envs circular import). - Replace `.clone(recurse=False)` with `.copy()`. - Drop unnecessary `{}` in `TensorDict(batch_size=..., device=...)`. - Iterate spec leaves with `include_nested=True, leaves_only=True` so nested observation/reward/done specs survive `_step` end-to-end. - Use `full_*_spec.zero()` (and `full_*_spec[key].zero()`) instead of hand-rolled `torch.zeros(...)` for default reward/done/terminated tensors, keeping shapes / dtypes consistent with the specs. - WorldModelLoss: drop unused `batch_size=[]` from final TensorDict construction. - Tests: consolidate `test/test_world_model.py` (deleted) into: - `test/envs/test_model_based.py` — TestWorldModelForward, TestWorldModelEnv (and a `_SpecOnlyEnv` helper). - `test/objectives/test_dreamer.py` — TestWorldModelLoss. No public-API changes. 22/22 affected tests pass locally.
vmoens
left a comment
There was a problem hiding this comment.
LGTM
do you think we should consider subclassing the dreamer classes with these for a more homogeneous implementation?
I can open an issue with the refactors for the existing dreamer classes. |
Fixes #3774
Motivation
TorchRL ships Dreamer-specific model-based components but no general abstraction. Users implementing MBPO, TD-MPC, PlaNet, or any custom world model must hand-wire `TensorDictModule` chains and write bespoke multi-step rollout loops. This adds the missing general layer.
Test plan