Skip to content

Commit d029fec

Browse files
committed
Refine replayer tests and typing hints
1 parent 1ff1d30 commit d029fec

3 files changed

Lines changed: 52 additions & 97 deletions

File tree

python/rcs/envs/configs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import copy
24
import time
35
from typing import ClassVar, Literal
@@ -37,7 +39,7 @@ class EmptyWorldFR3(SimEnvCreator):
3739
def config(self) -> SimEnvCreatorConfig:
3840
q_home = rcs.ROBOTS[RobotType.FR3].q_home
3941
q_home[-1] = np.pi / 4
40-
robot_cfg = SimRobotConfig[Literal[7]](
42+
robot_cfg: SimRobotConfig[Literal[7]] = SimRobotConfig(
4143
robot_type=RobotType.FR3,
4244
tcp_offset=GRIPPER_OFFSETS[rcs.common.GripperType.FrankaHand],
4345
attachment_site=rcs.ROBOTS[RobotType.FR3].attachment_site,
@@ -183,7 +185,7 @@ class EmptyWorldFR3Duo(SimEnvCreator):
183185
gripper_mesh_quaternion_offset: ClassVar[list[float]] = [0, 0, 0.7071068, 0.7071068]
184186

185187
def config(self) -> SimEnvCreatorConfig:
186-
robot_cfg = SimRobotConfig[Literal[7]](
188+
robot_cfg: SimRobotConfig[Literal[7]] = SimRobotConfig(
187189
tcp_offset=GRIPPER_OFFSETS[rcs.common.GripperType("Robotiq2F85")],
188190
robot_type=RobotType.FR3,
189191
attachment_site=rcs.ROBOTS[RobotType.FR3].attachment_site,
@@ -224,7 +226,7 @@ def config(self) -> SimEnvCreatorConfig:
224226
joint_limits=rcs.ROBOTS[RobotType.FR3].joint_limits,
225227
q_home=rcs.HOME_POSITIONS["FR3_DUO_LEFT"],
226228
)
227-
robot_cfg_right = copy.deepcopy(robot_cfg)
229+
robot_cfg_right: SimRobotConfig[Literal[7]] = copy.deepcopy(robot_cfg)
228230
robot_cfg_right.q_home = rcs.HOME_POSITIONS["FR3_DUO_RIGHT"]
229231

230232
robot_cfgs: dict[str, SimRobotConfig] = {"left": robot_cfg, "right": robot_cfg_right}

python/rcs/sim/replayer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def sim_state(self) -> np.ndarray:
4747
raise KeyError(msg)
4848

4949
@property
50-
def sim_state_spec(self) -> SimStateSchema | None:
50+
def sim_state_schema(self) -> SimStateSchema | None:
5151
if SimEnv.STATE_SCHEMA_KEY in self.info:
5252
return _normalize_sim_state_schema(self.info[SimEnv.STATE_SCHEMA_KEY])
5353

@@ -103,9 +103,9 @@ def restore_sim_step(env: gym.Env, recorded_step: RecordedSimStep):
103103
lead_env = None
104104

105105
if lead_env is not None:
106-
lead_env.set_replay_state(recorded_step.sim_state, spec=recorded_step.sim_state_spec)
106+
lead_env.set_replay_state(recorded_step.sim_state, schema=recorded_step.sim_state_schema)
107107
else:
108-
env.get_wrapper_attr("set_replay_state")(recorded_step.sim_state, spec=recorded_step.sim_state_spec)
108+
env.get_wrapper_attr("set_replay_state")(recorded_step.sim_state, schema=recorded_step.sim_state_schema)
109109

110110

111111
def replay_trajectory(env: gym.Env, recorded_steps: list[RecordedSimStep], headless: bool):

python/tests/test_replayer.py

Lines changed: 44 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,22 @@
33
from typing import Any
44

55
import duckdb
6-
import gymnasium as gym
7-
import mujoco as mj
86
import numpy as np
97
from rcs._core.sim import SimConfig
10-
from rcs.envs.base import RelativeTo, SimEnv
8+
from rcs.envs.base import RelativeTo
119
from rcs.envs.configs import EmptyWorldFR3Duo
1210
from rcs.envs.storage_wrapper import StorageWrapper
1311
from rcs.envs.tasks import PickTaskConfig
14-
from rcs.sim.replayer import (
15-
RecordedSimStep,
16-
load_distinct_uuids,
17-
load_trajectory,
18-
replay_trajectory,
19-
)
20-
from rcs.sim.sim import Sim
12+
from rcs.sim.replayer import load_distinct_uuids, load_trajectory, replay_trajectory
2113

2214

23-
def _build_env(output_dir: Path, *, with_cameras: bool, instruction: str = "") -> StorageWrapper:
15+
def _build_env(
16+
output_dir: Path,
17+
*,
18+
with_cameras: bool,
19+
instruction: str = "",
20+
scene_path: Path | None = None,
21+
) -> StorageWrapper:
2422
scene = EmptyWorldFR3Duo()
2523
cfg = scene.config()
2624
cfg.sim_cfg = SimConfig(async_control=True, realtime=False, frequency=30, max_convergence_steps=500)
@@ -29,6 +27,8 @@ def _build_env(output_dir: Path, *, with_cameras: bool, instruction: str = "") -
2927
if cfg.root_frame_objects is None:
3028
cfg.root_frame_objects = {}
3129
cfg.task_cfg = PickTaskConfig(robot_name="right")
30+
if scene_path is not None:
31+
cfg.scene = str(scene_path)
3232
if not with_cameras:
3333
cfg.camera_cfgs = {}
3434
else:
@@ -50,8 +50,14 @@ def _build_env(output_dir: Path, *, with_cameras: bool, instruction: str = "") -
5050
)
5151

5252

53-
def _record_source_dataset(dataset_dir: Path, *, limit: int, instruction: str) -> None:
54-
env = _build_env(dataset_dir, with_cameras=False, instruction=instruction)
53+
def _record_source_dataset(
54+
dataset_dir: Path,
55+
*,
56+
limit: int,
57+
instruction: str,
58+
scene_path: Path | None = None,
59+
) -> None:
60+
env = _build_env(dataset_dir, with_cameras=False, instruction=instruction, scene_path=scene_path)
5561
try:
5662
env.reset()
5763
action = {
@@ -103,9 +109,9 @@ def _replay_rows(dataset_dir: Path):
103109
connection.close()
104110

105111

106-
def _replay_prefix(output_dir: Path, *, with_cameras: bool, limit: int) -> None:
112+
def _replay_prefix(output_dir: Path, *, with_cameras: bool, limit: int, scene_path: Path | None = None) -> None:
107113
source_dir = output_dir.parent / "source"
108-
env = _build_env(output_dir, with_cameras=with_cameras)
114+
env = _build_env(output_dir, with_cameras=with_cameras, scene_path=scene_path)
109115
try:
110116
uuid = load_distinct_uuids(source_dir)[0]
111117
recorded_steps = load_trajectory(source_dir, uuid)[:limit]
@@ -115,48 +121,6 @@ def _replay_prefix(output_dir: Path, *, with_cameras: bool, limit: int) -> None:
115121
env.close()
116122

117123

118-
MINIMAL_XML = """
119-
<mujoco>
120-
<worldbody>
121-
<camera name="main" pos="1 0 0.7" xyaxes="0 1 0 -0.5 0 1"/>
122-
<body name="box" pos="0 0 0.1">
123-
<freejoint name="box_free"/>
124-
<geom type="box" size="0.05 0.05 0.05" rgba="0.2 0.6 0.9 1"/>
125-
</body>
126-
</worldbody>
127-
</mujoco>
128-
"""
129-
130-
131-
class DummyReplayEnv(gym.Env):
132-
def __init__(self, sim: Sim):
133-
super().__init__()
134-
self.sim = sim
135-
self._replay_state = None
136-
137-
def get_wrapper_attr(self, name: str):
138-
return getattr(self, name)
139-
140-
def set_replay_state(self, state: np.ndarray, spec=None):
141-
self._replay_state = (state, spec)
142-
143-
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
144-
super().reset(seed=seed)
145-
mj.mj_resetData(self.sim.model, self.sim.data)
146-
mj.mj_forward(self.sim.model, self.sim.data)
147-
return {}, {}
148-
149-
def step(self, action: dict[str, np.ndarray]):
150-
if self._replay_state is not None:
151-
state, spec = self._replay_state
152-
self.sim.set_state(state, spec)
153-
self._replay_state = None
154-
self.sim.data.qpos[0] += float(action["delta"][0])
155-
self.sim.data.qvel[:] = 0.0
156-
mj.mj_forward(self.sim.model, self.sim.data)
157-
return {}, 0.0, False, False, {}
158-
159-
160124
def _write_scene_with_extra_fixed_body_and_camera(src: Path, dst: Path):
161125
tree = ET.parse(src)
162126
root = tree.getroot()
@@ -183,26 +147,6 @@ def _write_scene_with_extra_fixed_body_and_camera(src: Path, dst: Path):
183147
tree.write(dst)
184148

185149

186-
def _recorded_dummy_step(model_path: Path) -> RecordedSimStep:
187-
sim = Sim(model_path)
188-
state = sim.get_state().copy()
189-
state[0] = 0.125
190-
sim.set_state(state, sim.get_state_schema())
191-
return RecordedSimStep(
192-
step=0,
193-
uuid="dummy-trajectory",
194-
timestamp=None,
195-
observation={},
196-
info={
197-
SimEnv.STATE_KEY: sim.get_state(),
198-
SimEnv.STATE_SCHEMA_KEY: sim.get_state_schema(),
199-
},
200-
action={"delta": np.array([0.0], dtype=np.float64)},
201-
instruction="",
202-
success=False,
203-
)
204-
205-
206150
def _assert_nested_close(actual: Any, expected: Any, *, atol: float = 1e-6):
207151
if isinstance(expected, dict):
208152
assert isinstance(actual, dict)
@@ -296,21 +240,30 @@ def test_replayer_reproduces_existing_parquet_prefix_without_cameras(tmp_path: P
296240

297241

298242
def test_replayer_restores_sim_state_across_fixed_scene_changes(tmp_path: Path):
299-
base_model_path = tmp_path / "base.xml"
300-
base_model_path.write_text(MINIMAL_XML)
301-
modified_model_path = tmp_path / "modified.xml"
302-
_write_scene_with_extra_fixed_body_and_camera(base_model_path, modified_model_path)
303-
304-
for record_model_path, replay_model_path in (
305-
(base_model_path, modified_model_path),
306-
(modified_model_path, base_model_path),
307-
):
308-
recorded_step = _recorded_dummy_step(record_model_path)
309-
replay_env = DummyReplayEnv(Sim(replay_model_path))
243+
source_scene_path = Path(EmptyWorldFR3Duo().config().scene)
244+
modified_scene_path = tmp_path / "modified_scene.xml"
245+
_write_scene_with_extra_fixed_body_and_camera(source_scene_path, modified_scene_path)
310246

311-
replay_trajectory(replay_env, [recorded_step], True)
312-
313-
assert np.allclose(replay_env.sim.get_state(), recorded_step.sim_state, atol=1e-9, rtol=0)
247+
for record_scene_path, replay_scene_path in (
248+
(source_scene_path, modified_scene_path),
249+
(modified_scene_path, source_scene_path),
250+
):
251+
case_dir = tmp_path / f"{record_scene_path.stem}-to-{replay_scene_path.stem}"
252+
source_dir = case_dir / "source"
253+
replay_dir = case_dir / "replayed"
254+
255+
_record_source_dataset(source_dir, limit=3, instruction="pick up cube", scene_path=record_scene_path)
256+
_replay_prefix(replay_dir, with_cameras=False, limit=3, scene_path=replay_scene_path)
257+
258+
source_uuid = load_distinct_uuids(source_dir)[0]
259+
replay_uuid = load_distinct_uuids(replay_dir)[0]
260+
source_steps = load_trajectory(source_dir, source_uuid)
261+
replay_steps = load_trajectory(replay_dir, replay_uuid)
262+
263+
assert len(source_steps) == len(replay_steps) == 3
264+
for replay_step, source_step in zip(replay_steps, source_steps, strict=True):
265+
assert replay_step.sim_state_schema == source_step.sim_state_schema
266+
assert np.allclose(replay_step.sim_state, source_step.sim_state, atol=1e-5, rtol=0)
314267

315268

316269
def test_replayer_adds_cameras_to_existing_episode_without_cameras(tmp_path: Path):

0 commit comments

Comments
 (0)