Source code for mikasa_robo_suite.vla.memory_envs.blink_count_button_press_vla
"""Blink-counting and button-press tasks for the VLA benchmark."""
from typing import Any, Dict, List, Union
import numpy as np
import sapien
import torch
from mani_skill.agents.robots.panda.panda import Panda
from mani_skill.agents.robots.panda.panda_wristcam import PandaWristCam
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import sapien_utils
from mani_skill.utils.building.actors.common import _build_by_type
from mani_skill.utils.registration import register_env
from mani_skill.utils.scene_builder.table import TableSceneBuilder
from mani_skill.utils.structs.pose import Pose
from mani_skill.utils.structs.types import Array, GPUMemoryConfig, SimConfig
from transforms3d.euler import euler2quat
from mikasa_robo_suite.vla.utils import shapes
[docs]
class BlinkCountButtonPressVLABaseEnv(BaseEnv):
"""Count a visual cue and reproduce it with discrete button presses.
The robot first observes a lamp blinking a sampled number of times. After
the cue ends, it must press the button exactly that many times. This task
is simple to understand but sensitive to temporal memory and to clean press
cycles, because repeated partial contacts should not be mistaken for new
presses.
Episode flow:
- The lamp waits briefly, then blinks `N` times.
- After the cue phase, the robot starts pressing the red button.
- Each press must be followed by a release and lift before the next one.
- When done counting, the robot presses the black button to submit.
Success (`success=True`):
- Success is produced only when the black submit button is pressed.
- At submit time, the counted number of valid red-button presses must
exactly match the target blink count.
How to customize:
- `BLINK_COUNT_RANGE` changes the memory difficulty by changing how many
blinks the agent may need to remember.
- `PRE_BLINK_OFF_STEPS` changes how long the task waits before cue onset.
- `BLINK_ON_STEPS` and `BLINK_OFF_STEPS` change the timing pattern of each
blink and therefore how easy the cue is to parse visually.
- `BUTTON_*` parameters change the physical button geometry and the press
detection thresholds.
- `REQUIRED_LIFT_HEIGHT` changes how much the end effector must lift after a
press before the next press can be counted reliably.
"""
LANGUAGE_INSTRUCTION = (
"Count how many times the blue lamp blinks, press the red button exactly that many times "
"when the red lamp turns green, then press the black button to submit your answer."
)
SUPPORTED_ROBOTS = ["panda", "panda_wristcam"]
agent: Union[Panda, PandaWristCam]
HEIGHT_OFFSET = 1000.0
BLINK_COUNT_RANGE: List[int] = [1, 5]
PRE_BLINK_OFF_STEPS: List[int] = [2, 4]
BLINK_ON_STEPS: List[int] = [1, 2]
BLINK_OFF_STEPS: List[int] = [2, 4]
BUTTON_BASE_HALF_SIZE = np.array([0.065, 0.065, 0.015], dtype=np.float32)
BUTTON_CAP_RADIUS = 0.03
BUTTON_CAP_HALF_HEIGHT = 0.014
BUTTON_CAP_TRAVEL = BUTTON_CAP_HALF_HEIGHT
BUTTON_PRESS_EVENT_RATIO = 0.35
BUTTON_RELEASE_READY_RATIO = 0.2
BUTTON_PRESS_XY_RADIUS = 0.065
BUTTON_PRESS_Z_MARGIN = 0.03
CONFIRM_BUTTON_Y_OFFSET = 0.16
REQUIRED_LIFT_HEIGHT = 0.1
LIFT_CONFIRM_TOL = 0.015
LAMP_BASE_RADIUS = 0.018
LAMP_BASE_HALF_HEIGHT = 0.008
LAMP_STEM_RADIUS = 0.004
LAMP_STEM_HALF_HEIGHT = 0.020
LAMP_BULB_RADIUS = 0.012
INDICATOR_FORWARD_OFFSET = 0.22
PHASE_INDICATOR_X_OFFSET = 0.2
INDICATOR_HEIGHT = 0.0
DEFAULT_BLINK_COLOR = np.array([0, 0, 255, 255], dtype=np.float32) / 255.0
PHASE_WAIT_COLOR = np.array([255, 0, 0, 255], dtype=np.float32) / 255.0
PHASE_READY_COLOR = np.array([0, 255, 0, 255], dtype=np.float32) / 255.0
ACTION_L2_COEF = 0.01
ACTION_DELTA_L2_COEF = 0.03
QVEL_L2_COEF = 0.01
SUCCESS_BONUS = 30.0
FAILURE_PENALTY = 25.0
def __init__(
self,
*args,
robot_uids="panda_wristcam",
robot_init_qpos_noise=0.02,
blink_color=None,
**kwargs,
):
self.robot_init_qpos_noise = robot_init_qpos_noise
if blink_color is None:
blink_color_arr = self.DEFAULT_BLINK_COLOR.copy()
else:
blink_color_arr = np.asarray(blink_color, dtype=np.float32)
if blink_color_arr.shape[0] == 3:
blink_color_arr = np.concatenate([blink_color_arr, np.array([1.0], dtype=np.float32)])
if float(np.max(blink_color_arr)) > 1.0:
blink_color_arr = blink_color_arr / 255.0
self.blink_color = blink_color_arr
super().__init__(*args, robot_uids=robot_uids, **kwargs)
@property
def _default_sim_config(self):
return SimConfig(
gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25,
max_rigid_patch_count=2**18,
)
)
@property
def _default_sensor_configs(self):
pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
return [CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)]
@property
def _default_human_render_camera_configs(self):
pose = sapien_utils.look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.15])
return CameraConfig("render_camera", pose, 512, 512, 1, 0.01, 100)
def _load_agent(self, options: dict):
super()._load_agent(options, sapien.Pose(p=[-0.615, 0, 0]))
def _load_scene(self, options: dict):
self.table_scene = TableSceneBuilder(
self,
robot_init_qpos_noise=self.robot_init_qpos_noise,
)
self.table_scene.build()
default_initial_pose = sapien.Pose(p=[0.0, 0.0, self.HEIGHT_OFFSET])
base_builder = self.scene.create_actor_builder()
base_builder.add_box_collision(half_size=self.BUTTON_BASE_HALF_SIZE)
base_builder.add_box_visual(
half_size=self.BUTTON_BASE_HALF_SIZE,
material=sapien.render.RenderMaterial(base_color=np.array([55, 64, 78, 255]) / 255.0),
)
self.button_base = _build_by_type(
base_builder,
name="button_base",
body_type="kinematic",
initial_pose=default_initial_pose,
)
cap_builder = self.scene.create_actor_builder()
cap_builder.add_cylinder_collision(radius=self.BUTTON_CAP_RADIUS, half_length=self.BUTTON_CAP_HALF_HEIGHT)
cap_builder.add_cylinder_visual(
radius=self.BUTTON_CAP_RADIUS,
half_length=self.BUTTON_CAP_HALF_HEIGHT,
material=sapien.render.RenderMaterial(base_color=np.array([210, 80, 80, 255]) / 255.0),
)
self.button_cap = _build_by_type(
cap_builder,
name="button_cap",
body_type="kinematic",
initial_pose=default_initial_pose,
)
confirm_base_builder = self.scene.create_actor_builder()
confirm_base_builder.add_box_collision(half_size=self.BUTTON_BASE_HALF_SIZE)
confirm_base_builder.add_box_visual(
half_size=self.BUTTON_BASE_HALF_SIZE,
material=sapien.render.RenderMaterial(base_color=np.array([40, 40, 40, 255]) / 255.0),
)
self.confirm_button_base = _build_by_type(
confirm_base_builder,
name="confirm_button_base",
body_type="kinematic",
initial_pose=default_initial_pose,
)
confirm_cap_builder = self.scene.create_actor_builder()
confirm_cap_builder.add_cylinder_collision(
radius=self.BUTTON_CAP_RADIUS, half_length=self.BUTTON_CAP_HALF_HEIGHT
)
confirm_cap_builder.add_cylinder_visual(
radius=self.BUTTON_CAP_RADIUS,
half_length=self.BUTTON_CAP_HALF_HEIGHT,
material=sapien.render.RenderMaterial(base_color=np.array([20, 20, 20, 255]) / 255.0),
)
self.confirm_button_cap = _build_by_type(
confirm_cap_builder,
name="confirm_button_cap",
body_type="kinematic",
initial_pose=default_initial_pose,
)
lamp_parts = shapes.build_color_switch_lamp(
scene=self.scene,
name="blink_lamp",
body_type="kinematic",
add_collision=False,
initial_pose=default_initial_pose,
base_radius=self.LAMP_BASE_RADIUS,
base_half_height=self.LAMP_BASE_HALF_HEIGHT,
stem_radius=self.LAMP_STEM_RADIUS,
stem_half_height=self.LAMP_STEM_HALF_HEIGHT,
bulb_radius=self.LAMP_BULB_RADIUS,
bulb_on_color=self.blink_color,
)
self.lamp_body = lamp_parts["body"]
self.lamp_bulb_off = lamp_parts["bulb_off"]
self.lamp_bulb_on = lamp_parts["bulb_on"]
phase_lamp_parts = shapes.build_color_switch_lamp(
scene=self.scene,
name="phase_lamp",
body_type="kinematic",
add_collision=False,
initial_pose=default_initial_pose,
base_radius=self.LAMP_BASE_RADIUS,
base_half_height=self.LAMP_BASE_HALF_HEIGHT,
stem_radius=self.LAMP_STEM_RADIUS,
stem_half_height=self.LAMP_STEM_HALF_HEIGHT,
bulb_radius=self.LAMP_BULB_RADIUS,
bulb_off_color=self.PHASE_WAIT_COLOR,
bulb_on_color=self.PHASE_READY_COLOR,
)
self.phase_lamp_body = phase_lamp_parts["body"]
self.phase_lamp_red = phase_lamp_parts["bulb_off"]
self.phase_lamp_green = phase_lamp_parts["bulb_on"]
shapes._set_actor_visual_rgba(
self.phase_lamp_red,
self.PHASE_WAIT_COLOR,
emission_scale=3.0,
remove_textures=True,
)
shapes._set_actor_visual_rgba(
self.phase_lamp_green,
self.PHASE_READY_COLOR,
emission_scale=3.0,
remove_textures=True,
)
n = self.num_envs
d = self.device
self.cue_steps_per_env = torch.zeros(n, dtype=torch.int64, device=d)
self.empty_steps_per_env = torch.zeros(n, dtype=torch.int64, device=d)
self.pre_blink_steps_per_env = torch.zeros(n, dtype=torch.int64, device=d)
self.max_blinks = int(self.BLINK_COUNT_RANGE[1])
self.blink_on_schedule = torch.zeros((n, self.max_blinks), dtype=torch.int64, device=d)
self.blink_off_schedule = torch.zeros((n, self.max_blinks), dtype=torch.int64, device=d)
self.blink_start_steps = torch.zeros((n, self.max_blinks), dtype=torch.int64, device=d)
self.target_blinks = torch.zeros(n, dtype=torch.int64, device=d)
self.press_count = torch.zeros(n, dtype=torch.int64, device=d)
self.raw_press_count = torch.zeros(n, dtype=torch.int64, device=d)
self.press_ready = torch.ones(n, dtype=torch.bool, device=d)
self.pending_press = torch.zeros(n, dtype=torch.bool, device=d)
self.new_raw_press_event = torch.zeros(n, dtype=torch.bool, device=d)
self.new_press_event = torch.zeros(n, dtype=torch.bool, device=d)
self.new_release_event = torch.zeros(n, dtype=torch.bool, device=d)
self.failed = torch.zeros(n, dtype=torch.bool, device=d)
self.submit_attempted = torch.zeros(n, dtype=torch.bool, device=d)
self.submit_success = torch.zeros(n, dtype=torch.bool, device=d)
self.confirm_press_ready = torch.ones(n, dtype=torch.bool, device=d)
self.new_confirm_press_event = torch.zeros(n, dtype=torch.bool, device=d)
self.button_xy = torch.zeros((n, 2), dtype=torch.float32, device=d)
self.button_cap_unpressed_z = torch.zeros(n, dtype=torch.float32, device=d)
self.button_top_z = torch.zeros(n, dtype=torch.float32, device=d)
self.button_press_depth = torch.zeros(n, dtype=torch.float32, device=d)
self.confirm_button_xy = torch.zeros((n, 2), dtype=torch.float32, device=d)
self.confirm_button_cap_unpressed_z = torch.zeros(n, dtype=torch.float32, device=d)
self.confirm_button_top_z = torch.zeros(n, dtype=torch.float32, device=d)
self.confirm_button_press_depth = torch.zeros(n, dtype=torch.float32, device=d)
self.press_start_tcp_z = torch.zeros(n, dtype=torch.float32, device=d)
self.indicator_on_pos = torch.zeros((n, 3), dtype=torch.float32, device=d)
self.indicator_off_pos = torch.zeros((n, 3), dtype=torch.float32, device=d)
self.phase_indicator_on_pos = torch.zeros((n, 3), dtype=torch.float32, device=d)
self.phase_indicator_off_pos = torch.zeros((n, 3), dtype=torch.float32, device=d)
self.button_cap_quat = torch.tensor(euler2quat(0, np.pi / 2, 0), dtype=torch.float32, device=d)
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
with torch.device(self.device):
b = len(env_idx)
self.table_scene.initialize(env_idx)
env_idx = env_idx.to(self.device)
self.task_cue = None
self.reward_dict = None
button_xyz = torch.zeros((b, 3), device=self.device)
button_xyz[..., 0] = torch.rand((b,), device=self.device) * 0.10 - 0.15
button_xyz[..., 1] = (torch.rand((b,), device=self.device) - 0.5) * 0.24
button_xyz[..., 2] = float(self.BUTTON_BASE_HALF_SIZE[2])
button_base_q = torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).repeat(b, 1)
self.button_base.set_pose(Pose.create_from_pq(p=button_xyz, q=button_base_q))
unpressed_z = float(self.BUTTON_BASE_HALF_SIZE[2]) * 2.0 + self.BUTTON_CAP_HALF_HEIGHT
cap_xyz = button_xyz.clone()
cap_xyz[..., 2] = unpressed_z
cap_q = self.button_cap_quat.unsqueeze(0).repeat(b, 1)
self.button_cap.set_pose(Pose.create_from_pq(p=cap_xyz, q=cap_q))
confirm_button_xyz = button_xyz.clone()
confirm_button_xyz[..., 1] += self.CONFIRM_BUTTON_Y_OFFSET
confirm_button_xyz[..., 1] = torch.clamp(confirm_button_xyz[..., 1], -0.28, 0.28)
self.confirm_button_base.set_pose(Pose.create_from_pq(p=confirm_button_xyz, q=button_base_q))
confirm_cap_xyz = confirm_button_xyz.clone()
confirm_cap_xyz[..., 2] = unpressed_z
self.confirm_button_cap.set_pose(Pose.create_from_pq(p=confirm_cap_xyz, q=cap_q))
indicator_on_xyz = button_xyz.clone()
indicator_on_xyz[..., 0] += self.INDICATOR_FORWARD_OFFSET
indicator_on_xyz[..., 2] = self.INDICATOR_HEIGHT
indicator_off_xyz = indicator_on_xyz.clone()
indicator_off_xyz[..., 2] += self.HEIGHT_OFFSET
self.lamp_body.set_pose(Pose.create_from_pq(p=indicator_on_xyz, q=button_base_q))
self.lamp_bulb_off.set_pose(Pose.create_from_pq(p=indicator_on_xyz, q=button_base_q))
self.lamp_bulb_on.set_pose(Pose.create_from_pq(p=indicator_off_xyz, q=button_base_q))
phase_indicator_on_xyz = indicator_on_xyz.clone()
phase_indicator_on_xyz[..., 1] -= self.PHASE_INDICATOR_X_OFFSET
phase_indicator_off_xyz = phase_indicator_on_xyz.clone()
phase_indicator_off_xyz[..., 2] += self.HEIGHT_OFFSET
self.phase_lamp_body.set_pose(Pose.create_from_pq(p=phase_indicator_on_xyz, q=button_base_q))
self.phase_lamp_red.set_pose(Pose.create_from_pq(p=phase_indicator_on_xyz, q=button_base_q))
self.phase_lamp_green.set_pose(Pose.create_from_pq(p=phase_indicator_off_xyz, q=button_base_q))
self.button_xy[env_idx] = button_xyz[:, :2]
self.button_cap_unpressed_z[env_idx] = unpressed_z
self.button_top_z[env_idx] = unpressed_z + self.BUTTON_CAP_HALF_HEIGHT
self.button_press_depth[env_idx] = 0.0
self.confirm_button_xy[env_idx] = confirm_button_xyz[:, :2]
self.confirm_button_cap_unpressed_z[env_idx] = unpressed_z
self.confirm_button_top_z[env_idx] = unpressed_z + self.BUTTON_CAP_HALF_HEIGHT
self.confirm_button_press_depth[env_idx] = 0.0
self.press_start_tcp_z[env_idx] = 0.0
self.indicator_on_pos[env_idx] = indicator_on_xyz
self.indicator_off_pos[env_idx] = indicator_off_xyz
self.phase_indicator_on_pos[env_idx] = phase_indicator_on_xyz
self.phase_indicator_off_pos[env_idx] = phase_indicator_off_xyz
blink_count = torch.randint(
low=self.BLINK_COUNT_RANGE[0],
high=self.BLINK_COUNT_RANGE[1] + 1,
size=(b,),
device=self.device,
dtype=torch.int64,
)
pre_steps = torch.randint(
low=self.PRE_BLINK_OFF_STEPS[0],
high=self.PRE_BLINK_OFF_STEPS[1] + 1,
size=(b,),
device=self.device,
dtype=torch.int64,
)
on_schedule = torch.randint(
low=self.BLINK_ON_STEPS[0],
high=self.BLINK_ON_STEPS[1] + 1,
size=(b, self.max_blinks),
device=self.device,
dtype=torch.int64,
)
off_schedule = torch.randint(
low=self.BLINK_OFF_STEPS[0],
high=self.BLINK_OFF_STEPS[1] + 1,
size=(b, self.max_blinks),
device=self.device,
dtype=torch.int64,
)
blink_idx = torch.arange(self.max_blinks, device=self.device).unsqueeze(0)
valid_blink_mask = blink_idx < blink_count.unsqueeze(1)
on_schedule = on_schedule * valid_blink_mask.to(torch.int64)
off_schedule = off_schedule * valid_blink_mask.to(torch.int64)
blink_durations = on_schedule + off_schedule
start_offsets = torch.cumsum(blink_durations, dim=1) - blink_durations
blink_start_steps = pre_steps.unsqueeze(1) + start_offsets
cue_steps = pre_steps + torch.sum(blink_durations, dim=1)
self.target_blinks[env_idx] = blink_count
self.pre_blink_steps_per_env[env_idx] = pre_steps
self.blink_on_schedule[env_idx] = on_schedule
self.blink_off_schedule[env_idx] = off_schedule
self.blink_start_steps[env_idx] = blink_start_steps
self.cue_steps_per_env[env_idx] = cue_steps
self.empty_steps_per_env[env_idx] = 0
self.press_count[env_idx] = 0
self.raw_press_count[env_idx] = 0
self.press_ready[env_idx] = True
self.pending_press[env_idx] = False
self.new_raw_press_event[env_idx] = False
self.new_press_event[env_idx] = False
self.new_release_event[env_idx] = False
self.failed[env_idx] = False
self.submit_attempted[env_idx] = False
self.submit_success[env_idx] = False
self.confirm_press_ready[env_idx] = True
self.new_confirm_press_event[env_idx] = False
self.oracle_info = self.target_blinks.to(torch.uint8)
if self.robot_uids in ("panda", "panda_wristcam"):
qpos = np.array([0.0, 0, 0, -np.pi * 2 / 3, 0, np.pi * 2 / 3, np.pi / 4, 0.04, 0.04])
qpos[:-2] += self._episode_rng.normal(0, self.robot_init_qpos_noise, len(qpos) - 2)
self.agent.reset(qpos)
self.agent.robot.set_root_pose(sapien.Pose([-0.615, 0, 0]))
else:
raise NotImplementedError(self.robot_uids)
if hasattr(self, "_prev_action") and torch.is_tensor(self._prev_action):
if self._prev_action.shape[0] >= int(env_idx.max().item()) + 1:
self._prev_action[env_idx] = 0
[docs]
def evaluate(self):
elapsed = self.elapsed_steps.to(torch.int64)
if elapsed.dim() > 1:
elapsed = elapsed.squeeze(-1)
cue_mask = elapsed < self.cue_steps_per_env
action_mask = ~cue_mask
elapsed_col = elapsed.unsqueeze(1)
blink_idx = torch.arange(self.max_blinks, device=self.device).unsqueeze(0)
valid_blinks = blink_idx < self.target_blinks.unsqueeze(1)
start = self.blink_start_steps
end = start + self.blink_on_schedule
in_on_window = (elapsed_col >= start) & (elapsed_col < end)
light_on = torch.any(valid_blinks & in_on_window, dim=1) & cue_mask
shapes._set_actor_visual_rgba(
self.lamp_bulb_on,
self.blink_color,
emission_scale=20.0,
remove_textures=True,
)
off_pose = self.lamp_bulb_off.pose.raw_pose.clone()
off_pose[light_on, :3] = self.indicator_off_pos[light_on]
off_pose[~light_on, :3] = self.indicator_on_pos[~light_on]
self.lamp_bulb_off.pose = off_pose
on_pose = self.lamp_bulb_on.pose.raw_pose.clone()
on_pose[light_on, :3] = self.indicator_on_pos[light_on]
on_pose[~light_on, :3] = self.indicator_off_pos[~light_on]
self.lamp_bulb_on.pose = on_pose
phase_ready = action_mask
phase_red_pose = self.phase_lamp_red.pose.raw_pose.clone()
phase_red_pose[phase_ready, :3] = self.phase_indicator_off_pos[phase_ready]
phase_red_pose[~phase_ready, :3] = self.phase_indicator_on_pos[~phase_ready]
self.phase_lamp_red.pose = phase_red_pose
phase_green_pose = self.phase_lamp_green.pose.raw_pose.clone()
phase_green_pose[phase_ready, :3] = self.phase_indicator_on_pos[phase_ready]
phase_green_pose[~phase_ready, :3] = self.phase_indicator_off_pos[~phase_ready]
self.phase_lamp_green.pose = phase_green_pose
tcp_pos = self.agent.tcp.pose.p
tcp_xy = tcp_pos[:, :2]
tcp_z = tcp_pos[:, 2]
xy_dist = torch.linalg.norm(tcp_xy - self.button_xy, axis=1)
raw_depth = self.button_top_z + self.BUTTON_PRESS_Z_MARGIN - tcp_z
depth = torch.clamp(raw_depth, min=0.0, max=self.BUTTON_CAP_TRAVEL)
depth = depth * (xy_dist < self.BUTTON_PRESS_XY_RADIUS).float()
self.button_press_depth = depth
confirm_xy_dist = torch.linalg.norm(tcp_xy - self.confirm_button_xy, axis=1)
confirm_raw_depth = self.confirm_button_top_z + self.BUTTON_PRESS_Z_MARGIN - tcp_z
confirm_depth = torch.clamp(confirm_raw_depth, min=0.0, max=self.BUTTON_CAP_TRAVEL)
confirm_depth = confirm_depth * (confirm_xy_dist < self.BUTTON_PRESS_XY_RADIUS).float()
self.confirm_button_press_depth = confirm_depth
cap_pose = self.button_cap.pose.raw_pose.clone()
cap_pose[:, 0:2] = self.button_xy
cap_pose[:, 2] = self.button_cap_unpressed_z - depth
cap_pose[:, 3:7] = self.button_cap_quat.repeat(cap_pose.shape[0], 1)
self.button_cap.pose = cap_pose
confirm_cap_pose = self.confirm_button_cap.pose.raw_pose.clone()
confirm_cap_pose[:, 0:2] = self.confirm_button_xy
confirm_cap_pose[:, 2] = self.confirm_button_cap_unpressed_z - confirm_depth
confirm_cap_pose[:, 3:7] = self.button_cap_quat.repeat(confirm_cap_pose.shape[0], 1)
self.confirm_button_cap.pose = confirm_cap_pose
pressed = depth >= (self.BUTTON_CAP_TRAVEL * self.BUTTON_PRESS_EVENT_RATIO)
released = depth <= (self.BUTTON_CAP_TRAVEL * self.BUTTON_RELEASE_READY_RATIO)
self.new_release_event = (~self.press_ready) & released & action_mask
self.press_ready = self.press_ready | self.new_release_event
self.new_raw_press_event = pressed & self.press_ready & action_mask & (~self.failed) & (~self.pending_press)
self.press_start_tcp_z[self.new_raw_press_event] = tcp_z[self.new_raw_press_event]
self.raw_press_count = self.raw_press_count + self.new_raw_press_event.to(torch.int64)
self.pending_press = self.pending_press | self.new_raw_press_event
self.press_ready = self.press_ready & (~self.new_raw_press_event)
lift_target_z = self.press_start_tcp_z + self.REQUIRED_LIFT_HEIGHT
at_lift_target = tcp_z >= (lift_target_z - self.LIFT_CONFIRM_TOL)
self.new_press_event = self.pending_press & at_lift_target & self.press_ready & action_mask & (~self.failed)
self.press_count = self.press_count + self.new_press_event.to(torch.int64)
self.pending_press = self.pending_press & (~self.new_press_event)
confirm_pressed = confirm_depth >= (self.BUTTON_CAP_TRAVEL * self.BUTTON_PRESS_EVENT_RATIO)
confirm_released = confirm_depth <= (self.BUTTON_CAP_TRAVEL * self.BUTTON_RELEASE_READY_RATIO)
self.confirm_press_ready = self.confirm_press_ready | (confirm_released & action_mask)
self.new_confirm_press_event = (
confirm_pressed & self.confirm_press_ready & action_mask & (~self.submit_attempted)
)
self.confirm_press_ready = self.confirm_press_ready & (~self.new_confirm_press_event)
self.failed = self.failed | (self.raw_press_count > self.target_blinks)
count_correct = self.press_count == self.target_blinks
self.submit_success = self.submit_success | (self.new_confirm_press_event & count_correct & (~self.failed))
self.failed = self.failed | (self.new_confirm_press_event & (~count_correct))
self.submit_attempted = self.submit_attempted | self.new_confirm_press_event
success = action_mask & self.submit_success
self.obj_to_goal_pos = self.button_cap.pose.p - self.agent.tcp.pose.p
return {
"success": success,
"failed": self.failed,
"submit_attempted": self.submit_attempted,
"submit_success": self.submit_success,
"action_mask": action_mask,
"phase_ready": phase_ready,
"light_on": light_on,
"press_count": self.press_count,
"raw_press_count": self.raw_press_count,
"target_blinks": self.target_blinks,
"new_raw_press_event": self.new_raw_press_event,
"new_press_event": self.new_press_event,
"new_release_event": self.new_release_event,
"new_confirm_press_event": self.new_confirm_press_event,
"press_ready": self.press_ready,
"pending_press": self.pending_press,
"lift_target_z": lift_target_z,
"at_lift_target": at_lift_target,
"xy_dist_to_button": xy_dist,
"press_depth": self.button_press_depth,
"xy_dist_to_confirm_button": confirm_xy_dist,
"confirm_press_depth": self.confirm_button_press_depth,
"obj_to_goal_pos": self.obj_to_goal_pos,
"task_cue": self.task_cue,
"language_instruction": self.LANGUAGE_INSTRUCTION,
"oracle_info": self.oracle_info,
"reward_dict": self.reward_dict,
}
def _get_obs_extra(self, info: Dict):
obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose)
if self._obs_mode in ["state", "state_dict"]:
active_bulb_pose = self.lamp_bulb_off.pose.raw_pose.clone()
light_on = info["light_on"]
active_bulb_pose[light_on] = self.lamp_bulb_on.pose.raw_pose[light_on]
phase_pose = self.phase_lamp_red.pose.raw_pose.clone()
phase_ready = info["phase_ready"]
phase_pose[phase_ready] = self.phase_lamp_green.pose.raw_pose[phase_ready]
obs.update(
button_base_pose=self.button_base.pose.raw_pose,
button_cap_pose=self.button_cap.pose.raw_pose,
confirm_button_base_pose=self.confirm_button_base.pose.raw_pose,
confirm_button_cap_pose=self.confirm_button_cap.pose.raw_pose,
indicator_pose=active_bulb_pose,
phase_indicator_pose=phase_pose,
press_count=self.press_count,
target_blinks=self.target_blinks,
action_mask=info["action_mask"],
phase_ready=phase_ready,
press_ready=info["press_ready"],
pending_press=info["pending_press"],
submit_attempted=info["submit_attempted"],
submit_success=info["submit_success"],
at_lift_target=info["at_lift_target"],
oracle_info=self.oracle_info,
)
return obs
[docs]
def step(self, action):
obs, reward, terminated, truncated, info = super().step(action)
if isinstance(info, dict):
success = info.get("success", None)
failed = info.get("failed", None)
if torch.is_tensor(terminated):
terminated = terminated.to(dtype=torch.bool)
if torch.is_tensor(success):
terminated = terminated | success.to(dtype=torch.bool)
elif success is not None:
terminated = terminated | torch.as_tensor(success, device=terminated.device).to(dtype=torch.bool)
if torch.is_tensor(failed):
terminated = terminated | failed.to(dtype=torch.bool)
elif failed is not None:
terminated = terminated | torch.as_tensor(failed, device=terminated.device).to(dtype=torch.bool)
else:
terminated = bool(terminated)
if success is not None:
if torch.is_tensor(success):
terminated = terminated or bool(success.any().item())
else:
terminated = terminated or bool(success)
if failed is not None:
if torch.is_tensor(failed):
terminated = terminated or bool(failed.any().item())
else:
terminated = terminated or bool(failed)
return obs, reward, terminated, truncated, info
[docs]
def compute_dense_reward(self, obs: Any, action: Array, info: Dict):
tcp_pos = self.agent.tcp.pose.p
tcp_xy = tcp_pos[:, :2]
tcp_z = tcp_pos[:, 2]
xy_dist = torch.linalg.norm(tcp_xy - self.button_xy, axis=1)
z_above_button = torch.clamp(tcp_z - self.button_top_z, min=0.0)
tcp_to_button_dist = torch.linalg.norm(self.obj_to_goal_pos, axis=1)
tcp_to_confirm_dist = torch.linalg.norm(self.confirm_button_cap.pose.p - self.agent.tcp.pose.p, axis=1)
reaching_reward = 1 - torch.tanh(8.0 * tcp_to_button_dist)
confirm_reaching_reward = 1 - torch.tanh(8.0 * tcp_to_confirm_dist)
press_progress_reward = torch.clamp(info["press_depth"] / self.BUTTON_CAP_TRAVEL, min=0.0, max=1.0)
confirm_press_progress_reward = torch.clamp(
info["confirm_press_depth"] / self.BUTTON_CAP_TRAVEL, min=0.0, max=1.0
)
target = torch.clamp(info["target_blinks"].float(), min=1.0)
count_error = torch.abs(info["target_blinks"].float() - info["press_count"].float())
count_progress = 1.0 - torch.clamp(count_error / target, min=0.0, max=1.0)
new_raw_press_reward = info["new_raw_press_event"].float()
new_press_reward = info["new_press_event"].float()
new_release_reward = info["new_release_event"].float()
new_confirm_press_reward = info["new_confirm_press_event"].float()
pending_press_f = info["pending_press"].float()
phase_ready = info["action_mask"].float()
cue_phase = 1.0 - phase_ready
action_phase1 = phase_ready * (1.0 - pending_press_f)
action_phase2 = phase_ready * pending_press_f
lift_target_z = self.press_start_tcp_z + self.REQUIRED_LIFT_HEIGHT
lift_gap = torch.clamp(lift_target_z - tcp_z, min=0.0)
lift_reward = 1.0 - torch.tanh(4.0 * lift_gap)
vertical_alignment_reward = 1.0 - torch.tanh(8.0 * xy_dist)
if not torch.is_tensor(action):
action = torch.as_tensor(action, device=self.device)
if not hasattr(self, "_prev_action") or self._prev_action is None or self._prev_action.shape != action.shape:
self._prev_action = torch.zeros_like(action)
delta_action = action - self._prev_action
action_l2 = torch.linalg.norm(action, axis=1)
delta_action_l2 = torch.linalg.norm(delta_action, axis=1)
qvel_l2 = torch.linalg.norm(self.agent.robot.get_qvel()[..., :-2], axis=1)
smooth_penalty = (
self.ACTION_L2_COEF * torch.tanh(2.0 * action_l2)
+ self.ACTION_DELTA_L2_COEF * torch.tanh(5.0 * delta_action_l2)
+ self.QVEL_L2_COEF * torch.tanh(2.0 * qvel_l2)
)
cue_reach_reward = cue_phase * reaching_reward
cue_no_press_reward = cue_phase * (1.0 - press_progress_reward)
action_count_reward = action_phase1 * count_progress
action_new_raw_press_reward = action_phase1 * new_raw_press_reward
button_reach_reward = action_phase1 * reaching_reward
button_press_reward = action_phase1 * press_progress_reward
phase1_lift_penalty = action_phase1 * torch.clamp(
(z_above_button - 0.20) / self.REQUIRED_LIFT_HEIGHT, min=0.0, max=1.0
)
phase1_hover_penalty = action_phase1 * torch.clamp(
(z_above_button - 0.06) / self.REQUIRED_LIFT_HEIGHT, min=0.0, max=1.0
)
phase2_lift_reward = action_phase2 * lift_reward
phase2_vertical_reward = action_phase2 * vertical_alignment_reward
phase2_release_reward = action_phase2 * new_release_reward
confirm_cycle_reward = action_phase2 * new_press_reward
submit_phase = phase_ready * (info["press_count"] == info["target_blinks"]).float()
early_press_penalty = cue_phase * press_progress_reward
hold_down_penalty = action_phase2 * press_progress_reward
reward = (
0.25 * reaching_reward * phase_ready
+ 0.75 * cue_reach_reward
+ 0.75 * cue_no_press_reward
+ 1.5 * button_reach_reward
+ 2.0 * button_press_reward
+ 1.0 * action_count_reward
+ 4.0 * action_new_raw_press_reward
+ 5.0 * phase2_lift_reward
+ 2.0 * phase2_vertical_reward
+ 2.0 * phase2_release_reward
+ 6.0 * confirm_cycle_reward
+ 1.5 * submit_phase * confirm_reaching_reward
+ 2.0 * submit_phase * confirm_press_progress_reward
+ 10.0 * new_confirm_press_reward
- 1.5 * early_press_penalty
- 2.0 * hold_down_penalty
- 1.5 * phase1_lift_penalty
- 3.0 * phase1_hover_penalty
- smooth_penalty
)
reward = reward * phase_ready
reward -= self.FAILURE_PENALTY * info["failed"].float()
reward[info["success"]] = self.SUCCESS_BONUS
self.reward_dict = {
"reaching_reward": reaching_reward,
"press_progress_reward": press_progress_reward,
"count_progress": count_progress,
"new_raw_press_reward": new_raw_press_reward,
"new_press_reward": new_press_reward,
"new_release_reward": new_release_reward,
"cue_reach_reward": cue_reach_reward,
"cue_no_press_reward": cue_no_press_reward,
"action_phase1": action_phase1,
"action_phase2": action_phase2,
"button_reach_reward": button_reach_reward,
"button_press_reward": button_press_reward,
"action_count_reward": action_count_reward,
"action_new_raw_press_reward": action_new_raw_press_reward,
"phase2_lift_reward": phase2_lift_reward,
"phase2_vertical_reward": phase2_vertical_reward,
"phase2_release_reward": phase2_release_reward,
"confirm_cycle_reward": confirm_cycle_reward,
"early_press_penalty": early_press_penalty,
"phase1_lift_penalty": phase1_lift_penalty,
"phase1_hover_penalty": phase1_hover_penalty,
"hold_down_penalty": hold_down_penalty,
"tcp_to_button_dist": tcp_to_button_dist,
"tcp_to_confirm_dist": tcp_to_confirm_dist,
"xy_dist_to_button": xy_dist,
"lift_gap": lift_gap,
"confirm_reaching_reward": confirm_reaching_reward,
"confirm_press_progress_reward": confirm_press_progress_reward,
"new_confirm_press_reward": new_confirm_press_reward,
"submit_phase": submit_phase,
"press_count": info["press_count"].float(),
"target_blinks": info["target_blinks"].float(),
"failed": info["failed"].float(),
"action_l2": action_l2,
"delta_action_l2": delta_action_l2,
"qvel_l2": qvel_l2,
"smooth_penalty": smooth_penalty,
}
self._prev_action = action.detach()
return reward
[docs]
def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict):
return self.compute_dense_reward(obs=obs, action=action, info=info) / self.SUCCESS_BONUS
# ----- Standard tasks -----
[docs]
@register_env("BlinkCountButtonPressEasy-VLA-v0", max_episode_steps=150)
class BlinkCountButtonPressEasyVLAEnv(BlinkCountButtonPressVLABaseEnv):
BLINK_COUNT_RANGE: List[int] = [1, 3]
PRE_BLINK_OFF_STEPS: List[int] = [5, 10]
BLINK_ON_STEPS: List[int] = [3, 3]
BLINK_OFF_STEPS: List[int] = [3, 5]
[docs]
@register_env("BlinkCountButtonPressMedium-VLA-v0", max_episode_steps=200)
class BlinkCountButtonPressMediumVLAEnv(BlinkCountButtonPressVLABaseEnv):
BLINK_COUNT_RANGE: List[int] = [1, 5]
PRE_BLINK_OFF_STEPS: List[int] = [5, 10]
BLINK_ON_STEPS: List[int] = [3, 3]
BLINK_OFF_STEPS: List[int] = [3, 5]
[docs]
@register_env("BlinkCountButtonPressHard-VLA-v0", max_episode_steps=300)
class BlinkCountButtonPressHardVLAEnv(BlinkCountButtonPressVLABaseEnv):
BLINK_COUNT_RANGE: List[int] = [1, 7]
PRE_BLINK_OFF_STEPS: List[int] = [5, 10]
BLINK_ON_STEPS: List[int] = [3, 3]
BLINK_OFF_STEPS: List[int] = [3, 5]
# ----- Long-horizon tasks -----
[docs]
@register_env("BlinkCountButtonPressEasy-Long-VLA-v0", max_episode_steps=1200)
class BlinkCountButtonPressEasyLongVLAEnv(BlinkCountButtonPressVLABaseEnv):
BLINK_COUNT_RANGE: List[int] = [1, 10]
PRE_BLINK_OFF_STEPS: List[int] = [5, 50]
BLINK_ON_STEPS: List[int] = [3, 3]
BLINK_OFF_STEPS: List[int] = [3, 5]
[docs]
@register_env("BlinkCountButtonPressMedium-Long-VLA-v0", max_episode_steps=1200)
class BlinkCountButtonPressMediumLongVLAEnv(BlinkCountButtonPressVLABaseEnv):
BLINK_COUNT_RANGE: List[int] = [10, 20]
PRE_BLINK_OFF_STEPS: List[int] = [5, 50]
BLINK_ON_STEPS: List[int] = [3, 3]
BLINK_OFF_STEPS: List[int] = [3, 5]
[docs]
@register_env("BlinkCountButtonPressHard-Long-VLA-v0", max_episode_steps=1200)
class BlinkCountButtonPressHardLongVLAEnv(BlinkCountButtonPressVLABaseEnv):
BLINK_COUNT_RANGE: List[int] = [20, 30]
PRE_BLINK_OFF_STEPS: List[int] = [5, 50]
BLINK_ON_STEPS: List[int] = [3, 3]
BLINK_OFF_STEPS: List[int] = [3, 5]