Skip to content

Commit 330e711

Browse files
committed
fix: avoid numpy truth ambiguity in pick task
1 parent 69c9365 commit 330e711

2 files changed

Lines changed: 64 additions & 6 deletions

File tree

python/rcs/envs/tasks.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,9 @@ def __init__(self, env, robot_name: str, shared2world: rcs.common.Pose, obj_join
3737
def step(self, action: dict[str, Any]): # type: ignore
3838
obs, reward, _, truncated, info = super().step(action)
3939

40-
if (
41-
self._gripper.get_normalized_width() > 0.01
42-
and self._gripper.get_normalized_width() < 0.99
43-
and obs[self.robot_name]["gripper"] == GripperWrapper.BINARY_GRIPPER_CLOSED
44-
):
40+
gripper_closed = obs[self.robot_name]["gripper"][0] == GripperWrapper.BINARY_GRIPPER_CLOSED[0]
41+
42+
if self._gripper.get_normalized_width() > 0.01 and self._gripper.get_normalized_width() < 0.99 and gripper_closed:
4543
self._gripper_closing += 1
4644
else:
4745
self._gripper_closing = 0
@@ -61,7 +59,7 @@ def step(self, action: dict[str, Any]): # type: ignore
6159
# NOTE: 4 depends on the time passing between each step.
6260
is_grasped = (
6361
self._gripper_closing >= 4 # gripper is closing since more than 4 steps
64-
and obs[self.robot_name]["gripper"] == GripperWrapper.BINARY_GRIPPER_CLOSED # command is still close
62+
and gripper_closed # command is still close
6563
and tcp_to_obj_dist <= 0.01 # tcp to cube center is max 1cm
6664
)
6765
success = obj_to_goal_dist <= 0.022 and info[self.robot_name]["is_grasped"]
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import warnings
2+
3+
import gymnasium as gym
4+
import numpy as np
5+
6+
import rcs
7+
from rcs.envs.tasks import PickObjSuccessWrapper
8+
9+
10+
class _DummyJoint:
11+
def __init__(self):
12+
self.qpos = np.array([0.0, 0.0, 0.0])
13+
14+
15+
class _DummyData:
16+
def __init__(self):
17+
self._joint = _DummyJoint()
18+
self.qvel = np.zeros(1)
19+
20+
def joint(self, _name: str):
21+
return self._joint
22+
23+
24+
class _DummySim:
25+
def __init__(self):
26+
self.data = _DummyData()
27+
28+
29+
class _DummyGripper:
30+
def get_normalized_width(self):
31+
return 0.5
32+
33+
34+
class _DummyEnv(gym.Env):
35+
def __init__(self):
36+
super().__init__()
37+
self.sim = _DummySim()
38+
self.gripper = {"robot": _DummyGripper()}
39+
40+
def get_wrapper_attr(self, name: str):
41+
return getattr(self, name)
42+
43+
def step(self, _action):
44+
obs = {"robot": {"gripper": np.array([0.0]), "tquat": np.zeros(7)}}
45+
info = {"robot": {"is_grasped": False}}
46+
return obs, 0.0, False, False, info
47+
48+
def reset(self, *, seed=None, options=None):
49+
super().reset(seed=seed)
50+
return {}, {}
51+
52+
53+
def test_pick_obj_success_wrapper_step_avoids_numpy_truth_ambiguity():
54+
wrapper = PickObjSuccessWrapper(_DummyEnv(), "robot", rcs.common.Pose())
55+
56+
with warnings.catch_warnings(record=True) as caught:
57+
warnings.simplefilter("error", DeprecationWarning)
58+
wrapper.step({})
59+
60+
assert not caught

0 commit comments

Comments
 (0)