Source code for mikasa_robo_suite.vla.memory_envs.trace_shape_seq_vla

"""Trace-shape-sequence procedural memory tasks for the VLA benchmark."""

from typing import Any, Dict, List, Union

import numpy as np
import sapien
import torch
from mani_skill.agents.robots.panda.panda import Panda
from mani_skill.agents.robots.panda.panda_wristcam import PandaWristCam
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import sapien_utils
from mani_skill.utils.building import actors
from mani_skill.utils.building.actors.common import _build_by_type
from mani_skill.utils.registration import register_env
from mani_skill.utils.scene_builder.table import TableSceneBuilder
from mani_skill.utils.structs.pose import Pose
from mani_skill.utils.structs.types import Array, GPUMemoryConfig, SimConfig
from transforms3d.euler import euler2quat

from mikasa_robo_suite.vla.utils import shapes


[docs] class TraceShapeSeqVLABaseEnv(BaseEnv): """Watch a sequence of red traces, then reproduce all traces in order. The robot observes multiple demonstrations in sequence. For each element, the red cube traces one shape (circle / square / triangle depending on the difficulty variant). During the action phase, the robot must reproduce the same sequence with the green cube. After finishing all traces, the robot must press the submit button. Success (`success=True`): - Every sequence element must be completed in order. - A sequence element is complete only when all its checkpoints are visited and the contour is closed (return near checkpoint[0]). - After all elements are complete, the robot must press the button. """ LANGUAGE_INSTRUCTION = ( "Watch the red cube trace a sequence of shapes. When the lamp turns green, " "pick up the green cube and trace the same sequence in order. " "After finishing all shapes, press the button to submit your answer." ) SUPPORTED_ROBOTS = ["panda", "panda_wristcam"] agent: Union[Panda, PandaWristCam] HEIGHT_OFFSET = 1000.0 SHAPE_CIRCLE = 0 SHAPE_SQUARE = 1 SHAPE_TRIANGLE = 2 AVAILABLE_SHAPES: List[int] = [0] MIN_SEQUENCE_LENGTH = 2 MAX_SEQUENCE_LENGTH = 5 NUM_WAYPOINTS = 64 NUM_CHECKPOINTS = 12 CHECKPOINT_THRESH = 0.035 PRE_DEMO_STEPS: List[int] = [3, 8] STEPS_PER_WAYPOINT = 1 CUBE_HALFSIZE = 0.02 SHAPE_RADIUS_RANGE = [0.078, 0.13] SHAPE_CENTER_X_RANGE = [-0.15, -0.05] SHAPE_CENTER_Y_RANGE = [-0.10, 0.10] GREEN_CUBE_OFFSET_X = -0.16 LAMP_BASE_RADIUS = 0.018 LAMP_BASE_HALF_HEIGHT = 0.008 LAMP_STEM_RADIUS = 0.004 LAMP_STEM_HALF_HEIGHT = 0.020 LAMP_BULB_RADIUS = 0.012 LAMP_OFFSET_X = 0.25 BUTTON_BASE_HALF_SIZE = np.array([0.065, 0.065, 0.015], dtype=np.float32) BUTTON_CAP_RADIUS = 0.03 BUTTON_CAP_HALF_HEIGHT = 0.014 BUTTON_CAP_TRAVEL = BUTTON_CAP_HALF_HEIGHT BUTTON_PRESS_EVENT_RATIO = 0.35 BUTTON_RELEASE_READY_RATIO = 0.2 BUTTON_PRESS_XY_RADIUS = 0.065 BUTTON_PRESS_Z_MARGIN = 0.03 BUTTON_OFFSET_FROM_LAMP_X = 0.02 BUTTON_OFFSET_FROM_LAMP_Y = 0.16 SUCCESS_BONUS = 40.0 ACTION_L2_COEF = 0.01 ACTION_DELTA_L2_COEF = 0.03 QVEL_L2_COEF = 0.01 def __init__( self, *args, robot_uids="panda_wristcam", robot_init_qpos_noise=0.02, **kwargs, ): self.robot_init_qpos_noise = robot_init_qpos_noise super().__init__(*args, robot_uids=robot_uids, **kwargs) @property def _default_sim_config(self): return SimConfig( gpu_memory_config=GPUMemoryConfig( found_lost_pairs_capacity=2**25, max_rigid_patch_count=2**18, ) ) @property def _default_sensor_configs(self): pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1]) return [CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)] @property def _default_human_render_camera_configs(self): pose = sapien_utils.look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.15]) return CameraConfig("render_camera", pose, 512, 512, 1, 0.01, 100) def _load_agent(self, options: dict): super()._load_agent(options, sapien.Pose(p=[-0.615, 0, 0])) def _load_scene(self, options: dict): self.table_scene = TableSceneBuilder( self, robot_init_qpos_noise=self.robot_init_qpos_noise, ) self.table_scene.build() default_initial_pose = sapien.Pose(p=[0.0, 0.0, self.HEIGHT_OFFSET]) self.red_cube = actors.build_cube( self.scene, half_size=self.CUBE_HALFSIZE, color=np.array([220, 50, 50, 255]) / 255.0, name="red_cube", body_type="kinematic", initial_pose=default_initial_pose, ) self.green_cube = actors.build_cube( self.scene, half_size=self.CUBE_HALFSIZE, color=np.array([50, 220, 50, 255]) / 255.0, name="green_cube", body_type="dynamic", initial_pose=sapien.Pose(p=[0, 0, self.CUBE_HALFSIZE]), ) lamp_kw = dict( body_type="kinematic", add_collision=False, initial_pose=default_initial_pose, base_radius=self.LAMP_BASE_RADIUS, base_half_height=self.LAMP_BASE_HALF_HEIGHT, stem_radius=self.LAMP_STEM_RADIUS, stem_half_height=self.LAMP_STEM_HALF_HEIGHT, bulb_radius=self.LAMP_BULB_RADIUS, ) lp_white = shapes.build_color_switch_lamp( scene=self.scene, name="lamp_white", bulb_off_color=np.array([245, 245, 245, 255]) / 255.0, bulb_on_color=np.array([245, 245, 245, 255]) / 255.0, **lamp_kw, ) lp_red = shapes.build_color_switch_lamp( scene=self.scene, name="lamp_red", bulb_off_color=np.array([245, 245, 245, 255]) / 255.0, bulb_on_color=np.array([255, 0, 0, 255]) / 255.0, **lamp_kw, ) lp_green = shapes.build_color_switch_lamp( scene=self.scene, name="lamp_green", bulb_off_color=np.array([245, 245, 245, 255]) / 255.0, bulb_on_color=np.array([0, 255, 0, 255]) / 255.0, **lamp_kw, ) self.lamp_body = lp_white["body"] self.lamp_white = lp_white["bulb_off"] self.lamp_red = lp_red["bulb_on"] self.lamp_green = lp_green["bulb_on"] shapes._set_actor_visual_rgba( self.lamp_red, np.array([255, 0, 0, 255]) / 255.0, emission_scale=20.0, remove_textures=True, ) shapes._set_actor_visual_rgba( self.lamp_green, np.array([0, 255, 0, 255]) / 255.0, emission_scale=20.0, remove_textures=True, ) self._lamp_aux = [ lp_red["body"], lp_green["body"], lp_white["bulb_on"], lp_red["bulb_off"], lp_green["bulb_off"], ] base_builder = self.scene.create_actor_builder() base_builder.add_box_collision(half_size=self.BUTTON_BASE_HALF_SIZE) base_builder.add_box_visual( half_size=self.BUTTON_BASE_HALF_SIZE, material=sapien.render.RenderMaterial(base_color=np.array([55, 64, 78, 255]) / 255.0), ) self.button_base = _build_by_type( base_builder, name="trace_shape_seq_button_base", body_type="kinematic", initial_pose=default_initial_pose, ) cap_builder = self.scene.create_actor_builder() cap_builder.add_cylinder_collision(radius=self.BUTTON_CAP_RADIUS, half_length=self.BUTTON_CAP_HALF_HEIGHT) cap_builder.add_cylinder_visual( radius=self.BUTTON_CAP_RADIUS, half_length=self.BUTTON_CAP_HALF_HEIGHT, material=sapien.render.RenderMaterial(base_color=np.array([210, 80, 80, 255]) / 255.0), ) self.button_cap = _build_by_type( cap_builder, name="trace_shape_seq_button_cap", body_type="kinematic", initial_pose=default_initial_pose, ) n = self.num_envs d = self.device self.pre_demo_steps_per_env = torch.zeros(n, dtype=torch.int64, device=d) self.demo_steps_per_env = torch.zeros(n, dtype=torch.int64, device=d) self.cue_steps_per_env = torch.zeros(n, dtype=torch.int64, device=d) self.sequence_len = torch.zeros(n, dtype=torch.int64, device=d) self.shape_sequence = torch.full((n, self.MAX_SEQUENCE_LENGTH), -1, dtype=torch.int64, device=d) self.active_shape_idx = torch.zeros(n, dtype=torch.int64, device=d) self.waypoints = torch.zeros( n, self.MAX_SEQUENCE_LENGTH, self.NUM_WAYPOINTS, 2, dtype=torch.float32, device=d, ) self.checkpoints = torch.zeros( n, self.MAX_SEQUENCE_LENGTH, self.NUM_CHECKPOINTS, 2, dtype=torch.float32, device=d, ) self.checkpoint_visited = torch.zeros( n, self.MAX_SEQUENCE_LENGTH, self.NUM_CHECKPOINTS, dtype=torch.bool, device=d, ) self.shape_closed = torch.zeros(n, self.MAX_SEQUENCE_LENGTH, dtype=torch.bool, device=d) self.shape_center_xy = torch.zeros(n, self.MAX_SEQUENCE_LENGTH, 2, dtype=torch.float32, device=d) self.lamp_on_pos = torch.zeros(n, 3, dtype=torch.float32, device=d) self.lamp_off_pos = torch.zeros(n, 3, dtype=torch.float32, device=d) self.button_xy = torch.zeros((n, 2), dtype=torch.float32, device=d) self.button_base_z = torch.zeros(n, dtype=torch.float32, device=d) self.button_cap_unpressed_z = torch.zeros(n, dtype=torch.float32, device=d) self.button_top_z = torch.zeros(n, dtype=torch.float32, device=d) self.button_press_depth = torch.zeros(n, dtype=torch.float32, device=d) self.button_pressed = torch.zeros(n, dtype=torch.bool, device=d) self.button_pressable = torch.zeros(n, dtype=torch.bool, device=d) self.press_ready = torch.ones(n, dtype=torch.bool, device=d) self.new_submit_event = torch.zeros(n, dtype=torch.bool, device=d) self.submit_success_latched = torch.zeros(n, dtype=torch.bool, device=d) self.button_cap_quat = torch.tensor(euler2quat(0, np.pi / 2, 0), dtype=torch.float32, device=d) def _generate_waypoints(self, shape_type, center_xy, radius, rotation, b): n = self.NUM_WAYPOINTS device = self.device waypoints = torch.zeros(b, n, 2, device=device) t = torch.linspace(0, 1.0, n + 1, device=device)[:-1] cm = shape_type == self.SHAPE_CIRCLE if cm.any(): angles = t.unsqueeze(0) * 2 * np.pi + rotation[cm].unsqueeze(1) r = radius[cm].unsqueeze(1) waypoints[cm, :, 0] = center_xy[cm, 0:1] + r * torch.cos(angles) waypoints[cm, :, 1] = center_xy[cm, 1:2] + r * torch.sin(angles) sm = shape_type == self.SHAPE_SQUARE if sm.any(): b_sq = sm.sum().item() s = radius[sm].unsqueeze(1) rot = rotation[sm] lx = torch.zeros(b_sq, n, device=device) ly = torch.zeros(b_sq, n, device=device) for side in range(4): lo, hi = side * 0.25, (side + 1) * 0.25 mask = (t >= lo) & (t < hi) nm = mask.sum().item() frac = (t[mask] - lo) / 0.25 if side == 0: lx[:, mask] = -s.expand(-1, nm) + 2 * s * frac.unsqueeze(0) ly[:, mask] = (-s).expand(-1, nm) elif side == 1: lx[:, mask] = s.expand(-1, nm) ly[:, mask] = -s.expand(-1, nm) + 2 * s * frac.unsqueeze(0) elif side == 2: lx[:, mask] = s.expand(-1, nm) - 2 * s * frac.unsqueeze(0) ly[:, mask] = s.expand(-1, nm) else: lx[:, mask] = (-s).expand(-1, nm) ly[:, mask] = s.expand(-1, nm) - 2 * s * frac.unsqueeze(0) cos_r = torch.cos(rot).unsqueeze(1) sin_r = torch.sin(rot).unsqueeze(1) waypoints[sm, :, 0] = center_xy[sm, 0:1] + lx * cos_r - ly * sin_r waypoints[sm, :, 1] = center_xy[sm, 1:2] + lx * sin_r + ly * cos_r tm = shape_type == self.SHAPE_TRIANGLE if tm.any(): b_tr = tm.sum().item() r = radius[tm].unsqueeze(1) rot = rotation[tm] v_angles = torch.tensor([0, 2 * np.pi / 3, 4 * np.pi / 3], device=device) lx = torch.zeros(b_tr, n, device=device) ly = torch.zeros(b_tr, n, device=device) for side in range(3): lo = side / 3.0 hi = (side + 1) / 3.0 mask = (t >= lo) & (t < hi) frac = (t[mask] - lo) * 3.0 a0, a1 = v_angles[side], v_angles[(side + 1) % 3] x0 = r * float(np.cos(a0.item())) y0 = r * float(np.sin(a0.item())) x1 = r * float(np.cos(a1.item())) y1 = r * float(np.sin(a1.item())) lx[:, mask] = x0 + (x1 - x0) * frac.unsqueeze(0) ly[:, mask] = y0 + (y1 - y0) * frac.unsqueeze(0) cos_r = torch.cos(rot).unsqueeze(1) sin_r = torch.sin(rot).unsqueeze(1) waypoints[tm, :, 0] = center_xy[tm, 0:1] + lx * cos_r - ly * sin_r waypoints[tm, :, 1] = center_xy[tm, 1:2] + lx * sin_r + ly * cos_r return waypoints def _initialize_episode(self, env_idx: torch.Tensor, options: dict): with torch.device(self.device): b = len(env_idx) self.table_scene.initialize(env_idx) env_idx = env_idx.to(self.device) self.reward_dict = None shape_choices = torch.tensor(self.AVAILABLE_SHAPES, device=self.device, dtype=torch.int64) seq_len = torch.randint( self.MIN_SEQUENCE_LENGTH, self.MAX_SEQUENCE_LENGTH + 1, (b,), device=self.device, dtype=torch.int64, ) shape_seq = torch.full((b, self.MAX_SEQUENCE_LENGTH), -1, dtype=torch.int64, device=self.device) for i in range(b): l = int(seq_len[i].item()) pick_idx = torch.randint(0, len(self.AVAILABLE_SHAPES), (l,), device=self.device) shape_seq[i, :l] = shape_choices[pick_idx] self.sequence_len[env_idx] = seq_len self.shape_sequence[env_idx] = shape_seq self.active_shape_idx[env_idx] = 0 center_x = ( torch.rand(b, device=self.device) * (self.SHAPE_CENTER_X_RANGE[1] - self.SHAPE_CENTER_X_RANGE[0]) + self.SHAPE_CENTER_X_RANGE[0] ) center_y = ( torch.rand(b, device=self.device) * (self.SHAPE_CENTER_Y_RANGE[1] - self.SHAPE_CENTER_Y_RANGE[0]) + self.SHAPE_CENTER_Y_RANGE[0] ) center_xy = torch.stack([center_x, center_y], dim=-1).unsqueeze(1).repeat(1, self.MAX_SEQUENCE_LENGTH, 1) self.shape_center_xy[env_idx] = center_xy radius = ( torch.rand(b, self.MAX_SEQUENCE_LENGTH, device=self.device) * (self.SHAPE_RADIUS_RANGE[1] - self.SHAPE_RADIUS_RANGE[0]) + self.SHAPE_RADIUS_RANGE[0] ) rotation = torch.rand(b, self.MAX_SEQUENCE_LENGTH, device=self.device) * 2 * np.pi all_waypoints = torch.zeros( b, self.MAX_SEQUENCE_LENGTH, self.NUM_WAYPOINTS, 2, dtype=torch.float32, device=self.device, ) for s_idx in range(self.MAX_SEQUENCE_LENGTH): shape_for_gen = shape_seq[:, s_idx].clone() invalid = shape_for_gen < 0 shape_for_gen[invalid] = self.SHAPE_CIRCLE waypoints_s = self._generate_waypoints( shape_for_gen, center_xy[:, s_idx], radius[:, s_idx], rotation[:, s_idx], b, ) if invalid.any(): center_repeat = center_xy[:, s_idx][invalid].unsqueeze(1).repeat(1, self.NUM_WAYPOINTS, 1) waypoints_s[invalid] = center_repeat all_waypoints[:, s_idx] = waypoints_s self.waypoints[env_idx] = all_waypoints step = max(1, self.NUM_WAYPOINTS // self.NUM_CHECKPOINTS) cp_idx = torch.arange(0, self.NUM_WAYPOINTS, step, device=self.device)[: self.NUM_CHECKPOINTS] self.checkpoints[env_idx] = all_waypoints[:, :, cp_idx] self.checkpoint_visited[env_idx] = False self.shape_closed[env_idx] = False pre_demo = torch.randint( self.PRE_DEMO_STEPS[0], self.PRE_DEMO_STEPS[1] + 1, (b,), device=self.device, dtype=torch.int64, ) per_shape_demo_steps = self.NUM_WAYPOINTS * self.STEPS_PER_WAYPOINT demo_steps = seq_len * per_shape_demo_steps self.pre_demo_steps_per_env[env_idx] = pre_demo self.demo_steps_per_env[env_idx] = demo_steps self.cue_steps_per_env[env_idx] = pre_demo + demo_steps red_xyz = torch.zeros(b, 3, device=self.device) red_xyz[:, :2] = all_waypoints[:, 0, 0] red_xyz[:, 2] = self.CUBE_HALFSIZE self.red_cube.set_pose(Pose.create_from_pq(p=red_xyz, q=[1, 0, 0, 0])) green_xyz = torch.zeros(b, 3, device=self.device) green_xyz[:, 0] = center_x + self.GREEN_CUBE_OFFSET_X green_xyz[:, 1] = center_y green_xyz[:, 2] = self.CUBE_HALFSIZE self.green_cube.set_pose(Pose.create_from_pq(p=green_xyz, q=[1, 0, 0, 0])) lamp_pos = torch.zeros(b, 3, device=self.device) lamp_pos[:, 0] = center_x + self.LAMP_OFFSET_X lamp_pos[:, 1] = center_y lamp_pos[:, 2] = 0.0 lamp_off = lamp_pos.clone() lamp_off[:, 2] += self.HEIGHT_OFFSET self.lamp_on_pos[env_idx] = lamp_pos self.lamp_off_pos[env_idx] = lamp_off lamp_q = torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).repeat(b, 1) self.lamp_body.set_pose(Pose.create_from_pq(p=lamp_pos, q=lamp_q)) self.lamp_white.set_pose(Pose.create_from_pq(p=lamp_pos, q=lamp_q)) self.lamp_red.set_pose(Pose.create_from_pq(p=lamp_off, q=lamp_q)) self.lamp_green.set_pose(Pose.create_from_pq(p=lamp_off, q=lamp_q)) for aux in self._lamp_aux: aux.set_pose(Pose.create_from_pq(p=lamp_off, q=lamp_q)) button_xy = torch.zeros((b, 2), device=self.device) button_xy[:, 0] = lamp_pos[:, 0] + self.BUTTON_OFFSET_FROM_LAMP_X button_xy[:, 1] = lamp_pos[:, 1] + self.BUTTON_OFFSET_FROM_LAMP_Y base_z = torch.full((b,), float(self.BUTTON_BASE_HALF_SIZE[2]), device=self.device) cap_unpressed_z = torch.full( (b,), float(self.BUTTON_BASE_HALF_SIZE[2]) * 2.0 + self.BUTTON_CAP_HALF_HEIGHT, device=self.device, ) button_base_xyz = torch.zeros((b, 3), device=self.device) button_base_xyz[:, :2] = button_xy button_base_xyz[:, 2] = base_z button_cap_xyz = torch.zeros((b, 3), device=self.device) button_cap_xyz[:, :2] = button_xy button_cap_xyz[:, 2] = cap_unpressed_z button_base_q = torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).repeat(b, 1) button_cap_q = self.button_cap_quat.unsqueeze(0).repeat(b, 1) self.button_xy[env_idx] = button_xy self.button_base_z[env_idx] = base_z self.button_cap_unpressed_z[env_idx] = cap_unpressed_z self.button_top_z[env_idx] = cap_unpressed_z + self.BUTTON_CAP_HALF_HEIGHT self.button_press_depth[env_idx] = 0.0 self.button_pressed[env_idx] = False self.button_pressable[env_idx] = False self.press_ready[env_idx] = True self.new_submit_event[env_idx] = False self.submit_success_latched[env_idx] = False self.button_base.set_pose(Pose.create_from_pq(p=button_base_xyz, q=button_base_q)) self.button_cap.set_pose(Pose.create_from_pq(p=button_cap_xyz, q=button_cap_q)) self.oracle_info = self.shape_sequence.to(torch.int64) self.task_cue = self.shape_sequence.to(torch.int64) if self.robot_uids in ("panda", "panda_wristcam"): qpos = np.array([0.0, 0, 0, -np.pi * 2 / 3, 0, np.pi * 2 / 3, np.pi / 4, 0.04, 0.04]) qpos[:-2] += self._episode_rng.normal(0, self.robot_init_qpos_noise, len(qpos) - 2) self.agent.reset(qpos) self.agent.robot.set_root_pose(sapien.Pose([-0.615, 0, 0])) else: raise NotImplementedError(self.robot_uids) if hasattr(self, "_prev_action") and torch.is_tensor(self._prev_action): if self._prev_action.shape[0] >= int(env_idx.max().item()) + 1: self._prev_action[env_idx] = 0
[docs] def evaluate(self): elapsed = self.elapsed_steps.to(torch.int64) if elapsed.dim() > 1: elapsed = elapsed.squeeze(-1) pre_demo_mask = elapsed < self.pre_demo_steps_per_env demo_mask = (~pre_demo_mask) & (elapsed < self.cue_steps_per_env) action_mask = elapsed >= self.cue_steps_per_env for lamp_actor, on_mask in [ (self.lamp_white, pre_demo_mask), (self.lamp_red, demo_mask), (self.lamp_green, action_mask), ]: pose = lamp_actor.pose.raw_pose.clone() pose[on_mask, :3] = self.lamp_on_pos[on_mask] pose[~on_mask, :3] = self.lamp_off_pos[~on_mask] lamp_actor.pose = pose red_pose = self.red_cube.pose.raw_pose.clone() red_pose[pre_demo_mask, 0] = self.waypoints[pre_demo_mask, 0, 0, 0] red_pose[pre_demo_mask, 1] = self.waypoints[pre_demo_mask, 0, 0, 1] red_pose[pre_demo_mask, 2] = self.CUBE_HALFSIZE if demo_mask.any(): demo_elapsed = (elapsed[demo_mask] - self.pre_demo_steps_per_env[demo_mask]).clamp(min=0) per_shape_demo_steps = self.NUM_WAYPOINTS * self.STEPS_PER_WAYPOINT demo_shape_idx = demo_elapsed // per_shape_demo_steps max_demo_shape_idx = torch.clamp(self.sequence_len[demo_mask] - 1, min=0) demo_shape_idx = torch.minimum(demo_shape_idx, max_demo_shape_idx) wp_idx = ((demo_elapsed % per_shape_demo_steps) // self.STEPS_PER_WAYPOINT).clamp( max=self.NUM_WAYPOINTS - 1 ) batch_idx = torch.arange(self.waypoints.shape[0], device=self.device)[demo_mask] red_xy = self.waypoints[batch_idx, demo_shape_idx, wp_idx] red_pose[demo_mask, 0] = red_xy[:, 0] red_pose[demo_mask, 1] = red_xy[:, 1] red_pose[demo_mask, 2] = self.CUBE_HALFSIZE red_pose[action_mask, 2] = self.CUBE_HALFSIZE + self.HEIGHT_OFFSET self.red_cube.pose = red_pose green_xy = self.green_cube.pose.p[:, :2] if bool(action_mask.any().item()): action_env_ids = torch.where(action_mask)[0] for env_id_t in action_env_ids: env_id = int(env_id_t.item()) seq_len_i = int(self.sequence_len[env_id].item()) if seq_len_i <= 0: continue active_i = int(self.active_shape_idx[env_id].item()) if active_i < 0: active_i = 0 if active_i >= seq_len_i: continue cp = self.checkpoints[env_id, active_i] dist = torch.linalg.norm(green_xy[env_id].unsqueeze(0) - cp, dim=-1) visited = self.checkpoint_visited[env_id, active_i] | (dist < self.CHECKPOINT_THRESH) self.checkpoint_visited[env_id, active_i] = visited all_visited_i = bool(visited.all().item()) start_dist_i = float(dist[0].item()) if all_visited_i and (start_dist_i < self.CHECKPOINT_THRESH): self.shape_closed[env_id, active_i] = True if active_i + 1 < seq_len_i: self.active_shape_idx[env_id] = active_i + 1 shape_range = torch.arange(self.MAX_SEQUENCE_LENGTH, device=self.device).unsqueeze(0) valid_shape_mask = shape_range < self.sequence_len.unsqueeze(1) all_shapes_closed = torch.where(valid_shape_mask, self.shape_closed, torch.ones_like(self.shape_closed)).all( dim=1 ) closed_count = (self.shape_closed & valid_shape_mask).sum(dim=1).float() sequence_progress = closed_count / torch.clamp(self.sequence_len.float(), min=1.0) batch = torch.arange(self.num_envs, device=self.device) safe_active_idx = torch.clamp(self.active_shape_idx, min=0, max=self.MAX_SEQUENCE_LENGTH - 1) active_checkpoints = self.checkpoints[batch, safe_active_idx] active_visited = self.checkpoint_visited[batch, safe_active_idx] active_visit_fraction = active_visited.float().mean(dim=1) active_all_visited = active_visited.all(dim=1) start_checkpoint_dist = torch.linalg.norm(green_xy - active_checkpoints[:, 0], dim=-1) is_active_contour_closed = active_all_visited & (start_checkpoint_dist < self.CHECKPOINT_THRESH) tcp_pos = self.agent.tcp.pose.p tcp_xy = tcp_pos[:, :2] tcp_z = tcp_pos[:, 2] self.button_pressable = all_shapes_closed & action_mask xy_dist_to_button = torch.linalg.norm(tcp_xy - self.button_xy, dim=1) raw_depth = self.button_top_z + self.BUTTON_PRESS_Z_MARGIN - tcp_z depth = torch.clamp(raw_depth, min=0.0, max=self.BUTTON_CAP_TRAVEL) depth = depth * (xy_dist_to_button < self.BUTTON_PRESS_XY_RADIUS).float() depth = depth * action_mask.float() self.button_press_depth = depth cap_pose = self.button_cap.pose.raw_pose.clone() cap_pose[:, 0:2] = self.button_xy cap_pose[:, 2] = self.button_cap_unpressed_z - depth cap_pose[:, 3:7] = self.button_cap_quat.repeat(cap_pose.shape[0], 1) self.button_cap.pose = cap_pose base_pose = self.button_base.pose.raw_pose.clone() base_pose[:, 0:2] = self.button_xy base_pose[:, 2] = self.button_base_z base_pose[:, 3:7] = torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).repeat(base_pose.shape[0], 1) self.button_base.pose = base_pose pressed = depth >= (self.BUTTON_CAP_TRAVEL * self.BUTTON_PRESS_EVENT_RATIO) released = depth <= (self.BUTTON_CAP_TRAVEL * self.BUTTON_RELEASE_READY_RATIO) self.press_ready = self.press_ready | (released & action_mask) self.new_submit_event = pressed & self.press_ready & action_mask self.press_ready = self.press_ready & (~self.new_submit_event) self.button_pressed = self.button_pressed | self.new_submit_event submitted = self.new_submit_event successful_submit_event = submitted & all_shapes_closed self.submit_success_latched = self.submit_success_latched | successful_submit_event success = self.submit_success_latched failed_submit = submitted & (~all_shapes_closed) dist_to_active_cp = torch.linalg.norm(green_xy.unsqueeze(1) - active_checkpoints, dim=-1) dist_to_active_cp = dist_to_active_cp + active_visited.float() * 1000.0 nearest_idx = dist_to_active_cp.min(dim=1).indices nearest_cp_xy = active_checkpoints[batch, nearest_idx] target_pos = torch.zeros((self.num_envs, 3), dtype=torch.float32, device=self.device) target_pos[:, :2] = nearest_cp_xy target_pos[:, 2] = self.CUBE_HALFSIZE button_target_mask = all_shapes_closed & (~self.button_pressed) & action_mask target_pos[button_target_mask, :2] = self.button_xy[button_target_mask] target_pos[button_target_mask, 2] = self.button_top_z[button_target_mask] + 0.005 self.obj_to_goal_pos = target_pos - tcp_pos active_shape_id = self.shape_sequence[batch, safe_active_idx] return { "success": success, "failed_submit": failed_submit, "submitted": submitted, "successful_submit_event": successful_submit_event, "action_mask": action_mask, "demo_mask": demo_mask, "pre_demo_mask": pre_demo_mask, "sequence_progress": sequence_progress, "all_shapes_closed": all_shapes_closed, "active_shape_idx": safe_active_idx, "active_shape_id": active_shape_id, "active_visit_fraction": active_visit_fraction, "active_all_visited": active_all_visited, "is_active_contour_closed": is_active_contour_closed, "start_checkpoint_dist": start_checkpoint_dist, "button_pressed": self.button_pressed, "button_pressable": self.button_pressable, "button_press_depth": self.button_press_depth, "xy_dist_to_button": xy_dist_to_button, "submit_success_latched": self.submit_success_latched, "obj_to_goal_pos": self.obj_to_goal_pos, "task_cue": self.task_cue, "language_instruction": self.LANGUAGE_INSTRUCTION, "oracle_info": self.oracle_info, "reward_dict": self.reward_dict, }
def _get_obs_extra(self, info: Dict): obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose) if self._obs_mode in ["state", "state_dict"]: obs.update( red_cube_pose=self.red_cube.pose.raw_pose, green_cube_pose=self.green_cube.pose.raw_pose, action_mask=info["action_mask"], sequence_progress=info["sequence_progress"], active_shape_idx=info["active_shape_idx"], active_shape_id=info["active_shape_id"], sequence_len=self.sequence_len, button_xy=self.button_xy, button_pressed=self.button_pressed, oracle_info=self.oracle_info, ) return obs
[docs] def step(self, action): obs, reward, terminated, truncated, info = super().step(action) if isinstance(info, dict): submitted = info.get("submitted", None) success = info.get("success", None) if torch.is_tensor(terminated): term_bool = terminated.to(dtype=torch.bool) if torch.is_tensor(submitted): term_bool = term_bool | submitted.to(dtype=torch.bool) elif submitted is not None: term_bool = term_bool | bool(submitted) if torch.is_tensor(success): term_bool = term_bool | success.to(dtype=torch.bool) elif success is not None: term_bool = term_bool | bool(success) terminated = term_bool else: terminated = bool(terminated) or bool(submitted) or bool(success) return obs, reward, terminated, truncated, info
[docs] def compute_dense_reward(self, obs: Any, action: Array, info: Dict): tcp_pos = self.agent.tcp.pose.p green_pos = self.green_cube.pose.p tcp_to_cube_dist = torch.linalg.norm(tcp_pos - green_pos, dim=-1) reaching_reward = 1 - torch.tanh(5.0 * tcp_to_cube_dist) is_grasping = (tcp_to_cube_dist < 0.05).float() active_visit_fraction = info["active_visit_fraction"] sequence_progress = info["sequence_progress"] active_all_visited = info["active_all_visited"].float() start_checkpoint_dist = info["start_checkpoint_dist"] closure_reward = 1 - torch.tanh(5.0 * start_checkpoint_dist) button_target_mask = info["all_shapes_closed"] & (~info["button_pressed"]) & info["action_mask"] button_target_pos = torch.zeros_like(tcp_pos) button_target_pos[:, :2] = self.button_xy button_target_pos[:, 2] = self.button_top_z + 0.005 tcp_to_button_dist = torch.linalg.norm(button_target_pos - tcp_pos, dim=-1) button_reach_reward = 1 - torch.tanh(8.0 * tcp_to_button_dist) button_press_reward = torch.clamp(info["button_press_depth"] / self.BUTTON_CAP_TRAVEL, min=0.0, max=1.0) if not torch.is_tensor(action): action = torch.as_tensor(action, device=self.device) if not hasattr(self, "_prev_action") or self._prev_action is None or self._prev_action.shape != action.shape: self._prev_action = torch.zeros_like(action) delta_action = action - self._prev_action action_l2 = torch.linalg.norm(action, dim=-1) delta_action_l2 = torch.linalg.norm(delta_action, dim=-1) qvel_l2 = torch.linalg.norm(self.agent.robot.get_qvel()[..., :-2], dim=-1) smooth_penalty = ( self.ACTION_L2_COEF * torch.tanh(2.0 * action_l2) + self.ACTION_DELTA_L2_COEF * torch.tanh(5.0 * delta_action_l2) + self.QVEL_L2_COEF * torch.tanh(2.0 * qvel_l2) ) act_f = info["action_mask"].float() cue_f = 1.0 - act_f button_target_f = button_target_mask.float() reward = ( 0.5 * cue_f * reaching_reward + 1.0 * act_f * reaching_reward + 1.8 * act_f * is_grasping + 2.5 * act_f * active_visit_fraction + 3.0 * act_f * sequence_progress + 2.0 * act_f * active_all_visited * closure_reward + 1.5 * button_target_f * button_reach_reward + 2.5 * button_target_f * button_press_reward - smooth_penalty ) reward = torch.where( info["failed_submit"], torch.full_like(reward, -2.0), reward, ) reward[info["success"]] = self.SUCCESS_BONUS self.reward_dict = { "reaching_reward": reaching_reward, "is_grasping": is_grasping, "active_visit_fraction": active_visit_fraction, "sequence_progress": sequence_progress, "closure_reward": closure_reward, "button_reach_reward": button_reach_reward, "button_press_reward": button_press_reward, "smooth_penalty": smooth_penalty, "start_checkpoint_dist": start_checkpoint_dist, } self._prev_action = action.detach() return reward
[docs] def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict): return self.compute_dense_reward(obs=obs, action=action, info=info) / self.SUCCESS_BONUS
[docs] @register_env("TraceShapeSeqEasy-VLA-v0", max_episode_steps=1500) class TraceShapeSeqEasyVLAEnv(TraceShapeSeqVLABaseEnv): """Sequence with circles only.""" AVAILABLE_SHAPES: List[int] = [0]
[docs] @register_env("TraceShapeSeqMedium-VLA-v0", max_episode_steps=1500) class TraceShapeSeqMediumVLAEnv(TraceShapeSeqVLABaseEnv): """Sequence with circles and squares.""" AVAILABLE_SHAPES: List[int] = [0, 1]
[docs] @register_env("TraceShapeSeqHard-VLA-v0", max_episode_steps=1500) class TraceShapeSeqHardVLAEnv(TraceShapeSeqVLABaseEnv): """Sequence with circles, squares, and triangles.""" AVAILABLE_SHAPES: List[int] = [0, 1, 2]