Source code for mikasa_robo_suite.vla.utils.apply_wrappers

"""Apply the canonical wrapper stack for MIKASA-Robo-VLA environments.

Most VLA tasks share the same outer observation wrappers, but some tasks need
an extra curriculum no-op wrapper or task-specific render overlays. This module
keeps those differences in :data:`VLA_WRAPPER_CONFIGS`.

To support a new VLA env:

1. Register the env itself with Gym/ManiSkill.
2. Add one entry to ``VLA_WRAPPER_CONFIGS`` below.
3. Pick ``overlays`` for the text/debug information rendered on videos.
4. Set ``curriculum_wrapper`` only when the train-data rollout stack freezes
   actions during the env's cue/empty phase.

Example::

    "MyTask-VLA-v0": VLAWrapperConfig(
        overlays=COLOR_MEMORY_OVERLAYS,
        curriculum_wrapper=CurriculumPhaseNoopActionWrapper,
    ),

The wrapper order is inner to outer:

1. ``StateOnlyTensorToDictWrapper`` for every env.
2. Optional ``config.curriculum_wrapper``.
3. Optional ``config.overlays`` when ``include_overlays=True``.
4. ``FlattenRGBDObservationWrapper`` for every env.
5. ``ConvertJointsToEEFXyzRpyGripperWrapper`` for every env.
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import Optional

import gymnasium as gym
from mani_skill.utils import gym_utils

from baselines.ppo.ppo_memtasks import FlattenRGBDObservationWrapper
from mikasa_robo_suite.vla.utils.wrappers import (
    ConvertJointsToEEFXyzRpyGripperWrapper,
    CurriculumPhaseNoopActionWrapper,
    MemoryCapacityInfoWrapper,
    RememberColorInfoWrapper,
    RememberShapeAndColorInfoWrapper,
    RememberShapeInfoWrapper,
    RenderPressProgressInfoWrapper,
    RenderRewardInfoWrapper,
    RenderStepInfoWrapper,
    RenderTimedTransferInfoWrapper,
    RenderTraceShapeDebugWrapper,
    RenderWorkingBatteriesInfoWrapper,
    RotateRenderAngleInfoWrapper,
    ShellGameRenderCupInfoWrapper,
    StateOnlyTensorToDictWrapper,
)

__all__ = [
    "apply_mikasa_vla_wrappers",
    "MIKASA_VLA_ENV_IDS",
    "VLA_WRAPPER_CONFIGS",
    "VLAWrapperConfig",
]


@dataclass(frozen=True)
class VLAWrapperConfig:
    """Task-specific choices used by :func:`apply_mikasa_vla_wrappers`."""

    overlays: tuple[type, ...] = ()
    curriculum_wrapper: Optional[type] = None
    expected_control_mode: Optional[str] = None


# Render overlays are applied in this order, inner to outer. Reuse one of these
# chains in VLA_WRAPPER_CONFIGS when a new task renders the same annotations.
SHELL_GAME_OVERLAYS = (
    RenderStepInfoWrapper,
    ShellGameRenderCupInfoWrapper,
    RenderRewardInfoWrapper,
)
INTERCEPT_OVERLAYS = (
    RenderStepInfoWrapper,
    RenderRewardInfoWrapper,
)
ROTATE_OVERLAYS = (
    # RotateRenderAngleInfoWrapper expects the uint8 frame produced by reward.
    RenderStepInfoWrapper,
    RenderRewardInfoWrapper,
    RotateRenderAngleInfoWrapper,
)
COLOR_MEMORY_OVERLAYS = (
    RememberColorInfoWrapper,
    RenderStepInfoWrapper,
    RenderRewardInfoWrapper,
)
SHAPE_MEMORY_OVERLAYS = (
    RememberShapeInfoWrapper,
    RenderStepInfoWrapper,
    RenderRewardInfoWrapper,
)
SHAPE_COLOR_MEMORY_OVERLAYS = (
    RememberShapeAndColorInfoWrapper,
    RenderStepInfoWrapper,
    RenderRewardInfoWrapper,
)
MEMORY_CAPACITY_OVERLAYS = (
    MemoryCapacityInfoWrapper,
    RenderStepInfoWrapper,
    RenderRewardInfoWrapper,
)
BATTERIES_CHECKER_OVERLAYS = (
    RenderWorkingBatteriesInfoWrapper,
    RenderStepInfoWrapper,
    RenderRewardInfoWrapper,
)
BLINK_COUNT_OVERLAYS = (
    RenderPressProgressInfoWrapper,
    RenderStepInfoWrapper,
    RenderRewardInfoWrapper,
)
TRACE_SHAPE_OVERLAYS = (
    RenderTraceShapeDebugWrapper,
    RenderStepInfoWrapper,
    RenderRewardInfoWrapper,
)
TIMED_TRANSFER_OVERLAYS = (
    RenderTimedTransferInfoWrapper,
    RenderStepInfoWrapper,
    RenderRewardInfoWrapper,
)
STEP_REWARD_OVERLAYS = (
    RenderStepInfoWrapper,
    RenderRewardInfoWrapper,
)


# Add a new supported env here. Each entry describes every task-specific choice
# that apply_mikasa_vla_wrappers needs for that env.
VLA_WRAPPER_CONFIGS: dict[str, VLAWrapperConfig] = {
    # Shell game.
    "ShellGameTouch-VLA-v0": VLAWrapperConfig(SHELL_GAME_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "ShellGamePush-VLA-v0": VLAWrapperConfig(SHELL_GAME_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "ShellGameShuffleTouch-VLA-v0": VLAWrapperConfig(SHELL_GAME_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "ShellGameShuffleColorLampTouch-VLA-v0": VLAWrapperConfig(SHELL_GAME_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "ShellGameColorLampTouch-VLA-v0": VLAWrapperConfig(SHELL_GAME_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "ShellGameShuffleTouch-Long-VLA-v0": VLAWrapperConfig(SHELL_GAME_OVERLAYS),
    "ShellGameShuffleColorLampTouch-Long-VLA-v0": VLAWrapperConfig(SHELL_GAME_OVERLAYS),
    # Intercept, rotate, and TakeItBack.
    "InterceptSlow-VLA-v0": VLAWrapperConfig(INTERCEPT_OVERLAYS),
    "InterceptMedium-VLA-v0": VLAWrapperConfig(INTERCEPT_OVERLAYS),
    "InterceptFast-VLA-v0": VLAWrapperConfig(INTERCEPT_OVERLAYS),
    "InterceptGrabSlow-VLA-v0": VLAWrapperConfig(INTERCEPT_OVERLAYS),
    "InterceptGrabMedium-VLA-v0": VLAWrapperConfig(INTERCEPT_OVERLAYS),
    "InterceptGrabFast-VLA-v0": VLAWrapperConfig(INTERCEPT_OVERLAYS),
    "RotateLenientPos-VLA-v0": VLAWrapperConfig(ROTATE_OVERLAYS),
    "RotateLenientPosNeg-VLA-v0": VLAWrapperConfig(ROTATE_OVERLAYS),
    "RotateStrictPos-VLA-v0": VLAWrapperConfig(ROTATE_OVERLAYS),
    "RotateStrictPosNeg-VLA-v0": VLAWrapperConfig(ROTATE_OVERLAYS),
    "TakeItBack-VLA-v0": VLAWrapperConfig(INTERCEPT_OVERLAYS),
    # Remember color.
    "RememberColor3-VLA-v0": VLAWrapperConfig(COLOR_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "RememberColor5-VLA-v0": VLAWrapperConfig(COLOR_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "RememberColor9-VLA-v0": VLAWrapperConfig(COLOR_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "RememberColor3-Long-VLA-v0": VLAWrapperConfig(COLOR_MEMORY_OVERLAYS),
    "RememberColor5-Long-VLA-v0": VLAWrapperConfig(COLOR_MEMORY_OVERLAYS),
    "RememberColor9-Long-VLA-v0": VLAWrapperConfig(COLOR_MEMORY_OVERLAYS),
    # Remember shape.
    "RememberShape3-VLA-v0": VLAWrapperConfig(SHAPE_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "RememberShape5-VLA-v0": VLAWrapperConfig(SHAPE_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "RememberShape9-VLA-v0": VLAWrapperConfig(SHAPE_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "RememberShape3-Long-VLA-v0": VLAWrapperConfig(SHAPE_MEMORY_OVERLAYS),
    "RememberShape5-Long-VLA-v0": VLAWrapperConfig(SHAPE_MEMORY_OVERLAYS),
    "RememberShape9-Long-VLA-v0": VLAWrapperConfig(SHAPE_MEMORY_OVERLAYS),
    # Remember shape and color.
    "RememberShapeAndColor3x2-VLA-v0": VLAWrapperConfig(SHAPE_COLOR_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "RememberShapeAndColor3x3-VLA-v0": VLAWrapperConfig(SHAPE_COLOR_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "RememberShapeAndColor5x3-VLA-v0": VLAWrapperConfig(SHAPE_COLOR_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "RememberShapeAndColor3x2-Long-VLA-v0": VLAWrapperConfig(SHAPE_COLOR_MEMORY_OVERLAYS),
    "RememberShapeAndColor3x3-Long-VLA-v0": VLAWrapperConfig(SHAPE_COLOR_MEMORY_OVERLAYS),
    "RememberShapeAndColor5x3-Long-VLA-v0": VLAWrapperConfig(SHAPE_COLOR_MEMORY_OVERLAYS),
    # Find imposter.
    "FindImposterColor3-VLA-v0": VLAWrapperConfig(COLOR_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "FindImposterColor5-VLA-v0": VLAWrapperConfig(COLOR_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "FindImposterColor9-VLA-v0": VLAWrapperConfig(COLOR_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "FindImposterShape3-VLA-v0": VLAWrapperConfig(SHAPE_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "FindImposterShape5-VLA-v0": VLAWrapperConfig(SHAPE_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "FindImposterShape9-VLA-v0": VLAWrapperConfig(SHAPE_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper),
    "FindImposterShapeAndColor3x2-VLA-v0": VLAWrapperConfig(
        SHAPE_COLOR_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper
    ),
    "FindImposterShapeAndColor3x3-VLA-v0": VLAWrapperConfig(
        SHAPE_COLOR_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper
    ),
    "FindImposterShapeAndColor5x3-VLA-v0": VLAWrapperConfig(
        SHAPE_COLOR_MEMORY_OVERLAYS, CurriculumPhaseNoopActionWrapper
    ),
    # Bunch/sequence/chain memory capacity.
    "BunchOfColors3-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "BunchOfColors5-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "BunchOfColors7-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "BunchOfColors3-Long-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "BunchOfColors5-Long-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "BunchOfColors7-Long-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "SeqOfColors3-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "SeqOfColors5-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "SeqOfColors7-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "SeqOfColors3-Long-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "SeqOfColors5-Long-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "SeqOfColors7-Long-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "ChainOfColors3-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "ChainOfColors5-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "ChainOfColors7-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "ChainOfColors3-Long-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "ChainOfColors5-Long-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    "ChainOfColors7-Long-VLA-v0": VLAWrapperConfig(MEMORY_CAPACITY_OVERLAYS),
    # Motion-planning datasets with task-specific overlays.
    "BatteriesCheckerEasy-3-VLA-v0": VLAWrapperConfig(BATTERIES_CHECKER_OVERLAYS),
    "BatteriesCheckerEasy-6-VLA-v0": VLAWrapperConfig(BATTERIES_CHECKER_OVERLAYS),
    "BatteriesCheckerHard-3-VLA-v0": VLAWrapperConfig(BATTERIES_CHECKER_OVERLAYS),
    "BatteriesCheckerHard-6-VLA-v0": VLAWrapperConfig(BATTERIES_CHECKER_OVERLAYS),
    # The pd_joint_pos hold-action wrappers in some motion-planning oracle
    # scripts are upstream of replay. Validation follows the unfiltered pd-ee
    # rollout that exports the published train data.
    "BlinkCountButtonPressEasy-VLA-v0": VLAWrapperConfig(BLINK_COUNT_OVERLAYS),
    "BlinkCountButtonPressMedium-VLA-v0": VLAWrapperConfig(BLINK_COUNT_OVERLAYS),
    "BlinkCountButtonPressHard-VLA-v0": VLAWrapperConfig(BLINK_COUNT_OVERLAYS),
    "BlinkCountButtonPressEasy-Long-VLA-v0": VLAWrapperConfig(BLINK_COUNT_OVERLAYS),
    "BlinkCountButtonPressMedium-Long-VLA-v0": VLAWrapperConfig(BLINK_COUNT_OVERLAYS),
    "BlinkCountButtonPressHard-Long-VLA-v0": VLAWrapperConfig(BLINK_COUNT_OVERLAYS),
    "TraceShapeEasy-VLA-v0": VLAWrapperConfig(TRACE_SHAPE_OVERLAYS),
    "TraceShapeMedium-VLA-v0": VLAWrapperConfig(TRACE_SHAPE_OVERLAYS),
    "TraceShapeHard-VLA-v0": VLAWrapperConfig(TRACE_SHAPE_OVERLAYS),
    "TraceShapeSeqEasy-VLA-v0": VLAWrapperConfig(TRACE_SHAPE_OVERLAYS),
    "TraceShapeSeqMedium-VLA-v0": VLAWrapperConfig(TRACE_SHAPE_OVERLAYS),
    "TraceShapeSeqHard-VLA-v0": VLAWrapperConfig(TRACE_SHAPE_OVERLAYS),
    "TimedTransferEasy-VLA-v0": VLAWrapperConfig(TIMED_TRANSFER_OVERLAYS),
    "TimedTransferMedium-VLA-v0": VLAWrapperConfig(TIMED_TRANSFER_OVERLAYS),
    "TimedTransferHard-VLA-v0": VLAWrapperConfig(TIMED_TRANSFER_OVERLAYS),
    "TimedTransferEasy-Long-VLA-v0": VLAWrapperConfig(TIMED_TRANSFER_OVERLAYS),
    "TimedTransferMedium-Long-VLA-v0": VLAWrapperConfig(TIMED_TRANSFER_OVERLAYS),
    "TimedTransferHard-Long-VLA-v0": VLAWrapperConfig(TIMED_TRANSFER_OVERLAYS),
    "GatherAndRecall1-VLA-v0": VLAWrapperConfig(STEP_REWARD_OVERLAYS),
    "GatherAndRecall3-VLA-v0": VLAWrapperConfig(STEP_REWARD_OVERLAYS),
    "GatherAndRecall5-VLA-v0": VLAWrapperConfig(STEP_REWARD_OVERLAYS),
    "GatherAndRecall7-VLA-v0": VLAWrapperConfig(STEP_REWARD_OVERLAYS),
    "GatherAndRecall9-VLA-v0": VLAWrapperConfig(STEP_REWARD_OVERLAYS),
}

MIKASA_VLA_ENV_IDS = tuple(VLA_WRAPPER_CONFIGS)


def _resolve_env_id(env: gym.Env) -> str:
    spec = getattr(env, "spec", None)
    if spec is None:
        spec = getattr(env.unwrapped, "spec", None)
    if spec is None or spec.id is None:
        raise ValueError("Cannot resolve env_id. Pass an env created via `gym.make(...)`.")
    return spec.id


def _warn_if_control_mode_differs(env: gym.Env, env_id: str, config: VLAWrapperConfig) -> None:
    expected = config.expected_control_mode
    actual = getattr(env.unwrapped, "control_mode", None)
    if expected is None or actual is None or actual == expected:
        return

    warnings.warn(
        f"{env_id} was collected with control_mode={expected!r}, but this env "
        f"was created with control_mode={actual!r}. The configured curriculum "
        "wrapper may not match the action space.",
        RuntimeWarning,
        stacklevel=3,
    )


[docs] def apply_mikasa_vla_wrappers(env: gym.Env, *, include_overlays: bool = True) -> gym.Env: """Apply the configured MIKASA-Robo-VLA wrapper stack. Call this immediately after ``gym.make``. Add support for another env by adding its :class:`VLAWrapperConfig` to :data:`VLA_WRAPPER_CONFIGS`. The returned env exposes ``env.max_episode_steps`` with the configured ManiSkill horizon, including a ``gym.make(..., max_episode_steps=...)`` override. """ env_id = _resolve_env_id(env) max_episode_steps = gym_utils.find_max_episode_steps_value(env) try: config = VLA_WRAPPER_CONFIGS[env_id] except KeyError as exc: raise ValueError( f"Unknown env_id={env_id!r}. Add it to VLA_WRAPPER_CONFIGS or use one of: {sorted(VLA_WRAPPER_CONFIGS)}" ) from exc _warn_if_control_mode_differs(env, env_id, config) env = StateOnlyTensorToDictWrapper(env) if config.curriculum_wrapper is not None: env = config.curriculum_wrapper(env) if include_overlays: for wrapper_cls in config.overlays: env = wrapper_cls(env) env = FlattenRGBDObservationWrapper(env, rgb=True, depth=False, state=False, oracle=False, joints=True) env = ConvertJointsToEEFXyzRpyGripperWrapper(env) env.max_episode_steps = max_episode_steps return env