Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mujoco_warp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from mujoco_warp._src.forward import step1 as step1
from mujoco_warp._src.forward import step2 as step2
from mujoco_warp._src.grad import SMOOTH_GRAD_FIELDS as SMOOTH_GRAD_FIELDS
from mujoco_warp._src.grad import SOLVER_GRAD_FIELDS as SOLVER_GRAD_FIELDS
from mujoco_warp._src.grad import diff_forward as diff_forward
from mujoco_warp._src.grad import diff_step as diff_step
from mujoco_warp._src.grad import disable_grad as disable_grad
Expand Down
188 changes: 185 additions & 3 deletions mujoco_warp/_src/adjoint.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
"""custom adjoint definitions for MuJoCo Warp autodifferentiation.

This module centralizes all ``@wp.func_grad`` registrations. It must be
imported before any tape recording so that custom gradients are registered
with Warp's AD system.
This module centralizes all ``@wp.func_grad`` registrations and the
implicit differentiation adjoint for the constraint solver.

Import this module via ``grad.py`` dont import it directly
"""

import warp as wp

from mujoco_warp._src import math
from mujoco_warp._src import support
from mujoco_warp._src import types
from mujoco_warp._src.block_cholesky import create_blocked_cholesky_func
from mujoco_warp._src.block_cholesky import create_blocked_cholesky_solve_func
from mujoco_warp._src.warp_util import cache_kernel


@wp.func_grad(math.quat_integrate)
Expand Down Expand Up @@ -105,3 +109,181 @@ def _quat_integrate_grad(q: wp.quat, v: wp.vec3, dt: float, adj_ret: wp.quat):
wp.adjoint[q] += adj_q_val
wp.adjoint[v] += adj_v_val
wp.adjoint[dt] += adj_dt_val


# ---------------------------------------------------------------------------
# Solver implicit differentiation adjoint
# ---------------------------------------------------------------------------

_BLOCK_CHOLESKY_DIM = 32


@wp.kernel
def _copy_grad_kernel(
# In:
src: wp.array2d(dtype=float),
# Out:
dst: wp.array2d(dtype=float),
):
worldid, dofid = wp.tid()
dst[worldid, dofid] = src[worldid, dofid]


@cache_kernel
def _adjoint_cholesky_tile(nv: int):
@wp.kernel(module="unique", enable_backward=False)
def kernel(
# In:
H: wp.array3d(dtype=float),
b: wp.array2d(dtype=float),
# Out:
out: wp.array2d(dtype=float),
):
worldid = wp.tid()
TILE_SIZE = wp.static(nv)
H_tile = wp.tile_load(H[worldid], shape=(TILE_SIZE, TILE_SIZE))
b_tile = wp.tile_load(b[worldid], shape=(TILE_SIZE,))
L = wp.tile_cholesky(H_tile)
x = wp.tile_cholesky_solve(L, b_tile)
wp.tile_store(out[worldid], x)

return kernel


@cache_kernel
def _adjoint_cholesky_blocked(tile_size: int, matrix_size: int):
@wp.kernel(module="unique", enable_backward=False)
def kernel(
# In:
hfactor: wp.array3d(dtype=float),
b: wp.array3d(dtype=float),
nv_runtime: int,
# Out:
out: wp.array3d(dtype=float),
):
worldid = wp.tid()
wp.static(create_blocked_cholesky_solve_func(tile_size, matrix_size))(
hfactor[worldid], b[worldid], nv_runtime, out[worldid]
)

return kernel


@cache_kernel
def _adjoint_cholesky_full_blocked(tile_size: int, matrix_size: int):
@wp.kernel(module="unique", enable_backward=False)
def kernel(
# In:
H: wp.array3d(dtype=float),
b: wp.array3d(dtype=float),
nv_runtime: int,
hfactor_tmp: wp.array3d(dtype=float),
# Out:
out: wp.array3d(dtype=float),
):
worldid = wp.tid()
wp.static(create_blocked_cholesky_func(tile_size))(
H[worldid], nv_runtime, hfactor_tmp[worldid]
)
wp.static(create_blocked_cholesky_solve_func(tile_size, matrix_size))(
hfactor_tmp[worldid], b[worldid], nv_runtime, out[worldid]
)

