"""Remember-shape tasks for the VLA memory benchmark."""
from typing import Any, Dict, List
import numpy as np
import sapien
import torch
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import sapien_utils
from mani_skill.utils.building import actors
from mani_skill.utils.registration import register_env
from mani_skill.utils.scene_builder.table import TableSceneBuilder
from mani_skill.utils.structs.pose import Pose
from mani_skill.utils.structs.types import Array, GPUMemoryConfig, SimConfig
from mikasa_robo_suite.vla.utils import shapes
[docs]
class RememberShapeVLABaseEnv(BaseEnv):
"""Remember one target shape and find it again after a delay.
The cue presents a single target geometry. After a memory phase, several
shapes reappear in new positions and the robot must identify the matching
geometry regardless of where it was originally shown.
Episode flow:
- One target shape is shown at the center as the cue.
- All shapes disappear during the memory phase.
- All shapes reappear and the robot selects the matching one.
Success (`success=True`):
- The robot must reach the object whose geometry matches the cue.
How to customize:
- `SHAPES` changes how many different shapes appear in the scene.
- `SHAPE_MAPPING` changes which procedural geometries are actually used.
- `CUE_PHASE_STEPS` and `EMPTY_PHASE_STEPS` control cue duration and memory delay.
- `GOAL_THRESH` changes how strict the final reach criterion is.
- `SHAPE_SCALE` changes the size of the generated objects.
"""
LANGUAGE_INSTRUCTION = "Observe the object's shape, wait, then touch the object of the same shape."
SUPPORTED_ROBOTS = ["panda", "panda_wristcam"]
SHAPES = 3
GOAL_THRESH = 0.05
SHAPE_SCALE = 0.02
COLOR = [0, 0, 255, 255]
# Number of steps to freeze object velocities after shapes teleport onto the
# table (start of manipulation phase). Without this, Sapien GPU contact
# resolution needs 1-2 solver iterations to settle the z-contact, during
# which gravity pulls the objects slightly below their target z, producing
# a visible "fall from above the table" artefact.
APPEAR_SETTLE_STEPS = 3
SHAPE_MAPPING = {
0: "cube",
1: "sphere",
2: "cylinder",
3: "cross",
4: "torus",
5: "star",
6: "pyramide",
7: "t_shape",
8: "crescent",
}
CUE_PHASE_STEPS: List[int] = [1, 5]
EMPTY_PHASE_STEPS: List[int] = [1, 5]
MANIP_MIN_SHAPE_DISTANCE = 0.09
MANIP_WIDTH_AXIS = 1
MANIP_WIDTH_SCALE = 2
MANIP_WIDTH_CLAMP = 0.5
ACTION_L2_COEF = 0.02
ACTION_DELTA_L2_COEF = 0.05
QVEL_L2_COEF = 0.01
def __init__(self, *args, robot_uids="panda_wristcam", robot_init_qpos_noise=0.02, **kwargs):
self.shape_dict = dict(list(self.SHAPE_MAPPING.items())[: self.SHAPES])
self.robot_init_qpos_noise = robot_init_qpos_noise
self.initial_poses = {}
super().__init__(*args, robot_uids=robot_uids, **kwargs)
@property
def _default_sim_config(self):
return SimConfig(
gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25,
max_rigid_patch_count=2**21,
max_rigid_contact_count=2**22,
)
)
@property
def _default_sensor_configs(self):
pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
return [CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)]
@property
def _default_human_render_camera_configs(self):
pose = sapien_utils.look_at([0.5, 1, 1], [-0.3, 0, 0])
return CameraConfig("render_camera", pose, 512, 512, 1, 0.01, 100)
def _load_agent(self, options: dict):
super()._load_agent(options, sapien.Pose(p=[-0.615, 0, 0]))
def _build_shape_actor(self, shape_name: str, key: int, color):
common = dict(
color=color,
body_type="dynamic",
initial_pose=sapien.Pose(p=[0, 0, self.SHAPE_SCALE]),
)
s = self.SHAPE_SCALE
builders = {
"cube": lambda: actors.build_cube(self.scene, half_size=s, name=f"cube_{key}", **common),
"sphere": lambda: actors.build_sphere(self.scene, radius=s, name=f"sphere_{key}", **common),
"cylinder": lambda: actors.build_cylinder(
self.scene, radius=s, half_length=s, name=f"cylinder_{key}", **common
),
"cross": lambda: shapes.build_cross(
self.scene, arm_length=s * 1.5, width=s * 0.75, name=f"cross_{key}", **common
),
"torus": lambda: shapes.build_torus(self.scene, radius=s, tube_radius=s / 2, name=f"torus_{key}", **common),
"star": lambda: shapes.build_star(
self.scene, radius=s * 1.5, thickness=s * 0.75, name=f"star_{key}", **common
),
"pyramide": lambda: shapes.build_pyramid(
self.scene, base_size=s, height=s, name=f"pyramide_{key}", **common
),
"t_shape": lambda: shapes.build_t_shape(
self.scene, width=s * 2, height=s * 2, thickness=s * 0.75, name=f"t_shape_{key}", **common
),
"crescent": lambda: shapes.build_crescent(
self.scene, outer_radius=s, height=s, thickness=s / 2, name=f"crescent_{key}", **common
),
}
if shape_name not in builders:
raise NotImplementedError(shape_name)
return builders[shape_name]()
def _load_scene(self, options: dict):
self.table_scene = TableSceneBuilder(self, robot_init_qpos_noise=self.robot_init_qpos_noise)
self.table_scene.build()
color = np.array(self.COLOR) / 255.0
self.shape_actors = {}
self.shape_resting_z = {}
for key, shape_name in self.shape_dict.items():
self.shape_actors[key] = self._build_shape_actor(shape_name, key, color)
actor_q = self._get_shape_quaternion(shape_name)
self.shape_resting_z[key] = self._compute_actor_resting_z(
self.shape_actors[key], actor_q
)
def _ensure_phase_buffers(self, env_idx: torch.Tensor):
target_size = int(env_idx.max().item()) + 1
if not hasattr(self, "cue_steps_per_env") or self.cue_steps_per_env is None:
self.cue_steps_per_env = torch.zeros(target_size, dtype=torch.int64, device=self.device)
self.empty_steps_per_env = torch.zeros(target_size, dtype=torch.int64, device=self.device)
return
current_size = self.cue_steps_per_env.shape[0]
if target_size > current_size:
pad = target_size - current_size
self.cue_steps_per_env = torch.cat(
[self.cue_steps_per_env, torch.zeros(pad, dtype=torch.int64, device=self.device)], dim=0
)
self.empty_steps_per_env = torch.cat(
[self.empty_steps_per_env, torch.zeros(pad, dtype=torch.int64, device=self.device)], dim=0
)
def _get_shape_quaternion(self, shape_name: str):
if shape_name == "cylinder":
return [0.7071068, 0, 0.7071068, 0]
return [1, 0, 0, 0]
@staticmethod
def _quat_to_rotmat(q):
"""sapien (w, x, y, z) → 3x3 rotation matrix."""
w, x, y, z = float(q[0]), float(q[1]), float(q[2]), float(q[3])
return np.array(
[
[1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)],
[2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)],
[2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)],
],
dtype=np.float64,
)
def _compute_actor_resting_z(self, actor, actor_quat) -> float:
"""Z-coordinate of the actor's origin such that the lowest collision
point sits exactly at z=0 in world frame, assuming the actor is rotated
by `actor_quat` (sapien w,x,y,z).
Probes the actor's actual collision-shape geometry rather than relying
on hardcoded per-shape constants. Necessary because shapes like torus
and crescent build their collision primitives from vertically-rotated
cylinder segments where the world-frame z extent is `half_length`
(≈ 2π·radius / segments / 2), not `radius` as one might naively expect.
"""
obj = actor._objs[0] if hasattr(actor, "_objs") else actor
body = obj.find_component_by_type(sapien.physx.PhysxRigidDynamicComponent)
if body is None:
return self.SHAPE_SCALE
actor_R = self._quat_to_rotmat(actor_quat)
min_z = float("inf")
for shape in body.collision_shapes:
local_pose = shape.local_pose
if hasattr(shape, "half_size"):
hs = np.array(shape.half_size, dtype=np.float64)
elif hasattr(shape, "half_length") and hasattr(shape, "radius"):
# Sapien cylinder: default along X axis with given half_length & radius.
hs = np.array([shape.half_length, shape.radius, shape.radius], dtype=np.float64)
elif hasattr(shape, "radius"):
r = float(shape.radius)
hs = np.array([r, r, r], dtype=np.float64)
else:
continue
local_R = self._quat_to_rotmat(local_pose.q)
corners = np.array(
[
[sx * hs[0], sy * hs[1], sz * hs[2]]
for sx in (-1, 1)
for sy in (-1, 1)
for sz in (-1, 1)
],
dtype=np.float64,
)
local_world = corners @ local_R.T + np.array(local_pose.p, dtype=np.float64)
actor_world = local_world @ actor_R.T
min_z = min(min_z, float(actor_world[:, 2].min()))
if min_z == float("inf"):
return self.SHAPE_SCALE
return -min_z
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
with torch.device(self.device):
b = len(env_idx)
self.table_scene.initialize(env_idx)
self.task_cue = None
self.reward_dict = None
if hasattr(self, "_prev_action") and self._prev_action is not None:
if torch.is_tensor(self._prev_action) and self._prev_action.shape[0] >= int(env_idx.max().item()) + 1:
self._prev_action[env_idx] = 0
self.true_shape_indices = self._batched_episode_rng.choice(list(self.shape_dict.keys()))
self.true_shape_indices = torch.from_numpy(self.true_shape_indices).to(
device=self.device, dtype=torch.uint8
)
xyz_initial = torch.zeros((b, 3))
# Per-shape center pose: z is the actor-specific resting height
# (computed from collision geometry in _load_scene) so the shape
# sits flush on the table when shown at the cue (xy=0,0).
self.center_pose = {}
for key in self.shape_dict:
cp = xyz_initial.clone()
cp[..., 2] = self.shape_resting_z[key]
self.center_pose[key] = cp[0].unsqueeze(0)
for key, shape_name in self.shape_dict.items():
xyz = xyz_initial.clone()
if self.SHAPES != 3:
angle = np.pi * (key - (len(self.shape_dict) // 2)) / len(self.shape_dict)
radius = 0.3
xyz[..., 0] = radius * np.cos(angle) - 0.25
xyz[..., 1] = radius * np.sin(angle)
if self.SHAPES in [5, 9]:
xyz[..., 1] -= (key - (len(self.shape_dict) // 2)) * 0.025
else:
xyz[..., 1] -= (key - (len(self.shape_dict) // 2)) * 0.1
# Per-shape z: probed from collision geometry so each shape's
# bottom touches the table exactly. Using SHAPE_SCALE for all
# shapes would leave short ones (cross/torus/…) hovering.
xyz[..., 2] = self.shape_resting_z[key]
q = self._get_shape_quaternion(shape_name)
self.shape_actors[key].set_pose(Pose.create_from_pq(p=xyz, q=q))
self.initial_poses[key] = xyz.clone()
min_distance = self.SHAPE_SCALE * 3
max_attempts = 50
for env_i in range(b):
positions = [self.initial_poses[key][env_i].clone() for key in self.initial_poses]
for i in range(len(positions)):
attempt = 0
while attempt < max_attempts:
noise = torch.randn(2, device=self.device) * self.SHAPE_SCALE * 0.5
new_pos = positions[i].clone()
new_pos[:2] += noise
valid = all(torch.norm(new_pos[:2] - positions[j][:2]) >= min_distance for j in range(i))
if valid:
positions[i] = new_pos
break
attempt += 1
shuffled_indices = torch.randperm(len(positions))
for key, idx in zip(self.initial_poses, shuffled_indices):
# Take only x,y from the shuffled position; z must remain
# the *destination* shape's resting z. With per-shape
# resting heights, shuffling the full (x,y,z) tuple would
# mix z values across shapes (e.g. cube ends up at the
# cross's z and visibly hovers/penetrates the table).
new_pos = positions[idx].clone()
new_pos[2] = self.shape_resting_z[key]
self.initial_poses[key][env_i] = new_pos
pose = self.shape_actors[key].pose.raw_pose.clone()
pose[env_i, :3] = new_pos
self.shape_actors[key].pose = pose
self.oracle_info = self.true_shape_indices
if self.robot_uids in ("panda", "panda_wristcam"):
qpos = np.array([0.0, 0, 0, -np.pi * 2 / 3, 0, np.pi * 2 / 3, np.pi / 4, 0.04, 0.04])
qpos[:-2] += self._episode_rng.normal(0, self.robot_init_qpos_noise, len(qpos) - 2)
self.agent.reset(qpos)
self.agent.robot.set_root_pose(sapien.Pose([-0.615, 0, 0]))
else:
raise NotImplementedError(self.robot_uids)
self._ensure_phase_buffers(env_idx)
cue_lo, cue_hi = self.CUE_PHASE_STEPS
empty_lo, empty_hi = self.EMPTY_PHASE_STEPS
self.cue_steps_per_env[env_idx] = torch.randint(
low=cue_lo,
high=cue_hi + 1,
size=(b,),
device=self.device,
dtype=torch.int64,
)
self.empty_steps_per_env[env_idx] = torch.randint(
low=empty_lo,
high=empty_hi + 1,
size=(b,),
device=self.device,
dtype=torch.int64,
)
if self.SHAPES not in (5, 9):
self._spread_shapes_for_manipulation_phase(env_idx)
def _spread_shapes_for_manipulation_phase(self, env_idx: torch.Tensor):
keys = list(self.initial_poses.keys())
max_attempts = 80
jitter_scale = self.SHAPE_SCALE * 0.6
for env_i in env_idx.tolist():
positions = [self.initial_poses[key][env_i].clone() for key in keys]
for i in range(len(positions)):
attempts = 0
while attempts < max_attempts:
valid = all(
torch.norm(positions[i][:2] - positions[j][:2]) >= self.MANIP_MIN_SHAPE_DISTANCE
for j in range(i)
)
if valid:
break
positions[i][:2] += torch.randn(2, device=self.device) * jitter_scale
attempts += 1
for key, pos in zip(keys, positions):
pos[self.MANIP_WIDTH_AXIS] = torch.clamp(
pos[self.MANIP_WIDTH_AXIS],
-self.MANIP_WIDTH_CLAMP,
self.MANIP_WIDTH_CLAMP,
)
self.initial_poses[key][env_i] = pos
current_pose = self.shape_actors[key].pose.raw_pose.clone()
current_pose[env_i, :3] = pos
self.shape_actors[key].pose = current_pose
self._stretch_width_axis(env_i, keys)
def _stretch_width_axis(self, env_i: int, keys):
axis = self.MANIP_WIDTH_AXIS
axis_vals = torch.stack([self.initial_poses[key][env_i, axis] for key in keys], dim=0)
center = axis_vals.mean()
for key in keys:
pos = self.initial_poses[key][env_i].clone()
pos[axis] = center + (pos[axis] - center) * self.MANIP_WIDTH_SCALE
pos[axis] = torch.clamp(pos[axis], -self.MANIP_WIDTH_CLAMP, self.MANIP_WIDTH_CLAMP)
self.initial_poses[key][env_i] = pos
current_pose = self.shape_actors[key].pose.raw_pose.clone()
current_pose[env_i, :3] = pos
self.shape_actors[key].pose = current_pose
[docs]
def evaluate(self):
self.original_poses = {key: self.shape_actors[key].pose.raw_pose.clone() for key in self.shape_actors}
elapsed_steps = self.elapsed_steps.to(torch.int64)
cue_end = self.cue_steps_per_env
empty_end = cue_end + self.empty_steps_per_env
empty_mask = (elapsed_steps >= cue_end) & (elapsed_steps < empty_end)
manip_mask = elapsed_steps >= empty_end
hidden_poses = {}
hidden_phase_mask = ~manip_mask
appeared_mask = manip_mask & (elapsed_steps == empty_end)
for key in self.shape_dict:
hidden_poses[key] = self.shape_actors[key].pose.raw_pose.clone()
hidden_poses[key][hidden_phase_mask, 2] = 1000
self.shape_actors[key].pose = hidden_poses[key]
for key in self.shape_dict:
true_mask = self.true_shape_indices == key
b_ = hidden_poses[key].shape[0]
hidden_poses[key][true_mask & hidden_phase_mask, :3] = self.center_pose[key].repeat(b_, 1)[
true_mask & hidden_phase_mask, :3
]
hidden_poses[key][true_mask & empty_mask, 2] = 1000
self.shape_actors[key].pose = hidden_poses[key]
# Freeze velocities AND re-pin pose for APPEAR_SETTLE_STEPS steps after
# shapes teleport onto the table. Without re-pinning the pose every
# settle step, GPU contact resolution shifts each shape by a small
# delta (different sign per shape due to different inertia tensors)
# — short shapes end up visibly above or below the table for the
# first 1-2 frames. Re-applying initial_poses every settle step keeps
# the visual flush; the pinning stops once the solver is converged.
settle_mask = manip_mask & (elapsed_steps < empty_end + self.APPEAR_SETTLE_STEPS)
for key in self.shape_dict:
# Re-pin xyz on every step in the settle window (NOT just on the
# first appeared step) so contact-solver drift can't push the
# shape off the table surface visibly.
hidden_poses[key][settle_mask, :3] = self.initial_poses[key][settle_mask, :3]
self.shape_actors[key].pose = hidden_poses[key]
lock_mask = hidden_phase_mask | settle_mask
if bool(lock_mask.any().item()):
lin_vel = self.shape_actors[key].linear_velocity.clone()
ang_vel = self.shape_actors[key].angular_velocity.clone()
lin_vel[lock_mask] = 0
ang_vel[lock_mask] = 0
self.shape_actors[key].set_linear_velocity(lin_vel)
self.shape_actors[key].set_angular_velocity(ang_vel)
self.masks = {key: (self.true_shape_indices == key).unsqueeze(-1) for key in self.shape_dict}
self.obj_to_goal_pos = torch.zeros_like(
self.shape_actors[0].pose.p,
device=self.shape_actors[0].pose.p.device,
dtype=self.shape_actors[0].pose.p.dtype,
)
for key in self.shape_dict:
self.obj_to_goal_pos += (self.shape_actors[key].pose.p - self.agent.tcp.pose.p) * self.masks[key]
is_obj_placed = torch.linalg.norm(self.obj_to_goal_pos, axis=1) <= self.GOAL_THRESH
is_robot_static = self.agent.is_static(0.2)
return {
"obj_to_goal_pos": self.obj_to_goal_pos,
"is_obj_placed": is_obj_placed,
"is_robot_static": is_robot_static,
"success": is_obj_placed & is_robot_static,
"task_cue": self.task_cue,
"language_instruction": self.LANGUAGE_INSTRUCTION,
"oracle_info": self.oracle_info,
"reward_dict": self.reward_dict,
}
def _get_obs_extra(self, info: Dict):
obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose)
if self._obs_mode in ["state", "state_dict"]:
obs["oracle_info"] = self.oracle_info
for key in self.shape_actors:
obs[f"goal_{key}_pose"] = self.shape_actors[key].pose.p * self.masks[key]
return obs
[docs]
def step(self, action):
obs, reward, terminated, truncated, info = super().step(action)
if isinstance(info, dict) and "success" in info:
success = info["success"]
if torch.is_tensor(success):
success = success.to(dtype=torch.bool)
if torch.is_tensor(terminated):
terminated = terminated.to(dtype=torch.bool) & (~success)
else:
terminated = bool(terminated) and (not bool(success.any().item()))
else:
terminated = bool(terminated) and (not bool(success))
return obs, reward, terminated, truncated, info
[docs]
def compute_dense_reward(self, obs: Any, action: Array, info: Dict):
tcp_to_obj_dist = torch.linalg.norm(self.obj_to_goal_pos, axis=1)
reaching_reward = 1 - torch.tanh(10.0 * tcp_to_obj_dist)
qvel = self.agent.robot.get_qvel()[..., :-2]
qvel_l2 = torch.linalg.norm(qvel, axis=1)
static_reward = 1 - torch.tanh(5 * qvel_l2)
if not torch.is_tensor(action):
action = torch.as_tensor(action, device=self.device)
if not hasattr(self, "_prev_action") or self._prev_action is None or self._prev_action.shape != action.shape:
self._prev_action = torch.zeros_like(action)
delta_action = action - self._prev_action
action_l2 = torch.linalg.norm(action, axis=1)
delta_action_l2 = torch.linalg.norm(delta_action, axis=1)
if hasattr(self, "elapsed_steps") and torch.is_tensor(self.elapsed_steps):
first_step_mask = self.elapsed_steps <= 1
delta_action_l2 = torch.where(first_step_mask, torch.zeros_like(delta_action_l2), delta_action_l2)
smooth_penalty = (
self.ACTION_L2_COEF * torch.tanh(2.0 * action_l2)
+ self.ACTION_DELTA_L2_COEF * torch.tanh(5.0 * delta_action_l2)
+ self.QVEL_L2_COEF * torch.tanh(2.0 * qvel_l2)
)
reached = tcp_to_obj_dist < self.GOAL_THRESH
reward = (
1.0 * reaching_reward
+ 0.5 * static_reward
+ 0.5 * info["is_robot_static"] * info["is_obj_placed"]
- smooth_penalty
)
reward[info["success"]] = 3.0
self.reward_dict = {
"tcp_to_obj_dist": tcp_to_obj_dist,
"reaching_reward": reaching_reward,
"is_robot_static": info["is_robot_static"],
"reached": reached,
"success": info["success"],
"static_reward": static_reward,
"action_l2": action_l2,
"delta_action_l2": delta_action_l2,
"qvel_l2": qvel_l2,
"smooth_penalty": smooth_penalty,
"obj_to_goal_pos_y": info["obj_to_goal_pos"][:, 1],
"obj_to_goal_pos_x": info["obj_to_goal_pos"][:, 0],
}
self._prev_action = action.detach()
return reward
[docs]
def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict):
return self.compute_dense_reward(obs=obs, action=action, info=info) / 3.0
# ----- Standard tasks -----
[docs]
@register_env("RememberShape3-VLA-v0", max_episode_steps=25)
class RememberShape3VLAEnv(RememberShapeVLABaseEnv):
SHAPES = 3
CUE_PHASE_STEPS = [1, 5]
EMPTY_PHASE_STEPS = [1, 5]
[docs]
@register_env("RememberShape5-VLA-v0", max_episode_steps=25)
class RememberShape5VLAEnv(RememberShapeVLABaseEnv):
SHAPES = 5
CUE_PHASE_STEPS = [1, 5]
EMPTY_PHASE_STEPS = [1, 5]
[docs]
@register_env("RememberShape9-VLA-v0", max_episode_steps=25)
class RememberShape9VLAEnv(RememberShapeVLABaseEnv):
SHAPES = 9
CUE_PHASE_STEPS = [1, 5]
EMPTY_PHASE_STEPS = [1, 5]
# ----- Long-horizon tasks -----
[docs]
@register_env("RememberShape3-Long-VLA-v0", max_episode_steps=600)
class RememberShape3LongVLAEnv(RememberShapeVLABaseEnv):
SHAPES = 3
CUE_PHASE_STEPS = [10, 100]
EMPTY_PHASE_STEPS = [50, 450]
[docs]
@register_env("RememberShape5-Long-VLA-v0", max_episode_steps=600)
class RememberShape5LongVLAEnv(RememberShapeVLABaseEnv):
SHAPES = 5
CUE_PHASE_STEPS = [10, 100]
EMPTY_PHASE_STEPS = [50, 450]
[docs]
@register_env("RememberShape9-Long-VLA-v0", max_episode_steps=600)
class RememberShape9LongVLAEnv(RememberShapeVLABaseEnv):
SHAPES = 9
CUE_PHASE_STEPS = [10, 100]
EMPTY_PHASE_STEPS = [50, 450]