Source code for mikasa_robo_suite.vla.memory_envs.intercept_grab_vla

"""Interception-and-grasp tasks for the VLA benchmark."""

from typing import Any, Dict, Union

import numpy as np
import sapien
import torch
from mani_skill.agents.robots.panda.panda import Panda
from mani_skill.agents.robots.panda.panda_wristcam import PandaWristCam
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import sapien_utils
from mani_skill.utils.building import actors
from mani_skill.utils.registration import register_env
from mani_skill.utils.scene_builder.table import TableSceneBuilder
from mani_skill.utils.structs.pose import Pose
from mani_skill.utils.structs.types import Array, GPUMemoryConfig, SimConfig


[docs] class InterceptGrabVLABaseEnv(BaseEnv): """Intercept a moving ball and finish with a stable grasp. A ball is launched across the table with random velocity. The robot has no separate observation phase: it must react immediately, intercept the ball, close the gripper around it, and settle into a stable final state. Episode flow: - The ball is spawned with a sampled initial velocity. - The robot moves to intercept its trajectory. - The robot grasps the ball and stabilizes. Success (`success=True`): - The ball must be grasped and the robot must be static at the end of the episode step. How to customize: - `VELOCITY_RANGE` changes how fast the ball moves and therefore how much anticipation the policy needs. - `BALL_RADIUS` changes grasp geometry and contact difficulty. - `ACTION_L2_COEF`, `ACTION_DELTA_L2_COEF`, and `QVEL_L2_COEF` can be used to regularize aggressive or jerky behavior. """ LANGUAGE_INSTRUCTION = "Intercept the rolling ball and grasp it to stop it." SUPPORTED_ROBOTS = ["panda", "panda_wristcam"] agent: Union[Panda, PandaWristCam] VELOCITY_RANGE = (0.0, 0.0) BALL_RADIUS: float = 0.02 ACTION_L2_COEF = 0.0 ACTION_DELTA_L2_COEF = 0.0 QVEL_L2_COEF = 0.0 def __init__(self, *args, robot_uids="panda_wristcam", robot_init_qpos_noise=0.02, **kwargs): self.robot_init_qpos_noise = robot_init_qpos_noise super().__init__(*args, robot_uids=robot_uids, **kwargs) @property def _default_sim_config(self): return SimConfig( gpu_memory_config=GPUMemoryConfig( found_lost_pairs_capacity=2**25, max_rigid_patch_count=2**18, ) ) @property def _default_sensor_configs(self): pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1]) return [CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)] @property def _default_human_render_camera_configs(self): pose = sapien_utils.look_at([0.5, 1, 1], [-0.3, 0, 0]) return CameraConfig("render_camera", pose, 512, 512, 1, 0.01, 100) def _load_agent(self, options: dict): super()._load_agent(options, sapien.Pose(p=[-0.615, 0, 0])) def _load_scene(self, options: dict): self.table_scene = TableSceneBuilder( self, robot_init_qpos_noise=self.robot_init_qpos_noise, ) self.table_scene.build() self.ball = actors.build_sphere( self.scene, radius=self.BALL_RADIUS, color=np.array([1, 0, 0, 1]), name="ball", body_type="dynamic", initial_pose=sapien.Pose(p=[0, 0, self.BALL_RADIUS]), ) self.reached_status = torch.zeros(self.num_envs, dtype=torch.float32) def _initialize_episode(self, env_idx: torch.Tensor, options: dict): self.reached_status = self.reached_status.to(self.device) with torch.device(self.device): b = len(env_idx) self.table_scene.initialize(env_idx) self.task_cue = None self.reward_dict = None xyz = torch.zeros((b, 3)) xyz[..., 0] = (torch.rand((b,)) * 2 - 1) * 0.15 - 0.2 xyz[..., 1] = torch.rand((b,)) * 0.25 - 1.0 + 0.1 xyz[..., 2] = self.BALL_RADIUS self.ball.set_pose(Pose.create_from_pq(p=xyz, q=[1, 0, 0, 0])) initial_velocity = torch.zeros((b, 3)) initial_velocity[..., 0] = torch.rand((b,)) * 0.05 min_vel, max_vel = self.VELOCITY_RANGE initial_velocity[..., 1] = torch.rand((b,)) * (max_vel - min_vel) + min_vel self.ball.set_linear_velocity(initial_velocity) self.oracle_info = initial_velocity if self.robot_uids in ("panda", "panda_wristcam"): qpos = np.array([0.0, 0, 0, -np.pi * 2 / 3, 0, np.pi * 2 / 3, np.pi / 4, 0.04, 0.04]) qpos[:-2] += self._episode_rng.normal( 0, self.robot_init_qpos_noise, len(qpos) - 2, ) self.agent.reset(qpos) self.agent.robot.set_root_pose(sapien.Pose([-0.615, 0, 0])) else: raise NotImplementedError(self.robot_uids) self.reached_status[env_idx] = 0.0
[docs] def evaluate(self): is_ball_grasped = self.agent.is_grasping(self.ball) is_robot_static = self.agent.is_static(0.2) return { "success": is_ball_grasped & is_robot_static, "task_cue": self.task_cue, "language_instruction": self.LANGUAGE_INSTRUCTION, "oracle_info": self.oracle_info, "is_ball_grasped": is_ball_grasped, "is_robot_static": is_robot_static, "reward_dict": self.reward_dict, }
def _get_obs_extra(self, info: Dict): obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose) if self._obs_mode in ["state", "state_dict"]: obs.update( ball_pose=self.ball.pose.raw_pose, oracle_info=self.oracle_info, is_ball_grasped=info["is_ball_grasped"], ball_linear_vel=self.ball.linear_velocity, ball_angular_vel=self.ball.angular_velocity, tcp_to_ball_pos=self.ball.pose.p - self.agent.tcp.pose.p, gripper_width=self.agent.robot.get_qpos()[:, -2:].sum(dim=1), relative_velocity=self.ball.linear_velocity - self.agent.tcp.linear_velocity, ) return obs
[docs] def step(self, action): obs, reward, terminated, truncated, info = super().step(action) if isinstance(info, dict) and "success" in info: success = info["success"] if torch.is_tensor(success): success = success.to(dtype=torch.bool) if torch.is_tensor(terminated): terminated = terminated.to(dtype=torch.bool) & (~success) else: terminated = bool(terminated) and not bool(success.any().item()) else: terminated = bool(terminated) and not bool(success) return obs, reward, terminated, truncated, info
[docs] def compute_dense_reward(self, obs: Any, action: Array, info: Dict): tcp_to_obj_dist = torch.linalg.norm(self.ball.pose.p - self.agent.tcp.pose.p, axis=1) reward = 10 * (1 - torch.tanh(3 * tcp_to_obj_dist)) reward = 10 * (1 - torch.tanh(3 * tcp_to_obj_dist)) current_gripper_width = torch.sum(self.agent.robot.get_qpos()[:, -2:], axis=1) is_ball_grasped = info["is_ball_grasped"] dist_tcp_to_ball_x = torch.abs(self.agent.tcp.pose.p[:, 0] - self.ball.pose.p[:, 0]) dist_tcp_to_ball_y = torch.abs(self.agent.tcp.pose.p[:, 1] - self.ball.pose.p[:, 1]) dist_tcp_to_ball_z = torch.abs(self.agent.tcp.pose.p[:, 2] - self.ball.pose.p[:, 2]) ball_in_gripper = ( (dist_tcp_to_ball_x <= self.BALL_RADIUS * 2.0) & (dist_tcp_to_ball_y <= self.BALL_RADIUS * 2.0) & (dist_tcp_to_ball_z <= self.BALL_RADIUS * 1.5) & (tcp_to_obj_dist <= self.BALL_RADIUS * 2.0) & (current_gripper_width >= self.BALL_RADIUS * 2) ) reward[ball_in_gripper] += 10.0 very_close_mask = tcp_to_obj_dist <= self.BALL_RADIUS * 1.25 reward[very_close_mask] += 30.0 * (1 - tcp_to_obj_dist[very_close_mask] / (self.BALL_RADIUS * 3)) optimal_width = 2 * self.BALL_RADIUS width_error = torch.abs(current_gripper_width[ball_in_gripper] - optimal_width) width_error = torch.clamp(width_error - 0.02, min=0.0) closing_reward = 90.0 * torch.exp(-5.0 * width_error) reward[ball_in_gripper] += closing_reward gripper_attempt_mask = ball_in_gripper & (current_gripper_width <= optimal_width + 0.02) reward[gripper_attempt_mask] += 90.0 reward[is_ball_grasped] += 200 v = torch.linalg.norm(self.ball.linear_velocity, axis=1) av = torch.linalg.norm(self.ball.angular_velocity, axis=1) static_reward = 1 - torch.tanh(v * 10 + av) robot_static_reward = self.agent.is_static(0.2) reward[info["success"]] = 300 self.reward_dict = { "tcp_to_obj_dist": tcp_to_obj_dist, "is_grasped_reward": is_ball_grasped, "ball_in_gripper": ball_in_gripper, "dist_tcp_to_ball_x": dist_tcp_to_ball_x, "dist_tcp_to_ball_y": dist_tcp_to_ball_y, "dist_tcp_to_ball_z": dist_tcp_to_ball_z, "current_gripper_width": current_gripper_width, "very_close_mask": very_close_mask, "static_reward": static_reward, "robot_static_reward": robot_static_reward, "success": info["success"], } return reward
[docs] def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict): return self.compute_dense_reward(obs=obs, action=action, info=info) / 300.0
[docs] @register_env("InterceptGrabSlow-VLA-v0", max_episode_steps=60) class InterceptGrabSlowVLAEnv(InterceptGrabVLABaseEnv): VELOCITY_RANGE = (0.25, 0.5)
[docs] @register_env("InterceptGrabMedium-VLA-v0", max_episode_steps=60) class InterceptGrabMediumVLAEnv(InterceptGrabVLABaseEnv): VELOCITY_RANGE = (0.5, 0.75)
[docs] @register_env("InterceptGrabFast-VLA-v0", max_episode_steps=60) class InterceptGrabFastVLAEnv(InterceptGrabVLABaseEnv): VELOCITY_RANGE = (0.75, 1.0)