Source code for mikasa_robo_suite.vla.memory_envs.gather_and_recall_vla

"""Gather-and-recall VLA task: move cubes to a disc and remember a lamp flash color."""

from typing import Any, Dict, List, Union

import numpy as np
import sapien
import torch
from mani_skill.agents.robots.panda.panda import Panda
from mani_skill.agents.robots.panda.panda_wristcam import PandaWristCam
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import sapien_utils
from mani_skill.utils.building import actors
from mani_skill.utils.building.actors.common import _build_by_type
from mani_skill.utils.registration import register_env
from mani_skill.utils.scene_builder.table import TableSceneBuilder
from mani_skill.utils.structs.pose import Pose
from mani_skill.utils.structs.types import Array, GPUMemoryConfig, SimConfig
from transforms3d.euler import euler2quat

from mikasa_robo_suite.vla.utils import shapes


[docs] class GatherAndRecallVLABaseEnv(BaseEnv): """Move cubes onto a target disc while remembering a brief lamp color flash. Cubes start in a cluster on one side of the table, with a target disc on the other side. The agent picks up cubes one by one and places them on the disc. While the agent is moving cubes (after the first cube lands on the disc but before the last), a signal lamp briefly flashes one of three colors: red, green, or blue. After all cubes are placed on the disc, the agent must press the button whose color matches the flash. Episode flow: 1. MOVE phase: pick and place cubes onto the disc. 2. During moving, the lamp flashes a random color once (briefly). 3. PRESS phase: once all cubes are on the disc, press the matching button. Success (`success=True`): - All cubes detected on the disc AND the correct color button is pressed. Failure (`failed=True`): - A wrong color button is pressed after all cubes are placed. How to customize: - `N_CUBES` controls difficulty (more cubes = longer distraction, harder memory). - `FLASH_DURATION_STEPS` controls how long the lamp stays on ([min, max]). - `DISC_RADIUS` controls the target disc size. - `DISC_ON_THRESH` controls the XY distance threshold for cube-on-disc detection. """ LANGUAGE_INSTRUCTION = ( "Move all cubes onto the disc. A lamp will briefly flash " "while you work. After all cubes are placed, press the button matching the flash color." ) SUPPORTED_ROBOTS = ["panda", "panda_wristcam"] agent: Union[Panda, PandaWristCam] N_CUBES: int = 5 CUBE_HALF_SIZE: float = 0.02 DISC_RADIUS: float = 0.12 DISC_HALF_HEIGHT: float = 0.003 DISC_ON_THRESH: float = 0.10 HEIGHT_OFFSET: float = 1000.0 FLASH_DURATION_STEPS: List[int] = [8, 14] # Flash / button colors (red, green, blue) FLASH_COLORS = [ np.array([255, 0, 0, 255], dtype=np.float32) / 255.0, np.array([0, 255, 0, 255], dtype=np.float32) / 255.0, np.array([0, 0, 255, 255], dtype=np.float32) / 255.0, ] # Cube colors must match button colors exactly. CUBE_COLORS = FLASH_COLORS BUTTON_BASE_HALF_SIZE = np.array([0.04, 0.04, 0.015], dtype=np.float32) BUTTON_CAP_RADIUS = 0.025 BUTTON_CAP_HALF_HEIGHT = 0.014 BUTTON_CAP_TRAVEL = BUTTON_CAP_HALF_HEIGHT BUTTON_PRESS_EVENT_RATIO = 0.35 BUTTON_PRESS_XY_RADIUS = 0.04 BUTTON_PRESS_Z_MARGIN = 0.03 BUTTON_SPACING = 0.14 # Scene layout DISC_X_MIN = 0.00 DISC_X_MAX = 0.08 DISC_Y_MIN = 0.10 DISC_Y_MAX = 0.18 BUTTON_X_OFFSET_FROM_DISC = -0.22 CUBE_CLUSTER_X_OFFSET = -0.08 CUBE_CLUSTER_CENTER_Y = -0.26 CUBE_CLUSTER_SPACING_SCALE = 4.5 LAMP_X_OFFSET_FROM_DISC = 0.24 LAMP_Y_OFFSET_FROM_DISC = 0.06 LAMP_BASE_RADIUS = 0.018 LAMP_BASE_HALF_HEIGHT = 0.008 LAMP_STEM_RADIUS = 0.004 LAMP_STEM_HALF_HEIGHT = 0.020 LAMP_BULB_RADIUS = 0.012 GRASP_THRESH = 0.05 CUBE_VEL_THRESH = 0.15 ACTION_L2_COEF = 0.01 ACTION_DELTA_L2_COEF = 0.03 QVEL_L2_COEF = 0.01 SUCCESS_BONUS = 50.0 FAILURE_PENALTY = 25.0 def __init__(self, *args, robot_uids="panda_wristcam", robot_init_qpos_noise=0.02, **kwargs): self.robot_init_qpos_noise = robot_init_qpos_noise super().__init__(*args, robot_uids=robot_uids, **kwargs) @property def _default_sim_config(self): return SimConfig( gpu_memory_config=GPUMemoryConfig( found_lost_pairs_capacity=2**25, max_rigid_contact_count=2**21, max_rigid_patch_count=2**18, ) ) @property def _default_sensor_configs(self): pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1]) return [CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)] @property def _default_human_render_camera_configs(self): pose = sapien_utils.look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.15]) return CameraConfig("render_camera", pose, 512, 512, 1, 0.01, 100) def _load_agent(self, options: dict): super()._load_agent(options, sapien.Pose(p=[-0.615, 0, 0])) def _load_scene(self, options: dict): self.table_scene = TableSceneBuilder(self, robot_init_qpos_noise=self.robot_init_qpos_noise) self.table_scene.build() default_hidden_pose = sapien.Pose(p=[0.0, 0.0, self.HEIGHT_OFFSET]) n = self.num_envs d = self.device # ── Cubes (dynamic) ────────────────────────────────────────────── self.cubes = [] for i in range(self.N_CUBES): color = self.CUBE_COLORS[i % len(self.CUBE_COLORS)] cube = actors.build_cube( self.scene, half_size=self.CUBE_HALF_SIZE, color=color, name=f"gather_cube_{i}", body_type="dynamic", initial_pose=sapien.Pose(p=[0, 0, self.CUBE_HALF_SIZE]), ) self.cubes.append(cube) # ── Target disc (kinematic) ────────────────────────────────────── disc_builder = self.scene.create_actor_builder() disc_builder.add_cylinder_collision( radius=self.DISC_RADIUS, half_length=self.DISC_HALF_HEIGHT, ) disc_builder.add_cylinder_visual( radius=self.DISC_RADIUS, half_length=self.DISC_HALF_HEIGHT, material=sapien.render.RenderMaterial( base_color=np.array([160, 160, 170, 255], dtype=np.float32) / 255.0, ), ) self.disc = _build_by_type( disc_builder, name="target_disc", body_type="kinematic", initial_pose=default_hidden_pose, ) self.disc_quat = torch.tensor( euler2quat(0, np.pi / 2, 0), dtype=torch.float32, device=d, ) # ── 3 colored buttons (red / green / blue, kinematic) ──────────── self.button_bases = [] self.button_caps = [] self.button_cap_quat = torch.tensor( euler2quat(0, np.pi / 2, 0), dtype=torch.float32, device=d, ) for i, color in enumerate(self.FLASH_COLORS): base_builder = self.scene.create_actor_builder() base_builder.add_box_collision(half_size=self.BUTTON_BASE_HALF_SIZE) base_builder.add_box_visual( half_size=self.BUTTON_BASE_HALF_SIZE, material=sapien.render.RenderMaterial( base_color=np.array([55, 64, 78, 255]) / 255.0, ), ) base = _build_by_type( base_builder, name=f"btn_base_{i}", body_type="kinematic", initial_pose=default_hidden_pose, ) self.button_bases.append(base) cap_builder = self.scene.create_actor_builder() cap_builder.add_cylinder_collision( radius=self.BUTTON_CAP_RADIUS, half_length=self.BUTTON_CAP_HALF_HEIGHT, ) cap_builder.add_cylinder_visual( radius=self.BUTTON_CAP_RADIUS, half_length=self.BUTTON_CAP_HALF_HEIGHT, material=sapien.render.RenderMaterial(base_color=color), ) cap = _build_by_type( cap_builder, name=f"btn_cap_{i}", body_type="kinematic", initial_pose=default_hidden_pose, ) self.button_caps.append(cap) # ── Signal lamp (one body/off-bulb, three colored on-bulbs) ────── self.lamp_bulbs_on = [] self._extra_lamp_bodies = [] self._extra_lamp_offs = [] for i, color in enumerate(self.FLASH_COLORS): parts = shapes.build_color_switch_lamp( scene=self.scene, name=f"signal_lamp_{i}", body_type="kinematic", add_collision=False, initial_pose=default_hidden_pose, base_radius=self.LAMP_BASE_RADIUS, base_half_height=self.LAMP_BASE_HALF_HEIGHT, stem_radius=self.LAMP_STEM_RADIUS, stem_half_height=self.LAMP_STEM_HALF_HEIGHT, bulb_radius=self.LAMP_BULB_RADIUS, bulb_on_color=color, ) shapes._set_actor_visual_rgba( parts["bulb_on"], color, emission_scale=20.0, remove_textures=True, ) self.lamp_bulbs_on.append(parts["bulb_on"]) if i == 0: self.lamp_body = parts["body"] self.lamp_bulb_off = parts["bulb_off"] else: self._extra_lamp_bodies.append(parts["body"]) self._extra_lamp_offs.append(parts["bulb_off"]) # ── State tensors ──────────────────────────────────────────────── self.cubes_on_disc = torch.zeros((n, self.N_CUBES), dtype=torch.bool, device=d) self.flash_color = torch.zeros(n, dtype=torch.int64, device=d) self.flash_trigger_count = torch.zeros(n, dtype=torch.int64, device=d) self.flash_triggered = torch.zeros(n, dtype=torch.bool, device=d) self.flash_start_step = torch.zeros(n, dtype=torch.int64, device=d) self.flash_duration = torch.zeros(n, dtype=torch.int64, device=d) self.disc_xy = torch.zeros((n, 2), dtype=torch.float32, device=d) self.disc_place_pos = torch.zeros((n, 3), dtype=torch.float32, device=d) self.lamp_visible_pos = torch.zeros((n, 3), dtype=torch.float32, device=d) self.lamp_hidden_pos = torch.zeros((n, 3), dtype=torch.float32, device=d) self.buttons_xy = torch.zeros((3, n, 2), dtype=torch.float32, device=d) self.button_top_z = torch.zeros(n, dtype=torch.float32, device=d) self.button_cap_unpressed_z = torch.zeros(n, dtype=torch.float32, device=d) self.pressed_button = torch.full((n,), -1, dtype=torch.int64, device=d) self.failed = torch.zeros(n, dtype=torch.bool, device=d) self.success_flag = torch.zeros(n, dtype=torch.bool, device=d) # ------------------------------------------------------------------ # Episode initialisation # ------------------------------------------------------------------ def _initialize_episode(self, env_idx: torch.Tensor, options: dict): with torch.device(self.device): b = len(env_idx) self.table_scene.initialize(env_idx) env_idx = env_idx.to(self.device) self.task_cue = None self.reward_dict = None # ── Reset state ────────────────────────────────────────────── self.cubes_on_disc[env_idx] = False self.flash_triggered[env_idx] = False self.flash_start_step[env_idx] = 0 self.pressed_button[env_idx] = -1 self.failed[env_idx] = False self.success_flag[env_idx] = False self.flash_color[env_idx] = torch.randint( 0, 3, (b,), device=self.device, dtype=torch.int64, ) # For N_CUBES=1, randint(1, 1, ...) is invalid; clamp upper bound. flash_trigger_hi = max(2, self.N_CUBES) self.flash_trigger_count[env_idx] = torch.randint( 1, flash_trigger_hi, (b,), device=self.device, dtype=torch.int64, ) lo, hi = self.FLASH_DURATION_STEPS self.flash_duration[env_idx] = torch.randint( lo, hi + 1, (b,), device=self.device, dtype=torch.int64, ) default_q = torch.tensor( [1.0, 0.0, 0.0, 0.0], device=self.device, ).repeat(b, 1) # ── Position disc (far side, farther from robot) ───────────── disc_xyz = torch.zeros((b, 3), device=self.device) disc_xyz[:, 0] = torch.rand(b, device=self.device) * (self.DISC_X_MAX - self.DISC_X_MIN) + self.DISC_X_MIN disc_xyz[:, 1] = torch.rand(b, device=self.device) * (self.DISC_Y_MAX - self.DISC_Y_MIN) + self.DISC_Y_MIN disc_xyz[:, 2] = self.DISC_HALF_HEIGHT disc_q = self.disc_quat.unsqueeze(0).repeat(b, 1) self.disc.set_pose(Pose.create_from_pq(p=disc_xyz, q=disc_q)) self.disc_xy[env_idx] = disc_xyz[:, :2] self.disc_place_pos[env_idx, :2] = disc_xyz[:, :2] self.disc_place_pos[env_idx, 2] = self.DISC_HALF_HEIGHT * 2 + self.CUBE_HALF_SIZE # ── Position cubes (cluster away from button row) ──────────── cube_center_x = disc_xyz[:, 0] + self.CUBE_CLUSTER_X_OFFSET cube_center_y = self.CUBE_CLUSTER_CENTER_Y spacing = self.CUBE_HALF_SIZE * self.CUBE_CLUSTER_SPACING_SCALE n_rows = (self.N_CUBES + 2) // 3 row_center = (n_rows - 1) / 2.0 for i in range(self.N_CUBES): row = i // 3 col = i % 3 n_in_row = min(3, self.N_CUBES - row * 3) col_center = (n_in_row - 1) / 2.0 cube_xyz = torch.zeros((b, 3), device=self.device) # Orient the cluster primarily along OX (was along OY). cube_xyz[:, 0] = cube_center_x + (col - col_center) * spacing cube_xyz[:, 1] = cube_center_y + (row - row_center) * spacing cube_xyz[:, 0] += (torch.rand(b, device=self.device) - 0.5) * 0.01 cube_xyz[:, 1] += (torch.rand(b, device=self.device) - 0.5) * 0.01 cube_xyz[:, 2] = self.CUBE_HALF_SIZE self.cubes[i].set_pose( Pose.create_from_pq(p=cube_xyz, q=default_q), ) lin_vel = self.cubes[i].linear_velocity.clone() ang_vel = self.cubes[i].angular_velocity.clone() lin_vel[env_idx] = 0 ang_vel[env_idx] = 0 self.cubes[i].set_linear_velocity(lin_vel) self.cubes[i].set_angular_velocity(ang_vel) # ── Position buttons (closer to robot than disc) ───────────── cap_q = self.button_cap_quat.unsqueeze(0).repeat(b, 1) base_z = float(self.BUTTON_BASE_HALF_SIZE[2]) unpressed_z = base_z * 2.0 + self.BUTTON_CAP_HALF_HEIGHT self.button_cap_unpressed_z[env_idx] = unpressed_z self.button_top_z[env_idx] = unpressed_z + self.BUTTON_CAP_HALF_HEIGHT btn_base_x = disc_xyz[:, 0] + self.BUTTON_X_OFFSET_FROM_DISC btn_base_y = disc_xyz[:, 1] for btn_idx in range(3): btn_xyz = torch.zeros((b, 3), device=self.device) btn_xyz[:, 0] = btn_base_x btn_xyz[:, 1] = btn_base_y + (btn_idx - 1) * self.BUTTON_SPACING btn_xyz[:, 2] = base_z self.button_bases[btn_idx].set_pose( Pose.create_from_pq(p=btn_xyz, q=default_q), ) cap_xyz = btn_xyz.clone() cap_xyz[:, 2] = unpressed_z self.button_caps[btn_idx].set_pose( Pose.create_from_pq(p=cap_xyz, q=cap_q), ) self.buttons_xy[btn_idx, env_idx] = btn_xyz[:, :2] # ── Position lamp (farther from disc and robot) ───────────── lamp_xyz = torch.zeros((b, 3), device=self.device) lamp_xyz[:, 0] = disc_xyz[:, 0] + self.LAMP_X_OFFSET_FROM_DISC lamp_xyz[:, 1] = disc_xyz[:, 1] + self.LAMP_Y_OFFSET_FROM_DISC lamp_xyz[:, 2] = 0.0 hidden_xyz = lamp_xyz.clone() hidden_xyz[:, 2] += self.HEIGHT_OFFSET self.lamp_visible_pos[env_idx] = lamp_xyz self.lamp_hidden_pos[env_idx] = hidden_xyz self.lamp_body.set_pose( Pose.create_from_pq(p=lamp_xyz, q=default_q), ) self.lamp_bulb_off.set_pose( Pose.create_from_pq(p=lamp_xyz, q=default_q), ) for on_bulb in self.lamp_bulbs_on: on_bulb.set_pose( Pose.create_from_pq(p=hidden_xyz, q=default_q), ) for body in self._extra_lamp_bodies: body.set_pose( Pose.create_from_pq(p=hidden_xyz, q=default_q), ) for off in self._extra_lamp_offs: off.set_pose( Pose.create_from_pq(p=hidden_xyz, q=default_q), ) # ── Oracle info (flash colour index) ───────────────────────── self.oracle_info = self.flash_color.to(torch.uint8) # ── Reset robot ────────────────────────────────────────────── if self.robot_uids in ("panda", "panda_wristcam"): qpos = np.array( [0.0, 0, 0, -np.pi * 2 / 3, 0, np.pi * 2 / 3, np.pi / 4, 0.04, 0.04], ) qpos[:-2] += self._episode_rng.normal( 0, self.robot_init_qpos_noise, len(qpos) - 2, ) self.agent.reset(qpos) self.agent.robot.set_root_pose(sapien.Pose([-0.615, 0, 0])) else: raise NotImplementedError(self.robot_uids) if hasattr(self, "_prev_action") and torch.is_tensor(self._prev_action): if self._prev_action.shape[0] >= int(env_idx.max().item()) + 1: self._prev_action[env_idx] = 0 # ------------------------------------------------------------------ # Evaluate # ------------------------------------------------------------------
[docs] def evaluate(self): elapsed = self.elapsed_steps.to(torch.int64) if elapsed.dim() > 1: elapsed = elapsed.squeeze(-1) # ── Cube-on-disc detection (sticky + current) ──────────────────── current_cubes_on_disc = torch.zeros( (self.num_envs, self.N_CUBES), dtype=torch.bool, device=self.device, ) for i in range(self.N_CUBES): cube_pos = self.cubes[i].pose.p xy_dist = torch.linalg.norm( cube_pos[:, :2] - self.disc_xy, dim=1, ) z_on_table = cube_pos[:, 2] < 0.5 cube_vel = torch.linalg.norm( self.cubes[i].linear_velocity, dim=1, ) on_disc = (xy_dist < self.DISC_ON_THRESH) & z_on_table & (cube_vel < self.CUBE_VEL_THRESH) current_cubes_on_disc[:, i] = on_disc self.cubes_on_disc[:, i] = self.cubes_on_disc[:, i] | on_disc n_on_disc = self.cubes_on_disc.sum(dim=1) all_on_disc = current_cubes_on_disc.all(dim=1) # ── Flash triggering ───────────────────────────────────────────── trigger = (~self.flash_triggered) & (n_on_disc >= self.flash_trigger_count) self.flash_triggered = self.flash_triggered | trigger self.flash_start_step[trigger] = elapsed[trigger] flash_elapsed = elapsed - self.flash_start_step flash_active = self.flash_triggered & (flash_elapsed >= 0) & (flash_elapsed < self.flash_duration) # ── Lamp control ───────────────────────────────────────────────── off_pose = self.lamp_bulb_off.pose.raw_pose.clone() off_pose[flash_active, :3] = self.lamp_hidden_pos[flash_active] off_pose[~flash_active, :3] = self.lamp_visible_pos[~flash_active] self.lamp_bulb_off.pose = off_pose for color_idx in range(3): color_mask = flash_active & (self.flash_color == color_idx) on_pose = self.lamp_bulbs_on[color_idx].pose.raw_pose.clone() on_pose[color_mask, :3] = self.lamp_visible_pos[color_mask] on_pose[~color_mask, :3] = self.lamp_hidden_pos[~color_mask] self.lamp_bulbs_on[color_idx].pose = on_pose # ── Button press detection ─────────────────────────────────────── tcp_pos = self.agent.tcp.pose.p tcp_xy = tcp_pos[:, :2] tcp_z = tcp_pos[:, 2] for btn_idx in range(3): btn_xy = self.buttons_xy[btn_idx] xy_dist = torch.linalg.norm(tcp_xy - btn_xy, dim=1) raw_depth = self.button_top_z + self.BUTTON_PRESS_Z_MARGIN - tcp_z depth = torch.clamp(raw_depth, min=0.0, max=self.BUTTON_CAP_TRAVEL) depth = depth * (xy_dist < self.BUTTON_PRESS_XY_RADIUS).float() # Visual cap depression cap_pose = self.button_caps[btn_idx].pose.raw_pose.clone() cap_pose[:, 0:2] = btn_xy cap_pose[:, 2] = self.button_cap_unpressed_z - depth cap_pose[:, 3:7] = self.button_cap_quat.repeat( cap_pose.shape[0], 1, ) self.button_caps[btn_idx].pose = cap_pose # Detect press (only after all cubes placed, first press only) pressed = depth >= (self.BUTTON_CAP_TRAVEL * self.BUTTON_PRESS_EVENT_RATIO) new_press = pressed & all_on_disc & (self.pressed_button == -1) self.pressed_button[new_press] = btn_idx # ── Success / failure ──────────────────────────────────────────── button_pressed_mask = self.pressed_button >= 0 correct = button_pressed_mask & (self.pressed_button == self.flash_color) & all_on_disc wrong = button_pressed_mask & (self.pressed_button != self.flash_color) self.success_flag = self.success_flag | correct self.failed = self.failed | wrong success = self.success_flag & all_on_disc # ── Reaching target (obj_to_goal_pos) ──────────────────────────── obj_to_goal = torch.zeros_like(tcp_pos) # Move phase: reach toward nearest unplaced cube or disc min_dist = torch.full( (self.num_envs,), float("inf"), device=self.device, ) holding_any = torch.zeros( self.num_envs, dtype=torch.bool, device=self.device, ) for i in range(self.N_CUBES): unplaced = ~self.cubes_on_disc[:, i] & ~all_on_disc if not unplaced.any(): continue cube_pos = self.cubes[i].pose.p diff = cube_pos - tcp_pos dist = torch.linalg.norm(diff, dim=1) grasped = unplaced & (dist < self.GRASP_THRESH) holding_any = holding_any | grasped reach_mask = unplaced & ~grasped closer = reach_mask & (dist < min_dist) obj_to_goal[closer] = diff[closer] min_dist = torch.where(closer, dist, min_dist) # If holding a cube, target the disc instead if holding_any.any(): disc_target = self.disc_place_pos - tcp_pos obj_to_goal[holding_any] = disc_target[holding_any] # Press phase: target the correct button for btn_idx in range(3): btn_mask = all_on_disc & (self.flash_color == btn_idx) & ~button_pressed_mask if btn_mask.any(): btn_pos = torch.zeros(self.num_envs, 3, device=self.device) btn_pos[:, :2] = self.buttons_xy[btn_idx] btn_pos[:, 2] = self.button_top_z + 0.005 obj_to_goal[btn_mask] = (btn_pos - tcp_pos)[btn_mask] self.obj_to_goal_pos = obj_to_goal return { "success": success, "failed": self.failed, "n_on_disc": n_on_disc, "all_on_disc": all_on_disc, "flash_active": flash_active, "flash_triggered": self.flash_triggered, "flash_color": self.flash_color, "pressed_button": self.pressed_button, "obj_to_goal_pos": obj_to_goal, "task_cue": self.task_cue, "language_instruction": self.LANGUAGE_INSTRUCTION, "oracle_info": self.oracle_info, "reward_dict": self.reward_dict, }
# ------------------------------------------------------------------ # Observation extras # ------------------------------------------------------------------ def _get_obs_extra(self, info: Dict): obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose) if self._obs_mode in ["state", "state_dict"]: obs.update( obj_to_goal_pos=self.obj_to_goal_pos, oracle_info=self.oracle_info, disc_pose=self.disc.pose.raw_pose, cubes_on_disc=self.cubes_on_disc, n_on_disc=info["n_on_disc"], all_on_disc=info["all_on_disc"], flash_active=info["flash_active"], flash_color=self.flash_color, pressed_button=self.pressed_button, ) for i in range(self.N_CUBES): obs[f"cube_{i}_pose"] = self.cubes[i].pose.raw_pose return obs # ------------------------------------------------------------------ # Step override (terminate on success or failure) # ------------------------------------------------------------------
[docs] def step(self, action): obs, reward, terminated, truncated, info = super().step(action) if isinstance(info, dict): success = info.get("success", None) failed = info.get("failed", None) if torch.is_tensor(terminated): terminated = terminated.to(dtype=torch.bool) if torch.is_tensor(success): terminated = terminated | success.to(dtype=torch.bool) if torch.is_tensor(failed): terminated = terminated | failed.to(dtype=torch.bool) else: terminated = bool(terminated) if success is not None: if torch.is_tensor(success): terminated = terminated or bool(success.any().item()) else: terminated = terminated or bool(success) if failed is not None: if torch.is_tensor(failed): terminated = terminated or bool(failed.any().item()) else: terminated = terminated or bool(failed) return obs, reward, terminated, truncated, info
# ------------------------------------------------------------------ # Dense reward # ------------------------------------------------------------------
[docs] def compute_dense_reward(self, obs: Any, action: Array, info: Dict): n_on_disc = info["n_on_disc"].float() all_on_disc_f = info["all_on_disc"].float() tcp_to_obj_dist = torch.linalg.norm(self.obj_to_goal_pos, dim=1) reaching_reward = 1 - torch.tanh(8.0 * tcp_to_obj_dist) # Progress reward: fraction of cubes placed place_progress = n_on_disc / self.N_CUBES # Button press reward (correct button depth) button_press_reward = torch.zeros(self.num_envs, device=self.device) tcp_xy = self.agent.tcp.pose.p[:, :2] tcp_z = self.agent.tcp.pose.p[:, 2] for btn_idx in range(3): btn_xy = self.buttons_xy[btn_idx] xy_dist = torch.linalg.norm(tcp_xy - btn_xy, dim=1) raw_depth = self.button_top_z + self.BUTTON_PRESS_Z_MARGIN - tcp_z depth = torch.clamp(raw_depth, min=0.0, max=self.BUTTON_CAP_TRAVEL) depth = depth * (xy_dist < self.BUTTON_PRESS_XY_RADIUS).float() depth_norm = torch.clamp( depth / self.BUTTON_CAP_TRAVEL, min=0.0, max=1.0, ) correct_mask = (self.flash_color == btn_idx) & info["all_on_disc"] button_press_reward += correct_mask.float() * depth_norm # Smoothness penalty if not torch.is_tensor(action): action = torch.as_tensor(action, device=self.device) if not hasattr(self, "_prev_action") or self._prev_action is None or self._prev_action.shape != action.shape: self._prev_action = torch.zeros_like(action) delta_action = action - self._prev_action action_l2 = torch.linalg.norm(action, dim=1) delta_action_l2 = torch.linalg.norm(delta_action, dim=1) qvel_l2 = torch.linalg.norm( self.agent.robot.get_qvel()[..., :-2], dim=1, ) smooth_penalty = ( self.ACTION_L2_COEF * torch.tanh(2.0 * action_l2) + self.ACTION_DELTA_L2_COEF * torch.tanh(5.0 * delta_action_l2) + self.QVEL_L2_COEF * torch.tanh(2.0 * qvel_l2) ) reward = ( 2.0 * reaching_reward + 10.0 * place_progress + 3.0 * all_on_disc_f * reaching_reward + 5.0 * all_on_disc_f * button_press_reward - smooth_penalty ) reward[info["failed"]] = -self.FAILURE_PENALTY reward[info["success"]] = self.SUCCESS_BONUS self.reward_dict = { "reaching_reward": reaching_reward, "place_progress": place_progress, "n_on_disc": n_on_disc, "button_press_reward": button_press_reward, "smooth_penalty": smooth_penalty, "action_l2": action_l2, "delta_action_l2": delta_action_l2, "qvel_l2": qvel_l2, } self._prev_action = action.detach() return reward
[docs] def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict): return self.compute_dense_reward(obs=obs, action=action, info=info) / self.SUCCESS_BONUS
# ── Standard variants ────────────────────────────────────────────────────
[docs] @register_env("GatherAndRecall1-VLA-v0", max_episode_steps=200) class GatherAndRecall1VLAEnv(GatherAndRecallVLABaseEnv): N_CUBES = 1
[docs] @register_env("GatherAndRecall3-VLA-v0", max_episode_steps=400) class GatherAndRecall3VLAEnv(GatherAndRecallVLABaseEnv): N_CUBES = 3
[docs] @register_env("GatherAndRecall5-VLA-v0", max_episode_steps=600) class GatherAndRecall5VLAEnv(GatherAndRecallVLABaseEnv): N_CUBES = 5
[docs] @register_env("GatherAndRecall7-VLA-v0", max_episode_steps=800) class GatherAndRecall7VLAEnv(GatherAndRecallVLABaseEnv): N_CUBES = 7
[docs] @register_env("GatherAndRecall9-VLA-v0", max_episode_steps=1000) class GatherAndRecall9VLAEnv(GatherAndRecallVLABaseEnv): N_CUBES = 9