Source code for mikasa_robo_suite.vla.memory_envs.shell_game_shuffle_touch_vla

"""Shell-game shuffle-and-touch tasks for the VLA benchmark."""

from typing import Any, Dict, List, Union

import numpy as np
import sapien
import torch
from mani_skill import ASSET_DIR
from mani_skill.agents.robots.panda.panda import Panda
from mani_skill.agents.robots.panda.panda_wristcam import PandaWristCam
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import sapien_utils
from mani_skill.utils.building import actors
from mani_skill.utils.io_utils import load_json
from mani_skill.utils.registration import register_env
from mani_skill.utils.scene_builder.table import TableSceneBuilder
from mani_skill.utils.structs.actor import Actor
from mani_skill.utils.structs.pose import Pose
from mani_skill.utils.structs.types import GPUMemoryConfig, SimConfig

WARNED_ONCE = False


[docs] class ShellGameShuffleTouchVLABaseEnv(BaseEnv): """Track one target cup through a shell-game shuffle. The robot first observes which cup hides the ball. The cups then swap places several times, and the robot must keep track of the target cup through the entire motion sequence before making its final selection. Episode flow: - The target cup is visible before the shuffle starts. - Cups swap positions multiple times. - The robot touches the cup it believes still hides the ball. Success (`success=True`): - The robot must touch the final cup position that contains the hidden ball. How to customize: - `CUE_PHASE_STEPS` changes the observation time before the shuffle begins. - `SHUFFLE_PHASE_STEPS` changes the overall duration of the shuffle. - `NUM_SWAPS` changes how many swaps the agent must track. - `SWAP_ARC_HEIGHT` changes the vertical arc used during swapping. - `MIN_DIST` changes spacing between cup slots. - `BALL_RADIUS` and `GOAL_THRESH` affect object geometry and touch tolerance. """ LANGUAGE_INSTRUCTION = ( "Observe which cup hides the ball, track the cups as they shuffle, then touch the correct cup." ) SUPPORTED_ROBOTS = ["panda", "panda_wristcam"] agent: Union[Panda, PandaWristCam] BALL_RADIUS = 0.02 MIN_DIST = 0.2 HEIGHT_OFFSET = 1000 MUG_SCALE = 1.3 GOAL_THRESH = 0.08 MUG_DISPLACEMENT_PENALTY_COEF = 0.1 MUG_DISPLACEMENT_SUCCESS_THRESH = 0.05 CUE_PHASE_STEPS: List[int] = [1, 5] SHUFFLE_PHASE_STEPS: List[int] = [20, 35] NUM_SWAPS: List[int] = [2, 4] SWAP_ARC_HEIGHT = 0.06 ACTION_L2_COEF = 0.0 ACTION_DELTA_L2_COEF = 0.0 QVEL_L2_COEF = 0.0 def __init__( self, *args, robot_uids="panda_wristcam", robot_init_qpos_noise=0.02, num_envs=1, reconfiguration_freq=None, **kwargs, ): self.robot_init_qpos_noise = robot_init_qpos_noise self.model_id = None self.all_model_ids = np.array(list(load_json(ASSET_DIR / "assets/mani_skill2_ycb/info_pick_v0.json").keys())) if reconfiguration_freq is None: reconfiguration_freq = 1 if num_envs == 1 else 0 super().__init__( *args, robot_uids=robot_uids, reconfiguration_freq=reconfiguration_freq, num_envs=num_envs, **kwargs, ) @property def _default_sim_config(self): return SimConfig( gpu_memory_config=GPUMemoryConfig( found_lost_pairs_capacity=2**25, max_rigid_patch_count=2**21, max_rigid_contact_count=2**22, ) ) @property def _default_sensor_configs(self): pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1]) return [CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)] @property def _default_human_render_camera_configs(self): pose = sapien_utils.look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.15]) return CameraConfig("render_camera", pose, 512, 512, 1, 0.01, 100) def _load_agent(self, options: dict): super()._load_agent(options, sapien.Pose(p=[-0.615, 0, 0])) def _initialize_mug(self, model_ids, id_cup, name_suffix): objs: List[Actor] = [] for i, _ in enumerate(model_ids): builder = actors.get_actor_builder(self.scene, id=f"ycb:{id_cup}") for record in builder.collision_records: if hasattr(record, "scale") and record.scale is not None: record.scale = (np.asarray(record.scale, dtype=np.float32) * self.MUG_SCALE).tolist() for record in builder.visual_records: if hasattr(record, "scale") and record.scale is not None: record.scale = (np.asarray(record.scale, dtype=np.float32) * self.MUG_SCALE).tolist() builder.initial_pose = sapien.Pose(p=[0, 0, 0]) builder.set_scene_idxs([i]) objs.append(builder.build(name=f"{id_cup}-{name_suffix}-{i}")) self.remove_from_state_dict_registry(objs[-1]) mug = Actor.merge(objs, name=f"mug_{name_suffix}") self.add_to_state_dict_registry(mug) return mug, objs def _load_scene(self, options: dict): global WARNED_ONCE self.table_scene = TableSceneBuilder( env=self, robot_init_qpos_noise=self.robot_init_qpos_noise, ) self.table_scene.build() model_ids = self._batched_episode_rng.choice(self.all_model_ids, replace=True) if ( self.num_envs > 1 and self.num_envs < len(self.all_model_ids) and self.reconfiguration_freq <= 0 and not WARNED_ONCE ): WARNED_ONCE = True print( "There are less parallel environments than total available models to sample. " "Not all models will be used during interaction even after resets unless you call " "env.reset(options=dict(reconfigure=True)) or set reconfiguration_freq >= 1." ) id_cup = "025_mug" self.mug_left, self._objs_1 = self._initialize_mug(model_ids, id_cup, "left") self.mug_center, self._objs_2 = self._initialize_mug(model_ids, id_cup, "center") self.mug_right, self._objs_3 = self._initialize_mug(model_ids, id_cup, "right") self.red_ball = actors.build_sphere( self.scene, radius=self.BALL_RADIUS, color=np.array([255, 0, 0, 255]) / 255, name="red_ball", body_type="dynamic", initial_pose=sapien.Pose(p=[0, 0, self.BALL_RADIUS]), ) def _after_reconfigure(self, options: dict): num_objects = len(self._objs_1) + len(self._objs_2) + len(self._objs_3) self.object_zs = torch.empty(num_objects, device=self.device) for idx, obj in enumerate((*self._objs_1, *self._objs_2, *self._objs_3)): collision_mesh = obj.get_first_collision_mesh() self.object_zs[idx] = -collision_mesh.bounding_box.bounds[0, 2] def _ensure_phase_buffers(self, env_idx: torch.Tensor): n = int(env_idx.max().item()) + 1 if not hasattr(self, "cue_steps_per_env") or self.cue_steps_per_env is None: self.cue_steps_per_env = torch.zeros(n, dtype=torch.int64, device=self.device) self.empty_steps_per_env = torch.zeros(n, dtype=torch.int64, device=self.device) self.shuffle_steps_per_env = torch.zeros(n, dtype=torch.int64, device=self.device) self.num_swaps_per_env = torch.zeros(n, dtype=torch.int64, device=self.device) self.steps_per_swap_per_env = torch.zeros(n, dtype=torch.int64, device=self.device) self.swap_pairs = torch.zeros(n, self.NUM_SWAPS[1], 2, dtype=torch.long, device=self.device) self.slot_of_mug = torch.zeros(n, self.NUM_SWAPS[1] + 1, 3, dtype=torch.long, device=self.device) self.slot_positions = torch.zeros(n, 3, 3, device=self.device) return cur = self.cue_steps_per_env.shape[0] if n > cur: p = n - cur def z(*s, **kw): return torch.zeros(*s, device=self.device, **kw) self.cue_steps_per_env = torch.cat([self.cue_steps_per_env, z(p, dtype=torch.int64)]) self.empty_steps_per_env = torch.cat([self.empty_steps_per_env, z(p, dtype=torch.int64)]) self.shuffle_steps_per_env = torch.cat([self.shuffle_steps_per_env, z(p, dtype=torch.int64)]) self.num_swaps_per_env = torch.cat([self.num_swaps_per_env, z(p, dtype=torch.int64)]) self.steps_per_swap_per_env = torch.cat([self.steps_per_swap_per_env, z(p, dtype=torch.int64)]) self.swap_pairs = torch.cat([self.swap_pairs, z(p, self.NUM_SWAPS[1], 2, dtype=torch.long)]) self.slot_of_mug = torch.cat([self.slot_of_mug, z(p, self.NUM_SWAPS[1] + 1, 3, dtype=torch.long)]) self.slot_positions = torch.cat([self.slot_positions, z(p, 3, 3)]) def _initialize_episode(self, env_idx: torch.Tensor, options: dict): with torch.device(self.device): b = len(env_idx) self.table_scene.initialize(env_idx) self.task_cue = None self.reward_dict = None self.cup_with_ball_number = self._batched_episode_rng.choice([0, 1, 2]) self.cup_with_ball_number = torch.from_numpy( self.cup_with_ball_number, ).to(device=self.device, dtype=torch.uint8) xyz = torch.zeros((b, 3)) xyz[:, :2] = torch.rand((b, 2)) * 0.2 - 0.1 xyz[:, 2] = self.object_zs[env_idx] q = torch.tensor([0, 1, 0.5, 0]).repeat(b, 1) left_pos = xyz + torch.tensor([0, -self.MIN_DIST, 0]).repeat(b, 1) center_pos = xyz.clone() right_pos = xyz + torch.tensor([0, self.MIN_DIST, 0]).repeat(b, 1) self.mug_left.set_pose(Pose.create_from_pq(p=left_pos, q=q)) self.mug_center.set_pose(Pose.create_from_pq(p=center_pos, q=q)) self.mug_right.set_pose(Pose.create_from_pq(p=right_pos, q=q)) for buf_name in ( "_mug_left_manip_ref", "_mug_center_manip_ref", "_mug_right_manip_ref", ): if not hasattr(self, buf_name): setattr( self, buf_name, torch.zeros(self.num_envs, 3, device=self.device), ) if ( not hasattr(self, "_mug_ref_ready") or self._mug_ref_ready is None or self._mug_ref_ready.shape[0] != self.num_envs ): self._mug_ref_ready = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) self._mug_ref_ready[env_idx] = False q_norm = q / q.norm(dim=-1, keepdim=True) self.mug_quat = q_norm q_ball = [1, 0, 0, 0] ball_xyz = xyz.clone() offsets = torch.zeros((b, 3), device=xyz.device) offsets[:, 1] = torch.where( self.cup_with_ball_number == 0, -self.MIN_DIST, torch.where( self.cup_with_ball_number == 1, 0.0, torch.where(self.cup_with_ball_number == 2, self.MIN_DIST, offsets[:, 1]), ), ) offsets[:, 2] = self.BALL_RADIUS - self.object_zs[env_idx] ball_xyz += offsets self.red_ball.set_pose(Pose.create_from_pq(p=ball_xyz, q=q_ball)) self.ball_initial_pose = ball_xyz if self.robot_uids in ("panda", "panda_wristcam"): qpos = np.array([0.0, 0, 0, -np.pi * 2 / 3, 0, np.pi * 2 / 3, np.pi / 4, 0.04, 0.04]) qpos[:-2] += self._episode_rng.normal( 0, self.robot_init_qpos_noise, len(qpos) - 2, ) self.agent.reset(qpos) self.agent.robot.set_root_pose(sapien.Pose([-0.615, 0, 0])) elif self.robot_uids == "xmate3_robotiq": qpos = np.array([0, 0.6, 0, 1.3, 0, 1.3, -1.57, 0, 0]) qpos[:-2] += self._episode_rng.normal( 0, self.robot_init_qpos_noise, len(qpos) - 2, ) self.agent.reset(qpos) self.agent.robot.set_root_pose(sapien.Pose([-0.562, 0, 0])) else: raise NotImplementedError(self.robot_uids) self._ensure_phase_buffers(env_idx) cue_lo, cue_hi = self.CUE_PHASE_STEPS shuf_lo, shuf_hi = self.SHUFFLE_PHASE_STEPS swap_lo, swap_hi = self.NUM_SWAPS cue_steps = torch.randint(cue_lo, cue_hi + 1, (b,), device=self.device, dtype=torch.int64) shuffle_steps = torch.randint(shuf_lo, shuf_hi + 1, (b,), device=self.device, dtype=torch.int64) num_swaps = torch.randint(swap_lo, swap_hi + 1, (b,), device=self.device, dtype=torch.int64) steps_per_swap = shuffle_steps // torch.clamp(num_swaps, min=1) self.cue_steps_per_env[env_idx] = cue_steps self.shuffle_steps_per_env[env_idx] = shuffle_steps self.empty_steps_per_env[env_idx] = shuffle_steps self.num_swaps_per_env[env_idx] = num_swaps self.steps_per_swap_per_env[env_idx] = steps_per_swap self.slot_positions[env_idx] = torch.stack( [left_pos, center_pos, right_pos], dim=1, ) ALL_PAIRS = torch.tensor([[0, 1], [0, 2], [1, 2]], device=self.device) pair_indices = torch.randint(0, 3, (b, self.NUM_SWAPS[1]), device=self.device) local_swap_pairs = ALL_PAIRS[pair_indices] self.swap_pairs[env_idx] = local_swap_pairs mug_at_slot = torch.arange(3, device=self.device).unsqueeze(0).expand(b, -1).clone() arange3 = torch.arange(3, device=self.device).unsqueeze(0).expand(b, -1) batch_r = torch.arange(b, device=self.device) slot_of_mug_all = torch.zeros(b, self.NUM_SWAPS[1] + 1, 3, dtype=torch.long, device=self.device) slot_of_mug_all[:, 0].scatter_(1, mug_at_slot, arange3) for s in range(self.NUM_SWAPS[1]): active = s < num_swaps slot_a = local_swap_pairs[:, s, 0] slot_b = local_swap_pairs[:, s, 1] mug_a = mug_at_slot[batch_r, slot_a] mug_b = mug_at_slot[batch_r, slot_b] new_mas = mug_at_slot.clone() new_mas[batch_r[active], slot_a[active]] = mug_b[active] new_mas[batch_r[active], slot_b[active]] = mug_a[active] mug_at_slot = new_mas slot_of_mug_all[:, s + 1].scatter_(1, mug_at_slot, arange3) self.slot_of_mug[env_idx] = slot_of_mug_all final_ball_slot = slot_of_mug_all[ batch_r, num_swaps, self.cup_with_ball_number.long(), ] self.oracle_info = final_ball_slot.to(torch.uint8)
[docs] def evaluate(self): elapsed = self.elapsed_steps.to(torch.int64) if elapsed.dim() > 1: elapsed = elapsed.squeeze(-1) B = elapsed.shape[0] bi = torch.arange(B, device=self.device) cue_end = self.cue_steps_per_env shuffle_end = cue_end + self.shuffle_steps_per_env cue_mask = elapsed < cue_end shuffle_mask = (elapsed >= cue_end) & (elapsed < shuffle_end) manip_mask = elapsed >= shuffle_end self.manip_mask = manip_mask ball_pose = self.red_ball.pose.raw_pose.clone() ball_pose[cue_mask, :3] = self.ball_initial_pose[cue_mask] hidden_ball_pos = self.ball_initial_pose.clone() hidden_ball_pos[:, 2] += self.HEIGHT_OFFSET ball_pose[shuffle_mask, :3] = hidden_ball_pos[shuffle_mask] final_ball_slot = self.slot_of_mug[ bi, self.num_swaps_per_env, self.cup_with_ball_number.long(), ] ball_final_pos = self.slot_positions[bi, final_ball_slot].clone() ball_final_pos[:, 2] = self.ball_initial_pose[:, 2] ball_pose[manip_mask, :3] = ball_final_pos[manip_mask] self.red_ball.pose = ball_pose steps_into_shuffle = torch.clamp(elapsed - cue_end, min=0) sps = torch.clamp(self.steps_per_swap_per_env, min=1) raw_idx = steps_into_shuffle // sps past_all = raw_idx >= self.num_swaps_per_env cur_swap = torch.where( past_all, torch.clamp(self.num_swaps_per_env - 1, min=0), torch.clamp(raw_idx, max=self.NUM_SWAPS[1] - 1), ).long() t = (steps_into_shuffle - cur_swap * sps).float() / sps.float() t = torch.clamp(t, 0.0, 1.0) t[past_all] = 1.0 progress = t * t * (3.0 - 2.0 * t) arc_z = self.SWAP_ARC_HEIGHT * torch.sin(torch.pi * progress) arc_z[past_all] = 0.0 cos_t = torch.cos(torch.pi * progress) sin_t = torch.sin(torch.pi * progress) mugs = [self.mug_left, self.mug_center, self.mug_right] next_swap = torch.clamp(cur_swap + 1, max=self.NUM_SWAPS[1]) for m, mug in enumerate(mugs): before_slot = self.slot_of_mug[bi, cur_swap, m] after_slot = self.slot_of_mug[bi, next_swap, m] after_slot = torch.where(past_all, before_slot, after_slot) start_pos = self.slot_positions[bi, before_slot] end_pos = self.slot_positions[bi, after_slot] mid = (start_pos + end_pos) * 0.5 half = (start_pos - end_pos) * 0.5 perp_x = -half[:, 1] perp_y = half[:, 0] anim_pos = mid.clone() anim_pos[:, 0] += half[:, 0] * cos_t + perp_x * sin_t anim_pos[:, 1] += half[:, 1] * cos_t + perp_y * sin_t anim_pos[:, 2] = start_pos[:, 2] + arc_z final_slot = self.slot_of_mug[bi, self.num_swaps_per_env, m] final_pos = self.slot_positions[bi, final_slot] cue_pos = self.slot_positions[:, m].clone() cue_pos[:, 2] += self.HEIGHT_OFFSET new_pose = mug.pose.raw_pose.clone() new_pose[cue_mask, :3] = cue_pos[cue_mask] new_pose[shuffle_mask, :3] = anim_pos[shuffle_mask] new_pose[manip_mask, :3] = final_pos[manip_mask] new_pose[:, 3:7] = self.mug_quat mug.pose = new_pose self.left_mask = (self.cup_with_ball_number == 0).unsqueeze(-1) self.center_mask = (self.cup_with_ball_number == 1).unsqueeze(-1) self.right_mask = (self.cup_with_ball_number == 2).unsqueeze(-1) self.obj_to_goal_pos = ( (self.mug_left.pose.p - self.agent.tcp.pose.p) * self.left_mask + (self.mug_center.pose.p - self.agent.tcp.pose.p) * self.center_mask + (self.mug_right.pose.p - self.agent.tcp.pose.p) * self.right_mask ) self.is_obj_placed = torch.linalg.norm(self.obj_to_goal_pos, axis=1) <= self.GOAL_THRESH self.is_robot_static = self.agent.is_static(0.2) just_entered_manip = self.manip_mask & (~self._mug_ref_ready) if torch.any(just_entered_manip): self._mug_left_manip_ref[just_entered_manip] = self.mug_left.pose.p[just_entered_manip] self._mug_center_manip_ref[just_entered_manip] = self.mug_center.pose.p[just_entered_manip] self._mug_right_manip_ref[just_entered_manip] = self.mug_right.pose.p[just_entered_manip] self._mug_ref_ready[just_entered_manip] = True left_displacement = torch.linalg.norm(self.mug_left.pose.p[:, :2] - self._mug_left_manip_ref[:, :2], dim=-1) center_displacement = torch.linalg.norm( self.mug_center.pose.p[:, :2] - self._mug_center_manip_ref[:, :2], dim=-1 ) right_displacement = torch.linalg.norm(self.mug_right.pose.p[:, :2] - self._mug_right_manip_ref[:, :2], dim=-1) self.mug_max_displacement = ( torch.maximum(torch.maximum(left_displacement, center_displacement), right_displacement) * self._mug_ref_ready.float() ) self.is_mug_displacement_ok = self.mug_max_displacement <= self.MUG_DISPLACEMENT_SUCCESS_THRESH return { "obj_to_goal_pos": self.obj_to_goal_pos, "is_obj_placed": self.is_obj_placed, "is_robot_static": self.is_robot_static, "mug_max_displacement": self.mug_max_displacement, "is_mug_displacement_ok": self.is_mug_displacement_ok, "success": self.is_obj_placed & self.is_robot_static & self.is_mug_displacement_ok, "task_cue": self.task_cue, "language_instruction": self.LANGUAGE_INSTRUCTION, "oracle_info": self.oracle_info, "reward_dict": self.reward_dict, }
def _get_obs_extra(self, info: Dict): self.obj_pose = ( self.mug_left.pose.raw_pose * self.left_mask + self.mug_center.pose.raw_pose * self.center_mask + self.mug_right.pose.raw_pose * self.right_mask ) obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose) if self.obs_mode in ["state", "state_dict"]: obs.update( obj_pose=self.obj_pose, ball_pose=self.red_ball.pose.raw_pose, oracle_info=self.oracle_info, ) return obs
[docs] def step(self, action): obs, reward, terminated, truncated, info = super().step(action) if isinstance(info, dict) and "success" in info: success = info["success"] if torch.is_tensor(success): success = success.to(dtype=torch.bool) if torch.is_tensor(terminated): terminated = terminated.to(dtype=torch.bool) & (~success) else: terminated = bool(terminated) and not bool(success.any().item()) else: terminated = bool(terminated) and not bool(success) return obs, reward, terminated, truncated, info
[docs] def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict): info["success"] *= self.manip_mask tcp_to_obj_dist = torch.linalg.norm(self.obj_to_goal_pos, axis=1) reaching_reward = 1 - torch.tanh(5.0 * tcp_to_obj_dist) static_reward = 1 - torch.tanh(5.0 * torch.linalg.norm(self.agent.robot.get_qvel()[..., :-2], axis=1)) reward = ( reaching_reward + static_reward * info["is_obj_placed"] + info["is_robot_static"] * info["is_obj_placed"] ) mug_shift_penalty = self.MUG_DISPLACEMENT_PENALTY_COEF * torch.tanh(10.0 * info["mug_max_displacement"]) reward -= mug_shift_penalty reward *= self.manip_mask reward[info["success"]] = 3.0 self.reward_dict = { "reaching_reward": reaching_reward, "static_reward": static_reward, "tcp_to_obj_dist": tcp_to_obj_dist, "mug_max_displacement": info["mug_max_displacement"], "mug_shift_penalty": mug_shift_penalty, "is_mug_displacement_ok": info["is_mug_displacement_ok"], 'info["is_obj_placed"]': info["is_obj_placed"], } return reward
[docs] def compute_normalized_dense_reward( self, obs: Any, action: torch.Tensor, info: Dict, ): return self.compute_dense_reward(obs=obs, action=action, info=info) / 3.0
[docs] @register_env("ShellGameShuffleTouch-VLA-v0", max_episode_steps=60, asset_download_ids=["ycb"]) class ShellGameShuffleTouchVLAEnv(ShellGameShuffleTouchVLABaseEnv): CUE_PHASE_STEPS = [1, 5] SHUFFLE_PHASE_STEPS = [20, 35] NUM_SWAPS = [2, 4]
[docs] @register_env("ShellGameShuffleTouch-Long-VLA-v0", max_episode_steps=600, asset_download_ids=["ycb"]) class ShellGameShuffleTouchLongVLAEnv(ShellGameShuffleTouchVLABaseEnv): CUE_PHASE_STEPS = [10, 100] SHUFFLE_PHASE_STEPS = [100, 400] NUM_SWAPS = [5, 15]