return kernel


@wp.kernel
def _padding_h_adjoint(
nv: int,
H_out: wp.array3d(dtype=float),
):
worldid, elementid = wp.tid()
dofid = nv + elementid
H_out[worldid, dofid, dofid] = 1.0


def _solve_hessian_system(m: types.Model, d: types.Data, b, out):
"""Solve H * x = b using stored solver Hessian."""
if m.nv <= _BLOCK_CHOLESKY_DIM:
wp.launch_tiled(
_adjoint_cholesky_tile(m.nv),
dim=d.nworld,
inputs=[d.solver_h, b],
outputs=[out],
block_dim=m.block_dim.update_gradient_cholesky,
)
else:
b_3d = b.reshape((d.nworld, m.nv_pad, 1))
out_3d = out.reshape((d.nworld, m.nv_pad, 1))

if d.solver_hfactor.shape[1] > 0:
# Solve-only using stored Cholesky factor
wp.launch_tiled(
_adjoint_cholesky_blocked(
types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad
),
dim=d.nworld,
inputs=[d.solver_hfactor, b_3d, m.nv],
outputs=[out_3d],
block_dim=m.block_dim.update_gradient_cholesky_blocked,
)
else:
# Full factorize + solve (no stored factor)
# Pad diagonal for stability
if m.nv_pad > m.nv:
wp.launch(
_padding_h_adjoint,
dim=(d.nworld, m.nv_pad - m.nv),
inputs=[m.nv],
outputs=[d.solver_h],
)
hfactor_tmp = wp.zeros(
(d.nworld, m.nv_pad, m.nv_pad), dtype=float
)
wp.launch_tiled(
_adjoint_cholesky_full_blocked(
types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad
),
dim=d.nworld,
inputs=[d.solver_h, b_3d, m.nv, hfactor_tmp],
outputs=[out_3d],
block_dim=m.block_dim.update_gradient_cholesky_blocked,
)


def solver_implicit_adjoint(m: types.Model, d: types.Data):
"""Implicit differentiation adjoint for constraint solver.

Called during tape backward. Reads d.qacc.grad (set by downstream),
solves H*v = adj_qacc, writes d.qacc_smooth.grad = M*v.
"""
nv = m.nv
if nv == 0:
return

if d.njmax == 0:
# Solver was identity (qacc = qacc_smooth), copy adjoint through
wp.launch(
_copy_grad_kernel,
dim=(d.nworld, nv),
inputs=[d.qacc.grad],
outputs=[d.qacc_smooth.grad],
)
return

if m.opt.solver != types.SolverType.NEWTON:
# CG solver: no Hessian stored, fall back to identity
wp.launch(
_copy_grad_kernel,
dim=(d.nworld, nv),
inputs=[d.qacc.grad],
outputs=[d.qacc_smooth.grad],
)
return

# Solve H * v = adj_qacc
v = wp.zeros((d.nworld, m.nv_pad), dtype=float)
_solve_hessian_system(m, d, d.qacc.grad, v)

# adj_qacc_smooth = M * v
support.mul_m(m, d, d.qacc_smooth.grad, v)
40 changes: 34 additions & 6 deletions mujoco_warp/_src/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,11 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
# TODO(team): can we assume static timesteps?

# Clone arrays used as both input and output so that Warp's tape retains the
# original values for correct reverse-mode AD.
act_in = wp.clone(d.act)
qvel_prev = wp.clone(d.qvel)
qpos_prev = wp.clone(d.qpos)
# original values for correct reverse-mode AD. Guard with requires_grad so
# non-AD paths pay zero overhead.
act_in = wp.clone(d.act) if d.act.requires_grad else d.act
qvel_prev = wp.clone(d.qvel) if d.qvel.requires_grad else d.qvel
qpos_prev = wp.clone(d.qpos) if d.qpos.requires_grad else d.qpos

