Source code for mikasa_robo_suite.vla.memory_envs.batteries_checker_hard_vla

"""Batteries Checker tasks for the VLA memory benchmark."""

from typing import Any, Dict, Union

import numpy as np
import sapien
import torch
from mani_skill.agents.robots.panda.panda import Panda
from mani_skill.agents.robots.panda.panda_wristcam import PandaWristCam
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import sapien_utils
from mani_skill.utils.building.actors.common import _build_by_type
from mani_skill.utils.registration import register_env
from mani_skill.utils.scene_builder.table import TableSceneBuilder
from mani_skill.utils.structs.pose import Pose
from mani_skill.utils.structs.types import Array, GPUMemoryConfig, SimConfig
from transforms3d.euler import euler2quat

from mikasa_robo_suite.vla.utils import shapes


[docs] class BatteriesCheckerVLABaseEnv(BaseEnv): """Battery-testing task with memory over repeated check cycles. The scene contains a tray of batteries, a socket with a lamp, and a button. The robot must test batteries one at a time, remember which ones worked, and only then move on to the next candidate. Because each battery must be returned to its original slot before the next confirmation, the task mixes memory with careful sequential manipulation. Episode flow: - Pick one battery from the tray and insert it into the socket. - Read the lamp outcome: lit means the battery is working. - Remove the battery and return it to its original tray slot. - Press the button to mark that this battery has been checked. Success (`success=True`): - Every working battery must be discovered through the full insert-return- confirm procedure. Partial progress does not count as success. How to customize: - `ACTIVE_BATTERY_COUNT` controls how many batteries are present in the episode and therefore how long the search can become. - `WORKING_BATTERY_COUNT` controls how many of those batteries are true positives that the agent must eventually identify. - `SOCKET_INSERT_XY_TOL` and `SOCKET_INSERT_Z_TOL` control how precisely a battery must be placed before the environment counts it as inserted. - `SLOT_RETURN_XY_TOL` and `SLOT_RETURN_Z_TOL` control how accurately the battery must be put back into its home slot. - `BUTTON_*` parameters control the size, travel, and press thresholds of the confirmation button. - `LAMP_AFTERGLOW_STEPS` controls how long the lamp remains visibly on after a successful working-battery test. """ LANGUAGE_INSTRUCTION = "Find all working batteries by inserting each one into the socket, observing the lamp result, returning it from the socket to its initial slot, and then pressing the button to confirm." SUPPORTED_ROBOTS = ["panda", "panda_wristcam"] agent: Union[Panda, PandaWristCam] HEIGHT_OFFSET = 1000.0 TRAY_ROWS = 5 TRAY_COLS = 3 NUM_BATTERIES = TRAY_ROWS * TRAY_COLS ACTIVE_BATTERY_COUNT = 15 WORKING_BATTERY_COUNT = 3 SLOT_SPACING_X = 0.052 SLOT_SPACING_Y = 0.050 BATTERY_RADIUS = 0.010 BATTERY_HALF_HEIGHT = 0.030 BATTERY_COLOR = np.array([70, 190, 90, 255], dtype=np.float32) / 255.0 BATTERY_STATIC_FRICTION = 2.0 BATTERY_DYNAMIC_FRICTION = 2.0 BATTERY_RESTITUTION = 0.0 TRAY_HALF_HEIGHT = BATTERY_HALF_HEIGHT * 0.5 TRAY_PADDING_X = 0.028 TRAY_PADDING_Y = 0.026 TRAY_COLOR = np.array([72, 86, 108, 255], dtype=np.float32) / 255.0 SLOT_VISUAL_COLOR = np.array([42, 48, 62, 255], dtype=np.float32) / 255.0 SOCKET_HALF_SIZE = np.array([0.048, 0.040, BATTERY_HALF_HEIGHT * 0.5], dtype=np.float32) SOCKET_SLOT_RADIUS = BATTERY_RADIUS * 1.6 SOCKET_COLOR = np.array([88, 92, 98, 255], dtype=np.float32) / 255.0 SOCKET_SLOT_COLOR = np.array([32, 32, 36, 255], dtype=np.float32) / 255.0 SOCKET_X_OFFSET_FROM_TRAY = 0.2 BUTTON_X_OFFSET_FROM_TRAY = 0.00 BUTTON_Y_OFFSET_FROM_TRAY = 0.24 BUTTON_BASE_HALF_SIZE = np.array([0.075, 0.075, 0.015], dtype=np.float32) BUTTON_CAP_RADIUS = 0.033 BUTTON_CAP_HALF_HEIGHT = 0.014 BUTTON_CAP_TRAVEL = BUTTON_CAP_HALF_HEIGHT BUTTON_PRESS_EVENT_RATIO = 0.35 BUTTON_RELEASE_READY_RATIO = 0.2 BUTTON_PRESS_XY_RADIUS = 0.075 BUTTON_PRESS_Z_MARGIN = 0.030 SOCKET_INSERT_XY_TOL = 0.010 SOCKET_INSERT_Z_TOL = 0.025 SLOT_RETURN_XY_TOL = 0.023 SLOT_RETURN_Z_TOL = 0.024 LAMP_BASE_RADIUS = 0.018 LAMP_BASE_HALF_HEIGHT = 0.008 LAMP_STEM_RADIUS = 0.004 LAMP_STEM_HALF_HEIGHT = 0.020 LAMP_BULB_RADIUS = 0.012 LAMP_X_OFFSET_FROM_SOCKET = 0.08 LAMP_Y_OFFSET_FROM_SOCKET = 0.0 LAMP_HEIGHT = 0.0 LAMP_OFF_COLOR = np.array([255, 255, 255, 255], dtype=np.float32) / 255.0 LAMP_ON_COLOR = np.array([255, 236, 110, 255], dtype=np.float32) / 255.0 LAMP_AFTERGLOW_STEPS = 7 STAGE_INSERT = 0 STAGE_RETURN = 1 STAGE_CONFIRM = 2 ACTION_L2_COEF = 0.01 ACTION_DELTA_L2_COEF = 0.02 QVEL_L2_COEF = 0.01 SUCCESS_BONUS = 40.0 def __init__( self, *args, robot_uids="panda_wristcam", robot_init_qpos_noise=0.02, **kwargs, ): self.robot_init_qpos_noise = robot_init_qpos_noise super().__init__(*args, robot_uids=robot_uids, **kwargs) @property def _default_sim_config(self): return SimConfig( gpu_memory_config=GPUMemoryConfig( found_lost_pairs_capacity=2**25, max_rigid_patch_count=2**20, max_rigid_contact_count=2**21, ) ) @property def _default_sensor_configs(self): pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1]) return [CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)] @property def _default_human_render_camera_configs(self): pose = sapien_utils.look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.15]) return CameraConfig("render_camera", pose, 512, 512, 1, 0.01, 100) def _load_agent(self, options: dict): super()._load_agent(options, sapien.Pose(p=[-0.615, 0, 0])) def _build_tray(self, initial_pose: sapien.Pose): tray_half_x = (self.TRAY_ROWS - 1) * self.SLOT_SPACING_X * 0.5 + self.TRAY_PADDING_X tray_half_y = (self.TRAY_COLS - 1) * self.SLOT_SPACING_Y * 0.5 + self.TRAY_PADDING_Y builder = self.scene.create_actor_builder() builder.add_box_visual( half_size=[tray_half_x, tray_half_y, self.TRAY_HALF_HEIGHT], material=sapien.render.RenderMaterial(base_color=self.TRAY_COLOR), ) for r in range(self.TRAY_ROWS): for c in range(self.TRAY_COLS): local_x = (r - (self.TRAY_ROWS - 1) * 0.5) * self.SLOT_SPACING_X local_y = (c - (self.TRAY_COLS - 1) * 0.5) * self.SLOT_SPACING_Y slot_pose = sapien.Pose( p=[local_x, local_y, self.TRAY_HALF_HEIGHT + 0.0015], q=euler2quat(0, np.pi / 2, 0), ) builder.add_cylinder_visual( pose=slot_pose, radius=self.BATTERY_RADIUS * 1.2, half_length=0.0015, material=sapien.render.RenderMaterial(base_color=self.SLOT_VISUAL_COLOR), ) return _build_by_type( builder, name="battery_tray", body_type="kinematic", initial_pose=initial_pose, ) def _build_socket_box(self, initial_pose: sapien.Pose): builder = self.scene.create_actor_builder() outer_hx, outer_hy, outer_hz = [float(v) for v in self.SOCKET_HALF_SIZE] inner_hx = min(self.SOCKET_SLOT_RADIUS * 1.15, outer_hx - 0.004) inner_hy = min(self.SOCKET_SLOT_RADIUS * 1.15, outer_hy - 0.004) inner_hx = max(inner_hx, 0.003) inner_hy = max(inner_hy, 0.003) bottom_hz = max(0.004, outer_hz * 0.35) builder.add_box_collision( half_size=[outer_hx, outer_hy, bottom_hz], pose=sapien.Pose(p=[0.0, 0.0, -outer_hz + bottom_hz]), ) wall_x_hx = max(0.002, 0.5 * (outer_hx - inner_hx)) wall_y_hy = max(0.002, 0.5 * (outer_hy - inner_hy)) wall_hz = outer_hz builder.add_box_collision( half_size=[wall_x_hx, outer_hy, wall_hz], pose=sapien.Pose(p=[inner_hx + wall_x_hx, 0.0, 0.0]), ) builder.add_box_collision( half_size=[wall_x_hx, outer_hy, wall_hz], pose=sapien.Pose(p=[-(inner_hx + wall_x_hx), 0.0, 0.0]), ) builder.add_box_collision( half_size=[inner_hx, wall_y_hy, wall_hz], pose=sapien.Pose(p=[0.0, inner_hy + wall_y_hy, 0.0]), ) builder.add_box_collision( half_size=[inner_hx, wall_y_hy, wall_hz], pose=sapien.Pose(p=[0.0, -(inner_hy + wall_y_hy), 0.0]), ) builder.add_box_visual( half_size=self.SOCKET_HALF_SIZE, material=sapien.render.RenderMaterial(base_color=self.SOCKET_COLOR), ) slot_visual_pose = sapien.Pose( p=[0.0, 0.0, self.SOCKET_HALF_SIZE[2] + 0.0015], q=euler2quat(0, np.pi / 2, 0), ) builder.add_cylinder_visual( pose=slot_visual_pose, radius=self.SOCKET_SLOT_RADIUS, half_length=0.0015, material=sapien.render.RenderMaterial(base_color=self.SOCKET_SLOT_COLOR), ) return _build_by_type( builder, name="battery_socket_box", body_type="kinematic", initial_pose=initial_pose, ) def _build_button(self, initial_pose: sapien.Pose): base_builder = self.scene.create_actor_builder() base_builder.add_box_collision(half_size=self.BUTTON_BASE_HALF_SIZE) base_builder.add_box_visual( half_size=self.BUTTON_BASE_HALF_SIZE, material=sapien.render.RenderMaterial(base_color=np.array([58, 66, 80, 255]) / 255.0), ) button_base = _build_by_type( base_builder, name="battery_button_base", body_type="kinematic", initial_pose=initial_pose, ) cap_builder = self.scene.create_actor_builder() cap_builder.add_cylinder_collision(radius=self.BUTTON_CAP_RADIUS, half_length=self.BUTTON_CAP_HALF_HEIGHT) cap_builder.add_cylinder_visual( radius=self.BUTTON_CAP_RADIUS, half_length=self.BUTTON_CAP_HALF_HEIGHT, material=sapien.render.RenderMaterial(base_color=np.array([205, 110, 110, 255]) / 255.0), ) button_cap = _build_by_type( cap_builder, name="battery_button_cap", body_type="kinematic", initial_pose=initial_pose, ) return button_base, button_cap def _load_scene(self, options: dict): self.table_scene = TableSceneBuilder(self, robot_init_qpos_noise=self.robot_init_qpos_noise) self.table_scene.build() hidden_pose = sapien.Pose(p=[0.0, 0.0, self.HEIGHT_OFFSET]) self.tray = self._build_tray(hidden_pose) self.socket_box = self._build_socket_box(hidden_pose) self.button_base, self.button_cap = self._build_button(hidden_pose) lamp_parts = shapes.build_color_switch_lamp( scene=self.scene, name="battery_checker_lamp", body_type="kinematic", add_collision=False, initial_pose=hidden_pose, base_radius=self.LAMP_BASE_RADIUS, base_half_height=self.LAMP_BASE_HALF_HEIGHT, stem_radius=self.LAMP_STEM_RADIUS, stem_half_height=self.LAMP_STEM_HALF_HEIGHT, bulb_radius=self.LAMP_BULB_RADIUS, bulb_off_color=self.LAMP_OFF_COLOR, bulb_on_color=self.LAMP_ON_COLOR, ) self.lamp_body = lamp_parts["body"] self.lamp_bulb_off = lamp_parts["bulb_off"] self.lamp_bulb_on = lamp_parts["bulb_on"] shapes._set_actor_visual_rgba(self.lamp_bulb_on, self.LAMP_ON_COLOR, emission_scale=3.0, remove_textures=True) battery_mat = sapien.physx.PhysxMaterial( static_friction=self.BATTERY_STATIC_FRICTION, dynamic_friction=self.BATTERY_DYNAMIC_FRICTION, restitution=self.BATTERY_RESTITUTION, ) self.batteries = [] for i in range(self.NUM_BATTERIES): builder = self.scene.create_actor_builder() builder.add_cylinder_collision( radius=self.BATTERY_RADIUS, half_length=self.BATTERY_HALF_HEIGHT, material=battery_mat, ) builder.add_cylinder_visual( radius=self.BATTERY_RADIUS, half_length=self.BATTERY_HALF_HEIGHT, material=sapien.render.RenderMaterial(base_color=self.BATTERY_COLOR), ) battery = _build_by_type( builder, name=f"battery_{i}", body_type="dynamic", initial_pose=hidden_pose, ) self.batteries.append(battery) n = self.num_envs d = self.device self.action_stage = torch.zeros(n, dtype=torch.int64, device=d) self.active_battery_idx = torch.full((n,), -1, dtype=torch.int64, device=d) self.working_mask = torch.zeros((n, self.NUM_BATTERIES), dtype=torch.bool, device=d) self.checked_mask = torch.zeros((n, self.NUM_BATTERIES), dtype=torch.bool, device=d) self.found_working_mask = torch.zeros((n, self.NUM_BATTERIES), dtype=torch.bool, device=d) self.active_mask = torch.zeros((n, self.NUM_BATTERIES), dtype=torch.bool, device=d) self.active_battery_count = torch.zeros(n, dtype=torch.int64, device=d) self.target_working_count = torch.zeros(n, dtype=torch.int64, device=d) self.checked_count = torch.zeros(n, dtype=torch.int64, device=d) self.found_working_count = torch.zeros(n, dtype=torch.int64, device=d) oracle_width = int(np.clip(self.WORKING_BATTERY_COUNT, 1, self.NUM_BATTERIES)) self.oracle_info = torch.zeros((n, oracle_width), dtype=torch.uint8, device=d) self.new_insert_event = torch.zeros(n, dtype=torch.bool, device=d) self.new_return_event = torch.zeros(n, dtype=torch.bool, device=d) self.new_confirm_event = torch.zeros(n, dtype=torch.bool, device=d) self.new_button_press_event = torch.zeros(n, dtype=torch.bool, device=d) self.tray_slot_positions = torch.zeros((n, self.NUM_BATTERIES, 3), device=d) self.socket_slot_pos = torch.zeros((n, 3), device=d) self.lamp_on_pos = torch.zeros((n, 3), device=d) self.lamp_off_pos = torch.zeros((n, 3), device=d) self.button_xy = torch.zeros((n, 2), device=d) self.button_cap_unpressed_z = torch.zeros(n, device=d) self.button_top_z = torch.zeros(n, device=d) self.button_press_depth = torch.zeros(n, device=d) self.lamp_afterglow_steps = torch.zeros(n, dtype=torch.int64, device=d) self.press_ready = torch.ones(n, dtype=torch.bool, device=d) self.button_cap_quat = torch.tensor(euler2quat(0, np.pi / 2, 0), dtype=torch.float32, device=d) self.battery_quat = torch.tensor(euler2quat(0, np.pi / 2, 0), dtype=torch.float32, device=d) def _initialize_episode(self, env_idx: torch.Tensor, options: dict): with torch.device(self.device): b = len(env_idx) env_idx = env_idx.to(self.device) self.table_scene.initialize(env_idx) self.task_cue = None self.reward_dict = None tray_xyz = torch.zeros((b, 3), device=self.device) tray_xyz[:, 0] = torch.rand((b,), device=self.device) * 0.08 - 0.09 tray_xyz[:, 1] = torch.rand((b,), device=self.device) * 0.10 - 0.05 tray_xyz[:, 2] = self.TRAY_HALF_HEIGHT tray_q = torch.tensor(euler2quat(0, 0, np.pi / 2), dtype=torch.float32, device=self.device).repeat(b, 1) self.tray.set_pose(Pose.create_from_pq(p=tray_xyz, q=tray_q)) row_offsets = ( torch.arange(self.TRAY_ROWS, device=self.device, dtype=torch.float32) - (self.TRAY_ROWS - 1) * 0.5 ) * self.SLOT_SPACING_X col_offsets = ( torch.arange(self.TRAY_COLS, device=self.device, dtype=torch.float32) - (self.TRAY_COLS - 1) * 0.5 ) * self.SLOT_SPACING_Y rr, cc = torch.meshgrid(row_offsets, col_offsets, indexing="ij") local_xy = torch.stack([rr.reshape(-1), cc.reshape(-1)], dim=1) tray_slot_pos = torch.zeros((b, self.NUM_BATTERIES, 3), device=self.device) tray_slot_pos[:, :, 0] = tray_xyz[:, 0:1] - local_xy[:, 1].unsqueeze(0) tray_slot_pos[:, :, 1] = tray_xyz[:, 1 : 1 + 1] + local_xy[:, 0].unsqueeze(0) tray_slot_pos[:, :, 2] = self.TRAY_HALF_HEIGHT * 2.0 self.tray_slot_positions[env_idx] = tray_slot_pos active_count = int(np.clip(self.ACTIVE_BATTERY_COUNT, 1, self.NUM_BATTERIES)) working_count = int(np.clip(self.WORKING_BATTERY_COUNT, 1, active_count)) active_mask_local = torch.zeros((b, self.NUM_BATTERIES), dtype=torch.bool, device=self.device) active_mask_local[:, :active_count] = True self.active_mask[env_idx] = active_mask_local self.active_battery_count[env_idx] = active_count self.target_working_count[env_idx] = working_count hidden_pos = torch.zeros((b, 3), device=self.device) hidden_pos[:, 2] = self.HEIGHT_OFFSET for i, battery in enumerate(self.batteries): pos_i = tray_slot_pos[:, i, :] present_i = active_mask_local[:, i] pos_i = torch.where(present_i.unsqueeze(-1), pos_i, hidden_pos) q_i = self.battery_quat.unsqueeze(0).repeat(b, 1) battery.set_pose(Pose.create_from_pq(p=pos_i, q=q_i)) battery.set_linear_velocity(torch.zeros((b, 3), device=self.device)) battery.set_angular_velocity(torch.zeros((b, 3), device=self.device)) socket_xyz = torch.zeros((b, 3), device=self.device) socket_xyz[:, 0] = tray_xyz[:, 0] + self.SOCKET_X_OFFSET_FROM_TRAY socket_xyz[:, 1] = tray_xyz[:, 1] socket_xyz[:, 2] = float(self.SOCKET_HALF_SIZE[2]) self.socket_box.set_pose(Pose.create_from_pq(p=socket_xyz, q=tray_q)) socket_slot_pos = socket_xyz.clone() socket_slot_pos[:, 2] = float(self.SOCKET_HALF_SIZE[2]) * 2.0 self.socket_slot_pos[env_idx] = socket_slot_pos lamp_pos = socket_xyz.clone() lamp_pos[:, 0] += self.LAMP_X_OFFSET_FROM_SOCKET lamp_pos[:, 1] += self.LAMP_Y_OFFSET_FROM_SOCKET lamp_pos[:, 2] = self.LAMP_HEIGHT lamp_off = lamp_pos.clone() lamp_off[:, 2] += self.HEIGHT_OFFSET self.lamp_body.set_pose(Pose.create_from_pq(p=lamp_pos, q=tray_q)) self.lamp_bulb_off.set_pose(Pose.create_from_pq(p=lamp_pos, q=tray_q)) self.lamp_bulb_on.set_pose(Pose.create_from_pq(p=lamp_off, q=tray_q)) self.lamp_on_pos[env_idx] = lamp_pos self.lamp_off_pos[env_idx] = lamp_off button_xyz = torch.zeros((b, 3), device=self.device) button_xyz[:, 0] = tray_xyz[:, 0] + self.BUTTON_X_OFFSET_FROM_TRAY button_xyz[:, 1] = tray_xyz[:, 1] + self.BUTTON_Y_OFFSET_FROM_TRAY button_xyz[:, 2] = float(self.BUTTON_BASE_HALF_SIZE[2]) self.button_base.set_pose(Pose.create_from_pq(p=button_xyz, q=tray_q)) unpressed_z = float(self.BUTTON_BASE_HALF_SIZE[2]) * 2.0 + self.BUTTON_CAP_HALF_HEIGHT cap_xyz = button_xyz.clone() cap_xyz[:, 2] = unpressed_z cap_q = self.button_cap_quat.unsqueeze(0).repeat(b, 1) self.button_cap.set_pose(Pose.create_from_pq(p=cap_xyz, q=cap_q)) self.button_xy[env_idx] = button_xyz[:, :2] self.button_cap_unpressed_z[env_idx] = unpressed_z self.button_top_z[env_idx] = unpressed_z + self.BUTTON_CAP_HALF_HEIGHT self.button_press_depth[env_idx] = 0.0 self.lamp_afterglow_steps[env_idx] = 0 self.action_stage[env_idx] = self.STAGE_INSERT self.active_battery_idx[env_idx] = -1 self.checked_mask[env_idx] = ~active_mask_local self.found_working_mask[env_idx] = False self.checked_count[env_idx] = 0 self.found_working_count[env_idx] = 0 self.press_ready[env_idx] = True rand_scores = torch.rand((b, self.NUM_BATTERIES), device=self.device) rand_scores = torch.where( active_mask_local, rand_scores, torch.full_like(rand_scores, -1.0), ) working_idx = torch.topk(rand_scores, k=working_count, dim=1, largest=True, sorted=False).indices working_mask_local = torch.zeros((b, self.NUM_BATTERIES), dtype=torch.bool, device=self.device) working_mask_local.scatter_(1, working_idx, True) self.working_mask[env_idx] = working_mask_local self.oracle_info[env_idx] = 255 self.oracle_info[env_idx, :working_count] = working_idx.to(torch.uint8) self.new_insert_event[env_idx] = False self.new_return_event[env_idx] = False self.new_confirm_event[env_idx] = False self.new_button_press_event[env_idx] = False if self.robot_uids in ("panda", "panda_wristcam"): qpos = np.array([0.0, 0, 0, -np.pi * 2 / 3, 0, np.pi * 2 / 3, np.pi / 4, 0.04, 0.04]) qpos[:-2] += self._episode_rng.normal(0, self.robot_init_qpos_noise, len(qpos) - 2) self.agent.reset(qpos) self.agent.robot.set_root_pose(sapien.Pose([-0.615, 0, 0])) else: raise NotImplementedError(self.robot_uids) if hasattr(self, "_prev_action") and torch.is_tensor(self._prev_action): if self._prev_action.shape[0] >= int(env_idx.max().item()) + 1: self._prev_action[env_idx] = 0 def _stack_battery_poses(self): pos = torch.stack([bat.pose.p for bat in self.batteries], dim=1) raw_pose = torch.stack([bat.pose.raw_pose for bat in self.batteries], dim=1) grasp = torch.stack([self.agent.is_grasping(bat) for bat in self.batteries], dim=1) return pos, raw_pose, grasp
[docs] def evaluate(self): cue_mask = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) action_mask = torch.ones(self.num_envs, dtype=torch.bool, device=self.device) self.new_insert_event[:] = False self.new_return_event[:] = False self.new_confirm_event[:] = False battery_pos, battery_raw_pose, grasp_mask = self._stack_battery_poses() socket_xy = self.socket_slot_pos[:, :2] socket_z = self.socket_slot_pos[:, 2] dist_xy_to_socket = torch.linalg.norm(battery_pos[:, :, :2] - socket_xy.unsqueeze(1), dim=2) dist_z_to_socket = torch.abs(battery_pos[:, :, 2] - socket_z.unsqueeze(1)) in_socket = ( (dist_xy_to_socket <= self.SOCKET_INSERT_XY_TOL) & (dist_z_to_socket <= self.SOCKET_INSERT_Z_TOL) & (~grasp_mask) & self.active_mask ) socket_has_battery = in_socket.any(dim=1) socket_battery_idx = torch.argmax(in_socket.to(torch.int64), dim=1) socket_has_working = (in_socket & self.working_mask).any(dim=1) self.lamp_afterglow_steps = torch.clamp(self.lamp_afterglow_steps - action_mask.to(torch.int64), min=0) stage_insert_mask = action_mask & (self.action_stage == self.STAGE_INSERT) insert_events = stage_insert_mask & socket_has_battery self.new_insert_event = insert_events if insert_events.any(): ins_idx = torch.where(insert_events)[0] self.action_stage[ins_idx] = self.STAGE_RETURN self.active_battery_idx[ins_idx] = socket_battery_idx[ins_idx] working_ins_idx = ins_idx[socket_has_working[ins_idx]] if working_ins_idx.numel() > 0: self.lamp_afterglow_steps[working_ins_idx] = self.LAMP_AFTERGLOW_STEPS # Snap inserted battery to exact socket center so extraction is reliable. for ei in ins_idx.tolist(): bat_id = int(socket_battery_idx[ei].item()) raw_pose = self.batteries[bat_id].pose.raw_pose.clone() raw_pose[ei, 0] = self.socket_slot_pos[ei, 0] raw_pose[ei, 1] = self.socket_slot_pos[ei, 1] raw_pose[ei, 2] = self.socket_slot_pos[ei, 2] raw_pose[ei, 3:7] = self.battery_quat self.batteries[bat_id].pose = raw_pose zero_vel = torch.zeros((self.num_envs, 3), device=self.device) self.batteries[bat_id].set_linear_velocity(zero_vel) self.batteries[bat_id].set_angular_velocity(zero_vel) lamp_on = action_mask & (socket_has_working | (self.lamp_afterglow_steps > 0)) lamp_off_pose = self.lamp_bulb_off.pose.raw_pose.clone() lamp_on_pose = self.lamp_bulb_on.pose.raw_pose.clone() lamp_off_pose[lamp_on, :3] = self.lamp_off_pos[lamp_on] lamp_off_pose[~lamp_on, :3] = self.lamp_on_pos[~lamp_on] lamp_on_pose[lamp_on, :3] = self.lamp_on_pos[lamp_on] lamp_on_pose[~lamp_on, :3] = self.lamp_off_pos[~lamp_on] self.lamp_bulb_off.pose = lamp_off_pose self.lamp_bulb_on.pose = lamp_on_pose tcp_pos = self.agent.tcp.pose.p tcp_xy = tcp_pos[:, :2] tcp_z = tcp_pos[:, 2] button_xy_dist = torch.linalg.norm(tcp_xy - self.button_xy, dim=1) raw_depth = self.button_top_z + self.BUTTON_PRESS_Z_MARGIN - tcp_z depth = torch.clamp(raw_depth, min=0.0, max=self.BUTTON_CAP_TRAVEL) depth = depth * (button_xy_dist <= self.BUTTON_PRESS_XY_RADIUS).float() self.button_press_depth = depth cap_pose = self.button_cap.pose.raw_pose.clone() cap_pose[:, 0:2] = self.button_xy cap_pose[:, 2] = self.button_cap_unpressed_z - depth cap_pose[:, 3:7] = self.button_cap_quat.repeat(cap_pose.shape[0], 1) self.button_cap.pose = cap_pose pressed = depth >= (self.BUTTON_CAP_TRAVEL * self.BUTTON_PRESS_EVENT_RATIO) released = depth <= (self.BUTTON_CAP_TRAVEL * self.BUTTON_RELEASE_READY_RATIO) self.press_ready = self.press_ready | ((~self.press_ready) & released & action_mask) self.new_button_press_event = pressed & self.press_ready & action_mask self.press_ready = self.press_ready & (~self.new_button_press_event) stage_return_mask = action_mask & (self.action_stage == self.STAGE_RETURN) if stage_return_mask.any(): r_idx = torch.where(stage_return_mask)[0] active_idx = self.active_battery_idx[r_idx] valid = active_idx >= 0 if valid.any(): r_idx = r_idx[valid] active_idx = active_idx[valid] cur_pos = battery_pos[r_idx, active_idx] home_pos = self.tray_slot_positions[r_idx, active_idx] return_xy = torch.linalg.norm(cur_pos[:, :2] - home_pos[:, :2], dim=1) return_z = torch.abs(cur_pos[:, 2] - home_pos[:, 2]) not_grasped = ~grasp_mask[r_idx, active_idx] not_in_socket = ~in_socket[r_idx, active_idx] returned = ( (return_xy <= self.SLOT_RETURN_XY_TOL) & (return_z <= self.SLOT_RETURN_Z_TOL) & not_grasped & not_in_socket ) if returned.any(): returned_idx = r_idx[returned] self.new_return_event[returned_idx] = True self.action_stage[returned_idx] = self.STAGE_CONFIRM stage_confirm_mask = action_mask & (self.action_stage == self.STAGE_CONFIRM) confirm_events = stage_confirm_mask & self.new_button_press_event self.new_confirm_event = confirm_events if confirm_events.any(): c_idx = torch.where(confirm_events)[0] active_idx = self.active_battery_idx[c_idx] valid = active_idx >= 0 if valid.any(): c_idx = c_idx[valid] active_idx = active_idx[valid] already_checked = self.checked_mask[c_idx, active_idx] newly_checked = ~already_checked self.checked_mask[c_idx, active_idx] = True self.found_working_mask[c_idx, active_idx] = self.found_working_mask[c_idx, active_idx] | ( self.working_mask[c_idx, active_idx] & newly_checked ) self.checked_count[c_idx] = self.checked_count[c_idx] + newly_checked.to(torch.int64) self.found_working_count[c_idx] = torch.sum(self.found_working_mask[c_idx], dim=1).to(torch.int64) self.action_stage[c_idx] = self.STAGE_INSERT self.active_battery_idx[c_idx] = -1 self.found_working_count = torch.sum(self.found_working_mask, dim=1).to(torch.int64) # Success requires all checked batteries to be back near their home slots. # Use relaxed tolerances (3x stage transition thresholds) — battery is # clearly on the tray but may be slightly shifted after placement. SUCCESS_RETURN_XY_TOL = self.SLOT_RETURN_XY_TOL * 3.0 SUCCESS_RETURN_Z_TOL = self.SLOT_RETURN_Z_TOL * 3.0 all_returned = torch.ones(self.num_envs, dtype=torch.bool, device=self.device) for env_i in range(self.num_envs): # Only check active batteries (inactive ones are pre-marked checked but hidden). checked = self.checked_mask[env_i] & self.active_mask[env_i] if not checked.any(): continue checked_ids = torch.where(checked)[0] for bat_id in checked_ids.tolist(): bat_pos = self.batteries[bat_id].pose.p[env_i] home_pos = self.tray_slot_positions[env_i, bat_id] xy_dist = torch.linalg.norm(bat_pos[:2] - home_pos[:2]) z_dist = torch.abs(bat_pos[2] - home_pos[2]) if xy_dist > SUCCESS_RETURN_XY_TOL or z_dist > SUCCESS_RETURN_Z_TOL: all_returned[env_i] = False break all_checked = self.checked_count >= self.active_battery_count success = action_mask & (self.found_working_count >= self.target_working_count) & all_checked & all_returned self.obj_to_goal_pos = self.socket_slot_pos - self.agent.tcp.pose.p self.battery_raw_pose = battery_raw_pose self.socket_battery_idx = socket_battery_idx self.socket_has_working = socket_has_working return { "success": success, "action_mask": action_mask, "cue_mask": cue_mask, "action_stage": self.action_stage, "active_battery_idx": self.active_battery_idx, "socket_has_battery": socket_has_battery, "socket_battery_idx": socket_battery_idx, "socket_has_working": socket_has_working, "active_battery_count": self.active_battery_count, "checked_count": self.checked_count, "found_working_count": self.found_working_count, "target_working_count": self.target_working_count, "all_checked": all_checked, "new_insert_event": self.new_insert_event, "new_return_event": self.new_return_event, "new_confirm_event": self.new_confirm_event, "new_button_press_event": self.new_button_press_event, "button_press_depth": self.button_press_depth, "obj_to_goal_pos": self.obj_to_goal_pos, "task_cue": self.task_cue, "language_instruction": self.LANGUAGE_INSTRUCTION, "oracle_info": self.oracle_info, "reward_dict": self.reward_dict, }
def _get_obs_extra(self, info: Dict): obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose) if self._obs_mode in ["state", "state_dict"]: battery_pose_flat = self.battery_raw_pose.reshape(self.num_envs, -1) obs.update( battery_pose=battery_pose_flat, tray_pose=self.tray.pose.raw_pose, socket_pose=self.socket_box.pose.raw_pose, button_base_pose=self.button_base.pose.raw_pose, button_cap_pose=self.button_cap.pose.raw_pose, lamp_off_pose=self.lamp_bulb_off.pose.raw_pose, lamp_on_pose=self.lamp_bulb_on.pose.raw_pose, action_stage=info["action_stage"], action_mask=info["action_mask"], active_mask=self.active_mask, oracle_info=self.oracle_info, ) return obs
[docs] def step(self, action): obs, reward, terminated, truncated, info = super().step(action) if isinstance(info, dict) and "success" in info: success = info["success"] if torch.is_tensor(success): success = success.to(dtype=torch.bool) if torch.is_tensor(terminated): terminated = terminated.to(dtype=torch.bool) | success else: terminated = bool(terminated) or bool(success.any().item()) else: terminated = bool(terminated) or bool(success) return obs, reward, terminated, truncated, info
[docs] def compute_dense_reward(self, obs: Any, action: Array, info: Dict): tcp_pos = self.agent.tcp.pose.p stage_insert = (info["action_stage"] == self.STAGE_INSERT).float() stage_return = (info["action_stage"] == self.STAGE_RETURN).float() stage_confirm = (info["action_stage"] == self.STAGE_CONFIRM).float() action_mask_f = info["action_mask"].float() battery_pos = torch.stack([bat.pose.p for bat in self.batteries], dim=1) dist_tcp_to_battery = torch.linalg.norm(battery_pos - tcp_pos.unsqueeze(1), dim=2) unchecked_mask = (~self.checked_mask).float() unchecked_dist = dist_tcp_to_battery + (1.0 - unchecked_mask) * 1e6 min_unchecked_dist = torch.min(unchecked_dist, dim=1).values has_unchecked = (unchecked_mask.sum(dim=1) > 0).float() reach_unchecked_reward = (1 - torch.tanh(4.0 * min_unchecked_dist)) * has_unchecked grasp_mask = torch.stack([self.agent.is_grasping(bat) for bat in self.batteries], dim=1) has_grasp = grasp_mask.any(dim=1) grasp_idx = torch.argmax(grasp_mask.to(torch.int64), dim=1) grasped_battery_pos = battery_pos[torch.arange(self.num_envs, device=self.device), grasp_idx] grasped_to_socket_dist = torch.linalg.norm(grasped_battery_pos - self.socket_slot_pos, dim=1) carry_to_socket_reward = (1 - torch.tanh(4.0 * grasped_to_socket_dist)) * has_grasp.float() active_idx = torch.clamp(self.active_battery_idx, min=0) active_pos = battery_pos[torch.arange(self.num_envs, device=self.device), active_idx] active_home = self.tray_slot_positions[torch.arange(self.num_envs, device=self.device), active_idx] return_dist = torch.linalg.norm(active_pos - active_home, dim=1) return_reward = 1 - torch.tanh(6.0 * return_dist) tcp_to_button_dist = torch.linalg.norm(self.obj_to_goal_pos, dim=1) button_reach_reward = 1 - torch.tanh(8.0 * tcp_to_button_dist) button_press_reward = torch.clamp(info["button_press_depth"] / self.BUTTON_CAP_TRAVEL, min=0.0, max=1.0) wrong_phase_press = info["new_button_press_event"] & (~stage_confirm.bool()) newly_found_working = ( info["new_confirm_event"] & (self.active_battery_idx >= 0) & self.working_mask[ torch.arange(self.num_envs, device=self.device), torch.clamp(self.active_battery_idx, min=0), ] ) if not torch.is_tensor(action): action = torch.as_tensor(action, device=self.device) if not hasattr(self, "_prev_action") or self._prev_action is None or self._prev_action.shape != action.shape: self._prev_action = torch.zeros_like(action) delta_action = action - self._prev_action action_l2 = torch.linalg.norm(action, axis=1) delta_action_l2 = torch.linalg.norm(delta_action, axis=1) qvel_l2 = torch.linalg.norm(self.agent.robot.get_qvel()[..., :-2], axis=1) smooth_penalty = ( self.ACTION_L2_COEF * torch.tanh(2.0 * action_l2) + self.ACTION_DELTA_L2_COEF * torch.tanh(5.0 * delta_action_l2) + self.QVEL_L2_COEF * torch.tanh(2.0 * qvel_l2) ) reward = ( stage_insert * (1.2 * reach_unchecked_reward + 2.0 * carry_to_socket_reward) + stage_return * (3.0 * return_reward) + stage_confirm * (1.8 * button_reach_reward + 2.5 * button_press_reward) + 5.0 * info["new_insert_event"].float() + 6.0 * info["new_return_event"].float() + 7.0 * info["new_confirm_event"].float() + 12.0 * newly_found_working.float() - 2.0 * wrong_phase_press.float() - smooth_penalty ) reward = reward * action_mask_f reward[info["success"]] = self.SUCCESS_BONUS self.reward_dict = { "stage_insert": stage_insert, "stage_return": stage_return, "stage_confirm": stage_confirm, "reach_unchecked_reward": reach_unchecked_reward, "carry_to_socket_reward": carry_to_socket_reward, "return_reward": return_reward, "button_reach_reward": button_reach_reward, "button_press_reward": button_press_reward, "new_insert_event": info["new_insert_event"].float(), "new_return_event": info["new_return_event"].float(), "new_confirm_event": info["new_confirm_event"].float(), "newly_found_working": newly_found_working.float(), "wrong_phase_press": wrong_phase_press.float(), "checked_count": info["checked_count"].float(), "found_working_count": info["found_working_count"].float(), "target_working_count": info["target_working_count"].float(), "action_l2": action_l2, "delta_action_l2": delta_action_l2, "qvel_l2": qvel_l2, "smooth_penalty": smooth_penalty, } self._prev_action = action.detach() return reward
[docs] def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict): return self.compute_dense_reward(obs=obs, action=action, info=info) / self.SUCCESS_BONUS
[docs] @register_env("BatteriesCheckerHard-3-VLA-v0", max_episode_steps=1080) class BatteriesChecker3VLAEnv(BatteriesCheckerVLABaseEnv): ACTIVE_BATTERY_COUNT = 3 WORKING_BATTERY_COUNT = 1
[docs] @register_env("BatteriesCheckerHard-6-VLA-v0", max_episode_steps=2160) class BatteriesChecker6VLAEnv(BatteriesCheckerVLABaseEnv): ACTIVE_BATTERY_COUNT = 6 WORKING_BATTERY_COUNT = 3
[docs] @register_env("BatteriesCheckerHard-9-VLA-v0", max_episode_steps=3240) class BatteriesChecker9VLAEnv(BatteriesCheckerVLABaseEnv): ACTIVE_BATTERY_COUNT = 9 WORKING_BATTERY_COUNT = 5
[docs] @register_env("BatteriesCheckerHard-12-VLA-v0", max_episode_steps=4320) class BatteriesChecker12VLAEnv(BatteriesCheckerVLABaseEnv): ACTIVE_BATTERY_COUNT = 12 WORKING_BATTERY_COUNT = 7
[docs] @register_env("BatteriesCheckerHard-15-VLA-v0", max_episode_steps=4320) class BatteriesChecker15VLAEnv(BatteriesCheckerVLABaseEnv): ACTIVE_BATTERY_COUNT = 15 WORKING_BATTERY_COUNT = 9