|
| 1 | +.. _warp-env-migration: |
| 2 | + |
| 3 | +Warp Environment Guide |
| 4 | +====================== |
| 5 | + |
| 6 | +This guide covers the key conventions and patterns used by the warp-first environment |
| 7 | +infrastructure, useful for migrating existing stable environments or creating new ones natively. |
| 8 | + |
| 9 | +.. note:: |
| 10 | + |
| 11 | + The warp environment infrastructure lives in ``isaaclab_experimental`` and |
| 12 | + ``isaaclab_tasks_experimental``. It's an experimental feature. |
| 13 | + |
| 14 | + |
| 15 | +Design Rationale |
| 16 | +~~~~~~~~~~~~~~~~ |
| 17 | + |
| 18 | +The warp environment path is built around `CUDA graph capture |
| 19 | +<https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/cuda-graphs.html>`_. |
| 20 | +A CUDA graph records a sequence of GPU operations (kernel launches, memory copies) during a |
| 21 | +capture phase, then replays the entire sequence with a single launch. This eliminates per-kernel |
| 22 | +CPU overhead — the parameter validation, kernel selection, and buffer setup that normally costs |
| 23 | +20–200 μs per operation is performed once during graph instantiation and reused on every replay |
| 24 | +(~10 μs total). All CPU-side code (Python logic, torch dispatching) executed during capture is |
| 25 | +completely bypassed during replay. See the `Warp concurrency documentation |
| 26 | +<https://nvidia.github.io/warp/deep_dive/concurrency.html>`_ for Warp's graph capture API |
| 27 | +(``wp.ScopedCapture``). |
| 28 | + |
| 29 | +All design decisions in the warp infrastructure follow from this constraint: every operation in the |
| 30 | +step loop must be a GPU kernel launch with stable memory pointers so that the captured graph can |
| 31 | +be replayed without modification. |
| 32 | + |
| 33 | +Key consequences: |
| 34 | + |
| 35 | +- All buffers are **pre-allocated** — no dynamic allocation inside the step loop |
| 36 | +- Data flows through **persistent ``wp.array`` pointers** — never replaced, only overwritten |
| 37 | +- MDP terms are **pure ``@wp.kernel`` functions** — no Python branching on GPU data |
| 38 | +- Reset uses **boolean masks** (``env_mask``) instead of index lists (``env_ids``) to avoid |
| 39 | + variable-length indexing that changes graph topology |
| 40 | + |
| 41 | + |
| 42 | +Project Structure |
| 43 | +~~~~~~~~~~~~~~~~~ |
| 44 | + |
| 45 | +Warp-specific implementations that deviate from stable live in the ``_experimental`` packages: |
| 46 | + |
| 47 | +- ``isaaclab_experimental`` — warp managers, base env classes, warp MDP terms |
| 48 | +- ``isaaclab_tasks_experimental`` — warp task configs and task-specific MDP terms |
| 49 | + |
| 50 | +Any new warp implementation that differs from the stable API belongs in these packages. |
| 51 | +Warp task configs reference Newton physics directly (no ``PresetCfg``) since the warp path |
| 52 | +is Newton-only. |
| 53 | + |
| 54 | + |
| 55 | +Writing Warp MDP Terms |
| 56 | +~~~~~~~~~~~~~~~~~~~~~~ |
| 57 | + |
| 58 | +Imports |
| 59 | +^^^^^^^ |
| 60 | + |
| 61 | +Warp task configs import from the experimental packages: |
| 62 | + |
| 63 | +.. code-block:: python |
| 64 | +
|
| 65 | + # Warp |
| 66 | + from isaaclab_experimental.managers import ObservationTermCfg, RewardTermCfg, SceneEntityCfg |
| 67 | + import isaaclab_experimental.envs.mdp as mdp |
| 68 | +
|
| 69 | +The term config classes have the same interface — only the import path changes. |
| 70 | + |
| 71 | + |
| 72 | +Common Pattern |
| 73 | +^^^^^^^^^^^^^^ |
| 74 | + |
| 75 | +All warp MDP terms (observations, rewards, terminations, events, actions) follow the same |
| 76 | +**kernel + launch** pattern. Stable terms use torch tensors and return results; warp terms |
| 77 | +write into pre-allocated ``wp.array`` output buffers via ``@wp.kernel`` functions: |
| 78 | + |
| 79 | +.. code-block:: python |
| 80 | +
|
| 81 | + # Stable — returns a tensor |
| 82 | + def lin_vel_z_l2(env, asset_cfg) -> torch.Tensor: |
| 83 | + return torch.square(asset.data.root_lin_vel_b[:, 2]) |
| 84 | +
|
| 85 | + # Warp — writes into pre-allocated output |
| 86 | + @wp.kernel |
| 87 | + def _lin_vel_z_l2_kernel(vel: wp.array(...), out: wp.array(dtype=wp.float32)): |
| 88 | + i = wp.tid() |
| 89 | + out[i] = vel[i][2] * vel[i][2] |
| 90 | +
|
| 91 | + def lin_vel_z_l2(env, out, asset_cfg) -> None: |
| 92 | + wp.launch(_lin_vel_z_l2_kernel, dim=env.num_envs, inputs=[..., out]) |
| 93 | +
|
| 94 | +The output buffer shapes differ by term type: |
| 95 | + |
| 96 | +- **Observations**: ``(num_envs, D)`` where D is the observation dimension |
| 97 | +- **Rewards**: ``(num_envs,)`` |
| 98 | +- **Terminations**: ``(num_envs,)`` with dtype ``bool`` |
| 99 | +- **Events**: ``(num_envs,)`` mask — events don't produce output, they modify sim state |
| 100 | + |
| 101 | + |
| 102 | +Observation Terms |
| 103 | +^^^^^^^^^^^^^^^^^ |
| 104 | + |
| 105 | +Since warp terms write into pre-allocated buffers, the observation manager must know each |
| 106 | +term's output dimension at initialization to allocate the correct ``(num_envs, D)`` output |
| 107 | +array. This is resolved via a fallback chain (see |
| 108 | +``ObservationManager._infer_term_dim_scalar`` in |
| 109 | +``isaaclab_experimental/managers/observation_manager.py``): |
| 110 | + |
| 111 | +1. **Explicit ``out_dim`` in decorator** (preferred): |
| 112 | + |
| 113 | + .. code-block:: python |
| 114 | +
|
| 115 | + @generic_io_descriptor_warp(out_dim=3, observation_type="RootState") |
| 116 | + def base_lin_vel(env, out, asset_cfg) -> None: ... |
| 117 | +
|
| 118 | + ``out_dim`` can be an integer, or a string that resolves at initialization: |
| 119 | + |
| 120 | + - ``"joint"`` — number of selected joints from ``asset_cfg`` |
| 121 | + - ``"body:N"`` — N components per selected body from ``asset_cfg`` |
| 122 | + - ``"command"`` — dimension from command manager |
| 123 | + - ``"action"`` — dimension from action manager |
| 124 | + |
| 125 | +2. **``axes`` metadata**: Dimension equals the number of axes listed: |
| 126 | + |
| 127 | + .. code-block:: python |
| 128 | +
|
| 129 | + @generic_io_descriptor_warp(axes=["X", "Y", "Z"], observation_type="RootState") |
| 130 | + def projected_gravity(env, out, asset_cfg) -> None: ... |
| 131 | + # → dimension = 3 |
| 132 | +
|
| 133 | +3. **Legacy params**: ``term_dim``, ``out_dim``, or ``obs_dim`` keys in ``term_cfg.params``. |
| 134 | + |
| 135 | +4. **Asset config fallback**: Count of ``asset_cfg.joint_ids`` (or ``joint_ids_wp``) for |
| 136 | + joint-level terms. |
| 137 | + |
| 138 | + |
| 139 | +Event Terms |
| 140 | +^^^^^^^^^^^ |
| 141 | + |
| 142 | +Events use ``env_mask`` (boolean ``wp.array``) instead of ``env_ids``, and each kernel |
| 143 | +checks the mask to skip non-selected environments: |
| 144 | + |
| 145 | +.. code-block:: python |
| 146 | +
|
| 147 | + def reset_joints_by_offset(env, env_mask, ...): |
| 148 | + wp.launch(_kernel, dim=env.num_envs, inputs=[env_mask, ...]) |
| 149 | +
|
| 150 | + @wp.kernel |
| 151 | + def _kernel(env_mask: wp.array(dtype=wp.bool), ...): |
| 152 | + i = wp.tid() |
| 153 | + if not env_mask[i]: |
| 154 | + return |
| 155 | + # ... modify state for selected envs only |
| 156 | +
|
| 157 | +- RNG uses per-env ``env.rng_state_wp`` (``wp.uint32``) instead of ``torch.rand`` |
| 158 | +- **Startup/prestartup** events use the stable convention ``(env, env_ids, **params)`` |
| 159 | +- **Reset/interval** events use the warp convention ``(env, env_mask, **params)`` |
| 160 | + |
| 161 | + |
| 162 | +Action Terms |
| 163 | +^^^^^^^^^^^^ |
| 164 | + |
| 165 | +Actions follow a **two-stage execution**: ``process_actions`` (called once per env step) scales |
| 166 | +and clips raw actions, and ``apply_actions`` (called once per sim step) writes targets to the |
| 167 | +asset. Both stages use warp kernels with pre-allocated ``_raw_actions`` and ``_processed_actions`` |
| 168 | +buffers. |
| 169 | + |
| 170 | + |
| 171 | +Capture Safety |
| 172 | +^^^^^^^^^^^^^^ |
| 173 | + |
| 174 | +When writing terms that run inside the captured step loop, keep in mind: |
| 175 | + |
| 176 | +- **No ``wp.to_torch``** or torch arithmetic — stay in warp throughout |
| 177 | +- **No lazy-evaluated properties** — use sim-bound (Tier 1) data directly; if a derived |
| 178 | + quantity is needed, compute it inline in the kernel |
| 179 | +- **No dynamic allocation** — all buffers must be pre-allocated in ``__init__`` |
| 180 | + |
| 181 | + |
| 182 | +Parity Testing |
| 183 | +~~~~~~~~~~~~~~ |
| 184 | + |
| 185 | +Two levels of parity testing are used to validate warp terms: |
| 186 | + |
| 187 | +**1. Implementation parity (stable vs warp)** — verifies that the warp kernel produces the |
| 188 | +same result as the stable torch implementation. This is optional for terms that have no stable |
| 189 | +counterpart (e.g. new terms written directly in warp). |
| 190 | + |
| 191 | +.. code-block:: python |
| 192 | +
|
| 193 | + import isaaclab.envs.mdp.observations as stable_obs |
| 194 | + import isaaclab_experimental.envs.mdp.observations as warp_obs |
| 195 | +
|
| 196 | + # Stable baseline |
| 197 | + expected = stable_obs.joint_pos(stable_env, asset_cfg=cfg) |
| 198 | +
|
| 199 | + # Warp (uncaptured) |
| 200 | + out = wp.zeros((num_envs, num_joints), dtype=wp.float32, device=device) |
| 201 | + warp_obs.joint_pos(warp_env, out, asset_cfg=cfg) |
| 202 | + actual = wp.to_torch(out) |
| 203 | +
|
| 204 | + torch.testing.assert_close(actual, expected) |
| 205 | +
|
| 206 | +**2. Capture parity (warp vs warp-captured)** — verifies that the term produces identical |
| 207 | +results when replayed from a CUDA graph vs launched directly. A mismatch here indicates capture-unsafe |
| 208 | +code (e.g. stale pointers, dynamic allocation, or lazy property access that doesn't replay). |
| 209 | +This test should always be run, even for terms without a stable counterpart. |
| 210 | + |
| 211 | +.. code-block:: python |
| 212 | +
|
| 213 | + # Warp uncaptured |
| 214 | + out_uncaptured = wp.zeros((num_envs, num_joints), dtype=wp.float32, device=device) |
| 215 | + warp_obs.joint_pos(warp_env, out_uncaptured, asset_cfg=cfg) |
| 216 | +
|
| 217 | + # Warp captured (graph replay) |
| 218 | + out_captured = wp.zeros((num_envs, num_joints), dtype=wp.float32, device=device) |
| 219 | + with wp.ScopedCapture() as cap: |
| 220 | + warp_obs.joint_pos(warp_env, out_captured, asset_cfg=cfg) |
| 221 | + wp.capture_launch(cap.graph) |
| 222 | +
|
| 223 | + torch.testing.assert_close(wp.to_torch(out_captured), wp.to_torch(out_uncaptured)) |
| 224 | +
|
| 225 | +See ``source/isaaclab_experimental/test/envs/mdp/`` for complete parity test examples. |
| 226 | + |
| 227 | + |
| 228 | +Available Warp MDP Terms |
| 229 | +~~~~~~~~~~~~~~~~~~~~~~~~ |
| 230 | + |
| 231 | +.. list-table:: |
| 232 | + :header-rows: 1 |
| 233 | + :widths: 20 80 |
| 234 | + |
| 235 | + * - Category |
| 236 | + - Available Terms |
| 237 | + * - Observations (11) |
| 238 | + - | ``base_pos_z`` |
| 239 | + | ``base_lin_vel`` |
| 240 | + | ``base_ang_vel`` |
| 241 | + | ``projected_gravity`` |
| 242 | + | ``joint_pos`` |
| 243 | + | ``joint_pos_rel`` |
| 244 | + | ``joint_pos_limit_normalized`` |
| 245 | + | ``joint_vel`` |
| 246 | + | ``joint_vel_rel`` |
| 247 | + | ``last_action`` |
| 248 | + | ``generated_commands`` |
| 249 | + * - Rewards (16) |
| 250 | + - | ``is_alive`` |
| 251 | + | ``is_terminated`` |
| 252 | + | ``lin_vel_z_l2`` |
| 253 | + | ``ang_vel_xy_l2`` |
| 254 | + | ``flat_orientation_l2`` |
| 255 | + | ``joint_torques_l2`` |
| 256 | + | ``joint_vel_l1`` |
| 257 | + | ``joint_vel_l2`` |
| 258 | + | ``joint_acc_l2`` |
| 259 | + | ``joint_deviation_l1`` |
| 260 | + | ``joint_pos_limits`` |
| 261 | + | ``action_rate_l2`` |
| 262 | + | ``action_l2`` |
| 263 | + | ``undesired_contacts`` |
| 264 | + | ``track_lin_vel_xy_exp`` |
| 265 | + | ``track_ang_vel_z_exp`` |
| 266 | + * - Events (6) |
| 267 | + - | ``reset_joints_by_offset`` |
| 268 | + | ``reset_joints_by_scale`` |
| 269 | + | ``reset_root_state_uniform`` |
| 270 | + | ``push_by_setting_velocity`` |
| 271 | + | ``apply_external_force_torque`` |
| 272 | + | ``randomize_rigid_body_com`` |
| 273 | + * - Terminations (4) |
| 274 | + - | ``time_out`` |
| 275 | + | ``root_height_below_minimum`` |
| 276 | + | ``joint_pos_out_of_manual_limit`` |
| 277 | + | ``illegal_contact`` |
| 278 | + * - Actions (2) |
| 279 | + - | ``JointPositionAction`` |
| 280 | + | ``JointEffortAction`` |
| 281 | +
|
| 282 | +Terms not listed here remain in stable only. When using an env that requires unlisted terms, |
| 283 | +those terms must be implemented in warp first. |
0 commit comments