# advance activations
wp.launch(
Expand Down Expand Up @@ -947,8 +948,13 @@ def fwd_actuation(m: Model, d: Data):
],
outputs=[d.qfrc_actuator],
)
# clone to break input/output aliasing for correct AD
qfrc_actuator_in = wp.clone(d.qfrc_actuator)
# clone to break input/output aliasing for correct AD; skip when not
# recording a backward tape to avoid unnecessary allocation + copy.
qfrc_actuator_in = (
wp.clone(d.qfrc_actuator)
if d.qfrc_actuator.requires_grad
else d.qfrc_actuator
)
wp.launch(
_qfrc_actuator_gravcomp_limits,
dim=(d.nworld, m.nv),
Expand Down Expand Up @@ -1035,6 +1041,17 @@ def forward(m: Model, d: Data):
fwd_acceleration(m, d, factorize=True)

solver.solve(m, d)

# Record implicit differentiation adjoint on the active tape
tape = wp._src.context.runtime.tape
if tape is not None and d.qpos.requires_grad:
from mujoco_warp._src.adjoint import solver_implicit_adjoint

tape.record_func(
lambda m=m, d=d: solver_implicit_adjoint(m, d),
[d.qacc, d.qacc_smooth],
)

sensor.sensor_acc(m, d)


Expand Down Expand Up @@ -1090,6 +1107,17 @@ def step2(m: Model, d: Data):
fwd_actuation(m, d)
fwd_acceleration(m, d)
solver.solve(m, d)

# Record implicit differentiation adjoint on the active tape
tape = wp._src.context.runtime.tape
if tape is not None and d.qpos.requires_grad:
from mujoco_warp._src.adjoint import solver_implicit_adjoint

tape.record_func(
lambda m=m, d=d: solver_implicit_adjoint(m, d),
[d.qacc, d.qacc_smooth],
)

sensor.sensor_acc(m, d)
# TODO(team): mj_checkAcc

Expand Down
18 changes: 18 additions & 0 deletions mujoco_warp/_src/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
grad_ctrl = d.ctrl.grad
"""

import warnings
from typing import Callable, Optional, Sequence

import warp as wp
Expand All @@ -26,6 +27,7 @@
from mujoco_warp._src.forward import step
from mujoco_warp._src.types import Data
from mujoco_warp._src.types import Model
from mujoco_warp._src.types import SolverType

SMOOTH_GRAD_FIELDS: tuple = (
# primary state, user-controlled inputs
Expand Down Expand Up @@ -91,6 +93,10 @@
"sensordata",
)

SOLVER_GRAD_FIELDS: tuple = (
"qfrc_constraint",
)


def enable_grad(d: Data, fields: Optional[Sequence[str]] = None) -> None:
"""Enables gradient tracking on Data arrays."""
Expand Down Expand Up @@ -122,12 +128,23 @@ def make_diff_data(
return d


def _warn_if_cg_solver(m: Model, d: Data):
"""Warn if CG solver is used with constraints (gradients will be zero)."""
if d.njmax > 0 and m.opt.solver != SolverType.NEWTON:
warnings.warn(
"Differentiable solver requires Newton. CG solver "
"gradients through constraints will be zero.",
stacklevel=3,
)


def diff_step(
m: Model,
d: Data,
loss_fn: Callable[[Model, Data], wp.array],
) -> wp.Tape:
"""Runs a differentiable physics step."""
_warn_if_cg_solver(m, d)
tape = wp.Tape()
with tape:
step(m, d)
Expand All @@ -142,6 +159,7 @@ def diff_forward(
loss_fn: Callable[[Model, Data], wp.array],
) -> wp.Tape:
"""Runs differentiable forward dynamics (no integration)."""
_warn_if_cg_solver(m, d)
tape = wp.Tape()
with tape:
forward(m, d)
Expand Down
Loading