Source code for mikasa_robo_suite.vla.memory_envs.batteries_checker_easy_vla
"""Easy Batteries Checker variants for the VLA benchmark."""
import torch
from mani_skill.utils.registration import register_env
from .batteries_checker_hard_vla import BatteriesCheckerVLABaseEnv
[docs]
class BatteriesCheckerEasyVLABaseEnv(BatteriesCheckerVLABaseEnv):
"""Simplified version of Batteries Checker.
In this task, the robot checks batteries one by one using the socket and
the lamp. A tested battery does not need to be placed back manually: once
the check is complete, the environment returns it to its original tray slot
automatically. This keeps the memory component but removes part of the
manipulation burden.
Episode flow:
- The robot picks one battery and inserts it into the socket.
- The lamp reveals whether that battery is working.
- The environment snaps the battery back to its home slot.
- The robot presses the button to confirm that this battery was checked.
Success (`success=True`):
- Same criterion as the hard variant: all working batteries must be found
through completed check-confirm cycles.
How to customize:
- `ACTIVE_BATTERY_COUNT` changes how many batteries are present and therefore
how much search the agent must do.
- `WORKING_BATTERY_COUNT` changes how many positive findings exist in the tray.
- The hard-variant socket, tray, and button thresholds still control how
strict insertion and confirmation are.
"""
LANGUAGE_INSTRUCTION = "Find all working batteries by inserting each one into the socket, observing the lamp result, and then pressing the button to confirm."
[docs]
def evaluate(self):
info = super().evaluate()
auto_return_mask = (
info["action_mask"] & (self.action_stage == self.STAGE_RETURN) & (self.active_battery_idx >= 0)
)
if not auto_return_mask.any():
return info
env_ids = torch.where(auto_return_mask)[0]
active_idx = self.active_battery_idx[env_ids]
home_pos = self.tray_slot_positions[env_ids, active_idx]
home_quat = self.battery_quat.unsqueeze(0).repeat(env_ids.shape[0], 1)
for bat_id in torch.unique(active_idx).tolist():
bat_id = int(bat_id)
per_battery_mask = active_idx == bat_id
per_battery_env_ids = env_ids[per_battery_mask]
if per_battery_env_ids.numel() == 0:
continue
raw_pose = self.batteries[bat_id].pose.raw_pose.clone()
raw_pose[per_battery_env_ids, :3] = home_pos[per_battery_mask]
raw_pose[per_battery_env_ids, 3:7] = home_quat[per_battery_mask]
self.batteries[bat_id].pose = raw_pose
zero_vel = torch.zeros((self.num_envs, 3), device=self.device)
self.batteries[bat_id].set_linear_velocity(zero_vel)
self.batteries[bat_id].set_angular_velocity(zero_vel)
self.new_return_event[env_ids] = True
self.action_stage[env_ids] = self.STAGE_CONFIRM
info["new_return_event"] = self.new_return_event
info["action_stage"] = self.action_stage
info["socket_has_battery"][env_ids] = False
info["socket_has_working"][env_ids] = False
info["socket_battery_idx"][env_ids] = -1
return info
[docs]
@register_env("BatteriesCheckerEasy-3-VLA-v0", max_episode_steps=540)
class BatteriesCheckerEasy3VLAEnv(BatteriesCheckerEasyVLABaseEnv):
ACTIVE_BATTERY_COUNT = 3
WORKING_BATTERY_COUNT = 1
[docs]
@register_env("BatteriesCheckerEasy-6-VLA-v0", max_episode_steps=1080)
class BatteriesCheckerEasy6VLAEnv(BatteriesCheckerEasyVLABaseEnv):
ACTIVE_BATTERY_COUNT = 6
WORKING_BATTERY_COUNT = 3
[docs]
@register_env("BatteriesCheckerEasy-9-VLA-v0", max_episode_steps=1620)
class BatteriesCheckerEasy9VLAEnv(BatteriesCheckerEasyVLABaseEnv):
ACTIVE_BATTERY_COUNT = 9
WORKING_BATTERY_COUNT = 5
[docs]
@register_env("BatteriesCheckerEasy-12-VLA-v0", max_episode_steps=2160)
class BatteriesCheckerEasy12VLAEnv(BatteriesCheckerEasyVLABaseEnv):
ACTIVE_BATTERY_COUNT = 12
WORKING_BATTERY_COUNT = 7
[docs]
@register_env("BatteriesCheckerEasy-15-VLA-v0", max_episode_steps=2400)
class BatteriesCheckerEasy15VLAEnv(BatteriesCheckerEasyVLABaseEnv):
ACTIVE_BATTERY_COUNT = 15
WORKING_BATTERY_COUNT = 9