diff --git a/contrib/render.py b/contrib/render.py index a70aa4e0b..fa0cb6af7 100644 --- a/contrib/render.py +++ b/contrib/render.py @@ -18,7 +18,7 @@ Usage: mjwarp-render [flags] Example: - mjwarp-render benchmark/humanoid/humanoid.xml --nworld=1 --cam=0 --width=512 --height=512 + mjwarp-render benchmarks/humanoid/humanoid.xml --nworld=1 --cam=0 --width=512 --height=512 """ import sys @@ -42,6 +42,7 @@ _HEIGHT = flags.DEFINE_integer("height", 512, "render height (pixels)") _RENDER_RGB = flags.DEFINE_bool("rgb", True, "render RGB image") _RENDER_DEPTH = flags.DEFINE_bool("depth", True, "render depth image") +_RENDER_SEG = flags.DEFINE_bool("seg", False, "render segmentation image") _USE_TEXTURES = flags.DEFINE_bool("textures", True, "use textures") _USE_SHADOWS = flags.DEFINE_bool("shadows", False, "use shadows") _DEVICE = flags.DEFINE_string("device", None, "override the default Warp device") @@ -207,6 +208,7 @@ def _main(argv: Sequence[str]): (render_width, render_height), _RENDER_RGB.value, _RENDER_DEPTH.value, + _RENDER_SEG.value, _USE_TEXTURES.value, _USE_SHADOWS.value, enabled_geom_groups=[0, 1, 2], diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py index 6de859326..a9756e5e7 100644 --- a/mujoco_warp/__init__.py +++ b/mujoco_warp/__init__.py @@ -46,12 +46,15 @@ from mujoco_warp._src.forward import rungekutta4 as rungekutta4 from mujoco_warp._src.forward import step1 as step1 from mujoco_warp._src.forward import step2 as step2 +from mujoco_warp._src.grad import COLLISION_GRAD_FIELDS as COLLISION_GRAD_FIELDS from mujoco_warp._src.grad import SMOOTH_GRAD_FIELDS as SMOOTH_GRAD_FIELDS from mujoco_warp._src.grad import SOLVER_GRAD_FIELDS as SOLVER_GRAD_FIELDS from mujoco_warp._src.grad import diff_forward as diff_forward from mujoco_warp._src.grad import diff_step as diff_step from mujoco_warp._src.grad import disable_grad as disable_grad from mujoco_warp._src.grad import enable_grad as enable_grad +from mujoco_warp._src.grad import enable_smooth_adjoint as enable_smooth_adjoint +from mujoco_warp._src.grad import disable_smooth_adjoint as disable_smooth_adjoint from mujoco_warp._src.grad import make_diff_data as make_diff_data from mujoco_warp._src.inverse import inverse as inverse from mujoco_warp._src.io import create_render_context as create_render_context diff --git a/mujoco_warp/_src/adjoint.py b/mujoco_warp/_src/adjoint.py index 10d79bc2c..341a1320e 100644 --- a/mujoco_warp/_src/adjoint.py +++ b/mujoco_warp/_src/adjoint.py @@ -1,11 +1,14 @@ """custom adjoint definitions for MuJoCo Warp autodifferentiation. -This module centralizes all ``@wp.func_grad`` registrations and the -implicit differentiation adjoint for the constraint solver. +This module centralizes all ``@wp.func_grad`` registrations, the +implicit differentiation adjoint for the constraint solver, and the +smooth constraint adjoint for friction gradient signal. Import this module via ``grad.py`` dont import it directly """ +import os + import warp as wp from mujoco_warp._src import math @@ -13,8 +16,409 @@ from mujoco_warp._src import types from mujoco_warp._src.block_cholesky import create_blocked_cholesky_func from mujoco_warp._src.block_cholesky import create_blocked_cholesky_solve_func +from mujoco_warp._src.collision_smooth import compute_k_imp from mujoco_warp._src.warp_util import cache_kernel +# --------------------------------------------------------------------------- +# Phase 3: efc-level gradient kernels for collision chain +# --------------------------------------------------------------------------- + + +@wp.kernel +def _efc_J_grad_kernel( + # Model: + nv: int, + # Data in: + nefc_in: wp.array(dtype=int), + efc_force_in: wp.array2d(dtype=float), + njmax_in: int, + # In: + v_in: wp.array2d(dtype=float), + # Out: + efc_J_grad_out: wp.array3d(dtype=float), +): + """Compute adj_efc_J[i, j] = v[j] * efc_force[i]. + + From KKT: F(qacc) = M*qacc - qfrc_smooth - J^T*f = 0 + The derivative of J^T*f w.r.t. J[i,j] is f[i] * delta, and the + adjoint vector v gives the sensitivity: adj_J[i,j] = v[j] * f[i]. + """ + worldid, efcid, dofid = wp.tid() + if efcid < nefc_in[worldid] and dofid < nv: + efc_J_grad_out[worldid, efcid, dofid] = v_in[worldid, dofid] * efc_force_in[worldid, efcid] + + +@wp.kernel +def _efc_pos_grad_kernel( + # Model: + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + # Data in: + contact_dist_in: wp.array(dtype=float), + contact_includemargin_in: wp.array(dtype=float), + contact_solref_in: wp.array(dtype=wp.vec2), + contact_solimp_in: wp.array(dtype=types.vec5), + contact_efc_address_in: wp.array2d(dtype=int), + contact_worldid_in: wp.array(dtype=int), + contact_type_in: wp.array(dtype=int), + nacon_in: wp.array(dtype=int), + # In: + efc_aref_grad_in: wp.array2d(dtype=float), + # Out: + efc_pos_grad_out: wp.array2d(dtype=float), +): + """Compute adj_efc_pos from adj_efc_aref. + + From efc_aref = -k * imp * pos - b * vel, d(aref)/d(pos) = -k*imp. + So adj_efc_pos = adj_efc_aref * (-k * imp). + We iterate over contacts and their first dimension (normal direction). + """ + conid = wp.tid() + if conid >= nacon_in[0]: + return + if not (contact_type_in[conid] & 1): # ContactType.CONSTRAINT + return + + efcid = contact_efc_address_in[conid, 0] + if efcid < 0: + return + + worldid = contact_worldid_in[conid] + timestep = opt_timestep[worldid % opt_timestep.shape[0]] + + solref = contact_solref_in[conid] + solimp = contact_solimp_in[conid] + includemargin = contact_includemargin_in[conid] + pos_val = contact_dist_in[conid] - includemargin + + k_imp = compute_k_imp(opt_disableflags, solref, solimp, pos_val, timestep) + + # d(aref)/d(pos) = -k * imp + daref_dpos = -k_imp[0] * k_imp[1] + + adj_aref = efc_aref_grad_in[worldid, efcid] + efc_pos_grad_out[worldid, efcid] = adj_aref * daref_dpos + + +# --------------------------------------------------------------------------- +# Smooth constraint adjoint: friction Hessian correction kernel +# --------------------------------------------------------------------------- + + +@wp.kernel +def _smooth_hessian_friction_correction( + # Model: + nv: int, + # Contact data: + contact_efc_address_in: wp.array2d(dtype=int), + contact_dim_in: wp.array(dtype=int), + contact_type_in: wp.array(dtype=int), + contact_worldid_in: wp.array(dtype=int), + nacon_in: wp.array(dtype=int), + # Constraint data: + efc_J_in: wp.array3d(dtype=float), + efc_D_in: wp.array2d(dtype=float), + efc_state_in: wp.array2d(dtype=int), + # Parameters: + friction_viscosity: float, + friction_scale: float, + # Out: + H_out: wp.array3d(dtype=float), +): + """Apply friction smoothing correction to the Hessian. + + For each friction constraint row (dimid > 0): + - QUADRATIC (active): delta_D = D * (friction_scale - 1.0) [reduces stiffness] + - Otherwise (SATISFIED etc): delta_D = friction_viscosity [adds viscous term] + + Applies delta_D * J_row^T * J_row to H via atomic_add. + """ + conid, dimid = wp.tid() + + if conid >= nacon_in[0]: + return + + # Only process constraint contacts + if not (contact_type_in[conid] & 1): # ContactType.CONSTRAINT = 1 + return + + # Skip normal direction (dimid=0) — only modify friction rows + if dimid == 0: + return + + condim = contact_dim_in[conid] + if condim == 1: + return # frictionless contact, no friction rows + if dimid >= 2 * (condim - 1): + return # beyond valid friction dimensions + + efcid = contact_efc_address_in[conid, dimid] + if efcid < 0: + return + + worldid = contact_worldid_in[conid] + + D = efc_D_in[worldid, efcid] + state = efc_state_in[worldid, efcid] + + # Compute delta_D: difference between smooth D and what's currently in H + # QUADRATIC state (value=1): constraint was active, D is in H → reduce it + # SATISFIED state (value=0): constraint was inactive, 0 in H → add viscous + delta_D = float(0.0) + if state == 1: # QUADRATIC + delta_D = D * (friction_scale - 1.0) + else: + delta_D = friction_viscosity + + if delta_D == 0.0: + return + + # Apply delta_D * J_row^T * J_row to H + for i in range(nv): + Ji = efc_J_in[worldid, efcid, i] + if Ji == 0.0: + continue + for j in range(nv): + Jj = efc_J_in[worldid, efcid, j] + if Jj == 0.0: + continue + wp.atomic_add(H_out, worldid, i, j, delta_D * Ji * Jj) + + +# --------------------------------------------------------------------------- +# Smooth constraint adjoint: friction gradient bypass kernel +# --------------------------------------------------------------------------- + + +@wp.kernel +def _friction_bypass_correction( + # Model: + nv: int, + # Contact data: + contact_efc_address_in: wp.array2d(dtype=int), + contact_dim_in: wp.array(dtype=int), + contact_type_in: wp.array(dtype=int), + contact_worldid_in: wp.array(dtype=int), + nacon_in: wp.array(dtype=int), + # Constraint data: + efc_J_in: wp.array3d(dtype=float), + # Solve results: + v_hessian_in: wp.array2d(dtype=float), + v_free_in: wp.array2d(dtype=float), + # Parameters: + bypass_kf: float, + # Out: + v_out: wp.array2d(dtype=float), +): + """Friction gradient bypass: restore tangential gradients attenuated by H^{-1}. + + For each friction constraint face (dimid > 0), computes: + delta = J_fric . (v_free - v_hessian) [gradient lost to friction attenuation] + v_out += kf * J_fric^T * delta [inject it back, scaled by kf] + + v_hessian = H^{-1} * adj_qacc (attenuated in friction directions) + v_free = M^{-1} * adj_qacc (what gradient would be without constraints) + + This makes the backward pass produce dflex-like friction gradients while + keeping the forward physics unchanged. + """ + conid, dimid = wp.tid() + + if conid >= nacon_in[0]: + return + + # Only process constraint contacts + if not (contact_type_in[conid] & 1): # ContactType.CONSTRAINT = 1 + return + + # Skip normal direction (dimid=0) — only bypass friction rows + if dimid == 0: + return + + condim = contact_dim_in[conid] + if condim == 1: + return # frictionless contact, no friction rows + if dimid >= 2 * (condim - 1): + return # beyond valid friction dimensions + + efcid = contact_efc_address_in[conid, dimid] + if efcid < 0: + return + + worldid = contact_worldid_in[conid] + + # Compute delta = J_fric . (v_free - v_hessian) for this friction face + delta = float(0.0) + for dofid in range(nv): + J_val = efc_J_in[worldid, efcid, dofid] + if J_val != 0.0: + delta += J_val * (v_free_in[worldid, dofid] - v_hessian_in[worldid, dofid]) + + # Apply correction: v_out += kf * J_fric^T * delta + if delta != 0.0: + scaled_delta = bypass_kf * delta + for dofid in range(nv): + J_val = efc_J_in[worldid, efcid, dofid] + if J_val != 0.0: + wp.atomic_add(v_out, worldid, dofid, scaled_delta * J_val) + + +@wp.kernel +def _friction_bypass_correction_normalized( + # Model: + nv: int, + # Contact data: + contact_efc_address_in: wp.array2d(dtype=int), + contact_dim_in: wp.array(dtype=int), + contact_type_in: wp.array(dtype=int), + contact_worldid_in: wp.array(dtype=int), + nacon_in: wp.array(dtype=int), + # Constraint data: + efc_J_in: wp.array3d(dtype=float), + # Solve results: + v_hessian_in: wp.array2d(dtype=float), + v_free_in: wp.array2d(dtype=float), + # Parameters: + bypass_kf: float, + max_ratio: float, + norm_eps: float, + # Out: + v_out: wp.array2d(dtype=float), +): + """Normalized and capped friction bypass correction. + + Projects the free-body delta onto each friction row and injects only a + bounded fraction of that projected component. + + Compared to _friction_bypass_correction this avoids scaling by ||J_row||^2 + and prevents over-injection when contact rows become poorly conditioned. + """ + conid, dimid = wp.tid() + + if conid >= nacon_in[0]: + return + + # Only process constraint contacts + if not (contact_type_in[conid] & 1): # ContactType.CONSTRAINT = 1 + return + + # Skip normal direction (dimid=0) - only bypass friction rows + if dimid == 0: + return + + condim = contact_dim_in[conid] + if condim == 1: + return + if dimid >= 2 * (condim - 1): + return + + efcid = contact_efc_address_in[conid, dimid] + if efcid < 0: + return + + worldid = contact_worldid_in[conid] + + delta = float(0.0) + j_norm2 = float(0.0) + for dofid in range(nv): + J_val = efc_J_in[worldid, efcid, dofid] + if J_val != 0.0: + delta += J_val * (v_free_in[worldid, dofid] - v_hessian_in[worldid, dofid]) + j_norm2 += J_val * J_val + + if j_norm2 <= norm_eps: + return + + # Row-normalized projection coefficient. + base_coeff = delta / j_norm2 + coeff = bypass_kf * base_coeff + + # Bound injected magnitude relative to the projected free-body component. + max_coeff = wp.abs(base_coeff) * max_ratio + abs_coeff = wp.abs(coeff) + if abs_coeff > max_coeff and abs_coeff > 0.0: + coeff = coeff * (max_coeff / abs_coeff) + + if coeff == 0.0: + return + + for dofid in range(nv): + J_val = efc_J_in[worldid, efcid, dofid] + if J_val != 0.0: + wp.atomic_add(v_out, worldid, dofid, coeff * J_val) + + +# Penalty-model adjoint: friction damping kernel +# --------------------------------------------------------------------------- + + +@wp.kernel +def _penalty_friction_damping( + # Model: + nv: int, + # Contact data: + contact_efc_address_in: wp.array2d(dtype=int), + contact_dim_in: wp.array(dtype=int), + contact_type_in: wp.array(dtype=int), + contact_worldid_in: wp.array(dtype=int), + nacon_in: wp.array(dtype=int), + # Constraint data: + efc_J_in: wp.array3d(dtype=float), + # Input: + v_free_in: wp.array2d(dtype=float), + # Parameters: + damping_alpha: float, + # Out: + v_out: wp.array2d(dtype=float), +): + """Apply penalty-model friction damping to the free-body adjoint. + + For each friction face: v_out -= alpha * J_fric^T * (J_fric . v_free) + + This attenuates v in friction directions by factor (1 - alpha), mimicking + dflex's penalty friction gradient where d(v_next)/d(v_prev) has eigenvalues + < 1 in friction-constrained directions. Provides natural BPTT decay that + prevents gradient explosion while preserving gradient direction. + """ + conid, dimid = wp.tid() + + if conid >= nacon_in[0]: + return + + if not (contact_type_in[conid] & 1): + return + + # Friction rows only (dimid > 0) + if dimid == 0: + return + + condim = contact_dim_in[conid] + if condim == 1: + return + if dimid >= 2 * (condim - 1): + return + + efcid = contact_efc_address_in[conid, dimid] + if efcid < 0: + return + + worldid = contact_worldid_in[conid] + + # Project v_free onto this friction face + proj = float(0.0) + for dofid in range(nv): + J_val = efc_J_in[worldid, efcid, dofid] + if J_val != 0.0: + proj += J_val * v_free_in[worldid, dofid] + + # Subtract friction damping: v_out -= alpha * J^T * proj + if proj != 0.0: + scaled = damping_alpha * proj + for dofid in range(nv): + J_val = efc_J_in[worldid, efcid, dofid] + if J_val != 0.0: + wp.atomic_add(v_out, worldid, dofid, -scaled * J_val) + @wp.func_grad(math.quat_integrate) def _quat_integrate_grad(q: wp.quat, v: wp.vec3, dt: float, adj_ret: wp.quat): @@ -123,10 +527,21 @@ def _copy_grad_kernel( # In: src: wp.array2d(dtype=float), # Out: - dst: wp.array2d(dtype=float), + dst_out: wp.array2d(dtype=float), ): worldid, dofid = wp.tid() - dst[worldid, dofid] = src[worldid, dofid] + dst_out[worldid, dofid] = src[worldid, dofid] + + +@wp.kernel +def _accumulate_grad_kernel( + # In: + src: wp.array2d(dtype=float), + # Out: + dst_out: wp.array2d(dtype=float), +): + worldid, dofid = wp.tid() + dst_out[worldid, dofid] = dst_out[worldid, dofid] + src[worldid, dofid] @cache_kernel @@ -182,9 +597,7 @@ def kernel( out: wp.array3d(dtype=float), ): worldid = wp.tid() - wp.static(create_blocked_cholesky_func(tile_size))( - H[worldid], nv_runtime, hfactor_tmp[worldid] - ) + wp.static(create_blocked_cholesky_func(tile_size))(H[worldid], nv_runtime, hfactor_tmp[worldid]) wp.static(create_blocked_cholesky_solve_func(tile_size, matrix_size))( hfactor_tmp[worldid], b[worldid], nv_runtime, out[worldid] ) @@ -194,7 +607,9 @@ def kernel( @wp.kernel def _padding_h_adjoint( + # Model: nv: int, + # Out: H_out: wp.array3d(dtype=float), ): worldid, elementid = wp.tid() @@ -202,13 +617,26 @@ def _padding_h_adjoint( H_out[worldid, dofid, dofid] = 1.0 -def _solve_hessian_system(m: types.Model, d: types.Data, b, out): - """Solve H * x = b using stored solver Hessian.""" +def _solve_hessian_system(m: types.Model, d: types.Data, b, out, H=None): + """Solve H * x = b using stored solver Hessian or a provided H. + + Args: + m: Model. + d: Data. + b: Right-hand side vector (nworld, nv_pad). + out: Solution vector (nworld, nv_pad). + H: Optional Hessian override. When provided, always factorizes from + scratch (ignores stored d.solver_hfactor). Used by smooth adjoint. + """ + use_stored = H is None + if use_stored: + H = d.solver_h + if m.nv <= _BLOCK_CHOLESKY_DIM: wp.launch_tiled( _adjoint_cholesky_tile(m.nv), dim=d.nworld, - inputs=[d.solver_h, b], + inputs=[H, b], outputs=[out], block_dim=m.block_dim.update_gradient_cholesky, ) @@ -216,74 +644,427 @@ def _solve_hessian_system(m: types.Model, d: types.Data, b, out): b_3d = b.reshape((d.nworld, m.nv_pad, 1)) out_3d = out.reshape((d.nworld, m.nv_pad, 1)) - if d.solver_hfactor.shape[1] > 0: - # Solve-only using stored Cholesky factor + if use_stored and d.solver_hfactor.shape[1] > 0: + # Solve-only using stored Cholesky factor (original H only) wp.launch_tiled( - _adjoint_cholesky_blocked( - types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad - ), + _adjoint_cholesky_blocked(types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad), dim=d.nworld, inputs=[d.solver_hfactor, b_3d, m.nv], outputs=[out_3d], block_dim=m.block_dim.update_gradient_cholesky_blocked, ) else: - # Full factorize + solve (no stored factor) - # Pad diagonal for stability + # Full factorize + solve if m.nv_pad > m.nv: wp.launch( _padding_h_adjoint, dim=(d.nworld, m.nv_pad - m.nv), inputs=[m.nv], - outputs=[d.solver_h], + outputs=[H], ) - hfactor_tmp = wp.zeros( - (d.nworld, m.nv_pad, m.nv_pad), dtype=float - ) + hfactor_tmp = wp.zeros((d.nworld, m.nv_pad, m.nv_pad), dtype=float) wp.launch_tiled( - _adjoint_cholesky_full_blocked( - types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad - ), + _adjoint_cholesky_full_blocked(types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad), dim=d.nworld, - inputs=[d.solver_h, b_3d, m.nv, hfactor_tmp], + inputs=[H, b_3d, m.nv, hfactor_tmp], outputs=[out_3d], block_dim=m.block_dim.update_gradient_cholesky_blocked, ) -def solver_implicit_adjoint(m: types.Model, d: types.Data): +def solver_implicit_adjoint(m: types.Model, d: types.Data, qacc_array=None, qacc_smooth_ref=None): """Implicit differentiation adjoint for constraint solver. - Called during tape backward. Reads d.qacc.grad (set by downstream), - solves H*v = adj_qacc, writes d.qacc_smooth.grad = M*v. + Called during tape backward. Reads qacc_array.grad (set by downstream + integrator adjoint), solves H*v = adj_qacc, accumulates into + qacc_smooth_ref.grad += M*v. + + Args: + m: Model containing static simulation parameters. + d: Data containing mutable simulation state. + qacc_array: The array whose .grad contains the incoming adjoint. + Defaults to d.qacc when called from diff_forward(). + Integrators pass their local qacc array when it differs + from d.qacc (e.g. euler with implicit damping). + qacc_smooth_ref: The qacc_smooth array whose .grad receives the + accumulated adjoint. Captured at record time for + correct gradient isolation when intermediate arrays + are cloned between substeps. Defaults to d.qacc_smooth. """ nv = m.nv if nv == 0: return + if qacc_array is None: + qacc_array = d.qacc + + if qacc_smooth_ref is None: + qacc_smooth_ref = d.qacc_smooth + + adj_qacc = qacc_array.grad + if adj_qacc is None: + return + + debug_level = os.environ.get("MJW_DEBUG_ADJOINT", "0") + if debug_level in ("1", "2"): + import numpy as np + + adj_norm = np.linalg.norm(adj_qacc.numpy()) + print(f"[adjoint] |adj_qacc|={adj_norm:.6e}, njmax={d.njmax}") + + if debug_level == "2" and d.njmax > 0: + import numpy as np + + efc_state_np = d.efc.state.numpy() + nefc_np = d.nefc.numpy() + for w in range(min(d.nworld, 1)): + ne = nefc_np[w] + n_quad = int(np.sum(efc_state_np[w, :ne] == 1)) + n_sat = int(np.sum(efc_state_np[w, :ne] == 0)) + H_np = d.solver_h.numpy()[w, :nv, :nv] + H_diag = np.diag(H_np) + cond_approx = np.max(H_diag) / max(np.min(H_diag[H_diag > 0]), 1e-30) + print( + f"[adjoint:diag] world={w} nefc={ne} QUAD={n_quad} SAT={n_sat}" + f" H_cond~{cond_approx:.1f}" + f" H_diag=[{np.min(H_diag):.3e}, {np.max(H_diag):.3e}]" + ) + if d.njmax == 0: - # Solver was identity (qacc = qacc_smooth), copy adjoint through + # Solver was identity (qacc = qacc_smooth), accumulate adjoint through wp.launch( - _copy_grad_kernel, + _accumulate_grad_kernel, dim=(d.nworld, nv), - inputs=[d.qacc.grad], - outputs=[d.qacc_smooth.grad], + inputs=[adj_qacc], + outputs=[qacc_smooth_ref.grad], ) return if m.opt.solver != types.SolverType.NEWTON: # CG solver: no Hessian stored, fall back to identity wp.launch( - _copy_grad_kernel, + _accumulate_grad_kernel, dim=(d.nworld, nv), - inputs=[d.qacc.grad], - outputs=[d.qacc_smooth.grad], + inputs=[adj_qacc], + outputs=[qacc_smooth_ref.grad], ) return # Solve H * v = adj_qacc v = wp.zeros((d.nworld, m.nv_pad), dtype=float) - _solve_hessian_system(m, d, d.qacc.grad, v) + _solve_hessian_system(m, d, adj_qacc, v) + + # adj_qacc_smooth += M * v (accumulate, not overwrite) + tmp = wp.zeros((d.nworld, m.nv_pad), dtype=float) + support.mul_m(m, d, tmp, v) + wp.launch( + _accumulate_grad_kernel, + dim=(d.nworld, nv), + inputs=[tmp], + outputs=[qacc_smooth_ref.grad], + ) + + # Phase 3: compute efc-level gradients for collision chain + _efc_level_gradients(m, d, v) + + +def _efc_level_gradients(m: types.Model, d: types.Data, v): + """Compute efc-level gradients for collision chain (shared by both adjoints).""" + if d.njmax > 0: + efc_J = d.efc.J + if hasattr(efc_J, "grad") and efc_J.grad is not None: + wp.launch( + _efc_J_grad_kernel, + dim=(d.nworld, d.njmax_pad, m.nv_pad), + inputs=[m.nv, d.nefc, d.efc.force, d.njmax, v], + outputs=[efc_J.grad], + ) + + efc_aref = d.efc.aref + efc_pos = d.efc.pos + if hasattr(efc_aref, "grad") and efc_aref.grad is not None and hasattr(efc_pos, "grad") and efc_pos.grad is not None: + wp.launch( + _efc_pos_grad_kernel, + dim=d.naconmax, + inputs=[ + m.opt.timestep, + m.opt.disableflags, + d.contact.dist, + d.contact.includemargin, + d.contact.solref, + d.contact.solimp, + d.contact.efc_address, + d.contact.worldid, + d.contact.type, + d.nacon, + efc_aref.grad, + ], + outputs=[efc_pos.grad], + ) + + +# --------------------------------------------------------------------------- +# Smooth constraint adjoint: backward-only friction gradient smoothing +# --------------------------------------------------------------------------- + + +def solver_smooth_adjoint( + m: types.Model, + d: types.Data, + qacc_array=None, + qacc_smooth_ref=None, +): + """Smooth constraint adjoint for friction gradient signal. + + Like solver_implicit_adjoint, but builds a modified Hessian H_smooth that + reduces friction constraint stiffness and adds viscous friction for + SATISFIED constraints. This provides non-zero gradients through the friction + cone dead zone while keeping the forward physics unchanged. + + Parameters are read from d.smooth_friction_viscosity and + d.smooth_friction_scale. Enable via d.smooth_adjoint = 1. + + Args: + m: Model containing static simulation parameters. + d: Data containing mutable simulation state. + qacc_array: The array whose .grad contains the incoming adjoint. + qacc_smooth_ref: The qacc_smooth array whose .grad receives the + accumulated adjoint. + """ + nv = m.nv + if nv == 0: + return + + if qacc_array is None: + qacc_array = d.qacc + + if qacc_smooth_ref is None: + qacc_smooth_ref = d.qacc_smooth + + adj_qacc = qacc_array.grad + if adj_qacc is None: + return + + debug_level = os.environ.get("MJW_DEBUG_ADJOINT", "0") + if debug_level in ("1", "2"): + import numpy as np + + adj_norm = np.linalg.norm(adj_qacc.numpy()) + print(f"[smooth_adjoint] |adj_qacc|={adj_norm:.6e}, njmax={d.njmax}") + + if d.njmax == 0: + wp.launch( + _accumulate_grad_kernel, + dim=(d.nworld, nv), + inputs=[adj_qacc], + outputs=[qacc_smooth_ref.grad], + ) + return + + if m.opt.solver != types.SolverType.NEWTON: + wp.launch( + _accumulate_grad_kernel, + dim=(d.nworld, nv), + inputs=[adj_qacc], + outputs=[qacc_smooth_ref.grad], + ) + return + + # Read smooth adjoint parameters from Data + free_body = getattr(d, "smooth_free_body_adjoint", False) + penalty_alpha = getattr(d, "smooth_penalty_damping_alpha", 0.0) + surrogate = getattr(d, "smooth_friction_surrogate_adjoint", False) + surrogate_alpha = float(getattr(d, "smooth_friction_surrogate_alpha", 0.0)) + if surrogate_alpha < 0.0: + surrogate_alpha = 0.0 + elif surrogate_alpha > 1.0: + surrogate_alpha = 1.0 + + if surrogate: + friction_viscosity = getattr(d, "smooth_friction_viscosity", 10.0) + friction_scale = getattr(d, "smooth_friction_scale", 0.01) + + H_smooth = wp.clone(d.solver_h) + + if d.naconmax > 0: + wp.launch( + _smooth_hessian_friction_correction, + dim=(d.naconmax, m.nmaxpyramid), + inputs=[ + m.nv, + d.contact.efc_address, + d.contact.dim, + d.contact.type, + d.contact.worldid, + d.nacon, + d.efc.J, + d.efc.D, + d.efc.state, + friction_viscosity, + friction_scale, + ], + outputs=[H_smooth], + ) + + v_hessian = wp.zeros((d.nworld, m.nv_pad), dtype=float) + _solve_hessian_system(m, d, adj_qacc, v_hessian, H=H_smooth) + + from mujoco_warp._src.smooth import solve_m + + v_free = wp.zeros((d.nworld, m.nv_pad), dtype=float) + solve_m(m, d, v_free, adj_qacc) + + v = wp.clone(v_hessian) + if d.naconmax > 0: + # Recover only a controlled fraction of the tangential free-body signal. + # alpha=0 keeps the full bypass, alpha=1 leaves the smooth/Newton result. + correction_scale = 1.0 - surrogate_alpha + correction_cap_ratio = 1.0 + correction_norm_eps = 1.0e-8 + wp.launch( + _friction_bypass_correction_normalized, + dim=(d.naconmax, m.nmaxpyramid), + inputs=[ + m.nv, + d.contact.efc_address, + d.contact.dim, + d.contact.type, + d.contact.worldid, + d.nacon, + d.efc.J, + v_hessian, + v_free, + correction_scale, + correction_cap_ratio, + correction_norm_eps, + ], + outputs=[v], + ) + + elif free_body or penalty_alpha > 0.0: + # Free-body base: v = M^{-1} * adj_qacc + # Eliminates H^{-1} attenuation entirely. + from mujoco_warp._src.smooth import solve_m + + v = wp.zeros((d.nworld, m.nv_pad), dtype=float) + solve_m(m, d, v, adj_qacc) + + # Penalty-model friction damping: attenuate v in friction directions + # by factor (1 - alpha) per face, mimicking dflex's penalty friction + # d(v_next)/d(v_prev) eigenvalues. Provides natural BPTT decay. + if penalty_alpha > 0.0 and d.naconmax > 0: + v_free = wp.clone(v) # save unmodified for projection + wp.launch( + _penalty_friction_damping, + dim=(d.naconmax, m.nmaxpyramid), + inputs=[ + m.nv, + d.contact.efc_address, + d.contact.dim, + d.contact.type, + d.contact.worldid, + d.nacon, + d.efc.J, + v_free, + penalty_alpha, + ], + outputs=[v], + ) + + else: + # Original smooth adjoint: H_smooth with friction correction + optional bypass + friction_viscosity = getattr(d, "smooth_friction_viscosity", 10.0) + friction_scale = getattr(d, "smooth_friction_scale", 0.01) + bypass_kf = getattr(d, "smooth_friction_bypass_kf", 0.0) + + # Build H_smooth = d.solver_h + friction correction + H_smooth = wp.clone(d.solver_h) + + if d.naconmax > 0: + wp.launch( + _smooth_hessian_friction_correction, + dim=(d.naconmax, m.nmaxpyramid), + inputs=[ + m.nv, + d.contact.efc_address, + d.contact.dim, + d.contact.type, + d.contact.worldid, + d.nacon, + d.efc.J, + d.efc.D, + d.efc.state, + friction_viscosity, + friction_scale, + ], + outputs=[H_smooth], + ) + + if debug_level == "2": + import numpy as np + + H_np = H_smooth.numpy()[0, :nv, :nv] + H_orig = d.solver_h.numpy()[0, :nv, :nv] + diff = H_np - H_orig + print( + f"[smooth_adjoint:diag] H_smooth diag=" + f"[{np.min(np.diag(H_np)):.3e}, {np.max(np.diag(H_np)):.3e}]" + f" |delta_H|_F={np.linalg.norm(diff):.3e}" + ) + + # Solve H_smooth * v = adj_qacc + v = wp.zeros((d.nworld, m.nv_pad), dtype=float) + _solve_hessian_system(m, d, adj_qacc, v, H=H_smooth) + + if debug_level == "2": + import numpy as np + + v_np = v.numpy()[0, :nv] + print(f"[smooth_adjoint:diag] |v|={np.linalg.norm(v_np):.6e} v={v_np}") + + # Friction gradient bypass: restore tangential gradients attenuated by H^{-1} + if bypass_kf > 0.0 and d.naconmax > 0: + from mujoco_warp._src.smooth import solve_m + + v_free = wp.zeros((d.nworld, m.nv_pad), dtype=float) + solve_m(m, d, v_free, adj_qacc) + + wp.launch( + _friction_bypass_correction, + dim=(d.naconmax, m.nmaxpyramid), + inputs=[ + m.nv, + d.contact.efc_address, + d.contact.dim, + d.contact.type, + d.contact.worldid, + d.nacon, + d.efc.J, + v, + v_free, + bypass_kf, + ], + outputs=[v], + ) + + if debug_level == "2": + import numpy as np + + v_bypass = v.numpy()[0, :nv] + print( + f"[smooth_adjoint:diag] bypass kf={bypass_kf} " + f"|v_after_bypass|={np.linalg.norm(v_bypass):.6e}" + ) + + # adj_qacc_smooth += M * v + tmp = wp.zeros((d.nworld, m.nv_pad), dtype=float) + support.mul_m(m, d, tmp, v) + wp.launch( + _accumulate_grad_kernel, + dim=(d.nworld, nv), + inputs=[tmp], + outputs=[qacc_smooth_ref.grad], + ) - # adj_qacc_smooth = M * v - support.mul_m(m, d, d.qacc_smooth.grad, v) + # Phase 3: efc-level gradients for collision chain + _efc_level_gradients(m, d, v) diff --git a/mujoco_warp/_src/bvh.py b/mujoco_warp/_src/bvh.py index 0efdc1344..60c4093cf 100644 --- a/mujoco_warp/_src/bvh.py +++ b/mujoco_warp/_src/bvh.py @@ -189,12 +189,12 @@ def _compute_bvh_bounds( upper_out: wp.array(dtype=wp.vec3), group_out: wp.array(dtype=int), ): - world_id, geom_local_id = wp.tid() + worldid, geom_local_id = wp.tid() geom_id = enabled_geom_ids[geom_local_id] - pos = geom_xpos_in[world_id, geom_id] - rot = geom_xmat_in[world_id, geom_id] - size = geom_size[world_id % geom_size.shape[0], geom_id] + pos = geom_xpos_in[worldid, geom_id] + rot = geom_xmat_in[worldid, geom_id] + size = geom_size[worldid % geom_size.shape[0], geom_id] type = geom_type[geom_id] # TODO: Investigate branch elimination with static loop unrolling @@ -218,9 +218,9 @@ def _compute_bvh_bounds( hfield_center = pos + rot[:, 2] * size[2] lower_bound, upper_bound = _compute_box_bounds(hfield_center, rot, size) - lower_out[world_id * bvh_ngeom + geom_local_id] = lower_bound - upper_out[world_id * bvh_ngeom + geom_local_id] = upper_bound - group_out[world_id * bvh_ngeom + geom_local_id] = world_id + lower_out[worldid * bvh_ngeom + geom_local_id] = lower_bound + upper_out[worldid * bvh_ngeom + geom_local_id] = upper_bound + group_out[worldid * bvh_ngeom + geom_local_id] = worldid @wp.kernel @@ -235,14 +235,70 @@ def compute_bvh_group_roots( group_root_out[tid] = root +@wp.kernel +def _compute_flex_bvh_bounds( + # Model: + flex_vertadr: wp.array(dtype=int), + flex_vertnum: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), + # Data in: + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), + # In: + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + bvh_ngeom: int, + total_bvh_size: int, + # Out: + lower_out: wp.array(dtype=wp.vec3), + upper_out: wp.array(dtype=wp.vec3), + group_out: wp.array(dtype=int), +): + worldid, flexlocalid = wp.tid() + + flex_id = flex_geom_flexid[flexlocalid] + edge_id = flex_geom_edgeid[flexlocalid] + out_idx = worldid * total_bvh_size + bvh_ngeom + flexlocalid + radius = flex_radius[flex_id] + inflate = wp.vec3(radius, radius, radius) + + if edge_id >= 0: # capsule (1D edge) + edge = flex_edge[edge_id] + vert_adr = flex_vertadr[flex_id] + v0 = flexvert_xpos_in[worldid, vert_adr + edge[0]] + v1 = flexvert_xpos_in[worldid, vert_adr + edge[1]] + lower_out[out_idx] = wp.min(v0, v1) - inflate + upper_out[out_idx] = wp.max(v0, v1) + inflate + else: # mesh (2D/3D) + vert_adr = flex_vertadr[flex_id] + nvert = flex_vertnum[flex_id] + min_bound = wp.vec3(MJ_MAXVAL, MJ_MAXVAL, MJ_MAXVAL) + max_bound = wp.vec3(-MJ_MAXVAL, -MJ_MAXVAL, -MJ_MAXVAL) + for i in range(nvert): + v = flexvert_xpos_in[worldid, vert_adr + i] + min_bound = wp.min(min_bound, v) + max_bound = wp.max(max_bound, v) + lower_out[out_idx] = min_bound - inflate + upper_out[out_idx] = max_bound + inflate + + group_out[out_idx] = worldid + + def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, nworld: int): """Build a global BVH for all geometries in all worlds.""" + total_bvh_size = rc.bvh_ngeom + rc.bvh_nflexgeom + geom_type = wp.array(mjm.geom_type, dtype=int) geom_dataid = wp.array(mjm.geom_dataid, dtype=int) geom_size = wp.array(np.tile(mjm.geom_size[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) geom_xpos = wp.array(np.tile(mjd.geom_xpos[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) geom_xmat = wp.array(np.tile(mjd.geom_xmat.reshape(mjm.ngeom, 3, 3)[np.newaxis, :, :, :], (nworld, 1, 1, 1)), dtype=wp.mat33) + flex_vertadr = wp.array(mjm.flex_vertadr, dtype=int) + flex_vertnum = wp.array(mjm.flex_vertnum, dtype=int) + flex_edge = wp.array(mjm.flex_edge, dtype=wp.vec2i) + flex_radius = wp.array(mjm.flex_radius, dtype=float) + wp.launch( kernel=_compute_bvh_bounds, dim=(nworld, rc.bvh_ngeom), @@ -252,7 +308,7 @@ def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, geom_size, geom_xpos, geom_xmat, - rc.bvh_ngeom, + total_bvh_size, rc.enabled_geom_ids, rc.mesh_bounds_size, rc.hfield_bounds_size, @@ -262,6 +318,26 @@ def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, ], ) + flexvert_xpos = wp.array(np.tile(mjd.flexvert_xpos[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) + wp.launch( + kernel=_compute_flex_bvh_bounds, + dim=(nworld, rc.bvh_nflexgeom), + inputs=[ + flex_vertadr, + flex_vertnum, + flex_edge, + flex_radius, + flexvert_xpos, + rc.flex_geom_flexid, + rc.flex_geom_edgeid, + rc.bvh_ngeom, + total_bvh_size, + rc.lower, + rc.upper, + rc.group, + ], + ) + bvh = wp.Bvh(rc.lower, rc.upper, groups=rc.group, constructor="sah") # BVH handle must be stored to avoid garbage collection @@ -277,6 +353,8 @@ def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, def refit_scene_bvh(m: Model, d: Data, rc: RenderContext): + total_bvh_size = rc.bvh_ngeom + rc.bvh_nflexgeom + wp.launch( kernel=_compute_bvh_bounds, dim=(d.nworld, rc.bvh_ngeom), @@ -286,7 +364,7 @@ def refit_scene_bvh(m: Model, d: Data, rc: RenderContext): m.geom_size, d.geom_xpos, d.geom_xmat, - rc.bvh_ngeom, + total_bvh_size, rc.enabled_geom_ids, rc.mesh_bounds_size, rc.hfield_bounds_size, @@ -296,6 +374,26 @@ def refit_scene_bvh(m: Model, d: Data, rc: RenderContext): ], ) + if rc.bvh_nflexgeom > 0: + wp.launch( + kernel=_compute_flex_bvh_bounds, + dim=(d.nworld, rc.bvh_nflexgeom), + inputs=[ + m.flex_vertadr, + m.flex_vertnum, + m.flex_edge, + m.flex_radius, + d.flexvert_xpos, + rc.flex_geom_flexid, + rc.flex_geom_edgeid, + rc.bvh_ngeom, + total_bvh_size, + rc.lower, + rc.upper, + rc.group, + ], + ) + rc.bvh.refit() @@ -500,6 +598,12 @@ def build_hfield_bvh( @wp.kernel def accumulate_flex_vertex_normals( # Model: + nflex: int, + flex_dim: wp.array(dtype=int), + flex_vertadr: wp.array(dtype=int), + flex_elemadr: wp.array(dtype=int), + flex_elemnum: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_elem: wp.array(dtype=int), # Data in: flexvert_xpos_in: wp.array2d(dtype=wp.vec3), @@ -509,10 +613,22 @@ def accumulate_flex_vertex_normals( """Accumulate per-vertex normals by summing adjacent face normals.""" worldid, elemid = wp.tid() - elem_base = elemid * 3 - i0 = flex_elem[elem_base + 0] - i1 = flex_elem[elem_base + 1] - i2 = flex_elem[elem_base + 2] + for i in range(nflex): + locid = elemid - flex_elemadr[i] + if locid >= 0 and locid < flex_elemnum[i]: + f = i + break + + if flex_dim[f] == 1 or flex_dim[f] == 3: + return + + local_elemid = elemid - flex_elemadr[f] + elem_adr = flex_elemdataadr[f] + vert_adr = flex_vertadr[f] + elem_base = elem_adr + local_elemid * 3 + i0 = vert_adr + flex_elem[elem_base + 0] + i1 = vert_adr + flex_elem[elem_base + 1] + i2 = vert_adr + flex_elem[elem_base + 2] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] @@ -718,12 +834,11 @@ def _build_flex_3d_shells( @wp.kernel -def _update_flex_face_points( +def _update_flex_2d_face_points( # Model: - nflex: int, - flex_dim: wp.array(dtype=int), flex_vertadr: wp.array(dtype=int), flex_elemnum: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_shelldataadr: wp.array(dtype=int), flex_elem: wp.array(dtype=int), flex_shell: wp.array(dtype=int), @@ -732,149 +847,150 @@ def _update_flex_face_points( flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: flexvert_norm_in: wp.array2d(dtype=wp.vec3), - flex_elemdataadr: wp.array(dtype=int), - flex_faceadr: wp.array(dtype=int), - flex_workadr: wp.array(dtype=int), - flex_worknum: wp.array(dtype=int), - nfaces: int, + flex_id: int, + nface: int, smooth: bool, # Out: face_point_out: wp.array(dtype=wp.vec3), ): worldid, workid = wp.tid() - # identify which flex this work item belongs to - f = int(0) - locid = int(0) - for i in range(nflex): - locid = workid - flex_workadr[i] - if locid >= 0 and locid < flex_worknum[i]: - f = i - break - - dim = flex_dim[f] - face_offset = flex_faceadr[f] - world_face_offset = worldid * nfaces - vert_adr = flex_vertadr[f] + elem_adr = flex_elemdataadr[flex_id] + vert_adr = flex_vertadr[flex_id] + radius = flex_radius[flex_id] + nelem = flex_elemnum[flex_id] + world_face_offset = worldid * nface - if dim == 2: - radius = flex_radius[f] - elem_count = flex_elemnum[f] - - if locid < elem_count: - # 2D element faces - elemid = locid - elem_adr = flex_elemdataadr[f] - ebase = elem_adr + elemid * 3 - i0 = vert_adr + flex_elem[ebase + 0] - i1 = vert_adr + flex_elem[ebase + 1] - i2 = vert_adr + flex_elem[ebase + 2] - - v0 = flexvert_xpos_in[worldid, i0] - v1 = flexvert_xpos_in[worldid, i1] - v2 = flexvert_xpos_in[worldid, i2] - - # TODO: Use static conditional - if smooth: - n0 = flexvert_norm_in[worldid, i0] - n1 = flexvert_norm_in[worldid, i1] - n2 = flexvert_norm_in[worldid, i2] - else: - face_nrm = wp.cross(v1 - v0, v2 - v0) - face_nrm = wp.normalize(face_nrm) - n0 = face_nrm - n1 = face_nrm - n2 = face_nrm - - p0_pos = v0 + radius * n0 - p1_pos = v1 + radius * n1 - p2_pos = v2 + radius * n2 - - p0_neg = v0 - radius * n0 - p1_neg = v1 - radius * n1 - p2_neg = v2 - radius * n2 - - face_id0 = world_face_offset + face_offset + (2 * elemid) - base0 = face_id0 * 3 - face_point_out[base0 + 0] = p0_pos - face_point_out[base0 + 1] = p1_pos - face_point_out[base0 + 2] = p2_pos - - face_id1 = world_face_offset + face_offset + (2 * elemid + 1) - base1 = face_id1 * 3 - face_point_out[base1 + 0] = p0_neg - face_point_out[base1 + 1] = p1_neg - face_point_out[base1 + 2] = p2_neg - else: - # 2D shell faces - shellid = locid - elem_count - shell_adr = flex_shelldataadr[f] - sbase = shell_adr + 2 * shellid - i0 = vert_adr + flex_shell[sbase + 0] - i1 = vert_adr + flex_shell[sbase + 1] + if workid < nelem: + # 2D element faces + elemid = workid + ebase = elem_adr + elemid * 3 + i0 = vert_adr + flex_elem[ebase + 0] + i1 = vert_adr + flex_elem[ebase + 1] + i2 = vert_adr + flex_elem[ebase + 2] - v0 = flexvert_xpos_in[worldid, i0] - v1 = flexvert_xpos_in[worldid, i1] + v0 = flexvert_xpos_in[worldid, i0] + v1 = flexvert_xpos_in[worldid, i1] + v2 = flexvert_xpos_in[worldid, i2] + # TODO: Use static conditional + if smooth: n0 = flexvert_norm_in[worldid, i0] n1 = flexvert_norm_in[worldid, i1] - - shell_face_offset = face_offset + (2 * elem_count) - face_id0 = world_face_offset + shell_face_offset + (2 * shellid) - base0 = face_id0 * 3 - face_point_out[base0 + 0] = v0 + radius * n0 - face_point_out[base0 + 1] = v1 - radius * n1 - face_point_out[base0 + 2] = v1 + radius * n1 - - face_id1 = world_face_offset + shell_face_offset + (2 * shellid + 1) - base1 = face_id1 * 3 - face_point_out[base1 + 0] = v1 - radius * n1 - face_point_out[base1 + 1] = v0 + radius * n0 - face_point_out[base1 + 2] = v0 - radius * n0 + n2 = flexvert_norm_in[worldid, i2] + else: + face_nrm = wp.cross(v1 - v0, v2 - v0) + face_nrm = wp.normalize(face_nrm) + n0 = face_nrm + n1 = face_nrm + n2 = face_nrm + + p0_pos = v0 + radius * n0 + p1_pos = v1 + radius * n1 + p2_pos = v2 + radius * n2 + + p0_neg = v0 - radius * n0 + p1_neg = v1 - radius * n1 + p2_neg = v2 - radius * n2 + + face_id0 = world_face_offset + (2 * elemid) + base0 = face_id0 * 3 + face_point_out[base0 + 0] = p0_pos + face_point_out[base0 + 1] = p1_pos + face_point_out[base0 + 2] = p2_pos + + face_id1 = world_face_offset + (2 * elemid + 1) + base1 = face_id1 * 3 + face_point_out[base1 + 0] = p0_neg + face_point_out[base1 + 1] = p1_neg + face_point_out[base1 + 2] = p2_neg else: - # 3D shell faces - shellid = locid - shell_adr = flex_shelldataadr[f] - sbase = shell_adr + shellid * 3 + # 2D shell faces + shell_adr = flex_shelldataadr[flex_id] + shellid = workid - nelem + sbase = shell_adr + 2 * shellid i0 = vert_adr + flex_shell[sbase + 0] i1 = vert_adr + flex_shell[sbase + 1] - i2 = vert_adr + flex_shell[sbase + 2] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] - v2 = flexvert_xpos_in[worldid, i2] - face_id = world_face_offset + face_offset + shellid - fbase = face_id * 3 + n0 = flexvert_norm_in[worldid, i0] + n1 = flexvert_norm_in[worldid, i1] - face_point_out[fbase + 0] = v0 - face_point_out[fbase + 1] = v1 - face_point_out[fbase + 2] = v2 + shell_face_offset = 2 * nelem + face_id0 = world_face_offset + shell_face_offset + (2 * shellid) + base0 = face_id0 * 3 + face_point_out[base0 + 0] = v0 + radius * n0 + face_point_out[base0 + 1] = v1 - radius * n1 + face_point_out[base0 + 2] = v1 + radius * n1 + face_id1 = world_face_offset + shell_face_offset + (2 * shellid + 1) + base1 = face_id1 * 3 + face_point_out[base1 + 0] = v1 - radius * n1 + face_point_out[base1 + 1] = v0 + radius * n0 + face_point_out[base1 + 2] = v0 - radius * n0 -def build_flex_bvh( - mjm: mujoco.MjModel, mjd: mujoco.MjData, nworld: int, constructor: str = "sah", leaf_size: int = 2 -) -> tuple[wp.Mesh, wp.array, wp.array, wp.array, wp.array, wp.array, int]: - """Create a Warp mesh BVH from flex data.""" - if (mjm.flex_dim == 1).any(): - raise ValueError("1D Flex objects are not currently supported.") - nflex = mjm.nflex +@wp.kernel +def _update_flex_3d_face_points( + # Model: + flex_vertadr: wp.array(dtype=int), + flex_shelldataadr: wp.array(dtype=int), + flex_shell: wp.array(dtype=int), + # Data in: + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), + # In: + flex_id: int, + nface: int, + # Out: + face_point_out: wp.array(dtype=wp.vec3), +): + worldid, shellid = wp.tid() + + shell_adr = flex_shelldataadr[flex_id] + vert_adr = flex_vertadr[flex_id] + + face_id = worldid * nface + shellid + fbase = face_id * 3 + + sbase = shell_adr + shellid * 3 + i0 = vert_adr + flex_shell[sbase + 0] + i1 = vert_adr + flex_shell[sbase + 1] + i2 = vert_adr + flex_shell[sbase + 2] + + face_point_out[fbase + 0] = flexvert_xpos_in[worldid, i0] + face_point_out[fbase + 1] = flexvert_xpos_in[worldid, i1] + face_point_out[fbase + 2] = flexvert_xpos_in[worldid, i2] + + +def build_flex_bvh( + mjm: mujoco.MjModel, + mjd: mujoco.MjData, + nworld: int, + flex_id: int, + constructor: str = "sah", + leaf_size: int = 2, +) -> tuple[wp.Mesh, wp.array, wp.array, wp.array, int]: + """Create a Warp mesh BVH for a single 2D or 3D flex.""" nflexvert = mjm.nflexvert - nflexelemdata = len(mjm.flex_elem) + flex_dim = wp.array(mjm.flex_dim, dtype=int) + flex_elemadr = wp.array(mjm.flex_elemadr, dtype=int) + flex_elemnum = wp.array(mjm.flex_elemnum, dtype=int) flex_elem = wp.array(mjm.flex_elem, dtype=int) + flex_elemdataadr = wp.array(mjm.flex_elemdataadr, dtype=int) + flex_vertadr = wp.array(mjm.flex_vertadr, dtype=int) flexvert_xpos = wp.array(np.tile(mjd.flexvert_xpos[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) - flex_faceadr = [0] - for f in range(nflex): - if mjm.flex_dim[f] == 2: - flex_faceadr.append(flex_faceadr[-1] + 2 * mjm.flex_elemnum[f] + 2 * mjm.flex_shellnum[f]) - elif mjm.flex_dim[f] == 3: - flex_faceadr.append(flex_faceadr[-1] + mjm.flex_shellnum[f]) + dim = int(mjm.flex_dim[flex_id]) + nelem = int(mjm.flex_elemnum[flex_id]) + nshell = int(mjm.flex_shellnum[flex_id]) - nface = int(flex_faceadr[-1]) - flex_faceadr = flex_faceadr[:-1] + if dim == 2: + nface = 2 * nelem + 2 * nshell + else: + nface = nshell face_point = wp.empty(nface * 3 * nworld, dtype=wp.vec3) face_index = wp.empty(nface * 3 * nworld, dtype=wp.int32) @@ -885,8 +1001,8 @@ def build_flex_bvh( wp.launch( kernel=accumulate_flex_vertex_normals, - dim=(nworld, nflexelemdata // 3), - inputs=[flex_elem, flexvert_xpos], + dim=(nworld, mjm.nflexelem), + inputs=[mjm.nflex, flex_dim, flex_vertadr, flex_elemadr, flex_elemnum, flex_elemdataadr, flex_elem, flexvert_xpos], outputs=[flexvert_norm], ) @@ -896,60 +1012,56 @@ def build_flex_bvh( inputs=[flexvert_norm], ) - for f in range(nflex): - dim = mjm.flex_dim[f] - elem_adr = mjm.flex_elemdataadr[f] - nelem = mjm.flex_elemnum[f] - shell_adr = mjm.flex_shelldataadr[f] - nshell = mjm.flex_shellnum[f] - vert_adr = mjm.flex_vertadr[f] + elem_adr = mjm.flex_elemdataadr[flex_id] + shell_adr = mjm.flex_shelldataadr[flex_id] + vert_adr = mjm.flex_vertadr[flex_id] - if dim == 2: - wp.launch( - kernel=_build_flex_2d_elements, - dim=(nworld, nelem), - inputs=[ - flex_elem, - flexvert_xpos, - flexvert_norm, - elem_adr, - vert_adr, - flex_faceadr[f], - mjm.flex_radius[f], - nface, - ], - outputs=[face_point, face_index, group], - ) - - wp.launch( - kernel=_build_flex_2d_sides, - dim=(nworld, nshell), - inputs=[ - flex_shell, - flexvert_xpos, - flexvert_norm, - shell_adr, - vert_adr, - flex_faceadr[f] + 2 * nelem, - mjm.flex_radius[f], - nface, - ], - outputs=[face_point, face_index, group], - ) - elif dim == 3: - wp.launch( - kernel=_build_flex_3d_shells, - dim=(nworld, nshell), - inputs=[ - flex_shell, - flexvert_xpos, - shell_adr, - vert_adr, - flex_faceadr[f], - nface, - ], - outputs=[face_point, face_index, group], - ) + if dim == 2: + wp.launch( + kernel=_build_flex_2d_elements, + dim=(nworld, nelem), + inputs=[ + flex_elem, + flexvert_xpos, + flexvert_norm, + elem_adr, + vert_adr, + 0, # face_offset + mjm.flex_radius[flex_id], + nface, + ], + outputs=[face_point, face_index, group], + ) + + wp.launch( + kernel=_build_flex_2d_sides, + dim=(nworld, nshell), + inputs=[ + flex_shell, + flexvert_xpos, + flexvert_norm, + shell_adr, + vert_adr, + 2 * nelem, # face_offset + mjm.flex_radius[flex_id], + nface, + ], + outputs=[face_point, face_index, group], + ) + elif dim == 3: + wp.launch( + kernel=_build_flex_3d_shells, + dim=(nworld, nshell), + inputs=[ + flex_shell, + flexvert_xpos, + shell_adr, + vert_adr, + 0, # face_offset + nface, + ], + outputs=[face_point, face_index, group], + ) flex_mesh = wp.Mesh( points=face_point, @@ -967,24 +1079,23 @@ def build_flex_bvh( outputs=[group_root], ) - return ( - flex_mesh, - face_point, - group_root, - flex_shell, - flex_faceadr, - nface, - ) + return flex_mesh, group_root def refit_flex_bvh(m: Model, d: Data, rc: RenderContext): - """Refit the flex BVH.""" + """Refit per-flex BVHs.""" flexvert_norm = wp.zeros(d.flexvert_xpos.shape, dtype=wp.vec3) wp.launch( kernel=accumulate_flex_vertex_normals, - dim=(d.nworld, m.nflexelemdata // 3), + dim=(d.nworld, m.nflexelem), inputs=[ + m.nflex, + m.flex_dim, + m.flex_vertadr, + m.flex_elemadr, + m.flex_elemnum, + m.flex_elemdataadr, m.flex_elem, d.flexvert_xpos, ], @@ -993,32 +1104,49 @@ def refit_flex_bvh(m: Model, d: Data, rc: RenderContext): wp.launch( kernel=normalize_vertex_normals, - dim=(d.nworld, m.nflexvert), + dim=(d.nworld, d.flexvert_xpos.shape[1]), inputs=[flexvert_norm], ) - wp.launch( - kernel=_update_flex_face_points, - dim=(d.nworld, rc.flex_nwork), - inputs=[ - m.nflex, - m.flex_dim, - m.flex_vertadr, - m.flex_elemnum, - m.flex_shelldataadr, - m.flex_elem, - m.flex_shell, - m.flex_radius, - d.flexvert_xpos, - flexvert_norm, - rc.flex_elemdataadr, - rc.flex_faceadr, - rc.flex_workadr, - rc.flex_worknum, - rc.flex_nface, - rc.flex_render_smooth, - ], - outputs=[rc.flex_face_point], - ) + for i in range(m.nflex): + if rc.flex_dim_np[i] == 1: + continue + mesh = rc.flex_mesh_registry[i] + nface = mesh.points.shape[0] // (3 * d.nworld) + + if rc.flex_dim_np[i] == 2: + wp.launch( + kernel=_update_flex_2d_face_points, + dim=(d.nworld, nface // 2), + inputs=[ + m.flex_vertadr, + m.flex_elemnum, + m.flex_elemdataadr, + m.flex_shelldataadr, + m.flex_elem, + m.flex_shell, + m.flex_radius, + d.flexvert_xpos, + flexvert_norm, + i, + nface, + rc.flex_render_smooth, + ], + outputs=[mesh.points], + ) + else: + wp.launch( + kernel=_update_flex_3d_face_points, + dim=(d.nworld, nface), + inputs=[ + m.flex_vertadr, + m.flex_shelldataadr, + m.flex_shell, + d.flexvert_xpos, + i, + nface, + ], + outputs=[mesh.points], + ) - rc.flex_mesh.refit() + mesh.refit() diff --git a/mujoco_warp/_src/bvh_test.py b/mujoco_warp/_src/bvh_test.py index ee15d65b5..5241334d3 100644 --- a/mujoco_warp/_src/bvh_test.py +++ b/mujoco_warp/_src/bvh_test.py @@ -33,9 +33,12 @@ def _assert_eq(a, b, name): @dataclasses.dataclass class MinimalRenderContext: bvh_ngeom: int + bvh_nflexgeom: int enabled_geom_ids: wp.array mesh_bounds_size: wp.array hfield_bounds_size: wp.array + flex_geom_flexid: wp.array + flex_geom_edgeid: wp.array lower: wp.array upper: wp.array group: wp.array @@ -53,9 +56,12 @@ def _create_minimal_context(mjm, nworld, enabled_geom_groups=None): return MinimalRenderContext( bvh_ngeom=bvh_ngeom, + bvh_nflexgeom=0, enabled_geom_ids=wp.array(geom_enabled_idx, dtype=int), mesh_bounds_size=wp.zeros(max(mjm.nmesh, 1), dtype=wp.vec3), hfield_bounds_size=wp.zeros(max(mjm.nhfield, 1), dtype=wp.vec3), + flex_geom_flexid=wp.zeros(max(mjm.nflex, 1), dtype=int), + flex_geom_edgeid=wp.zeros(max(mjm.nflex, 1), dtype=int), lower=wp.zeros(nworld * bvh_ngeom, dtype=wp.vec3), upper=wp.zeros(nworld * bvh_ngeom, dtype=wp.vec3), group=wp.zeros(nworld * bvh_ngeom, dtype=int), @@ -211,12 +217,18 @@ def test_accumulate_flex_vertex_normals(self): dtype=wp.vec3, ) flex_elem = wp.array([0, 1, 2, 1, 3, 2], dtype=int) + flex_elemdataadr = wp.array([0], dtype=int) + flex_elemadr = wp.array([0], dtype=int) + flex_elemnum = wp.array([len(flex_elem)], dtype=int) + flex_vertadr = wp.array([0], dtype=int) + flex_dim = wp.array([2], dtype=int) + flex_id = 0 flexvert_norm = wp.zeros((nworld, nvert), dtype=wp.vec3) wp.launch( kernel=bvh.accumulate_flex_vertex_normals, dim=(nworld, nelem), - inputs=[flex_elem, flexvert_xpos], + inputs=[1, flex_dim, flex_vertadr, flex_elemadr, flex_elemnum, flex_elemdataadr, flex_elem, flexvert_xpos], outputs=[flexvert_norm], ) @@ -252,7 +264,7 @@ def test_build_flex_bvh(self): mjm, mjd, m, d = test_data.fixture("flex/floppy.xml") - flex_mesh, face_point, group_root, flex_shell, flex_faceadr, nface = bvh.build_flex_bvh(mjm, mjd, 1) + flex_mesh, face_point, flex_shell, group_root, nface = bvh.build_flex_bvh(mjm, mjd, 1, 0) self.assertNotEqual(flex_mesh.id, wp.uint64(0), "flex_mesh id") diff --git a/mujoco_warp/_src/collision_flex.py b/mujoco_warp/_src/collision_flex.py index 3523730b0..13d8b0e0d 100644 --- a/mujoco_warp/_src/collision_flex.py +++ b/mujoco_warp/_src/collision_flex.py @@ -398,6 +398,7 @@ def _flex_narrowphase_dim2( flex_vertadr: wp.array(dtype=int), flex_elemadr: wp.array(dtype=int), flex_elemnum: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_elem: wp.array(dtype=int), flex_radius: wp.array(dtype=float), # Data in: @@ -443,7 +444,7 @@ def _flex_narrowphase_dim2( tri_radius = flex_radius[flexid] tri_margin = flex_margin[flexid] - elem_data_idx = elemid * 3 + elem_data_idx = flex_elemdataadr[flexid] + (elemid - flex_elemadr[flexid]) * 3 v0_local = flex_elem[elem_data_idx] v1_local = flex_elem[elem_data_idx + 1] v2_local = flex_elem[elem_data_idx + 2] @@ -709,6 +710,7 @@ def flex_narrowphase(m: Model, d: Data): m.flex_vertadr, m.flex_elemadr, m.flex_elemnum, + m.flex_elemdataadr, m.flex_elem, m.flex_radius, d.geom_xpos, diff --git a/mujoco_warp/_src/collision_smooth.py b/mujoco_warp/_src/collision_smooth.py new file mode 100644 index 000000000..7547e932b --- /dev/null +++ b/mujoco_warp/_src/collision_smooth.py @@ -0,0 +1,825 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Smooth (differentiable) collision recomputation for autodifferentiation. + +This module provides differentiable replacements for the collision pipeline's +contact geometry and constraint assembly. It runs *after* the discrete pipeline +and overwrites contact.{dist, pos, frame} and efc.{J, pos, D, aref, vel} with +smooth values that Warp can differentiate through. + +Supported geom type pairs: + - plane-sphere, sphere-sphere, sphere-capsule + - capsule-capsule (2 contacts), plane-capsule (2 contacts) + +Unsupported types (box, mesh, convex, etc.) are no-ops that keep discrete +values (zero gradient through those contacts). +""" + +from typing import Tuple + +import warp as wp + +from mujoco_warp._src import support +from mujoco_warp._src import types +from mujoco_warp._src.types import MJ_MINVAL +from mujoco_warp._src.types import DisableBit + +wp.set_module_options({"enable_backward": True}) + + +# ============================================================================ +# Custom types (matching collision_primitive_core.py) +# ============================================================================ + + +class mat23f(wp.types.matrix(shape=(2, 3), dtype=wp.float32)): + pass + + +# ============================================================================ +# Smooth distance functions +# ============================================================================ + + +@wp.func +def smooth_plane_sphere( + # In: + plane_normal: wp.vec3, + plane_pos: wp.vec3, + sphere_pos: wp.vec3, + sphere_radius: float, +) -> Tuple[float, wp.vec3]: + """Plane-sphere distance (already smooth).""" + dist = wp.dot(sphere_pos - plane_pos, plane_normal) - sphere_radius + pos = sphere_pos - plane_normal * (sphere_radius + 0.5 * dist) + return dist, pos + + +@wp.func +def smooth_sphere_sphere( + # In: + pos1: wp.vec3, + radius1: float, + pos2: wp.vec3, + radius2: float, +) -> Tuple[float, wp.vec3, wp.vec3]: + """Sphere-sphere distance with smooth normalization at coincident centers.""" + dir = pos2 - pos1 + raw_dist = wp.length(dir) + # Smooth normalization: replaces if dist==0 branch + n = dir / wp.max(raw_dist, 1e-8) + dist = raw_dist - (radius1 + radius2) + pos = pos1 + n * (radius1 + 0.5 * dist) + return dist, pos, n + + +@wp.func +def smooth_sphere_capsule( + # In: + sphere_pos: wp.vec3, + sphere_radius: float, + capsule_pos: wp.vec3, + capsule_axis: wp.vec3, + capsule_radius: float, + capsule_half_length: float, +) -> Tuple[float, wp.vec3, wp.vec3]: + """Sphere-capsule distance using wp.clamp (subdifferentiable at boundary).""" + segment = capsule_axis * capsule_half_length + seg_start = capsule_pos - segment + seg_end = capsule_pos + segment + + # Closest point on capsule centerline to sphere center + ab = seg_end - seg_start + t = wp.dot(sphere_pos - seg_start, ab) / (wp.dot(ab, ab) + 1e-6) + pt = seg_start + wp.clamp(t, 0.0, 1.0) * ab + + return smooth_sphere_sphere(sphere_pos, sphere_radius, pt, capsule_radius) + + +@wp.func +def smooth_capsule_capsule( + # In: + cap1_pos: wp.vec3, + cap1_axis: wp.vec3, + cap1_radius: float, + cap1_half_length: float, + cap2_pos: wp.vec3, + cap2_axis: wp.vec3, + cap2_radius: float, + cap2_half_length: float, + margin: float, +) -> Tuple[wp.vec2, mat23f, mat23f]: + """Capsule-capsule distance returning 2 contacts, regularized for parallel axes.""" + contact_dist = wp.vec2(wp.inf, wp.inf) + contact_pos = mat23f() + contact_normal = mat23f() + + axis1 = cap1_axis * cap1_half_length + axis2 = cap2_axis * cap2_half_length + dif = cap1_pos - cap2_pos + + ma = wp.dot(axis1, axis1) + mb = -wp.dot(axis1, axis2) + mc = wp.dot(axis2, axis2) + u = -wp.dot(axis1, dif) + v = wp.dot(axis2, dif) + det = ma * mc - mb * mb + + # Regularized determinant: smooth handling of near-parallel axes + det_abs = wp.abs(det) + det_sign = wp.where(det >= 0.0, 1.0, -1.0) + det_reg = det_sign * wp.max(det_abs, 1e-10) + + # Blend: use non-parallel path when |det| > threshold, parallel otherwise + # Smooth blending factor + blend_threshold = 1e-8 + alpha = wp.min(det_abs / wp.max(blend_threshold, 1e-15), 1.0) + + # -- Non-parallel path -- + inv_det = 1.0 / det_reg + x1_np = (mc * u - mb * v) * inv_det + x2_np = (ma * v - mb * u) * inv_det + + # Clamp with recomputation (smooth via wp.clamp) + x1_np = wp.clamp(x1_np, -1.0, 1.0) + x2_np = wp.clamp(x2_np, -1.0, 1.0) + + # Re-clamp for consistency + x2_np = wp.clamp((v + mb * x1_np) / wp.max(mc, 1e-10), -1.0, 1.0) + x1_np = wp.clamp((u - mb * x2_np) / wp.max(ma, 1e-10), -1.0, 1.0) + + vec1_np = cap1_pos + axis1 * x1_np + vec2_np = cap2_pos + axis2 * x2_np + dist_np, pos_np, normal_np = smooth_sphere_sphere(vec1_np, cap1_radius, vec2_np, cap2_radius) + + # -- Parallel path: test 4 endpoint pairs, keep first 2 -- + # Endpoint 1: x1 = 1 + vec1_a = cap1_pos + axis1 + x2_a = wp.clamp((v - mb) / wp.max(mc, 1e-10), -1.0, 1.0) + vec2_a = cap2_pos + axis2 * x2_a + dist_a, pos_a, normal_a = smooth_sphere_sphere(vec1_a, cap1_radius, vec2_a, cap2_radius) + + # Endpoint 2: x1 = -1 + vec1_b = cap1_pos - axis1 + x2_b = wp.clamp((v + mb) / wp.max(mc, 1e-10), -1.0, 1.0) + vec2_b = cap2_pos + axis2 * x2_b + dist_b, pos_b, normal_b = smooth_sphere_sphere(vec1_b, cap1_radius, vec2_b, cap2_radius) + + # Endpoint 3: x2 = 1 + vec2_c = cap2_pos + axis2 + x1_c = wp.clamp((u - mb) / wp.max(ma, 1e-10), -1.0, 1.0) + vec1_c = cap1_pos + axis1 * x1_c + dist_c, pos_c, normal_c = smooth_sphere_sphere(vec1_c, cap1_radius, vec2_c, cap2_radius) + + # Endpoint 4: x2 = -1 + vec2_d = cap2_pos - axis2 + x1_d = wp.clamp((u + mb) / wp.max(ma, 1e-10), -1.0, 1.0) + vec1_d = cap1_pos + axis1 * x1_d + dist_d, pos_d, normal_d = smooth_sphere_sphere(vec1_d, cap1_radius, vec2_d, cap2_radius) + + # Sort 4 endpoints by distance, pick best 2 for parallel contacts + # Contact 0: best of all 4 + par_dist0 = dist_a + par_pos0 = pos_a + par_normal0 = normal_a + + if dist_b < par_dist0: + par_dist0 = dist_b + par_pos0 = pos_b + par_normal0 = normal_b + if dist_c < par_dist0: + par_dist0 = dist_c + par_pos0 = pos_c + par_normal0 = normal_c + if dist_d < par_dist0: + par_dist0 = dist_d + par_pos0 = pos_d + par_normal0 = normal_d + + # Contact 1: second best + par_dist1 = wp.inf + par_pos1 = wp.vec3(0.0) + par_normal1 = wp.vec3(1.0, 0.0, 0.0) + + if dist_a <= margin and dist_a != par_dist0: + par_dist1 = dist_a + par_pos1 = pos_a + par_normal1 = normal_a + if dist_b <= margin and dist_b != par_dist0: + if dist_b < par_dist1: + par_dist1 = dist_b + par_pos1 = pos_b + par_normal1 = normal_b + if dist_c <= margin and dist_c != par_dist0: + if dist_c < par_dist1: + par_dist1 = dist_c + par_pos1 = pos_c + par_normal1 = normal_c + if dist_d <= margin and dist_d != par_dist0: + if dist_d < par_dist1: + par_dist1 = dist_d + par_pos1 = pos_d + par_normal1 = normal_d + + # Blend between non-parallel (1 contact) and parallel (2 contacts) + # Non-parallel: contact 0 = np result, contact 1 = inf + # Parallel: contact 0, 1 from sorted endpoints + blend_dist0 = alpha * dist_np + (1.0 - alpha) * par_dist0 + blend_pos0 = alpha * pos_np + (1.0 - alpha) * par_pos0 + blend_normal0 = alpha * normal_np + (1.0 - alpha) * par_normal0 + # Renormalize blended normal + blend_normal0 = blend_normal0 / wp.max(wp.length(blend_normal0), 1e-8) + + # Contact 1: only from parallel path (non-parallel has 1 contact) + blend_dist1 = (1.0 - alpha) * par_dist1 + alpha * wp.inf + + if blend_dist0 <= margin: + contact_dist[0] = blend_dist0 + contact_pos[0] = blend_pos0 + contact_normal[0] = blend_normal0 + + if blend_dist1 <= margin: + contact_dist[1] = blend_dist1 + contact_pos[1] = par_pos1 + contact_normal[1] = par_normal1 + + return contact_dist, contact_pos, contact_normal + + +@wp.func +def smooth_plane_capsule( + # In: + plane_normal: wp.vec3, + plane_pos: wp.vec3, + capsule_pos: wp.vec3, + capsule_axis: wp.vec3, + capsule_radius: float, + capsule_half_length: float, +) -> Tuple[wp.vec2, mat23f, wp.mat33]: + """Plane-capsule distance returning 2 contacts (already smooth).""" + n = plane_normal + axis = capsule_axis + segment = axis * capsule_half_length + + # Build contact frame (smooth version matching collision_primitive_core.py) + proj = axis - n * wp.dot(n, axis) + proj_len = wp.length(proj) + b = proj / wp.max(proj_len, 1e-8) + + # Fallback when capsule axis is nearly parallel to plane normal + if proj_len < 0.5: + if -0.5 < n[1] and n[1] < 0.5: + b = wp.vec3(0.0, 1.0, 0.0) + else: + b = wp.vec3(0.0, 0.0, 1.0) + + c = wp.cross(n, b) + frame = wp.mat33(n[0], n[1], n[2], b[0], b[1], b[2], c[0], c[1], c[2]) + + # Two contacts at capsule endpoints + dist1, pos1 = smooth_plane_sphere(n, plane_pos, capsule_pos + segment, capsule_radius) + dist2, pos2 = smooth_plane_sphere(n, plane_pos, capsule_pos - segment, capsule_radius) + + dist = wp.vec2(dist1, dist2) + pos = mat23f(pos1[0], pos1[1], pos1[2], pos2[0], pos2[1], pos2[2]) + + return dist, pos, frame + + +@wp.func +def smooth_make_frame(normal: wp.vec3) -> wp.mat33: + """Construct contact frame from normal with smooth tangent directions.""" + a = normal / wp.max(wp.length(normal), 1e-8) + + # Gram-Schmidt orthogonalization (same as math.orthogonals but using + # wp.where instead of branching on a[1] for smoother gradients) + y = wp.vec3(0.0, 1.0, 0.0) + z = wp.vec3(0.0, 0.0, 1.0) + b = wp.where((-0.5 < a[1]) and (a[1] < 0.5), y, z) + b = b - a * wp.dot(a, b) + b_len = wp.length(b) + b = b / wp.max(b_len, 1e-8) + c = wp.cross(a, b) + + return wp.mat33( + a[0], + a[1], + a[2], + b[0], + b[1], + b[2], + c[0], + c[1], + c[2], + ) + + +# ============================================================================ +# Smooth contact recomputation kernel +# ============================================================================ + + +@wp.kernel +def _smooth_recompute_kernel( + # Model: + geom_type: wp.array(dtype=int), + geom_bodyid: wp.array(dtype=int), + geom_size: wp.array2d(dtype=wp.vec3), + # Data in: + geom_xpos_in: wp.array2d(dtype=wp.vec3), + geom_xmat_in: wp.array2d(dtype=wp.mat33), + contact_geom_in: wp.array(dtype=wp.vec2i), + contact_worldid_in: wp.array(dtype=int), + contact_geomcollisionid_in: wp.array(dtype=int), + nacon_in: wp.array(dtype=int), + # Data out: + contact_dist_out: wp.array(dtype=float), + contact_pos_out: wp.array(dtype=wp.vec3), + contact_frame_out: wp.array(dtype=wp.mat33), +): + cid = wp.tid() + + if cid >= nacon_in[0]: + return + + geoms = contact_geom_in[cid] + g1 = geoms[0] + g2 = geoms[1] + + # Skip flex contacts (geom id = -1) + if g1 < 0 or g2 < 0: + return + + worldid = contact_worldid_in[cid] + subcid = contact_geomcollisionid_in[cid] + t1 = geom_type[g1] + t2 = geom_type[g2] + + # Geom poses (differentiable from Phase 1 kinematics) + pos1 = geom_xpos_in[worldid, g1] + pos2 = geom_xpos_in[worldid, g2] + mat1 = geom_xmat_in[worldid, g1] + mat2 = geom_xmat_in[worldid, g2] + + # Geom sizes (model constants — use worldid=0 for batched models) + size_id = worldid % geom_size.shape[0] + size1 = geom_size[size_id, g1] + size2 = geom_size[size_id, g2] + + # Dispatch based on geom type pair + # Geom types: PLANE=0, HFIELD=1, SPHERE=2, CAPSULE=3, ELLIPSOID=4, + # CYLINDER=5, BOX=6, MESH=7, SDF=8 + + handled = False + + # plane-sphere + if t1 == 0 and t2 == 2: + plane_normal = wp.vec3(mat1[0, 2], mat1[1, 2], mat1[2, 2]) + dist, pos = smooth_plane_sphere(plane_normal, pos1, pos2, size2[0]) + frame = smooth_make_frame(plane_normal) + contact_dist_out[cid] = dist + contact_pos_out[cid] = pos + contact_frame_out[cid] = frame + handled = True + + # sphere-sphere + if not handled and t1 == 2 and t2 == 2: + dist, pos, normal = smooth_sphere_sphere(pos1, size1[0], pos2, size2[0]) + frame = smooth_make_frame(normal) + contact_dist_out[cid] = dist + contact_pos_out[cid] = pos + contact_frame_out[cid] = frame + handled = True + + # sphere-capsule + if not handled and t1 == 2 and t2 == 3: + cap_axis = wp.vec3(mat2[0, 2], mat2[1, 2], mat2[2, 2]) + dist, pos, normal = smooth_sphere_capsule(pos1, size1[0], pos2, cap_axis, size2[0], size2[1]) + frame = smooth_make_frame(normal) + contact_dist_out[cid] = dist + contact_pos_out[cid] = pos + contact_frame_out[cid] = frame + handled = True + + # capsule-capsule (2 contacts via geomcollisionid) + if not handled and t1 == 3 and t2 == 3: + cap1_axis = wp.vec3(mat1[0, 2], mat1[1, 2], mat1[2, 2]) + cap2_axis = wp.vec3(mat2[0, 2], mat2[1, 2], mat2[2, 2]) + dists, positions, normals = smooth_capsule_capsule( + pos1, + cap1_axis, + size1[0], + size1[1], + pos2, + cap2_axis, + size2[0], + size2[1], + 1e10, # large margin so we always compute both contacts + ) + if subcid == 0: + contact_dist_out[cid] = dists[0] + contact_pos_out[cid] = wp.vec3(positions[0, 0], positions[0, 1], positions[0, 2]) + normal0 = wp.vec3(normals[0, 0], normals[0, 1], normals[0, 2]) + contact_frame_out[cid] = smooth_make_frame(normal0) + else: + contact_dist_out[cid] = dists[1] + contact_pos_out[cid] = wp.vec3(positions[1, 0], positions[1, 1], positions[1, 2]) + normal1 = wp.vec3(normals[1, 0], normals[1, 1], normals[1, 2]) + contact_frame_out[cid] = smooth_make_frame(normal1) + handled = True + + # plane-capsule (2 contacts via geomcollisionid) + if not handled and t1 == 0 and t2 == 3: + plane_normal = wp.vec3(mat1[0, 2], mat1[1, 2], mat1[2, 2]) + cap_axis = wp.vec3(mat2[0, 2], mat2[1, 2], mat2[2, 2]) + dists, positions, frame = smooth_plane_capsule(plane_normal, pos1, pos2, cap_axis, size2[0], size2[1]) + if subcid == 0: + contact_dist_out[cid] = dists[0] + contact_pos_out[cid] = wp.vec3(positions[0, 0], positions[0, 1], positions[0, 2]) + else: + contact_dist_out[cid] = dists[1] + contact_pos_out[cid] = wp.vec3(positions[1, 0], positions[1, 1], positions[1, 2]) + contact_frame_out[cid] = frame + handled = True + + # Unsupported types: no-op (keeps discrete values, zero gradient) + + +# ============================================================================ +# Shared constraint parameter computation +# ============================================================================ + + +@wp.func +def compute_k_imp( + # Model: + opt_disableflags: int, + # In: + solref: wp.vec2, + solimp: types.vec5, + pos: float, + timestep: float, +) -> wp.vec2: + """Compute stiffness k and impedance imp from solref/solimp parameters. + + Returns (k, imp) packed as a vec2. Used by both the forward constraint + assembly and the adjoint gradient kernel. + """ + timeconst = solref[0] + dampratio = solref[1] + dmin = solimp[0] + dmax = solimp[1] + width = solimp[2] + mid = solimp[3] + power = solimp[4] + + if not (opt_disableflags & DisableBit.REFSAFE): + timeconst = wp.max(timeconst, 2.0 * timestep) + + dmin = wp.clamp(dmin, types.MJ_MINIMP, types.MJ_MAXIMP) + dmax = wp.clamp(dmax, types.MJ_MINIMP, types.MJ_MAXIMP) + width = wp.max(MJ_MINVAL, width) + mid = wp.clamp(mid, types.MJ_MINIMP, types.MJ_MAXIMP) + power = wp.max(1.0, power) + + dmax_sq = dmax * dmax + k = 1.0 / (dmax_sq * timeconst * timeconst * dampratio * dampratio) + k = wp.where(solref[0] <= 0.0, -solref[0] / dmax_sq, k) + + imp_x = wp.abs(pos) / width + imp_a = (1.0 / wp.pow(mid, power - 1.0)) * wp.pow(imp_x, power) + imp_b = 1.0 - (1.0 / wp.pow(1.0 - mid, power - 1.0)) * wp.pow(1.0 - imp_x, power) + imp_y = wp.where(imp_x < mid, imp_a, imp_b) + imp = dmin + imp_y * (dmax - dmin) + imp = wp.clamp(imp, dmin, dmax) + imp = wp.where(imp_x > 1.0, dmax, imp) + + return wp.vec2(k, imp) + + +# ============================================================================ +# Differentiable constraint assembly kernel +# ============================================================================ + + +@wp.func +def _smooth_efc_row( + # Model: + opt_disableflags: int, + # In: + worldid: int, + timestep: float, + efcid: int, + pos_aref: float, + pos_imp: float, + invweight: float, + solref: wp.vec2, + solimp: types.vec5, + margin: float, + vel: float, + # Out: + pos_out: wp.array2d(dtype=float), + D_out: wp.array2d(dtype=float), + aref_out: wp.array2d(dtype=float), + vel_out: wp.array2d(dtype=float), +): + """Smooth reimplementation of _efc_row for differentiable constraint params.""" + k_imp = compute_k_imp(opt_disableflags, solref, solimp, pos_imp, timestep) + k = k_imp[0] + imp = k_imp[1] + + # Damping coefficient (not shared — only needed by forward, not adjoint) + dmax = wp.clamp(solimp[1], types.MJ_MINIMP, types.MJ_MAXIMP) + timeconst = solref[0] + if not (opt_disableflags & DisableBit.REFSAFE): + timeconst = wp.max(timeconst, 2.0 * timestep) + b = 2.0 / (dmax * timeconst) + b = wp.where(solref[1] <= 0.0, -solref[1] / dmax, b) + + D_out[worldid, efcid] = 1.0 / wp.max(invweight * (1.0 - imp) / imp, MJ_MINVAL) + vel_out[worldid, efcid] = vel + aref_out[worldid, efcid] = -k * imp * pos_aref - b * vel + pos_out[worldid, efcid] = pos_aref + margin + + +@wp.kernel +def _smooth_contact_to_efc_kernel( + # Model: + nv: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + opt_impratio_invsqrt: wp.array(dtype=float), + body_parentid: wp.array(dtype=int), + body_rootid: wp.array(dtype=int), + body_weldid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), + body_invweight0: wp.array2d(dtype=wp.vec2), + dof_bodyid: wp.array(dtype=int), + dof_parentid: wp.array(dtype=int), + geom_bodyid: wp.array(dtype=int), + # Data in: + qvel_in: wp.array2d(dtype=float), + subtree_com_in: wp.array2d(dtype=wp.vec3), + cdof_in: wp.array2d(dtype=wp.spatial_vector), + contact_dist_in: wp.array(dtype=float), + contact_pos_in: wp.array(dtype=wp.vec3), + contact_frame_in: wp.array(dtype=wp.mat33), + contact_includemargin_in: wp.array(dtype=float), + contact_friction_in: wp.array(dtype=types.vec5), + contact_solref_in: wp.array(dtype=wp.vec2), + contact_solimp_in: wp.array(dtype=types.vec5), + contact_dim_in: wp.array(dtype=int), + contact_geom_in: wp.array(dtype=wp.vec2i), + contact_efc_address_in: wp.array2d(dtype=int), + contact_worldid_in: wp.array(dtype=int), + contact_type_in: wp.array(dtype=int), + njmax_in: int, + nacon_in: wp.array(dtype=int), + # Data out: + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), +): + conid, dimid = wp.tid() + + if conid >= nacon_in[0]: + return + + # Only process constraint contacts + if not (contact_type_in[conid] & 1): # ContactType.CONSTRAINT = 1 + return + + condim = contact_dim_in[conid] + if condim == 1 and dimid > 0: + return + elif condim > 1 and dimid >= 2 * (condim - 1): + return + + # Read efc_address — skip if -1 (not active) + efcid = contact_efc_address_in[conid, dimid] + if efcid < 0: + return + if efcid >= njmax_in: + return + + worldid = contact_worldid_in[conid] + timestep = opt_timestep[worldid % opt_timestep.shape[0]] + impratio_invsqrt = opt_impratio_invsqrt[worldid % opt_impratio_invsqrt.shape[0]] + + geom = contact_geom_in[conid] + body1 = geom_bodyid[geom[0]] + body2 = geom_bodyid[geom[1]] + + con_pos = contact_pos_in[conid] + frame = contact_frame_in[conid] + includemargin = contact_includemargin_in[conid] + pos = contact_dist_in[conid] - includemargin + + # Pyramidal invweight computation + body_invweight0_id = worldid % body_invweight0.shape[0] + invweight = body_invweight0[body_invweight0_id, body1][0] + body_invweight0[body_invweight0_id, body2][0] + + fri0 = float(0.0) + frii = float(0.0) + dimid2 = int(0) + if condim > 1: + dimid2 = dimid / 2 + 1 + friction = contact_friction_in[conid] + fri0 = friction[0] + frii = friction[dimid2 - 1] + invweight = invweight + fri0 * fri0 * invweight + invweight = invweight * 2.0 * fri0 * fri0 * impratio_invsqrt * impratio_invsqrt + + Jqvel = float(0.0) + + # Skip fixed bodies + body1 = body_weldid[body1] + body2 = body_weldid[body2] + + da1 = int(body_dofadr[body1] + body_dofnum[body1] - 1) + da2 = int(body_dofadr[body2] + body_dofnum[body2] - 1) + + # Dense Jacobian computation (AD requires dense) + da = wp.max(da1, da2) + dofid = int(nv - 1) + + while True: + if dofid < 0: + break + + if dofid == da: + jac1p, jac1r = support.jac_dof( + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + con_pos, + body1, + dofid, + worldid, + ) + jac2p, jac2r = support.jac_dof( + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + con_pos, + body2, + dofid, + worldid, + ) + + J = float(0.0) + Ji = float(0.0) + + for xyz in range(3): + jacp_dif = jac2p[xyz] - jac1p[xyz] + J += frame[0, xyz] * jacp_dif + + if condim > 1: + if dimid2 < 3: + Ji += frame[dimid2, xyz] * jacp_dif + else: + Ji += frame[dimid2 - 3, xyz] * (jac2r[xyz] - jac1r[xyz]) + + if condim > 1: + if dimid % 2 == 0: + J += Ji * frii + else: + J -= Ji * frii + + efc_J_out[worldid, efcid, dofid] = J + Jqvel += J * qvel_in[worldid, dofid] + + # Advance tree pointers + if da1 == da: + da1 = dof_parentid[da1] + if da2 == da: + da2 = dof_parentid[da2] + da = wp.max(da1, da2) + dofid -= 1 + else: + efc_J_out[worldid, efcid, dofid] = 0.0 + dofid -= 1 + + # Compute constraint equation row + _smooth_efc_row( + opt_disableflags, + worldid, + timestep, + efcid, + pos, + pos, + invweight, + contact_solref_in[conid], + contact_solimp_in[conid], + includemargin, + Jqvel, + efc_pos_out, + efc_D_out, + efc_aref_out, + efc_vel_out, + ) + + +# ============================================================================ +# Python launchers +# ============================================================================ + + +def smooth_recompute_contacts(m: types.Model, d: types.Data): + """Overwrite contact.{dist, pos, frame} with smooth differentiable values.""" + if d.naconmax == 0: + return + + wp.launch( + _smooth_recompute_kernel, + dim=d.naconmax, + inputs=[ + # Model + m.geom_type, + m.geom_bodyid, + m.geom_size, + # Data in + d.geom_xpos, + d.geom_xmat, + d.contact.geom, + d.contact.worldid, + d.contact.geomcollisionid, + d.nacon, + ], + outputs=[ + d.contact.dist, + d.contact.pos, + d.contact.frame, + ], + ) + + +def smooth_contact_to_efc(m: types.Model, d: types.Data): + """Overwrite efc.{J, pos, D, aref, vel} with smooth differentiable values.""" + if d.naconmax == 0 or d.njmax == 0: + return + + wp.launch( + _smooth_contact_to_efc_kernel, + dim=(d.naconmax, m.nmaxpyramid), + inputs=[ + # Model + m.nv, + m.opt.timestep, + m.opt.disableflags, + m.opt.impratio_invsqrt, + m.body_parentid, + m.body_rootid, + m.body_weldid, + m.body_dofnum, + m.body_dofadr, + m.body_invweight0, + m.dof_bodyid, + m.dof_parentid, + m.geom_bodyid, + # Data in + d.qvel, + d.subtree_com, + d.cdof, + d.contact.dist, + d.contact.pos, + d.contact.frame, + d.contact.includemargin, + d.contact.friction, + d.contact.solref, + d.contact.solimp, + d.contact.dim, + d.contact.geom, + d.contact.efc_address, + d.contact.worldid, + d.contact.type, + d.njmax, + d.nacon, + ], + outputs=[ + d.efc.J, + d.efc.pos, + d.efc.D, + d.efc.vel, + d.efc.aref, + ], + ) diff --git a/mujoco_warp/_src/derivative.py b/mujoco_warp/_src/derivative.py index 575a769d2..5deb1efd2 100644 --- a/mujoco_warp/_src/derivative.py +++ b/mujoco_warp/_src/derivative.py @@ -15,6 +15,7 @@ import warp as wp +from mujoco_warp._src.support import next_act from mujoco_warp._src.types import BiasType from mujoco_warp._src.types import Data from mujoco_warp._src.types import DisableBit @@ -30,18 +31,24 @@ @wp.kernel def _qderiv_actuator_passive_vel( # Model: + opt_timestep: wp.array(dtype=float), actuator_dyntype: wp.array(dtype=int), actuator_gaintype: wp.array(dtype=int), actuator_biastype: wp.array(dtype=int), actuator_actadr: wp.array(dtype=int), actuator_actnum: wp.array(dtype=int), actuator_forcelimited: wp.array(dtype=bool), + actuator_actlimited: wp.array(dtype=bool), + actuator_dynprm: wp.array2d(dtype=vec10f), actuator_gainprm: wp.array2d(dtype=vec10f), actuator_biasprm: wp.array2d(dtype=vec10f), + actuator_actearly: wp.array(dtype=bool), actuator_forcerange: wp.array2d(dtype=wp.vec2), + actuator_actrange: wp.array2d(dtype=wp.vec2), # Data in: act_in: wp.array2d(dtype=float), ctrl_in: wp.array2d(dtype=float), + act_dot_in: wp.array2d(dtype=float), actuator_force_in: wp.array2d(dtype=float), # Out: vel_out: wp.array2d(dtype=float), @@ -76,9 +83,24 @@ def _qderiv_actuator_passive_vel( vel = float(bias) if actuator_dyntype[actid] != DynType.NONE: if gain != 0.0: - act_first = actuator_actadr[actid] - act_last = act_first + actuator_actnum[actid] - 1 - vel += gain * act_in[worldid, act_last] + act_adr = actuator_actadr[actid] + actuator_actnum[actid] - 1 + + # use next activation if actearly is set (matching forward pass) + if actuator_actearly[actid]: + act = next_act( + opt_timestep[worldid % opt_timestep.shape[0]], + actuator_dyntype[actid], + actuator_dynprm[worldid % actuator_dynprm.shape[0], actid], + actuator_actrange[worldid % actuator_actrange.shape[0], actid], + act_in[worldid, act_adr], + act_dot_in[worldid, act_adr], + 1.0, + actuator_actlimited[actid], + ) + else: + act = act_in[worldid, act_adr] + + vel += gain * act else: if gain != 0.0: vel += gain * ctrl_in[worldid, actid] @@ -95,10 +117,9 @@ def _nonzero_mask(x: float) -> float: @wp.kernel -def _qderiv_actuator_passive_actuation_sparse( +def _qderiv_actuator_passive_actuation_dense( # Model: nu: int, - is_sparse: bool, # Data in: moment_rownnz_in: wp.array2d(dtype=int), moment_rowadr_in: wp.array2d(dtype=int), @@ -142,12 +163,63 @@ def _qderiv_actuator_passive_actuation_sparse( qderiv_contrib += moment_i * moment_j * vel - if is_sparse: - qDeriv_out[worldid, 0, elemid] = qderiv_contrib - else: - qDeriv_out[worldid, dofiid, dofjid] = qderiv_contrib - if dofiid != dofjid: - qDeriv_out[worldid, dofjid, dofiid] = qderiv_contrib + qDeriv_out[worldid, dofiid, dofjid] = qderiv_contrib + if dofiid != dofjid: + qDeriv_out[worldid, dofjid, dofiid] = qderiv_contrib + + +@wp.kernel +def _qderiv_actuator_passive_actuation_sparse( + # Model: + M_rownnz: wp.array(dtype=int), + M_rowadr: wp.array(dtype=int), + # Data in: + moment_rownnz_in: wp.array2d(dtype=int), + moment_rowadr_in: wp.array2d(dtype=int), + moment_colind_in: wp.array2d(dtype=int), + actuator_moment_in: wp.array2d(dtype=float), + # In: + vel_in: wp.array2d(dtype=float), + qMj: wp.array(dtype=int), + # Out: + qDeriv_out: wp.array3d(dtype=float), +): + worldid, actid = wp.tid() + + vel = vel_in[worldid, actid] + if vel == 0.0: + return + + rownnz = moment_rownnz_in[worldid, actid] + rowadr = moment_rowadr_in[worldid, actid] + + for i in range(rownnz): + rowadri = rowadr + i + moment_i = actuator_moment_in[worldid, rowadri] + if moment_i == 0.0: + continue + dofi = moment_colind_in[worldid, rowadri] + + for j in range(i + 1): + rowadrj = rowadr + j + moment_j = actuator_moment_in[worldid, rowadrj] + if moment_j == 0.0: + continue + dofj = moment_colind_in[worldid, rowadrj] + + contrib = moment_i * moment_j * vel + + # Search the corresponding elemid + # TODO: This could be precalculated for improved performance + row = dofi + col = dofj + row_startk = M_rowadr[row] - 1 + row_nnz = M_rownnz[row] + for k in range(row_nnz): + row_startk += 1 + if qMj[row_startk] == col: + wp.atomic_add(qDeriv_out[worldid, 0], row_startk, contrib) + break @wp.kernel @@ -268,27 +340,41 @@ def deriv_smooth_vel(m: Model, d: Data, out: wp.array2d(dtype=float)): _qderiv_actuator_passive_vel, dim=(d.nworld, m.nu), inputs=[ + m.opt.timestep, m.actuator_dyntype, m.actuator_gaintype, m.actuator_biastype, m.actuator_actadr, m.actuator_actnum, m.actuator_forcelimited, + m.actuator_actlimited, + m.actuator_dynprm, m.actuator_gainprm, m.actuator_biasprm, + m.actuator_actearly, m.actuator_forcerange, + m.actuator_actrange, d.act, d.ctrl, + d.act_dot, d.actuator_force, ], outputs=[vel], ) - wp.launch( - _qderiv_actuator_passive_actuation_sparse, - dim=(d.nworld, qMi.size), - inputs=[m.nu, m.is_sparse, d.moment_rownnz, d.moment_rowadr, d.moment_colind, d.actuator_moment, vel, qMi, qMj], - outputs=[out], - ) + if m.is_sparse: + wp.launch( + _qderiv_actuator_passive_actuation_sparse, + dim=(d.nworld, m.nu), + inputs=[m.M_rownnz, m.M_rowadr, d.moment_rownnz, d.moment_rowadr, d.moment_colind, d.actuator_moment, vel, qMj], + outputs=[out], + ) + else: + wp.launch( + _qderiv_actuator_passive_actuation_dense, + dim=(d.nworld, qMi.size), + inputs=[m.nu, d.moment_rownnz, d.moment_rowadr, d.moment_colind, d.actuator_moment, vel, qMi, qMj], + outputs=[out], + ) wp.launch( _qderiv_actuator_passive, dim=(d.nworld, qMi.size), diff --git a/mujoco_warp/_src/derivative_test.py b/mujoco_warp/_src/derivative_test.py index da4d4b3bc..cc4746d29 100644 --- a/mujoco_warp/_src/derivative_test.py +++ b/mujoco_warp/_src/derivative_test.py @@ -209,6 +209,261 @@ def test_step_tendon_serial_chain_no_nan(self): self.assertFalse(np.any(np.isnan(mjd.qpos))) self.assertFalse(np.any(np.isnan(mjd.qvel))) + def test_smooth_vel_sparse_tendon_coupled(self): + """Tests qDeriv kernel with nv > 32 and moment_rownnz > 1. + + Builds a chain of 35 DOFs (forcing sparse path) with a fixed tendon + coupling two joints, producing an actuator with moment_rownnz=2. + """ + # Build a chain long enough to force sparse (nv > 32) + xml = f""" + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + """ + + mjm, mjd, m, d = test_data.fixture( + xml=xml, + keyframe=0, + overrides={"opt.jacobian": mujoco.mjtJacobian.mjJAC_SPARSE}, + ) + + self.assertTrue(m.is_sparse, "Model should use sparse path (nv > 32)") + + mujoco.mj_step(mjm, mjd) + + out_smooth_vel = wp.zeros((1, 1, m.nM), dtype=float) + mjw.deriv_smooth_vel(m, d, out_smooth_vel) + + mjw_out = np.zeros((m.nv, m.nv)) + for elem, (i, j) in enumerate(zip(m.qM_fullm_i.numpy(), m.qM_fullm_j.numpy())): + mjw_out[i, j] = out_smooth_vel.numpy()[0, 0, elem] + mjw_out[j, i] = out_smooth_vel.numpy()[0, 0, elem] + + mj_qDeriv = np.zeros((mjm.nv, mjm.nv)) + mujoco.mju_sparse2dense(mj_qDeriv, mjd.qDeriv, mjm.D_rownnz, mjm.D_rowadr, mjm.D_colind) + + mj_qM = np.zeros((m.nv, m.nv)) + mujoco.mj_fullM(mjm, mj_qM, mjd.qM) + mj_out = mj_qM - mjm.opt.timestep * mj_qDeriv + + self.assertFalse(np.any(np.isnan(mjw_out))) + _assert_eq(mjw_out, mj_out, "qM - dt * qDeriv (sparse tendon coupled)") + + def test_actearly_derivative(self): + """Implicit derivatives should use next activation when actearly is set.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + """, + keyframe=0, + ) + + # both should have same act_dot (ctrl = 1 for integrator dynamics) + _assert_eq(d.act_dot.numpy()[0, 0], d.act_dot.numpy()[0, 1], "act_dot") + + # compute qDeriv using deriv_smooth_vel + out_smooth_vel = wp.zeros(d.qM.shape, dtype=float) + mjw.deriv_smooth_vel(m, d, out_smooth_vel) + mjw_out = out_smooth_vel.numpy()[0, : m.nv, : m.nv] + + # with actearly=true and nonzero act_dot, derivative should differ + # because actearly uses next activation: act + act_dot*dt + # for our model: next_act = 0 + 1*1 = 1, current_act = 0 + # derivative adds gain_vel * act to qDeriv diagonal + # qDeriv = qM - dt * actuator_vel_derivative + # for independent bodies with mass=1: qM diagonal = 1.0 + # actearly=true: vel = gain_vel * next_act = 1 * 1 = 1, out = 1 - 1*1 = 0 + # actearly=false: vel = gain_vel * current_act = 1 * 0 = 0, out = 1 - 1*0 = 1 + self.assertNotAlmostEqual( + mjw_out[0, 0], + mjw_out[1, 1], + msg="actearly=true should use next activation in derivative", + ) + _assert_eq(mjw_out[0, 0], 0.0, "actearly=true: qM - dt*gain_vel*next_act = 1 - 1*1 = 0") + _assert_eq(mjw_out[1, 1], 1.0, "actearly=false: qM - dt*gain_vel*current_act = 1 - 1*0 = 1") + def test_forcerange_clamped_derivative(self): """Implicit integration is more accurate than Euler with active forcerange clamping.""" xml = """ diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 67114e07d..060e472b2 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -18,6 +18,7 @@ import warp as wp from mujoco_warp._src import collision_driver +from mujoco_warp._src import collision_smooth from mujoco_warp._src import constraint from mujoco_warp._src import derivative from mujoco_warp._src import island @@ -26,7 +27,9 @@ from mujoco_warp._src import sensor from mujoco_warp._src import smooth from mujoco_warp._src import solver +from mujoco_warp._src import support from mujoco_warp._src import util_misc +from mujoco_warp._src.support import next_act from mujoco_warp._src.support import xfrc_accumulate from mujoco_warp._src.types import MJ_MINVAL from mujoco_warp._src.types import BiasType @@ -128,37 +131,6 @@ def _next_velocity( qvel_out[worldid, dofid] = qvel_in[worldid, dofid] + qacc_scale_in * qacc_in[worldid, dofid] * timestep -# TODO(team): kernel analyzer array slice? -@wp.func -def _next_act( - # Model: - opt_timestep: float, # kernel_analyzer: ignore - actuator_dyntype: int, # kernel_analyzer: ignore - actuator_dynprm: vec10f, # kernel_analyzer: ignore - actuator_actrange: wp.vec2, # kernel_analyzer: ignore - # Data In: - act_in: float, # kernel_analyzer: ignore - act_dot_in: float, # kernel_analyzer: ignore - # In: - act_dot_scale: float, - clamp: bool, -) -> float: - # advance actuation - if actuator_dyntype == DynType.FILTEREXACT: - tau = wp.max(MJ_MINVAL, actuator_dynprm[0]) - act = act_in + act_dot_scale * act_dot_in * tau * (1.0 - wp.exp(-opt_timestep / tau)) - elif actuator_dyntype == DynType.USER: - return act_in - else: - act = act_in + act_dot_scale * act_dot_in * opt_timestep - - # clamp to actrange - if clamp: - act = wp.clamp(act, actuator_actrange[0], actuator_actrange[1]) - - return act - - @wp.kernel def _next_activation( # Model: @@ -185,7 +157,7 @@ def _next_activation( actadr = actuator_actadr[uid] actnum = actuator_actnum[uid] for j in range(actadr, actadr + actnum): - act = _next_act( + act = next_act( opt_timestep[opt_timestep_id], actuator_dyntype[uid], actuator_dynprm[actuator_dynprm_id, uid], @@ -308,7 +280,15 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None) outputs=[d.time], ) - wp.copy(d.qacc_warmstart, d.qacc) + # Use _nograd_copy: warmstart is a numerical hint, not a gradient path. + # wp.copy would be tracked on the tape and create cross-substep gradient + # leaks through the shared d.qacc_warmstart array. + wp.launch( + support._nograd_copy, + dim=(d.nworld, qacc.shape[1]), + inputs=[qacc], + outputs=[d.qacc_warmstart], + ) @wp.kernel @@ -327,6 +307,20 @@ def _euler_damp_qfrc_sparse( qM_integration_out[worldid, 0, adr] += timestep * dof_damping[worldid % dof_damping.shape[0], tid] +@wp.kernel +def _euler_damp_qfrc_dense( + # Model: + opt_timestep: wp.array(dtype=float), + dof_damping: wp.array2d(dtype=float), + # Out: + qM_integration_out: wp.array3d(dtype=float), +): + """Add dt * damping to diagonal of dense (nworld, nv, nv) mass matrix.""" + worldid, tid = wp.tid() + timestep = opt_timestep[worldid % opt_timestep.shape[0]] + qM_integration_out[worldid, tid, tid] += timestep * dof_damping[worldid % dof_damping.shape[0], tid] + + @cache_kernel def _tile_euler_dense(tile: TileSet): @wp.kernel(module="unique", enable_backward=False) @@ -365,11 +359,12 @@ def euler(m: Model, d: Data): """Euler integrator, semi-implicit in velocity.""" # integrate damping implicitly if not (m.opt.disableflags & (DisableBit.EULERDAMP | DisableBit.DAMPER)): - qacc = wp.empty((d.nworld, m.nv), dtype=float) + ad_active = d.qpos.requires_grad + qacc = wp.empty((d.nworld, m.nv), dtype=float, requires_grad=ad_active) if m.is_sparse: qM = wp.clone(d.qM) - qLD = wp.empty((d.nworld, 1, m.nC), dtype=float) - qLDiagInv = wp.empty((d.nworld, m.nv), dtype=float) + qLD = wp.empty((d.nworld, 1, m.nC), dtype=float, requires_grad=ad_active) + qLDiagInv = wp.empty((d.nworld, m.nv), dtype=float, requires_grad=ad_active) wp.launch( _euler_damp_qfrc_sparse, dim=(d.nworld, m.nv), @@ -386,8 +381,11 @@ def euler(m: Model, d: Data): outputs=[qacc], block_dim=m.block_dim.euler_dense, ) + _record_solver_adjoint(m, d, qacc_array=qacc) + _record_euler_damp_adjoint(m, d, qacc) _advance(m, d, qacc) else: + _record_solver_adjoint(m, d, qacc_array=d.qacc) _advance(m, d, d.qacc) @@ -498,14 +496,15 @@ def rungekutta4(m: Model, d: Data): A = [0.5, 0.5, 1.0] # diagonal only B = [1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0] + ad_active = d.qpos.requires_grad qpos_t0 = wp.clone(d.qpos) qvel_t0 = wp.clone(d.qvel) - qvel_rk = wp.zeros((d.nworld, m.nv), dtype=float) - qacc_rk = wp.zeros((d.nworld, m.nv), dtype=float) + qvel_rk = wp.zeros((d.nworld, m.nv), dtype=float, requires_grad=ad_active) + qacc_rk = wp.zeros((d.nworld, m.nv), dtype=float, requires_grad=ad_active) if m.na: act_t0 = wp.clone(d.act) - act_dot_rk = wp.zeros((d.nworld, m.na), dtype=float) + act_dot_rk = wp.zeros((d.nworld, m.na), dtype=float, requires_grad=ad_active) else: act_t0 = None act_dot_rk = None @@ -525,6 +524,7 @@ def rungekutta4(m: Model, d: Data): wp.copy(d.act, act_t0) wp.copy(d.act_dot, act_dot_rk) + _record_solver_adjoint(m, d, qacc_array=qacc_rk) _advance(m, d, qacc_rk, qvel_rk) @@ -532,6 +532,7 @@ def rungekutta4(m: Model, d: Data): def implicit(m: Model, d: Data): """Integrates fully implicit in velocity.""" if ~(m.opt.disableflags | ~(DisableBit.ACTUATION | DisableBit.SPRING | DisableBit.DAMPER)): + ad_active = d.qpos.requires_grad if m.is_sparse: qDeriv = wp.empty((d.nworld, 1, m.nM), dtype=float) qLD = wp.empty((d.nworld, 1, m.nC), dtype=float) @@ -540,10 +541,12 @@ def implicit(m: Model, d: Data): qLD = wp.empty(d.qM.shape, dtype=float) qLDiagInv = wp.empty((d.nworld, m.nv), dtype=float) derivative.deriv_smooth_vel(m, d, qDeriv) - qacc = wp.empty((d.nworld, m.nv), dtype=float) + qacc = wp.empty((d.nworld, m.nv), dtype=float, requires_grad=ad_active) smooth.factor_solve_i(m, d, qDeriv, qLD, qLDiagInv, qacc, d.efc.Ma) + _record_solver_adjoint(m, d, qacc_array=qacc) _advance(m, d, qacc) else: + _record_solver_adjoint(m, d, qacc_array=d.qacc) _advance(m, d, d.qacc) @@ -567,7 +570,15 @@ def fwd_position(m: Model, d: Data, factorize: bool = True): smooth.factor_m(m, d) if m.opt.run_collision_detection: collision_driver.collision(m, d) + # Phase 3: smooth collision recomputation for AD + tape = wp._src.context.runtime.tape + if tape is not None and d.qpos.requires_grad: + collision_smooth.smooth_recompute_contacts(m, d) constraint.make_constraint(m, d) + # Phase 3: differentiable constraint assembly for AD + tape = wp._src.context.runtime.tape + if tape is not None and d.qpos.requires_grad: + collision_smooth.smooth_contact_to_efc(m, d) # TODO(team): remove False after island features are more complete if False and not (m.opt.disableflags & DisableBit.ISLAND): island.island(m, d) @@ -720,7 +731,7 @@ def _actuator_force( if dyntype == DynType.INTEGRATOR or dyntype == DynType.NONE: act = act_in[worldid, act_last] - ctrl_act = _next_act( + ctrl_act = next_act( opt_timestep[worldid % opt_timestep.shape[0]], dyntype, dynprm, @@ -950,11 +961,7 @@ def fwd_actuation(m: Model, d: Data): ) # clone to break input/output aliasing for correct AD; skip when not # recording a backward tape to avoid unnecessary allocation + copy. - qfrc_actuator_in = ( - wp.clone(d.qfrc_actuator) - if d.qfrc_actuator.requires_grad - else d.qfrc_actuator - ) + qfrc_actuator_in = wp.clone(d.qfrc_actuator) if d.qfrc_actuator.requires_grad else d.qfrc_actuator wp.launch( _qfrc_actuator_gravcomp_limits, dim=(d.nworld, m.nv), @@ -1012,10 +1019,166 @@ def fwd_acceleration(m: Model, d: Data, factorize: bool = False): else: smooth.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth) + # Custom adjoint for M_inv solve on the dense path. + # The tile Cholesky kernels have enable_backward=False, so the tape cannot + # propagate qacc_smooth.grad -> qfrc_smooth.grad automatically. We record + # a callback that performs the VJP: qfrc_smooth.grad += M_inv * qacc_smooth.grad + # (M is symmetric so M_inv^T = M_inv). + _record_fwd_accel_adjoint(m, d) + + +def _record_fwd_accel_adjoint(m: Model, d: Data): + """Record custom adjoint for the M_inv solve in fwd_acceleration. + + On the dense path, _tile_cholesky_factorize_solve has enable_backward=False. + This record_func propagates qacc_smooth.grad -> qfrc_smooth.grad via M_inv, + using the already-factored d.qLD from the forward pass. + + Array references are captured at record time (not through d) so that + intermediate array cloning between substeps routes each substep's adjoint + to the correct .grad memory. + """ + tape = wp._src.context.runtime.tape + if tape is not None and d.qpos.requires_grad: + from mujoco_warp._src.adjoint import _accumulate_grad_kernel + + # Capture current array refs for correct gradient isolation across substeps + qacc_smooth_ref = d.qacc_smooth + qfrc_smooth_ref = d.qfrc_smooth + + def _adjoint(m=m, d=d, qacc_smooth=qacc_smooth_ref, qfrc_smooth=qfrc_smooth_ref): + adj_qacc_smooth = qacc_smooth.grad + if adj_qacc_smooth is None: + return + + # qfrc_smooth.grad += M_inv * qacc_smooth.grad + tmp = wp.zeros_like(qfrc_smooth) + smooth.solve_m(m, d, tmp, adj_qacc_smooth) + if qfrc_smooth.grad is None: + qfrc_smooth.grad = tmp + else: + wp.launch( + _accumulate_grad_kernel, + dim=(d.nworld, m.nv), + inputs=[tmp], + outputs=[qfrc_smooth.grad], + ) + + tape.record_func(_adjoint, [qacc_smooth_ref, qfrc_smooth_ref]) + + +def _record_solver_adjoint(m: Model, d: Data, qacc_array=None): + """Record the solver implicit differentiation adjoint on the active tape. + + Args: + m: Model containing static simulation parameters. + d: Data containing mutable simulation state. + qacc_array: The array whose .grad will receive the incoming adjoint from + the integrator backward. Defaults to d.qacc (correct when + the integrator uses d.qacc directly, e.g. eulerdamp disabled). + Integrators that create a local qacc must pass it here. + + Array references are captured at record time so that intermediate array + cloning between substeps routes each substep's adjoint correctly. + """ + tape = wp._src.context.runtime.tape + if tape is not None and d.qpos.requires_grad: + if qacc_array is None: + qacc_array = d.qacc + + # Capture qacc_smooth ref at record time for gradient isolation + qacc_smooth_ref = d.qacc_smooth + + if getattr(d, "smooth_adjoint", 0): + from mujoco_warp._src.adjoint import solver_smooth_adjoint + + tape.record_func( + lambda m=m, d=d, qa=qacc_array, qs=qacc_smooth_ref: solver_smooth_adjoint(m, d, qacc_array=qa, qacc_smooth_ref=qs), + [qacc_array, qacc_smooth_ref], + ) + else: + from mujoco_warp._src.adjoint import solver_implicit_adjoint + + tape.record_func( + lambda m=m, d=d, qa=qacc_array, qs=qacc_smooth_ref: solver_implicit_adjoint(m, d, qacc_array=qa, qacc_smooth_ref=qs), + [qacc_array, qacc_smooth_ref], + ) + + +def _record_euler_damp_adjoint(m: Model, d: Data, qacc: wp.array): + """Record euler-damping adjoint transformation on the active tape. + + During backward, transforms qacc.grad from the raw integrator adjoint + into the correct adjoint that accounts for the (M+dt*D)^{-1}*M + transformation in the euler implicit damping solve. + + Forward: qacc_local = (M + dt*D)^{-1} * M * d.qacc + Adjoint: adj_d_qacc = M * (M + dt*D)^{-1} * adj_qacc_local + + This callback runs between _advance backward (which sets qacc.grad) + and _record_solver_adjoint backward (which reads qacc.grad). + """ + tape = wp._src.context.runtime.tape + if tape is None or not d.qpos.requires_grad: + return + + # Capture the forward-pass mass matrix reference at record time. + # _isolate_intermediates_for_ad() allocates fresh d.qM each substep, + # so this captures the correct per-substep mass matrix. + qM_ref = d.qM + qacc_ref = qacc + + def _adjoint(m=m, d=d, qM=qM_ref, qacc_arr=qacc_ref): + adj_qacc = qacc_arr.grad + if adj_qacc is None: + return + + nv = m.nv + + # Step 1: Construct M_damp = M + dt*D + qM_damp = wp.clone(qM) + if m.is_sparse: + wp.launch( + _euler_damp_qfrc_sparse, + dim=(d.nworld, nv), + inputs=[m.opt.timestep, m.dof_Madr, m.dof_damping], + outputs=[qM_damp], + ) + else: + wp.launch( + _euler_damp_qfrc_dense, + dim=(d.nworld, nv), + inputs=[m.opt.timestep, m.dof_damping], + outputs=[qM_damp], + ) + + # Step 2: Solve (M + dt*D) * tmp = adj_qacc + qLD_tmp = wp.zeros_like(d.qLD) + qLDiagInv_tmp = wp.zeros((d.nworld, nv), dtype=float) + tmp = wp.zeros((d.nworld, nv), dtype=float) + smooth.factor_solve_i(m, d, qM_damp, qLD_tmp, qLDiagInv_tmp, tmp, adj_qacc) + + # Step 3: result = M * tmp (using original undamped mass matrix) + result = wp.zeros((d.nworld, nv), dtype=float) + support.mul_m(m, d, result, tmp, M=qM) + + # Step 4: Overwrite qacc.grad with the corrected adjoint + wp.copy(qacc_arr.grad, result) + + tape.record_func(_adjoint, [qacc_ref]) + @event_scope -def forward(m: Model, d: Data): - """Forward dynamics.""" +def forward(m: Model, d: Data, record_solver_adjoint: bool = True): + """Forward dynamics. + + Args: + m: Model containing static simulation parameters. + d: Data containing mutable simulation state. + record_solver_adjoint: If True, record the solver implicit differentiation + adjoint on the tape. Set to False when called from step() since the + integrator records its own adjoint at the correct tape position. + """ energy = m.opt.enableflags & EnableBit.ENERGY fwd_position(m, d, factorize=False) @@ -1042,25 +1205,104 @@ def forward(m: Model, d: Data): solver.solve(m, d) - # Record implicit differentiation adjoint on the active tape - tape = wp._src.context.runtime.tape - if tape is not None and d.qpos.requires_grad: - from mujoco_warp._src.adjoint import solver_implicit_adjoint - - tape.record_func( - lambda m=m, d=d: solver_implicit_adjoint(m, d), - [d.qacc, d.qacc_smooth], - ) + # Record implicit differentiation adjoint on the active tape. + # When called from step(), the integrator handles this instead (at the + # correct tape position between factor_solve_i and _advance). + if record_solver_adjoint: + _record_solver_adjoint(m, d) sensor.sensor_acc(m, d) +def _isolate_intermediates_for_ad(m: Model, d: Data): + """Allocate fresh intermediate arrays for per-substep gradient isolation. + + In tape-all mode (single wp.Tape over multiple step() calls), intermediate + arrays like qfrc_smooth and qacc_smooth are overwritten each substep but + share a single .grad array. This causes backward to accumulate adjoint + contributions from ALL substeps into the same memory (~250,000x amplification + for 16 substeps). + + By allocating fresh arrays at the start of each step(), each substep writes + to its own memory. The tape records operations on these unique arrays, and + backward routes each substep's adjoint to the correct .grad memory. + + Only called when AD is active on an active tape. + + Every array here appears as an output of wp.launch in the pipeline. If + shared across substeps, Warp's adjoint zeroes the output .grad during + the later substep's backward, corrupting the earlier substep's gradient + chain. The allocation overhead is ~25KB/step (negligible for GPU). + """ + + def _clone_with_grad(arr): + """Clone arrays that contain static world data and preserve gradients.""" + cloned = wp.clone(arr) + cloned.requires_grad = True + return cloned + + nw = d.nworld + nv = m.nv + nu = m.nu + + # --- Force arrays --- + d.qfrc_smooth = wp.zeros((nw, nv), dtype=float, requires_grad=True) + d.qacc_smooth = wp.zeros((nw, nv), dtype=float, requires_grad=True) + d.qfrc_actuator = wp.zeros((nw, nv), dtype=float, requires_grad=True) + d.actuator_force = wp.zeros((nw, nu), dtype=float, requires_grad=True) + d.qacc = wp.zeros((nw, nv), dtype=float, requires_grad=True) + d.qfrc_bias = wp.zeros((nw, nv), dtype=float, requires_grad=True) + d.qfrc_passive = wp.zeros((nw, nv), dtype=float, requires_grad=True) + + # --- Kinematics arrays --- + # These use Warp vector/matrix dtypes (vec3, mat33, etc.) so use + # zeros_like to match the exact dtype and shape from the existing arrays. + d.xpos = wp.zeros_like(d.xpos, requires_grad=True) + # Preserve rows that are only initialized once, such as the world body. + d.xmat = _clone_with_grad(d.xmat) + d.xipos = wp.zeros_like(d.xipos, requires_grad=True) + d.ximat = _clone_with_grad(d.ximat) + d.subtree_com = wp.zeros_like(d.subtree_com, requires_grad=True) + d.cinert = wp.zeros_like(d.cinert, requires_grad=True) + d.cdof = wp.zeros_like(d.cdof, requires_grad=True) + d.cdof_dot = wp.zeros_like(d.cdof_dot, requires_grad=True) + d.cvel = wp.zeros_like(d.cvel, requires_grad=True) + d.crb = wp.zeros_like(d.crb, requires_grad=True) + d.cacc = wp.zeros_like(d.cacc, requires_grad=True) + + # --- Mass matrix --- + # Shapes depend on sparse vs dense; zeros_like handles both. + d.qM = wp.zeros_like(d.qM, requires_grad=True) + d.qLD = wp.zeros_like(d.qLD, requires_grad=True) + d.qLDiagInv = wp.zeros((nw, nv), dtype=float, requires_grad=True) + + # --- Geometry / joint kinematics --- + # Static world geoms are not recomputed in smooth._geom_local_to_global(), + # so keep their initialized transforms while giving each step unique storage. + d.geom_xpos = _clone_with_grad(d.geom_xpos) + d.geom_xmat = _clone_with_grad(d.geom_xmat) + d.xanchor = wp.zeros_like(d.xanchor, requires_grad=True) + d.xaxis = wp.zeros_like(d.xaxis, requires_grad=True) + d.subtree_linvel = wp.zeros_like(d.subtree_linvel, requires_grad=True) + d.subtree_angmom = wp.zeros_like(d.subtree_angmom, requires_grad=True) + + # --- Actuator arrays --- + d.actuator_velocity = wp.zeros((nw, nu), dtype=float, requires_grad=True) + + @event_scope def step(m: Model, d: Data): """Advance simulation.""" # TODO(team): mj_checkPos # TODO(team): mj_checkVel - forward(m, d) + + # Allocate fresh intermediate arrays only while recording on a tape. Forward + # rollouts on diff data still need the initialized static transforms. + tape = wp._src.context.runtime.tape + if d.qpos.requires_grad and tape is not None: + _isolate_intermediates_for_ad(m, d) + + forward(m, d, record_solver_adjoint=False) # TODO(team): mj_checkAcc if m.opt.integrator == IntegratorType.EULER: @@ -1108,15 +1350,8 @@ def step2(m: Model, d: Data): fwd_acceleration(m, d) solver.solve(m, d) - # Record implicit differentiation adjoint on the active tape - tape = wp._src.context.runtime.tape - if tape is not None and d.qpos.requires_grad: - from mujoco_warp._src.adjoint import solver_implicit_adjoint - - tape.record_func( - lambda m=m, d=d: solver_implicit_adjoint(m, d), - [d.qacc, d.qacc_smooth], - ) + # The solver adjoint record_func is handled by the integrator below, + # NOT here — see euler()/implicit() for details. sensor.sensor_acc(m, d) # TODO(team): mj_checkAcc diff --git a/mujoco_warp/_src/forward_test.py b/mujoco_warp/_src/forward_test.py index 3b6a9828b..7e031381b 100644 --- a/mujoco_warp/_src/forward_test.py +++ b/mujoco_warp/_src/forward_test.py @@ -40,7 +40,47 @@ def _assert_eq(a, b, name): np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol) +_STATIC_WORLD_GEOM_XML = """ + + +""" + + class ForwardTest(parameterized.TestCase): + @parameterized.parameters(False, True) + def test_step_preserves_static_world_geom_transforms(self, use_tape): + """Static world geoms must keep their initialized transforms on diff data.""" + mjm = mujoco.MjModel.from_xml_string(_STATIC_WORLD_GEOM_XML) + m = mjw.put_model(mjm) + d = mjw.make_diff_data(mjm) + mjw.reset_data(m, d) + wp.synchronize() + + floor_id = mujoco.mj_name2id(mjm, mujoco.mjtObj.mjOBJ_GEOM, "floor") + eye = np.eye(3, dtype=np.float32) + _assert_eq(d.geom_xmat.numpy()[0, floor_id], eye, "geom_xmat init") + self.assertEqual(int(d.nacon.numpy()[0]), 0) + + if use_tape: + tape = wp.Tape() + with tape: + mjw.step(m, d) + tape.zero() + else: + mjw.step(m, d) + wp.synchronize() + + _assert_eq(d.geom_xmat.numpy()[0, floor_id], eye, "geom_xmat after step") + self.assertEqual(int(d.nacon.numpy()[0]), 0) + @parameterized.product(xml=["humanoid/humanoid.xml", "pendula.xml"]) def test_fwd_velocity(self, xml): _, mjd, m, d = test_data.fixture(xml, qvel_noise=0.01, ctrl_noise=0.1) @@ -649,6 +689,12 @@ def oscillator(m, d): np.testing.assert_allclose(d.act.numpy()[0, 0], np.cos(2 * np.pi * frequency * t_next), atol=1e-3) np.testing.assert_allclose(d.act.numpy()[0, 1], np.sin(2 * np.pi * frequency * t_next), atol=1e-3) + def test_multiflex(self): + """Tests multiflex model with different flex dimensions.""" + _, _, m, d = test_data.fixture("flex/multiflex.xml") + + mjw.forward(m, d) + if __name__ == "__main__": wp.init() diff --git a/mujoco_warp/_src/grad.py b/mujoco_warp/_src/grad.py index e30fd0f30..3ed017065 100644 --- a/mujoco_warp/_src/grad.py +++ b/mujoco_warp/_src/grad.py @@ -93,25 +93,47 @@ "sensordata", ) -SOLVER_GRAD_FIELDS: tuple = ( - "qfrc_constraint", +SOLVER_GRAD_FIELDS: tuple = ("qfrc_constraint",) + +COLLISION_GRAD_FIELDS: tuple = ( + # Contact geometry (written by smooth_recompute_contacts) + "contact.dist", + "contact.pos", + "contact.frame", + # Constraint arrays (written by smooth_contact_to_efc) + "efc.J", + "efc.pos", + "efc.D", + "efc.aref", + "efc.vel", ) +def _resolve_field(d: Data, name: str): + """Resolve a field name, supporting dotted paths like 'contact.dist'.""" + if "." in name: + obj_name, field_name = name.split(".", 1) + obj = getattr(d, obj_name, None) + return getattr(obj, field_name, None) if obj else None + return getattr(d, name, None) + + def enable_grad(d: Data, fields: Optional[Sequence[str]] = None) -> None: """Enables gradient tracking on Data arrays.""" if fields is None: fields = SMOOTH_GRAD_FIELDS for name in fields: - arr = getattr(d, name, None) + arr = _resolve_field(d, name) if arr is not None and isinstance(arr, wp.array): arr.requires_grad = True -def disable_grad(d: Data) -> None: - """Disables gradient tracking on all Data arrays.""" - for name in SMOOTH_GRAD_FIELDS: - arr = getattr(d, name, None) +def disable_grad(d: Data, fields: Optional[Sequence[str]] = None) -> None: + """Disables gradient tracking on Data arrays.""" + if fields is None: + fields = SMOOTH_GRAD_FIELDS + SOLVER_GRAD_FIELDS + COLLISION_GRAD_FIELDS + for name in fields: + arr = _resolve_field(d, name) if arr is not None and isinstance(arr, wp.array): arr.requires_grad = False @@ -128,12 +150,71 @@ def make_diff_data( return d +def enable_smooth_adjoint( + d: Data, + friction_viscosity: float = 10.0, + friction_scale: float = 0.01, + friction_bypass_kf: float = 0.0, + free_body_adjoint: bool = False, + penalty_damping_alpha: float = 0.0, + friction_surrogate_adjoint: bool = False, + friction_surrogate_alpha: float = 0.0, +) -> None: + """Enable smooth constraint adjoint for friction gradient signal. + + Modifies the backward pass to build a smooth Hessian where friction + constraint stiffness is reduced (for active/QUADRATIC constraints) and + a viscous friction term is added (for satisfied/static constraints). + The forward physics is unchanged. + + Args: + d: Data object (must have gradient tracking enabled). + friction_viscosity: D value added for SATISFIED friction constraints. + Higher values give stronger gradient signal at zero velocity. + friction_scale: Scale factor for QUADRATIC friction constraint D in + the adjoint Hessian. Lower values reduce friction stiffness more, + giving larger tangential gradients. + friction_bypass_kf: Scale for friction gradient bypass. After the + Hessian solve, restores tangential gradient components that were + attenuated by H^{-1}. 0=off, 1=full bypass, >1=amplified. + free_body_adjoint: When True, replaces the solver adjoint entirely + with v = M^{-1} * adj_qacc (free-body assumption). Eliminates + all constraint attenuation. Overrides friction_scale/bypass_kf. + penalty_damping_alpha: Friction damping factor for penalty-model + adjoint. Attenuates v in friction directions by (1-alpha) per + face, mimicking dflex's bounded BPTT eigenvalues. Implies + free-body base (M^{-1}). 0=off, 0.1-0.3=typical. + friction_surrogate_adjoint: When True, keeps the smooth/Newton solve + as the baseline but replaces friction-face backward projections + with a damped tangential recovery toward the free-body solution. + This preserves solver-informed normal-contact handling while using + a training-oriented surrogate + in tangential directions. + friction_surrogate_alpha: Tangential damping factor for the friction + surrogate branch. 0=full tangential recovery, 0.9=10% recovery, + 1=disabled. Values in 0.8-0.95 are the intended range for + soft-contact ant experiments. + """ + d.smooth_adjoint = 1 + d.smooth_friction_viscosity = friction_viscosity + d.smooth_friction_scale = friction_scale + d.smooth_friction_bypass_kf = friction_bypass_kf + d.smooth_free_body_adjoint = free_body_adjoint + d.smooth_penalty_damping_alpha = penalty_damping_alpha + d.smooth_friction_surrogate_adjoint = friction_surrogate_adjoint + d.smooth_friction_surrogate_alpha = friction_surrogate_alpha + + +def disable_smooth_adjoint(d: Data) -> None: + """Disable smooth constraint adjoint, reverting to standard implicit diff.""" + d.smooth_adjoint = 0 + + def _warn_if_cg_solver(m: Model, d: Data): """Warn if CG solver is used with constraints (gradients will be zero).""" if d.njmax > 0 and m.opt.solver != SolverType.NEWTON: warnings.warn( - "Differentiable solver requires Newton. CG solver " - "gradients through constraints will be zero.", + "Differentiable solver requires Newton. CG solver gradients through constraints will be zero.", stacklevel=3, ) diff --git a/mujoco_warp/_src/grad_test.py b/mujoco_warp/_src/grad_test.py index 5ad2b85c5..d44b04aad 100644 --- a/mujoco_warp/_src/grad_test.py +++ b/mujoco_warp/_src/grad_test.py @@ -20,11 +20,16 @@ import mujoco_warp as mjw from mujoco_warp import test_data from mujoco_warp._src import math +from mujoco_warp._src.grad import _resolve_field from mujoco_warp._src.grad import enable_grad # tolerance for AD vs finite-difference comparison _FD_TOL = 1e-3 +# step-level AD requires GPU (Warp tape backward does not produce gradients on CPU) +_REQUIRES_GPU = not wp.get_device().is_cuda or wp.get_device().arch < 70 +_REQUIRES_GPU_REASON = "step-level AD requires CUDA with sm_70+" + # sparse jacobian to avoid tile kernels (which require cuSolverDx) _SIMPLE_HINGE_XML = """ @@ -120,6 +125,31 @@ """ +# Freejoint root + hinge child with actuator, for full step gradient test. +_FREE_HINGE_XML = """ + + + + + + + + + + + + + + + + + + + +""" + def _fd_gradient(fn, x_np, eps=1e-3): """Central-difference gradient of scalar fn w.r.t. x_np.""" @@ -133,6 +163,57 @@ def _fd_gradient(fn, x_np, eps=1e-3): return grad +def _assert_step_ctrl_grad( + test_case, + xml, + loss_on="qpos", + keyframe=0, + atol=_FD_TOL, + rtol=_FD_TOL, + eps=1e-3, + err_msg="AD vs FD mismatch", +): + """Compare AD dL/dctrl through step() against finite differences.""" + fixture_kw = dict(xml=xml) | ({"keyframe": keyframe} if keyframe is not None else {}) + mjm, mjd, m, d = test_data.fixture(**fixture_kw) + enable_grad(d) + + if loss_on == "qpos": + loss_kernel, loss_dim = _sum_qpos_kernel, (d.nworld, mjm.nq) + loss_field = lambda dd: dd.qpos + else: + loss_kernel, loss_dim = _sum_xpos_kernel, (d.nworld, m.nbody) + loss_field = lambda dd: dd.xpos + + # AD gradient + loss = wp.zeros(1, dtype=float, requires_grad=True) + tape = wp.Tape() + with tape: + mjw.step(m, d) + wp.launch(loss_kernel, dim=loss_dim, inputs=[loss_field(d), loss]) + tape.backward(loss=loss) + ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() + tape.zero() + + # Finite-difference gradient + def eval_loss(ctrl_np): + _, _, _, d_fd = test_data.fixture(**fixture_kw) + d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float) + mjw.step(m, d_fd) + l = wp.zeros(1, dtype=float) + wp.launch(loss_kernel, dim=loss_dim, inputs=[loss_field(d_fd), l]) + return l.numpy()[0] + + ctrl_np = mjd.ctrl.copy() + fd_grad = _fd_gradient(eval_loss, ctrl_np, eps=eps) + + test_case.assertTrue( + np.linalg.norm(ad_grad) > 1e-6, + f"AD gradient should be nonzero, got |grad|={np.linalg.norm(ad_grad):.3e}", + ) + np.testing.assert_allclose(ad_grad, fd_grad, atol=atol, rtol=rtol, err_msg=err_msg) + + @wp.kernel def _sum_xpos_kernel( # Data in: @@ -160,6 +241,7 @@ class GradSmoothTest(parameterized.TestCase): @parameterized.parameters( ("hinge", _SIMPLE_HINGE_XML), ("slide", _SIMPLE_SLIDE_XML), + ("free", _SIMPLE_FREE_XML), ) def test_kinematics_grad(self, name, xml): """dL/dqpos through kinematics(): loss = sum(xpos).""" @@ -321,52 +403,15 @@ def eval_loss(ctrl_np): err_msg=f"fwd_actuation grad mismatch ({name})", ) - @absltest.skipIf( - wp.get_device().is_cuda and wp.get_device().arch < 70, - "tile kernels (cuSolverDx) require sm_70+", - ) + @absltest.skipIf(_REQUIRES_GPU, _REQUIRES_GPU_REASON) def test_euler_step_grad(self): """Full Euler step gradient: dL/dctrl through step().""" - xml = _SIMPLE_HINGE_XML - mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0) - enable_grad(d) - - loss = wp.zeros(1, dtype=float, requires_grad=True) - tape = wp.Tape() - with tape: - mjw.step(m, d) - wp.launch( - _sum_xpos_kernel, - dim=(d.nworld, m.nbody), - inputs=[d.xpos, loss], - ) - tape.backward(loss=loss) - ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() - tape.zero() - - def eval_loss(ctrl_np): - _, _, _, d_fd = test_data.fixture(xml=xml, keyframe=0) - enable_grad(d_fd) - d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float) - mjw.step(m, d_fd) - l = wp.zeros(1, dtype=float) - wp.launch( - _sum_xpos_kernel, - dim=(d_fd.nworld, m.nbody), - inputs=[d_fd.xpos, l], - ) - return l.numpy()[0] + _assert_step_ctrl_grad(self, _SIMPLE_HINGE_XML, loss_on="xpos", err_msg="euler step grad mismatch") - ctrl_np = mjd.ctrl.copy() - fd_grad = _fd_gradient(eval_loss, ctrl_np) - - np.testing.assert_allclose( - ad_grad, - fd_grad, - atol=_FD_TOL, - rtol=_FD_TOL, - err_msg="euler step grad mismatch", - ) + @absltest.skipIf(_REQUIRES_GPU, _REQUIRES_GPU_REASON) + def test_euler_step_grad_free(self): + """Full Euler step gradient for freejoint + hinge model: dL/dctrl.""" + _assert_step_ctrl_grad(self, _FREE_HINGE_XML, loss_on="xpos", err_msg="euler step grad (freejoint+hinge) mismatch") @wp.kernel @@ -555,13 +600,35 @@ def eval_loss_q(q_test): """ +_CONTACT_TANGENTIAL_XML = """ + + +""" + # Tolerance for contact AD tests (relaxed for contacts) _CONTACT_FD_TOL = 1e-2 @wp.kernel def _sum_qpos_kernel( + # Data in: qpos_in: wp.array2d(dtype=float), + # In: loss: wp.array(dtype=float), ): worldid, qid = wp.tid() @@ -569,17 +636,28 @@ def _sum_qpos_kernel( class GradSolverAdjointTest(parameterized.TestCase): - @absltest.skipIf( - wp.get_device().is_cuda and wp.get_device().arch < 70, - "tile kernels (cuSolverDx) require sm_70+", - ) - def test_solver_adjoint_contact_step(self): - """dL/dctrl through step() with active contacts (Newton solver).""" - xml = _CONTACT_SLIDE_XML - mjm, mjd, m, d = test_data.fixture(xml=xml) + def _step_ctrl_grad_norm(self, xml, smooth_kwargs, settle_steps=60): + """Return ||d(sum(qpos_next))/d(ctrl)|| on a settled contact state.""" + mjm, _, m, d = test_data.fixture(xml=xml, keyframe=0) enable_grad(d) + mjw.enable_smooth_adjoint(d, **smooth_kwargs) + + # Settle forward-only to get a representative contact state. + for _ in range(settle_steps): + mjw.step(m, d) + + qpos_settled = wp.clone(d.qpos) + qvel_settled = wp.clone(d.qvel) + ctrl_settled = wp.clone(d.ctrl) + + d = mjw.make_diff_data(mjm) + enable_grad(d) + mjw.reset_data(m, d) + wp.copy(d.qpos, qpos_settled) + wp.copy(d.qvel, qvel_settled) + wp.copy(d.ctrl, ctrl_settled) + mjw.enable_smooth_adjoint(d, **smooth_kwargs) - # AD gradient loss = wp.zeros(1, dtype=float, requires_grad=True) tape = wp.Tape() with tape: @@ -590,38 +668,24 @@ def test_solver_adjoint_contact_step(self): inputs=[d.qpos, loss], ) tape.backward(loss=loss) - ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() + grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() tape.zero() + return float(np.linalg.norm(grad)) - # Finite-difference gradient - def eval_loss(ctrl_np): - _, _, _, d_fd = test_data.fixture(xml=xml) - enable_grad(d_fd) - d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float) - mjw.step(m, d_fd) - l = wp.zeros(1, dtype=float) - wp.launch( - _sum_qpos_kernel, - dim=(d_fd.nworld, mjm.nq), - inputs=[d_fd.qpos, l], - ) - return l.numpy()[0] - - ctrl_np = mjd.ctrl.copy() - fd_grad = _fd_gradient(eval_loss, ctrl_np, eps=1e-3) - - np.testing.assert_allclose( - ad_grad, - fd_grad, + @absltest.skipIf(_REQUIRES_GPU, _REQUIRES_GPU_REASON) + def test_solver_adjoint_contact_step(self): + """dL/dctrl through step() with active contacts (Newton solver).""" + _assert_step_ctrl_grad( + self, + _CONTACT_SLIDE_XML, + loss_on="qpos", + keyframe=None, atol=_CONTACT_FD_TOL, rtol=_CONTACT_FD_TOL, err_msg="solver adjoint contact step grad mismatch", ) - @absltest.skipIf( - wp.get_device().is_cuda and wp.get_device().arch < 70, - "tile kernels (cuSolverDx) require sm_70+", - ) + @absltest.skipIf(_REQUIRES_GPU, _REQUIRES_GPU_REASON) def test_solver_adjoint_no_active_constraints(self): """No active contacts: solver adjoint should match Phase 1 (unconstrained).""" # Ball high above ground — no contact @@ -640,140 +704,59 @@ def test_solver_adjoint_no_active_constraints(self): """ - mjm, mjd, m, d = test_data.fixture(xml=xml) - enable_grad(d) + _assert_step_ctrl_grad(self, xml, loss_on="qpos", keyframe=None, err_msg="solver adjoint no-contact grad mismatch") - loss = wp.zeros(1, dtype=float, requires_grad=True) - tape = wp.Tape() - with tape: - mjw.step(m, d) - wp.launch( - _sum_qpos_kernel, - dim=(d.nworld, mjm.nq), - inputs=[d.qpos, loss], - ) - tape.backward(loss=loss) - ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() - tape.zero() - - def eval_loss(ctrl_np): - _, _, _, d_fd = test_data.fixture(xml=xml) - enable_grad(d_fd) - d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float) - mjw.step(m, d_fd) - l = wp.zeros(1, dtype=float) - wp.launch( - _sum_qpos_kernel, - dim=(d_fd.nworld, mjm.nq), - inputs=[d_fd.qpos, l], - ) - return l.numpy()[0] - - ctrl_np = mjd.ctrl.copy() - fd_grad = _fd_gradient(eval_loss, ctrl_np, eps=1e-3) - - np.testing.assert_allclose( - ad_grad, - fd_grad, - atol=_FD_TOL, - rtol=_FD_TOL, - err_msg="solver adjoint no-contact grad mismatch", - ) - - @absltest.skipIf( - wp.get_device().is_cuda and wp.get_device().arch < 70, - "tile kernels (cuSolverDx) require sm_70+", - ) + @absltest.skipIf(_REQUIRES_GPU, _REQUIRES_GPU_REASON) def test_solver_adjoint_identity_unconstrained(self): """njmax==0 (constraints disabled): identity pass-through.""" - xml = _SIMPLE_HINGE_XML # has contact/constraint disabled - mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0) - enable_grad(d) - - loss = wp.zeros(1, dtype=float, requires_grad=True) - tape = wp.Tape() - with tape: - mjw.step(m, d) - wp.launch( - _sum_xpos_kernel, - dim=(d.nworld, m.nbody), - inputs=[d.xpos, loss], - ) - tape.backward(loss=loss) - ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() - tape.zero() - - def eval_loss(ctrl_np): - _, _, _, d_fd = test_data.fixture(xml=xml, keyframe=0) - enable_grad(d_fd) - d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float) - mjw.step(m, d_fd) - l = wp.zeros(1, dtype=float) - wp.launch( - _sum_xpos_kernel, - dim=(d_fd.nworld, m.nbody), - inputs=[d_fd.xpos, l], - ) - return l.numpy()[0] - - ctrl_np = mjd.ctrl.copy() - fd_grad = _fd_gradient(eval_loss, ctrl_np) - - np.testing.assert_allclose( - ad_grad, - fd_grad, - atol=_FD_TOL, - rtol=_FD_TOL, + _assert_step_ctrl_grad( + self, + _SIMPLE_HINGE_XML, + loss_on="xpos", err_msg="solver adjoint identity (unconstrained) grad mismatch", ) - @absltest.skipIf( - wp.get_device().is_cuda and wp.get_device().arch < 70, - "tile kernels (cuSolverDx) require sm_70+", - ) + @absltest.skipIf(_REQUIRES_GPU, _REQUIRES_GPU_REASON) def test_solver_adjoint_dense_jacobian(self): """Dense jacobian contact model: dL/dctrl through step().""" - xml = _CONTACT_SLIDE_DENSE_XML - mjm, mjd, m, d = test_data.fixture(xml=xml) - enable_grad(d) - - loss = wp.zeros(1, dtype=float, requires_grad=True) - tape = wp.Tape() - with tape: - mjw.step(m, d) - wp.launch( - _sum_qpos_kernel, - dim=(d.nworld, mjm.nq), - inputs=[d.qpos, loss], - ) - tape.backward(loss=loss) - ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() - tape.zero() - - def eval_loss(ctrl_np): - _, _, _, d_fd = test_data.fixture(xml=xml) - enable_grad(d_fd) - d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float) - mjw.step(m, d_fd) - l = wp.zeros(1, dtype=float) - wp.launch( - _sum_qpos_kernel, - dim=(d_fd.nworld, mjm.nq), - inputs=[d_fd.qpos, l], - ) - return l.numpy()[0] - - ctrl_np = mjd.ctrl.copy() - fd_grad = _fd_gradient(eval_loss, ctrl_np, eps=1e-3) - - np.testing.assert_allclose( - ad_grad, - fd_grad, + _assert_step_ctrl_grad( + self, + _CONTACT_SLIDE_DENSE_XML, + loss_on="qpos", + keyframe=None, atol=_CONTACT_FD_TOL, rtol=_CONTACT_FD_TOL, err_msg="solver adjoint dense jacobian grad mismatch", ) + @absltest.skipIf(_REQUIRES_GPU, _REQUIRES_GPU_REASON) + def test_surrogate_correction_bounded_relative_to_free_body(self): + """Surrogate tangential correction should stay bounded vs free-body.""" + grad_free = self._step_ctrl_grad_norm( + _CONTACT_TANGENTIAL_XML, + smooth_kwargs=dict( + free_body_adjoint=True, + ), + ) + grad_sur_90 = self._step_ctrl_grad_norm( + _CONTACT_TANGENTIAL_XML, + smooth_kwargs=dict( + friction_surrogate_adjoint=True, + friction_surrogate_alpha=0.9, + ), + ) + grad_sur_99 = self._step_ctrl_grad_norm( + _CONTACT_TANGENTIAL_XML, + smooth_kwargs=dict( + friction_surrogate_adjoint=True, + friction_surrogate_alpha=0.99, + ), + ) + + self.assertGreater(grad_free, 1.0e-6) + self.assertLessEqual(grad_sur_90, grad_free * 1.05) + self.assertLessEqual(grad_sur_99, grad_free * 1.05) + class GradUtilTest(absltest.TestCase): def test_enable_disable_grad(self): @@ -813,6 +796,287 @@ def test_make_diff_data_custom_fields(self): self.assertFalse(d.qvel.requires_grad) self.assertFalse(d.ctrl.requires_grad) + def test_enable_backward_module_flags(self): + """Verify enable_backward is set correctly on all AD-relevant modules.""" + from mujoco_warp._src import collision_smooth + from mujoco_warp._src import derivative + from mujoco_warp._src import forward as forward_mod + from mujoco_warp._src import passive + from mujoco_warp._src import smooth + + # Modules that SHOULD have enable_backward=True + for mod in [smooth, forward_mod, passive, derivative, collision_smooth]: + opts = wp.get_module_options(mod) + self.assertTrue( + opts.get("enable_backward", False), + f"{mod.__name__} should have enable_backward=True", + ) + + # Modules that should NOT have enable_backward + from mujoco_warp._src import collision_driver + from mujoco_warp._src import constraint + from mujoco_warp._src import solver + + for mod in [constraint, solver, collision_driver]: + opts = wp.get_module_options(mod) + self.assertFalse( + opts.get("enable_backward", False), + f"{mod.__name__} should have enable_backward=False", + ) + + def test_enable_grad_all_smooth_fields(self): + """All SMOOTH_GRAD_FIELDS are toggled by enable_grad.""" + mjm = mujoco.MjModel.from_xml_string(_SIMPLE_HINGE_XML) + d = mjw.make_data(mjm) + + mjw.enable_grad(d) + for name in mjw.SMOOTH_GRAD_FIELDS: + arr = _resolve_field(d, name) + if arr is not None and isinstance(arr, wp.array): + self.assertTrue( + arr.requires_grad, + f"SMOOTH_GRAD_FIELDS field '{name}' not enabled by enable_grad", + ) + + mjw.disable_grad(d) + for name in mjw.SMOOTH_GRAD_FIELDS: + arr = _resolve_field(d, name) + if arr is not None and isinstance(arr, wp.array): + self.assertFalse( + arr.requires_grad, + f"SMOOTH_GRAD_FIELDS field '{name}' not disabled by disable_grad", + ) + + def test_forward_without_grad_no_error(self): + """Forward pipeline without enable_grad works (no errors, no gradients).""" + mjm, mjd, m, d = test_data.fixture(xml=_SIMPLE_HINGE_XML, keyframe=0) + # Do NOT call enable_grad + mjw.kinematics(m, d) + mjw.com_pos(m, d) + mjw.crb(m, d) + + # Verify no requires_grad is set + self.assertFalse(d.qpos.requires_grad) + self.assertFalse(d.xpos.requires_grad) + + def test_diff_step_produces_nonzero_gradients(self): + """diff_step with enable_grad produces nonzero gradients.""" + mjm, mjd, m, d = test_data.fixture(xml=_SIMPLE_HINGE_XML, keyframe=0) + enable_grad(d) + + loss = wp.zeros(1, dtype=float, requires_grad=True) + tape = wp.Tape() + with tape: + mjw.kinematics(m, d) + mjw.com_pos(m, d) + wp.launch( + _sum_xpos_kernel, + dim=(d.nworld, m.nbody), + inputs=[d.xpos, loss], + ) + tape.backward(loss=loss) + + ad_grad = d.qpos.grad.numpy()[0, : mjm.nq] + # With a non-zero keyframe, kinematics gradients should be nonzero + self.assertTrue( + np.any(np.abs(ad_grad) > 1e-6), + "enable_grad + tape should produce nonzero gradients", + ) + + +# ---- Test models for integrator gradient path ---- + +_HINGE_EULERDAMP_DISABLED_XML = """ + + + + + + + + + + + + + + + + + + + + +""" + +_HINGE_EULERDAMP_ENABLED_XML = """ + + +""" + + +class GradIntegratorTest(parameterized.TestCase): + """Tests that exercise the gradient path through the integrator. + + Unlike test_euler_step_grad (which uses loss on xpos and bypasses the + integrator), these tests use loss on qpos after step(), verifying that + gradients flow through: ctrl -> actuation -> acceleration -> solver adjoint + -> integrator -> qpos. + """ + + @absltest.skipIf(_REQUIRES_GPU, _REQUIRES_GPU_REASON) + def test_euler_qpos_grad_no_eulerdamp(self): + """dL/dctrl through step() measured on qpos, eulerdamp disabled.""" + _assert_step_ctrl_grad( + self, + _HINGE_EULERDAMP_DISABLED_XML, + loss_on="qpos", + err_msg="AD vs FD mismatch for dL(qpos)/dctrl (eulerdamp disabled)", + ) + + @absltest.skipIf(_REQUIRES_GPU, _REQUIRES_GPU_REASON) + def test_euler_qpos_grad_with_eulerdamp(self): + """dL/dctrl through step() measured on qpos, eulerdamp enabled.""" + _assert_step_ctrl_grad( + self, + _HINGE_EULERDAMP_ENABLED_XML, + loss_on="qpos", + err_msg="AD vs FD mismatch for dL(qpos)/dctrl (eulerdamp enabled)", + ) + + @absltest.skipIf(_REQUIRES_GPU, _REQUIRES_GPU_REASON) + def test_multistep_qpos_grad_nonzero(self): + """dL/dctrl through 2 steps produces nonzero gradient.""" + xml = _HINGE_EULERDAMP_DISABLED_XML + mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0) + enable_grad(d) + + loss = wp.zeros(1, dtype=float, requires_grad=True) + tape = wp.Tape() + with tape: + mjw.step(m, d) + mjw.step(m, d) + wp.launch( + _sum_qpos_kernel, + dim=(d.nworld, mjm.nq), + inputs=[d.qpos, loss], + ) + tape.backward(loss=loss) + ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() + tape.zero() + + # Multi-step AD vs FD accuracy is limited by shared-array accumulation + # across steps (a known Warp tape limitation). Here we just verify the + # gradient is nonzero — single-step FD accuracy is tested above. + self.assertTrue( + np.linalg.norm(ad_grad) > 1e-6, + f"Multi-step AD gradient should be nonzero, got |grad|={np.linalg.norm(ad_grad):.3e}", + ) + + +_HINGE_EULERDAMP_HIGH_DAMPING_SPARSE_XML = """ + + +""" + +_HINGE_EULERDAMP_HIGH_DAMPING_DENSE_XML = """ + + +""" + +_HINGE_EULERDAMP_ENABLED_DENSE_XML = """ + + +""" + + +class GradEulerDampStressTest(parameterized.TestCase): + """Stress tests for the euler damping adjoint with high damping and dense jacobian.""" + + @absltest.skipIf(_REQUIRES_GPU, _REQUIRES_GPU_REASON) + @parameterized.named_parameters( + ("high_damp_sparse", _HINGE_EULERDAMP_HIGH_DAMPING_SPARSE_XML), + ("high_damp_dense", _HINGE_EULERDAMP_HIGH_DAMPING_DENSE_XML), + ("normal_damp_dense", _HINGE_EULERDAMP_ENABLED_DENSE_XML), + ) + def test_euler_damp_adjoint(self, xml): + """dL/dctrl through step() with eulerdamp enabled, AD matches FD.""" + _assert_step_ctrl_grad(self, xml, loss_on="qpos", err_msg="AD vs FD mismatch for euler damp adjoint") + if __name__ == "__main__": absltest.main() diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index aca9c94fe..11924efa2 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -669,6 +669,167 @@ def _default_njmax(mjm: mujoco.MjModel, mjd: Optional[mujoco.MjData] = None) -> return int(valid_sizes[np.searchsorted(valid_sizes, njmax)]) +def _body_pair_nnz(mjm: mujoco.MjModel, body1: int, body2: int) -> int: + """Returns the number of unique DOFs in the kinematic tree union of two bodies.""" + body1 = mjm.body_weldid[body1] + body2 = mjm.body_weldid[body2] + da1 = mjm.body_dofadr[body1] + mjm.body_dofnum[body1] - 1 + da2 = mjm.body_dofadr[body2] + mjm.body_dofnum[body2] - 1 + nnz = 0 + while da1 >= 0 or da2 >= 0: + da = max(da1, da2) + if da1 == da: + da1 = mjm.dof_parentid[da1] + if da2 == da: + da2 = mjm.dof_parentid[da2] + nnz += 1 + return nnz + + +def _default_njmax_nnz(mjm: mujoco.MjModel, nconmax: int, njmax: int) -> int: + """Returns a heuristic estimate for the number of non-zeros in the sparse constraint Jacobian. + + Assumes all equality, friction, and limit constraints are active and computes + their non-zeros. For contacts, assumes njmax contact rows at the maximum + body-pair non-zeros from all enabled collision pairs. + + Args: + mjm: The model containing kinematic and dynamic information (host). + nconmax: Maximum number of contacts per world. + njmax: Maximum number of constraint rows per world. + + Returns: + Estimated number of non-zeros in the constraint Jacobian. + """ + total_nnz = 0 + + def _eq_bodies(i): + """Returns body pair for equality constraint i.""" + obj1id, obj2id = mjm.eq_obj1id[i], mjm.eq_obj2id[i] + if mjm.eq_objtype[i] == mujoco.mjtObj.mjOBJ_SITE: + return mjm.site_bodyid[obj1id], mjm.site_bodyid[obj2id] + return obj1id, obj2id + + # equality constraints (assume all active) + for i in range(mjm.neq): + eq_type = mjm.eq_type[i] + + if eq_type == mujoco.mjtEq.mjEQ_CONNECT: + total_nnz += 3 * _body_pair_nnz(mjm, *_eq_bodies(i)) + + elif eq_type == mujoco.mjtEq.mjEQ_WELD: + total_nnz += 6 * _body_pair_nnz(mjm, *_eq_bodies(i)) + + elif eq_type == mujoco.mjtEq.mjEQ_JOINT: + total_nnz += 2 if mjm.eq_obj2id[i] >= 0 else 1 + + elif eq_type == mujoco.mjtEq.mjEQ_TENDON: + obj1id = mjm.eq_obj1id[i] + obj2id = mjm.eq_obj2id[i] + rownnz1 = mjm.ten_J_rownnz[obj1id] if obj1id < mjm.ntendon else 0 + if obj2id >= 0 and obj2id < mjm.ntendon: + rowadr1 = mjm.ten_J_rowadr[obj1id] + rowadr2 = mjm.ten_J_rowadr[obj2id] + rownnz2 = mjm.ten_J_rownnz[obj2id] + cols = set() + for j in range(rownnz1): + cols.add(mjm.ten_J_colind[rowadr1 + j]) + for j in range(rownnz2): + cols.add(mjm.ten_J_colind[rowadr2 + j]) + total_nnz += len(cols) + else: + total_nnz += rownnz1 + + elif eq_type == mujoco.mjtEq.mjEQ_FLEX: + obj1id = mjm.eq_obj1id[i] + if obj1id < mjm.nflex: + edge_start = mjm.flex_edgeadr[obj1id] + edge_count = mjm.flex_edgenum[obj1id] + for e in range(edge_count): + total_nnz += mjm.flexedge_J_rownnz[edge_start + e] + + # friction constraints + total_nnz += (mjm.dof_frictionloss > 0).sum() + for i in range(mjm.ntendon): + if mjm.tendon_frictionloss[i] > 0: + total_nnz += mjm.ten_J_rownnz[i] + + # limit constraints (assume all active) + for i in range(mjm.njnt): + if mjm.jnt_limited[i]: + jnt_type = mjm.jnt_type[i] + if jnt_type == mujoco.mjtJoint.mjJNT_BALL: + total_nnz += 3 + elif jnt_type in (mujoco.mjtJoint.mjJNT_SLIDE, mujoco.mjtJoint.mjJNT_HINGE): + total_nnz += 1 + for i in range(mjm.ntendon): + if mjm.tendon_limited[i]: + total_nnz += mjm.ten_J_rownnz[i] + + # contact constraints: njmax rows at max body-pair non-zeros + max_contact_nnz = 0 + + # contact pairs + for i in range(mjm.npair): + g1, g2 = mjm.pair_geom1[i], mjm.pair_geom2[i] + b1, b2 = mjm.geom_bodyid[g1], mjm.geom_bodyid[g2] + max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, b1, b2)) + + # filter geom-geom pairs (unique body pairs, filtered) + body_pair_seen = set() + for i in range(mjm.ngeom): + bi = mjm.geom_bodyid[i] + cti, cai = mjm.geom_contype[i], mjm.geom_conaffinity[i] + for j in range(i + 1, mjm.ngeom): + bj = mjm.geom_bodyid[j] + if bi == bj: + continue + if mjm.body_weldid[bi] == 0 and mjm.body_weldid[bj] == 0: + continue + bp = (min(bi, bj), max(bi, bj)) + if bp in body_pair_seen: + continue + ctj, caj = mjm.geom_contype[j], mjm.geom_conaffinity[j] + if not ((cti & caj) or (ctj & cai)): + continue + body_pair_seen.add(bp) + max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, bi, bj)) + + # flex vertex contacts + for fi in range(mjm.nflex): + fct = mjm.flex_contype[fi] + fca = mjm.flex_conaffinity[fi] + + vert_start = mjm.flex_vertadr[fi] + vert_count = mjm.flex_vertnum[fi] + flex_bodies = {mjm.flex_vertbodyid[vert_start + v] for v in range(vert_count)} + + geom_bodies = set() + for g in range(mjm.ngeom): + ct, ca = mjm.geom_contype[g], mjm.geom_conaffinity[g] + if (fct & ca) or (ct & fca): + geom_bodies.add(mjm.geom_bodyid[g]) + + for fb in flex_bodies: + for gb in geom_bodies: + if fb != gb: + max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, fb, gb)) + + # flex self-collision + if mjm.flex_selfcollide[fi]: + flex_body_list = sorted(flex_bodies) + for idx1 in range(len(flex_body_list)): + for idx2 in range(idx1 + 1, len(flex_body_list)): + max_contact_nnz = max( + max_contact_nnz, + _body_pair_nnz(mjm, flex_body_list[idx1], flex_body_list[idx2]), + ) + + total_nnz += njmax * max_contact_nnz + + return int(min(max(total_nnz, 1), njmax * mjm.nv)) + + def _resolve_batch_size(na: int | None, n: int | None, nworld: int, default: int) -> int: if na is not None: return na @@ -747,9 +908,11 @@ def make_data( sizes["naconmax"] = naconmax sizes["njmax"] = njmax - # TODO(team): heuristic for constraint Jacobian number of non-zeros - if njmax_nnz is None or not is_sparse(mjm): - njmax_nnz = njmax * mjm.nv + if njmax_nnz is None: + if is_sparse(mjm): + njmax_nnz = _default_njmax_nnz(mjm, nconmax, njmax) + else: + njmax_nnz = njmax * mjm.nv contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)}) contact.efc_address = wp.array(np.full((naconmax, sizes["nmaxpyramid"]), -1, dtype=int), dtype=int) @@ -832,9 +995,7 @@ def make_data( _alloc_h = mjm.opt.solver == mujoco.mjtSolver.mjSOL_NEWTON _alloc_hfactor = _alloc_h and mjm.nv > 32 # _BLOCK_CHOLESKY_DIM d.solver_h = ( - wp.zeros((nworld, sizes["nv_pad"], sizes["nv_pad"]), dtype=float) - if _alloc_h - else wp.empty((nworld, 0, 0), dtype=float) + wp.zeros((nworld, sizes["nv_pad"], sizes["nv_pad"]), dtype=float) if _alloc_h else wp.empty((nworld, 0, 0), dtype=float) ) d.solver_hfactor = ( wp.zeros((nworld, sizes["nv_pad"], sizes["nv_pad"]), dtype=float) @@ -935,9 +1096,11 @@ def put_data( sizes["naconmax"] = naconmax sizes["njmax"] = njmax - # TODO(team): heuristic for constraint Jacobian number of non-zeros - if njmax_nnz is None or not is_sparse(mjm): - njmax_nnz = njmax * mjm.nv + if njmax_nnz is None: + if is_sparse(mjm): + njmax_nnz = _default_njmax_nnz(mjm, nconmax, njmax) + else: + njmax_nnz = njmax * mjm.nv # ensure static geom positions are computed # TODO: remove once MjData creation semantics are fixed @@ -987,23 +1150,27 @@ def put_data( efc = types.Constraint(**efc_kwargs) if is_sparse(mjm): - # TODO(team): process efc_J sparsity structure for nv row shift - efc.J_rownnz = wp.array(np.full((nworld, njmax), mjm.nv, dtype=int), dtype=int) - efc.J_rowadr = wp.array( - np.tile(np.arange(0, njmax * mjm.nv, mjm.nv) if mjm.nv else np.zeros(njmax, dtype=int), (nworld, 1)), dtype=int - ) - efc.J_colind = wp.array(np.tile(np.arange(mjm.nv), (nworld, njmax)).reshape((nworld, 1, -1))[:, :, :njmax_nnz], dtype=int) - - mj_efc_J = np.zeros((mjd.nefc, mjm.nv)) + J_rownnz = np.zeros(njmax, dtype=np.int32) + J_rowadr = np.zeros(njmax, dtype=np.int32) + J_colind = np.zeros(njmax_nnz, dtype=np.int32) + J = np.zeros(njmax_nnz, dtype=np.float64) if mjd.nefc: if mujoco.mj_isSparse(mjm): - mujoco.mju_sparse2dense(mj_efc_J, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind) + J_rownnz[: mjd.nefc] = mjd.efc_J_rownnz[: mjd.nefc] + J_rowadr[: mjd.nefc] = mjd.efc_J_rowadr[: mjd.nefc] + nnz = int(mjd.efc_J_rownnz[: mjd.nefc].sum()) + J_colind[:nnz] = mjd.efc_J_colind[:nnz] + J[:nnz] = mjd.efc_J[:nnz] else: - mj_efc_J = mjd.efc_J.reshape((mjd.nefc, mjm.nv)) - efc_J = np.zeros((njmax, mjm.nv), dtype=float) - efc_J[: mjd.nefc, : mjm.nv] = mj_efc_J - efc_J_flat = np.tile(efc_J.reshape(-1), (nworld, 1, 1)).reshape((nworld, 1, -1))[:, :, :njmax_nnz] - efc.J = wp.array(efc_J_flat, dtype=float) + dense_J = mjd.efc_J.reshape((-1, mjm.nv))[: mjd.nefc] + mujoco.mju_dense2sparse( + J[: mjd.nefc * mjm.nv], dense_J, J_rownnz[: mjd.nefc], J_rowadr[: mjd.nefc], J_colind[: mjd.nefc * mjm.nv] + ) + + efc.J_rownnz = wp.array(np.tile(J_rownnz, (nworld, 1)), dtype=int) + efc.J_rowadr = wp.array(np.tile(J_rowadr, (nworld, 1)), dtype=int) + efc.J_colind = wp.array(np.tile(J_colind, (nworld, 1)).reshape((nworld, 1, -1)), dtype=int) + efc.J = wp.array(np.tile(J, (nworld, 1)).reshape((nworld, 1, -1)), dtype=float) else: efc.J_rownnz = wp.zeros((nworld, 0), dtype=int) efc.J_rowadr = wp.zeros((nworld, 0), dtype=int) @@ -1069,9 +1236,7 @@ def put_data( _alloc_h = mjm.opt.solver == mujoco.mjtSolver.mjSOL_NEWTON _alloc_hfactor = _alloc_h and mjm.nv > 32 # _BLOCK_CHOLESKY_DIM d.solver_h = ( - wp.zeros((nworld, sizes["nv_pad"], sizes["nv_pad"]), dtype=float) - if _alloc_h - else wp.empty((nworld, 0, 0), dtype=float) + wp.zeros((nworld, sizes["nv_pad"], sizes["nv_pad"]), dtype=float) if _alloc_h else wp.empty((nworld, 0, 0), dtype=float) ) d.solver_hfactor = ( wp.zeros((nworld, sizes["nv_pad"], sizes["nv_pad"]), dtype=float) @@ -2567,64 +2732,28 @@ def create_render_context( hfield_bounds_size_arr = wp.array(hfield_bounds_size, dtype=wp.vec3) # Flex BVHs - flex_bvh_id = wp.uint64(0) - flex_group_root = wp.zeros(nworld, dtype=int) - flex_mesh = None - flex_face_point = None - flex_elemdataadr = None - flex_shell = None - flex_shelldataadr = None - flex_faceadr = None - flex_nface = 0 - flex_radius = None - flex_vertflexid = None - flex_workadr = None - flex_worknum = None - flex_nwork = 0 - - if mjm.nflex > 0: - ( - fmesh, - face_point, - flex_group_roots, - flex_shell_data, - flex_faceadr_data, - flex_nface, - ) = bvh.build_flex_bvh(mjm, mjd, nworld) - - flex_mesh = fmesh - flex_bvh_id = fmesh.id - flex_face_point = face_point - flex_group_root = flex_group_roots - flex_elemdataadr = wp.array(mjm.flex_elemdataadr, dtype=int) - flex_shell = flex_shell_data - flex_shelldataadr = wp.array(mjm.flex_shelldataadr, dtype=int) - flex_faceadr = wp.array(flex_faceadr_data, dtype=int) - flex_radius = wp.array(mjm.flex_radius, dtype=float) - - # Compute flex_vertflexid: maps each flex vertex to its flex index - flex_vertflexid_data = np.zeros(mjm.nflexvert, dtype=np.int32) - for flexid in range(mjm.nflex): - vert_start = mjm.flex_vertadr[flexid] - vert_end = vert_start + mjm.flex_vertnum[flexid] - flex_vertflexid_data[vert_start:vert_end] = flexid - flex_vertflexid = wp.array(flex_vertflexid_data, dtype=int) - - # precompute work item layout for unified refit kernel - nflex = mjm.nflex - workadr = np.zeros(nflex, dtype=np.int32) - worknum = np.zeros(nflex, dtype=np.int32) - cumsum = 0 - for f in range(nflex): - workadr[f] = cumsum - if mjm.flex_dim[f] == 2: - worknum[f] = mjm.flex_elemnum[f] + mjm.flex_shellnum[f] - else: - worknum[f] = mjm.flex_shellnum[f] - cumsum += worknum[f] - flex_workadr = wp.array(workadr, dtype=int) - flex_worknum = wp.array(worknum, dtype=int) - flex_nwork = int(cumsum) + nflex = mjm.nflex + flex_registry = {} + + # Scene BVH flex primitives: 1D → one capsule per edge, 2D/3D → one box per flex + flex_geom_flexid = [] + flex_geom_edgeid = [] + flex_bvh_id = np.full(nflex, 0, dtype=wp.uint64) + flex_group_root = np.zeros((nflex, nworld), dtype=int) + + for f in range(nflex): + if mjm.flex_dim[f] == 1: + edge_adr = mjm.flex_edgeadr[f] + flex_geom_flexid.extend([f] * mjm.flex_edgenum[f]) + flex_geom_edgeid.extend([edge_adr + e for e in range(mjm.flex_edgenum[f])]) + flex_group_root[f] = np.zeros(nworld, dtype=int) + else: + flex_geom_flexid.append(f) + flex_geom_edgeid.append(-1) + fmesh, group_root = bvh.build_flex_bvh(mjm, mjd, nworld, f) + flex_registry[f] = fmesh + flex_bvh_id[f] = fmesh.id + flex_group_root[f] = group_root.numpy() textures_registry = [] for i in range(mjm.ntex): @@ -2743,26 +2872,20 @@ def create_render_context( hfield_registry=hfield_registry, hfield_bvh_id=hfield_bvh_id_arr, hfield_bounds_size=hfield_bounds_size_arr, - flex_mesh=flex_mesh, + flex_mesh_registry=flex_registry, flex_rgba=wp.array(mjm.flex_rgba, dtype=wp.vec4), - flex_bvh_id=flex_bvh_id, - flex_face_point=flex_face_point, - flex_faceadr=flex_faceadr, - flex_nface=flex_nface, - flex_nwork=flex_nwork, - flex_group_root=flex_group_root, - flex_elemdataadr=flex_elemdataadr, - flex_shell=flex_shell, - flex_shelldataadr=flex_shelldataadr, - flex_radius=flex_radius, - flex_workadr=flex_workadr, - flex_worknum=flex_worknum, + flex_bvh_id=wp.array(flex_bvh_id, dtype=wp.uint64), + flex_group_root=wp.array(flex_group_root, dtype=int), flex_render_smooth=flex_render_smooth, + bvh_nflexgeom=len(flex_geom_flexid), + flex_dim_np=mjm.flex_dim, + flex_geom_flexid=wp.array(flex_geom_flexid, dtype=int), + flex_geom_edgeid=wp.array(flex_geom_edgeid, dtype=int), bvh=None, bvh_id=None, - lower=wp.zeros(nworld * bvh_ngeom, dtype=wp.vec3), - upper=wp.zeros(nworld * bvh_ngeom, dtype=wp.vec3), - group=wp.zeros(nworld * bvh_ngeom, dtype=int), + lower=wp.zeros(nworld * (bvh_ngeom + len(flex_geom_flexid)), dtype=wp.vec3), + upper=wp.zeros(nworld * (bvh_ngeom + len(flex_geom_flexid)), dtype=wp.vec3), + group=wp.zeros(nworld * (bvh_ngeom + len(flex_geom_flexid)), dtype=int), group_root=wp.zeros(nworld, dtype=int), ray=ray, rgb_data=wp.zeros((nworld, ri), dtype=wp.uint32), diff --git a/mujoco_warp/_src/math.py b/mujoco_warp/_src/math.py index a4384854a..f212914ea 100644 --- a/mujoco_warp/_src/math.py +++ b/mujoco_warp/_src/math.py @@ -83,6 +83,35 @@ def quat_to_mat(quat: wp.quat) -> wp.mat33: ) +@wp.func +def quat_z2vec(vec: wp.vec3) -> wp.quat: + """Compute quaternion performing rotation from z-axis to given vector.""" + quat = wp.quat(0.0, 0.0, 0.0, 1.0) + + # normalize vector; if too small, no rotation + norm = wp.length(vec) + if norm < types.MJ_MINVAL: + return quat + vec = vec / norm + + axis = wp.vec3(-vec[1], vec[0], 0.0) + a = wp.length(axis) + + # almost parallel + if a < types.MJ_MINVAL: + # opposite: 180 deg rotation around x axis + if vec[2] < 0.0: + quat = wp.quat(1.0, 0.0, 0.0, 0.0) + return quat + + # make quaternion from angle and axis + axis = axis / a + angle = wp.atan2(a, vec[2]) + quat = axis_angle_to_quat(axis, angle) + + return quat + + @wp.func def quat_inv(quat: wp.quat) -> wp.quat: return wp.quat(quat[0], -quat[1], -quat[2], -quat[3]) diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index f5c098e4e..f8f534af1 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -574,6 +574,7 @@ def _flex_elasticity( flex_edgeadr: wp.array(dtype=int), flex_elemadr: wp.array(dtype=int), flex_elemnum: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_elemedgeadr: wp.array(dtype=int), flex_vertbodyid: wp.array(dtype=int), flex_elem: wp.array(dtype=int), @@ -599,6 +600,7 @@ def _flex_elasticity( f = i break + local_elemid = elemid - flex_elemadr[f] dim = flex_dim[f] nvert = dim + 1 nedge = nvert * (nvert - 1) / 2 @@ -612,10 +614,11 @@ def _flex_elasticity( else: kD = 0.0 + elem_data_adr = flex_elemdataadr[f] + local_elemid * (dim + 1) gradient = wp.matrix(0.0, shape=(6, 6)) for e in range(nedge): - vert0 = flex_elem[(dim + 1) * elemid + edges[e, 0]] - vert1 = flex_elem[(dim + 1) * elemid + edges[e, 1]] + vert0 = flex_elem[elem_data_adr + edges[e, 0]] + vert1 = flex_elem[elem_data_adr + edges[e, 1]] xpos0 = flexvert_xpos_in[worldid, vert0] xpos1 = flexvert_xpos_in[worldid, vert1] for i in range(3): @@ -624,7 +627,7 @@ def _flex_elasticity( elongation = wp.spatial_vectorf(0.0) for e in range(nedge): - idx = flex_elemedge[elemid * nedge + e] + idx = flex_elemedge[flex_elemedgeadr[f] + local_elemid * nedge + e] vel = flexedge_velocity_in[worldid, flex_edgeadr[f] + idx] deformed = flexedge_length_in[worldid, flex_edgeadr[f] + idx] reference = flexedge_length0[flex_edgeadr[f] + idx] @@ -647,7 +650,7 @@ def _flex_elasticity( force[edges[ed2, i], x] -= elongation[ed1] * gradient[ed2, 3 * i + x] * metric[ed1, ed2] for v in range(nvert): - vert = flex_elem[(dim + 1) * elemid + v] + vert = flex_elem[elem_data_adr + v] bodyid = flex_vertbodyid[flex_vertadr[f] + vert] for x in range(3): wp.atomic_add(qfrc_spring_out, worldid, body_dofadr[bodyid] + x, force[v, x]) @@ -784,6 +787,7 @@ def passive(m: Model, d: Data): m.flex_edgeadr, m.flex_elemadr, m.flex_elemnum, + m.flex_elemdataadr, m.flex_elemedgeadr, m.flex_vertbodyid, m.flex_elem, diff --git a/mujoco_warp/_src/ray.py b/mujoco_warp/_src/ray.py index 5bc2a9a28..44c56f962 100644 --- a/mujoco_warp/_src/ray.py +++ b/mujoco_warp/_src/ray.py @@ -752,7 +752,8 @@ def ray_mesh_with_bvh_anyhit( @wp.func def ray_flex_with_bvh( # In: - bvh_id: wp.uint64, + flex_bvh_id: wp.array(dtype=wp.uint64), + flexid: int, group_root: int, pnt: wp.vec3, vec: wp.vec3, @@ -769,7 +770,7 @@ def ray_flex_with_bvh( n = wp.vec3(0.0, 0.0, 0.0) f = int(-1) - hit = wp.mesh_query_ray(bvh_id, pnt, vec, max_t, t, u, v, sign, n, f, group_root) + hit = wp.mesh_query_ray(flex_bvh_id[flexid], pnt, vec, max_t, t, u, v, sign, n, f, group_root) if hit: return t, n, u, v, f @@ -777,6 +778,23 @@ def ray_flex_with_bvh( return -1.0, wp.vec3(0.0, 0.0, 0.0), 0.0, 0.0, -1 +@wp.func +def ray_flex_with_bvh_anyhit( + # In: + flex_bvh_id: wp.array(dtype=wp.uint64), + flexid: int, + group_root: int, + pnt: wp.vec3, + vec: wp.vec3, + max_t: float, +) -> bool: + """Returns True if there is any hit for ray flex intersections. + + Requires wp.Mesh be constructed and their ids to be passed. Flex are already in world space. + """ + return wp.mesh_query_ray_anyhit(flex_bvh_id[flexid], pnt, vec, max_t, group_root) + + @wp.func def ray_geom(pos: wp.vec3, mat: wp.mat33, size: wp.vec3, pnt: wp.vec3, vec: wp.vec3, geomtype: int) -> Tuple[float, wp.vec3]: """Returns distance along ray to intersection with geom and normal at intersection point. diff --git a/mujoco_warp/_src/render.py b/mujoco_warp/_src/render.py index dbaf001b0..47a67cbd5 100644 --- a/mujoco_warp/_src/render.py +++ b/mujoco_warp/_src/render.py @@ -23,6 +23,7 @@ from mujoco_warp._src.ray import ray_cylinder from mujoco_warp._src.ray import ray_ellipsoid from mujoco_warp._src.ray import ray_flex_with_bvh +from mujoco_warp._src.ray import ray_flex_with_bvh_anyhit from mujoco_warp._src.ray import ray_mesh_with_bvh from mujoco_warp._src.ray import ray_mesh_with_bvh_anyhit from mujoco_warp._src.ray import ray_plane @@ -90,17 +91,26 @@ def cast_ray( geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), + flex_vertadr: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: bvh_id: wp.uint64, group_root: int, - world_id: int, + worldid: int, bvh_ngeom: int, + flex_bvh_ngeom: int, enabled_geom_ids: wp.array(dtype=int), mesh_bvh_id: wp.array(dtype=wp.uint64), hfield_bvh_id: wp.array(dtype=wp.uint64), + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + flex_bvh_id: wp.array(dtype=wp.uint64), + flex_group_root: wp.array2d(dtype=int), ray_origin_world: wp.vec3, ray_dir_world: wp.vec3, ) -> Tuple[int, float, wp.vec3, float, float, int, int]: @@ -114,91 +124,127 @@ def cast_ray( query = wp.bvh_query_ray(bvh_id, ray_origin_world, ray_dir_world, group_root) bounds_nr = int(0) + ngeom = bvh_ngeom + flex_bvh_ngeom while wp.bvh_query_next(query, bounds_nr, dist): gi_global = bounds_nr - gi_bvh_local = gi_global - (world_id * bvh_ngeom) - gi = enabled_geom_ids[gi_bvh_local] + local_id = gi_global - (worldid * ngeom) + d = float(-1.0) hit_mesh_id = int(-1) u = float(0.0) v = float(0.0) f = int(-1) n = wp.vec3(0.0, 0.0, 0.0) + hit_geom_id = int(-1) + + if local_id < bvh_ngeom: + gi = enabled_geom_ids[local_id] + gtype = geom_type[gi] + else: + gi = local_id - bvh_ngeom + gtype = GeomType.FLEX + + hit_geom_id = gi # TODO: Investigate branch elimination with static loop unrolling - if geom_type[gi] == GeomType.PLANE: + if gtype == GeomType.PLANE: d, n = ray_plane( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.HFIELD: + if gtype == GeomType.HFIELD: d, n, u, v, f, geom_hfield_id = ray_mesh_with_bvh( hfield_bvh_id, geom_dataid[gi], - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], ray_origin_world, ray_dir_world, dist, ) - if geom_type[gi] == GeomType.SPHERE: + if gtype == GeomType.SPHERE: d, n = ray_sphere( - geom_xpos_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi][0] * geom_size[world_id % geom_size.shape[0], gi][0], + geom_xpos_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi][0] * geom_size[worldid % geom_size.shape[0], gi][0], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.ELLIPSOID: + if gtype == GeomType.ELLIPSOID: d, n = ray_ellipsoid( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.CAPSULE: + if gtype == GeomType.CAPSULE: d, n = ray_capsule( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.CYLINDER: + if gtype == GeomType.CYLINDER: d, n = ray_cylinder( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.BOX: + if gtype == GeomType.BOX: d, all, n = ray_box( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.MESH: + if gtype == GeomType.MESH: d, n, u, v, f, hit_mesh_id = ray_mesh_with_bvh( mesh_bvh_id, geom_dataid[gi], - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], ray_origin_world, ray_dir_world, dist, ) + if gtype == GeomType.FLEX: + hit_geom_id = -2 + flexid = flex_geom_flexid[gi] + edge_id = flex_geom_edgeid[gi] + + if edge_id >= 0: + edge = flex_edge[edge_id] + vert_adr = flex_vertadr[flexid] + v0 = flexvert_xpos_in[worldid, vert_adr + edge[0]] + v1 = flexvert_xpos_in[worldid, vert_adr + edge[1]] + pos = 0.5 * (v0 + v1) + vec = v1 - v0 + + length = wp.length(vec) + edgeq = math.quat_z2vec(vec) + mat = math.quat_to_mat(edgeq) + size = wp.vec3(flex_radius[flexid], 0.5 * length, 0.0) + + d, n = ray_capsule(pos, mat, size, ray_origin_world, ray_dir_world) + hit_mesh_id = flexid + else: + flex_gr = flex_group_root[worldid, flexid] + d, n, u, v, f = ray_flex_with_bvh(flex_bvh_id, flexid, flex_gr, ray_origin_world, ray_dir_world, dist) + if d >= 0.0: + hit_mesh_id = flexid if d >= 0.0 and d < dist: dist = d normal = n - geom_id = gi + geom_id = hit_geom_id bary_u = u bary_v = v face_idx = f @@ -213,17 +259,26 @@ def cast_ray_first_hit( geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), + flex_vertadr: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: bvh_id: wp.uint64, group_root: int, - world_id: int, + worldid: int, bvh_ngeom: int, + bvh_nflexgeom: int, enabled_geom_ids: wp.array(dtype=int), mesh_bvh_id: wp.array(dtype=wp.uint64), hfield_bvh_id: wp.array(dtype=wp.uint64), + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + flex_bvh_id: wp.array(dtype=wp.uint64), + flex_group_root: wp.array2d(dtype=int), ray_origin_world: wp.vec3, ray_dir_world: wp.vec3, max_dist: float, @@ -231,81 +286,119 @@ def cast_ray_first_hit( """A simpler version of casting rays that only checks for the first hit.""" query = wp.bvh_query_ray(bvh_id, ray_origin_world, ray_dir_world, group_root) bounds_nr = int(0) + ngeom = bvh_ngeom + bvh_nflexgeom while wp.bvh_query_next(query, bounds_nr, max_dist): gi_global = bounds_nr - gi_bvh_local = gi_global - (world_id * bvh_ngeom) - gi = enabled_geom_ids[gi_bvh_local] + local_id = gi_global - (worldid * ngeom) + + d = float(-1.0) + n = wp.vec3(0.0, 0.0, 0.0) + + if local_id < bvh_ngeom: + gi = enabled_geom_ids[local_id] + gtype = geom_type[gi] + else: + gi = local_id - bvh_ngeom + gtype = GeomType.FLEX # TODO: Investigate branch elimination with static loop unrolling - if geom_type[gi] == GeomType.PLANE: + if gtype == GeomType.PLANE: d, n = ray_plane( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.HFIELD: + if gtype == GeomType.HFIELD: d, n, u, v, f, geom_hfield_id = ray_mesh_with_bvh( hfield_bvh_id, geom_dataid[gi], - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], ray_origin_world, ray_dir_world, max_dist, ) - if geom_type[gi] == GeomType.SPHERE: + if gtype == GeomType.SPHERE: d, n = ray_sphere( - geom_xpos_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi][0] * geom_size[world_id % geom_size.shape[0], gi][0], + geom_xpos_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi][0] * geom_size[worldid % geom_size.shape[0], gi][0], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.ELLIPSOID: + if gtype == GeomType.ELLIPSOID: d, n = ray_ellipsoid( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.CAPSULE: + if gtype == GeomType.CAPSULE: d, n = ray_capsule( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.CYLINDER: + if gtype == GeomType.CYLINDER: d, n = ray_cylinder( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.BOX: + if gtype == GeomType.BOX: d, all, n = ray_box( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.MESH: + if gtype == GeomType.MESH: hit = ray_mesh_with_bvh_anyhit( mesh_bvh_id, geom_dataid[gi], - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], ray_origin_world, ray_dir_world, max_dist, ) d = 0.0 if hit else -1.0 + if gtype == GeomType.FLEX: + flexid = flex_geom_flexid[gi] + edge_id = flex_geom_edgeid[gi] + + if edge_id >= 0: + edge = flex_edge[edge_id] + vert_adr = flex_vertadr[flexid] + v0 = flexvert_xpos_in[worldid, vert_adr + edge[0]] + v1 = flexvert_xpos_in[worldid, vert_adr + edge[1]] + pos = 0.5 * (v0 + v1) + vec = v1 - v0 + + length = wp.length(vec) + edgeq = math.quat_z2vec(vec) + mat = math.quat_to_mat(edgeq) + size = wp.vec3(flex_radius[flexid], 0.5 * length, 0.0) + + d, n = ray_capsule(pos, mat, size, ray_origin_world, ray_dir_world) + else: + hit = ray_flex_with_bvh_anyhit( + flex_bvh_id, + flexid, + flex_group_root[worldid, flexid], + ray_origin_world, + ray_dir_world, + max_dist, + ) + d = 0.0 if hit else -1.0 if d >= 0.0 and d < max_dist: return True @@ -319,18 +412,27 @@ def compute_lighting( geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), + flex_vertadr: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: use_shadows: bool, bvh_id: wp.uint64, group_root: int, bvh_ngeom: int, + bvh_nflexgeom: int, enabled_geom_ids: wp.array(dtype=int), - world_id: int, + worldid: int, mesh_bvh_id: wp.array(dtype=wp.uint64), hfield_bvh_id: wp.array(dtype=wp.uint64), + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + flex_bvh_id: wp.array(dtype=wp.uint64), + flex_group_root: wp.array2d(dtype=int), lightactive: bool, lighttype: int, lightcastshadow: bool, @@ -381,15 +483,24 @@ def compute_lighting( geom_type, geom_dataid, geom_size, + flex_vertadr, + flex_edge, + flex_radius, geom_xpos_in, geom_xmat_in, + flexvert_xpos_in, bvh_id, group_root, - world_id, + worldid, bvh_ngeom, + bvh_nflexgeom, enabled_geom_ids, mesh_bvh_id, hfield_bvh_id, + flex_geom_flexid, + flex_geom_edgeid, + flex_bvh_id, + flex_group_root, shadow_origin, L, max_t, @@ -431,6 +542,9 @@ def _render_megakernel( light_type: wp.array2d(dtype=int), light_castshadow: wp.array2d(dtype=bool), light_active: wp.array2d(dtype=bool), + flex_vertadr: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), mesh_faceadr: wp.array(dtype=int), mat_texid: wp.array3d(dtype=int), mat_texrepeat: wp.array2d(dtype=wp.vec2), @@ -442,10 +556,12 @@ def _render_megakernel( cam_xmat_in: wp.array2d(dtype=wp.mat33), light_xpos_in: wp.array2d(dtype=wp.vec3), light_xdir_in: wp.array2d(dtype=wp.vec3), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: nrender: int, use_shadows: bool, bvh_ngeom: int, + bvh_nflexgeom: int, cam_res: wp.array(dtype=wp.vec2i), cam_id_map: wp.array(dtype=int), ray: wp.array(dtype=wp.vec3), @@ -457,8 +573,8 @@ def _render_megakernel( render_seg: wp.array(dtype=bool), bvh_id: wp.uint64, group_root: wp.array(dtype=int), - flex_bvh_id: wp.uint64, - flex_group_root: wp.array(dtype=int), + flex_bvh_id: wp.array(dtype=wp.uint64), + flex_group_root: wp.array2d(dtype=int), enabled_geom_ids: wp.array(dtype=int), mesh_bvh_id: wp.array(dtype=wp.uint64), mesh_facetexcoord: wp.array(dtype=wp.vec3i), @@ -466,26 +582,28 @@ def _render_megakernel( mesh_texcoord_offsets: wp.array(dtype=int), hfield_bvh_id: wp.array(dtype=wp.uint64), flex_rgba: wp.array(dtype=wp.vec4), + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), textures: wp.array(dtype=wp.Texture2D), # Out: rgb_out: wp.array2d(dtype=wp.uint32), depth_out: wp.array2d(dtype=float), seg_out: wp.array2d(dtype=int), ): - world_idx, ray_idx = wp.tid() + worldid, rayid = wp.tid() - # Map global ray_idx -> (cam_idx, ray_idx_local) using cumulative sizes + # Map global rayid -> (cam_idx, rayid_local) using cumulative sizes cam_idx = int(-1) - ray_idx_local = int(-1) + rayid_local = int(-1) accum = int(0) for i in range(nrender): num_i = cam_res[i][0] * cam_res[i][1] - if ray_idx < accum + num_i: + if rayid < accum + num_i: cam_idx = i - ray_idx_local = ray_idx - accum + rayid_local = rayid - accum break accum += num_i - if cam_idx == -1 or ray_idx_local < 0: + if cam_idx == -1 or rayid_local < 0: return if not render_rgb[cam_idx] and not render_depth[cam_idx] and not render_seg[cam_idx]: @@ -495,17 +613,17 @@ def _render_megakernel( mujoco_cam_id = cam_id_map[cam_idx] if wp.static(rc.use_precomputed_rays): - ray_dir_local_cam = ray[ray_idx] + ray_dir_local_cam = ray[rayid] else: img_w = cam_res[cam_idx][0] img_h = cam_res[cam_idx][1] - px = ray_idx_local % img_w - py = ray_idx_local // img_w + px = rayid_local % img_w + py = rayid_local // img_w ray_dir_local_cam = compute_ray( cam_projection[mujoco_cam_id], - cam_fovy[world_idx % cam_fovy.shape[0], mujoco_cam_id], + cam_fovy[worldid % cam_fovy.shape[0], mujoco_cam_id], cam_sensorsize[mujoco_cam_id], - cam_intrinsic[world_idx % cam_intrinsic.shape[0], mujoco_cam_id], + cam_intrinsic[worldid % cam_intrinsic.shape[0], mujoco_cam_id], img_w, img_h, px, @@ -513,41 +631,37 @@ def _render_megakernel( wp.static(rc.znear), ) - ray_dir_world = cam_xmat_in[world_idx, mujoco_cam_id] @ ray_dir_local_cam - ray_origin_world = cam_xpos_in[world_idx, mujoco_cam_id] + ray_dir_world = cam_xmat_in[worldid, mujoco_cam_id] @ ray_dir_local_cam + ray_origin_world = cam_xpos_in[worldid, mujoco_cam_id] geom_id, dist, normal, u, v, f, mesh_id = cast_ray( geom_type, geom_dataid, geom_size, + flex_vertadr, + flex_edge, + flex_radius, geom_xpos_in, geom_xmat_in, + flexvert_xpos_in, bvh_id, - group_root[world_idx], - world_idx, + group_root[worldid], + worldid, bvh_ngeom, + bvh_nflexgeom, enabled_geom_ids, mesh_bvh_id, hfield_bvh_id, + flex_geom_flexid, + flex_geom_edgeid, + flex_bvh_id, + flex_group_root, ray_origin_world, ray_dir_world, ) - if wp.static(m.nflex > 0): - d, n, u, v, f = ray_flex_with_bvh( - flex_bvh_id, - flex_group_root[world_idx], - ray_origin_world, - ray_dir_world, - dist, - ) - if d >= 0.0 and d < dist: - dist = d - normal = n - geom_id = -2 - if render_seg[cam_idx] and geom_id != -1: - seg_out[world_idx, seg_adr[cam_idx] + ray_idx_local] = geom_id + seg_out[worldid, seg_adr[cam_idx] + rayid_local] = geom_id # Early Out if geom_id == -1: @@ -558,7 +672,7 @@ def _render_megakernel( # In camera-local coordinates, the optical axis is -Z. The Z-component of the # normalized ray direction is negative, so -ray_dir_local_cam[2] gives cos(θ) # between the ray and the optical axis. - depth_out[world_idx, depth_adr[cam_idx] + ray_idx_local] = dist * (-ray_dir_local_cam[2]) + depth_out[worldid, depth_adr[cam_idx] + rayid_local] = dist * (-ray_dir_local_cam[2]) if not render_rgb[cam_idx]: return @@ -567,31 +681,30 @@ def _render_megakernel( hit_point = ray_origin_world + ray_dir_world * dist if geom_id == -2: - # TODO: Currently flex textures are not supported, and only the first rgba value - # is used until further flex support is added. - color = flex_rgba[0] - elif geom_matid[world_idx % geom_matid.shape[0], geom_id] == -1: - color = geom_rgba[world_idx % geom_rgba.shape[0], geom_id] + # We encode flex_id in mesh_id for flex ray hits during cast_ray + color = flex_rgba[mesh_id] + elif geom_matid[worldid % geom_matid.shape[0], geom_id] == -1: + color = geom_rgba[worldid % geom_rgba.shape[0], geom_id] else: - color = mat_rgba[world_idx % mat_rgba.shape[0], geom_matid[world_idx % geom_matid.shape[0], geom_id]] + color = mat_rgba[worldid % mat_rgba.shape[0], geom_matid[worldid % geom_matid.shape[0], geom_id]] base_color = wp.vec3(color[0], color[1], color[2]) hit_color = base_color if wp.static(rc.use_textures): if geom_id != -2: - mat_id = geom_matid[world_idx % geom_matid.shape[0], geom_id] + mat_id = geom_matid[worldid % geom_matid.shape[0], geom_id] if mat_id >= 0: - tex_id = mat_texid[world_idx % mat_texid.shape[0], mat_id, 1] + tex_id = mat_texid[worldid % mat_texid.shape[0], mat_id, 1] if tex_id >= 0: tex_color = sample_texture( geom_type, mesh_faceadr, geom_id, - mat_texrepeat[world_idx % mat_texrepeat.shape[0], mat_id], + mat_texrepeat[worldid % mat_texrepeat.shape[0], mat_id], textures[tex_id], - geom_xpos_in[world_idx, geom_id], - geom_xmat_in[world_idx, geom_id], + geom_xpos_in[worldid, geom_id], + geom_xmat_in[worldid, geom_id], mesh_facetexcoord, mesh_texcoord, mesh_texcoord_offsets, @@ -616,21 +729,30 @@ def _render_megakernel( geom_type, geom_dataid, geom_size, + flex_vertadr, + flex_edge, + flex_radius, geom_xpos_in, geom_xmat_in, + flexvert_xpos_in, use_shadows, bvh_id, - group_root[world_idx], + group_root[worldid], bvh_ngeom, + bvh_nflexgeom, enabled_geom_ids, - world_idx, + worldid, mesh_bvh_id, hfield_bvh_id, - light_active[world_idx % light_active.shape[0], l], - light_type[world_idx % light_type.shape[0], l], - light_castshadow[world_idx % light_castshadow.shape[0], l], - light_xpos_in[world_idx, l], - light_xdir_in[world_idx, l], + flex_geom_flexid, + flex_geom_edgeid, + flex_bvh_id, + flex_group_root, + light_active[worldid % light_active.shape[0], l], + light_type[worldid % light_type.shape[0], l], + light_castshadow[worldid % light_castshadow.shape[0], l], + light_xpos_in[worldid, l], + light_xdir_in[worldid, l], normal, hit_point, ) @@ -639,7 +761,7 @@ def _render_megakernel( hit_color = wp.min(result, wp.vec3(1.0, 1.0, 1.0)) hit_color = wp.max(hit_color, wp.vec3(0.0, 0.0, 0.0)) - rgb_out[world_idx, rgb_adr[cam_idx] + ray_idx_local] = pack_rgba_to_uint32( + rgb_out[worldid, rgb_adr[cam_idx] + rayid_local] = pack_rgba_to_uint32( hit_color[0] * 255.0, hit_color[1] * 255.0, hit_color[2] * 255.0, @@ -662,6 +784,9 @@ def _render_megakernel( m.light_type, m.light_castshadow, m.light_active, + m.flex_vertadr, + m.flex_edge, + m.flex_radius, m.mesh_faceadr, m.mat_texid, m.mat_texrepeat, @@ -672,9 +797,11 @@ def _render_megakernel( d.cam_xmat, d.light_xpos, d.light_xdir, + d.flexvert_xpos, rc.nrender, rc.use_shadows, rc.bvh_ngeom, + rc.bvh_nflexgeom, rc.cam_res, rc.cam_id_map, rc.ray, @@ -695,6 +822,8 @@ def _render_megakernel( rc.mesh_texcoord_offsets, rc.hfield_bvh_id, rc.flex_rgba, + rc.flex_geom_flexid, + rc.flex_geom_edgeid, rc.textures, ], outputs=[ diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 0254ec21f..595409b82 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -41,6 +41,9 @@ wp.set_module_options({"enable_backward": True}) +_nograd_copy_2d = support._nograd_copy + + # kernel_analyzer: off @wp.func def _process_joint( @@ -126,94 +129,113 @@ def _kinematics_branch( jntadr = body_jntadr[bodyid] jntnum = body_jntnum[bodyid] + # Check for freejoint — handled separately because it reads position and + # quaternion directly from qpos rather than composing with the parent + # transform. We use an integer flag instead of ``continue`` because + # Warp's AD replay for ``continue`` inside a dynamic for-loop emits a + # goto that skips all adjoint code for that iteration, zeroing gradients. + is_free = int(0) if jntnum == 1: jnt_type_ = jnt_type[jntadr] if jnt_type_ == JointType.FREE: - qadr = jnt_qposadr[jntadr] - xpos = wp.vec3(qpos[qadr], qpos[qadr + 1], qpos[qadr + 2]) - xquat = wp.quat(qpos[qadr + 3], qpos[qadr + 4], qpos[qadr + 5], qpos[qadr + 6]) - xquat = wp.normalize(xquat) - - xpos_out[worldid, bodyid] = xpos - xquat_out[worldid, bodyid] = xquat - xanchor_out[worldid, jntadr] = xpos - xaxis_out[worldid, jntadr] = jnt_axis[worldid % jnt_axis.shape[0], jntadr] - continue + is_free = int(1) - # regular or no joints - # apply fixed translation and rotation relative to parent - jnt_pos_id = worldid % jnt_pos.shape[0] - pid = body_parentid[bodyid] + if is_free == int(1): + qadr = jnt_qposadr[jntadr] + xpos = wp.vec3(qpos[qadr], qpos[qadr + 1], qpos[qadr + 2]) + xquat = wp.quat(qpos[qadr + 3], qpos[qadr + 4], qpos[qadr + 5], qpos[qadr + 6]) + xquat = wp.normalize(xquat) - # mocap bodies have world body as parent - mocapid = body_mocapid[bodyid] - if mocapid >= 0: - xpos = mocap_pos_in[worldid, mocapid] - xquat = mocap_quat_in[worldid, mocapid] + xanchor_out[worldid, jntadr] = xpos + xaxis_out[worldid, jntadr] = jnt_axis[worldid % jnt_axis.shape[0], jntadr] else: - xpos = body_pos[worldid % body_pos.shape[0], bodyid] - xquat = body_quat[worldid % body_quat.shape[0], bodyid] + # regular or no joints + # apply fixed translation and rotation relative to parent + jnt_pos_id = worldid % jnt_pos.shape[0] + pid = body_parentid[bodyid] + + # mocap bodies have world body as parent + mocapid = body_mocapid[bodyid] + if mocapid >= 0: + xpos = mocap_pos_in[worldid, mocapid] + xquat = mocap_quat_in[worldid, mocapid] + else: + xpos = body_pos[worldid % body_pos.shape[0], bodyid] + xquat = body_quat[worldid % body_quat.shape[0], bodyid] + + if pid >= 0: + xpos = math.rot_vec_quat(xpos, xquat_out[worldid, pid]) + xpos_out[worldid, pid] + xquat = math.mul_quat(xquat_out[worldid, pid], xquat) + + # Unrolled joint processing — avoids nested dynamic-range loop which + # produces incorrect gradients in Warp's AD. + if jntnum >= 1: + xpos, xquat = _process_joint( + xpos, + xquat, + jntadr, + jnt_pos_id, + worldid, + qpos0, + jnt_type, + jnt_qposadr, + jnt_pos, + jnt_axis, + qpos, + xanchor_out, + xaxis_out, + ) + if jntnum >= 2: + xpos, xquat = _process_joint( + xpos, + xquat, + jntadr + 1, + jnt_pos_id, + worldid, + qpos0, + jnt_type, + jnt_qposadr, + jnt_pos, + jnt_axis, + qpos, + xanchor_out, + xaxis_out, + ) + if jntnum >= 3: + xpos, xquat = _process_joint( + xpos, + xquat, + jntadr + 2, + jnt_pos_id, + worldid, + qpos0, + jnt_type, + jnt_qposadr, + jnt_pos, + jnt_axis, + qpos, + xanchor_out, + xaxis_out, + ) + if jntnum >= 4: + xpos, xquat = _process_joint( + xpos, + xquat, + jntadr + 3, + jnt_pos_id, + worldid, + qpos0, + jnt_type, + jnt_qposadr, + jnt_pos, + jnt_axis, + qpos, + xanchor_out, + xaxis_out, + ) - if pid >= 0: - xpos = math.rot_vec_quat(xpos, xquat_out[worldid, pid]) + xpos_out[worldid, pid] - xquat = math.mul_quat(xquat_out[worldid, pid], xquat) + xquat = wp.normalize(xquat) - # Unrolled joint processing — avoids nested dynamic-range loop which - # produces incorrect gradients in Warp's AD. - if jntnum >= 1: - xpos, xquat = _process_joint( - xpos, xquat, jntadr, jnt_pos_id, worldid, qpos0, jnt_type, jnt_qposadr, jnt_pos, jnt_axis, qpos, xanchor_out, xaxis_out - ) - if jntnum >= 2: - xpos, xquat = _process_joint( - xpos, - xquat, - jntadr + 1, - jnt_pos_id, - worldid, - qpos0, - jnt_type, - jnt_qposadr, - jnt_pos, - jnt_axis, - qpos, - xanchor_out, - xaxis_out, - ) - if jntnum >= 3: - xpos, xquat = _process_joint( - xpos, - xquat, - jntadr + 2, - jnt_pos_id, - worldid, - qpos0, - jnt_type, - jnt_qposadr, - jnt_pos, - jnt_axis, - qpos, - xanchor_out, - xaxis_out, - ) - if jntnum >= 4: - xpos, xquat = _process_joint( - xpos, - xquat, - jntadr + 3, - jnt_pos_id, - worldid, - qpos0, - jnt_type, - jnt_qposadr, - jnt_pos, - jnt_axis, - qpos, - xanchor_out, - xaxis_out, - ) - - xquat = wp.normalize(xquat) xpos_out[worldid, bodyid] = xpos xquat_out[worldid, bodyid] = xquat @@ -2151,20 +2173,18 @@ def _comvel_branch( jntid = body_jntadr[bodyid] jntnum = body_jntnum[bodyid] - if jntnum == 0: - cvel_out[worldid, bodyid] = cvel - continue - - # unrolled joint processing — avoids nested dynamic-range loop which - # produces incorrect gradients in warp's AD + # Use if/else instead of ``continue`` — Warp's AD replay for + # ``continue`` inside a dynamic for-loop skips adjoint code. if jntnum >= 1: + # unrolled joint processing — avoids nested dynamic-range loop which + # produces incorrect gradients in warp's AD cvel, dofid = _process_joint_vel(cvel, dofid, jntid, worldid, jnt_type, qvel, cdof, cdof_dot_out) - if jntnum >= 2: - cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 1, worldid, jnt_type, qvel, cdof, cdof_dot_out) - if jntnum >= 3: - cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 2, worldid, jnt_type, qvel, cdof, cdof_dot_out) - if jntnum >= 4: - cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 3, worldid, jnt_type, qvel, cdof, cdof_dot_out) + if jntnum >= 2: + cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 1, worldid, jnt_type, qvel, cdof, cdof_dot_out) + if jntnum >= 3: + cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 2, worldid, jnt_type, qvel, cdof, cdof_dot_out) + if jntnum >= 4: + cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 3, worldid, jnt_type, qvel, cdof, cdof_dot_out) cvel_out[worldid, bodyid] = cvel @@ -2851,7 +2871,13 @@ def transmission(m: Model, d: Data): ) -@wp.kernel +# Sparse solve kernels have enable_backward=False because Warp's auto-AD +# for in-place operations (x used as both input and output) accumulates +# rather than replaces gradients, producing ~2x the correct result. +# The manual _record_fwd_accel_adjoint callback handles the correct backward +# (qacc_smooth.grad -> qfrc_smooth.grad via M^{-1}) for both sparse and dense. +# This matches the dense Cholesky kernels which also have enable_backward=False. +@wp.kernel(enable_backward=False) def _solve_LD_sparse_x_acc_up( # In: L: wp.array3d(dtype=float), @@ -2865,7 +2891,7 @@ def _solve_LD_sparse_x_acc_up( wp.atomic_sub(x[worldid], i, L[worldid, 0, Madr_ki] * x[worldid, k]) -@wp.kernel +@wp.kernel(enable_backward=False) def _solve_LD_sparse_qLDiag_mul( # In: D: wp.array2d(dtype=float), @@ -2876,7 +2902,7 @@ def _solve_LD_sparse_qLDiag_mul( out[worldid, dofid] = out[worldid, dofid] * D[worldid, dofid] -@wp.kernel +@wp.kernel(enable_backward=False) def _solve_LD_sparse_x_acc_down( # In: L: wp.array3d(dtype=float), @@ -2898,8 +2924,16 @@ def _solve_LD_sparse( x: wp.array2d(dtype=float), y: wp.array2d(dtype=float), ): - """Computes sparse backsubstitution: x = inv(L'*D*L)*y.""" - wp.copy(x, y) + """Computes sparse backsubstitution: x = inv(L'*D*L)*y. + + The solve kernels have enable_backward=False to avoid Warp's auto-AD + accumulation bug with in-place operations (x used as both input and output + gives ~2x the correct gradient). The manual _record_fwd_accel_adjoint + callback handles the qacc_smooth -> qfrc_smooth gradient path instead. + """ + # Use _nograd_copy_2d so the initial y->x copy doesn't create an auto-AD + # gradient path (the manual adjoint handles qacc_smooth -> qfrc_smooth). + wp.launch(_nograd_copy_2d, dim=(d.nworld, m.nv), inputs=[y], outputs=[x]) for qLD_updates in reversed(m.qLD_updates): wp.launch(_solve_LD_sparse_x_acc_up, dim=(d.nworld, qLD_updates.size), inputs=[L, qLD_updates], outputs=[x]) diff --git a/mujoco_warp/_src/solver.py b/mujoco_warp/_src/solver.py index 58a255c55..3feb7a733 100644 --- a/mujoco_warp/_src/solver.py +++ b/mujoco_warp/_src/solver.py @@ -2066,8 +2066,8 @@ def kernel( gauss_cost += (efc_Ma_in[worldid, ii] - qfrc_smooth_in[worldid, ii]) * ( qacc_in[worldid, ii] - qacc_smooth_in[worldid, ii] ) - wp.atomic_add(ctx_gauss_out, worldid, gauss_cost) - wp.atomic_add(ctx_cost_out, worldid, gauss_cost) + wp.atomic_add(ctx_gauss_out, worldid, 0.5 * gauss_cost) + wp.atomic_add(ctx_cost_out, worldid, 0.5 * gauss_cost) return kernel @@ -3317,7 +3317,12 @@ def init_context(m: types.Model, d: types.Data, ctx: SolverContext | InverseCont @event_scope def solve(m: types.Model, d: types.Data): if d.njmax == 0 or m.nv == 0: - wp.copy(d.qacc, d.qacc_smooth) + wp.launch( + support._nograd_copy, + dim=(d.nworld, m.nv), + inputs=[d.qacc_smooth], + outputs=[d.qacc], + ) d.solver_niter.fill_(0) else: ctx = create_solver_context(m, d) @@ -3327,9 +3332,19 @@ def solve(m: types.Model, d: types.Data): def _solve(m: types.Model, d: types.Data, ctx: SolverContext): """Finds forces that satisfy constraints.""" if not (m.opt.disableflags & types.DisableBit.WARMSTART): - wp.copy(d.qacc, d.qacc_warmstart) + wp.launch( + support._nograd_copy, + dim=(d.nworld, m.nv), + inputs=[d.qacc_warmstart], + outputs=[d.qacc], + ) else: - wp.copy(d.qacc, d.qacc_smooth) + wp.launch( + support._nograd_copy, + dim=(d.nworld, m.nv), + inputs=[d.qacc_smooth], + outputs=[d.qacc], + ) # context init_context(m, d, ctx, grad=True) diff --git a/mujoco_warp/_src/solver_test.py b/mujoco_warp/_src/solver_test.py index 4561c6ebd..5487771fb 100644 --- a/mujoco_warp/_src/solver_test.py +++ b/mujoco_warp/_src/solver_test.py @@ -677,7 +677,7 @@ def test_solver_retained_state(self, solver_, jacobian): qacc = d.qacc.numpy()[0] # Verify Jaref = efc_J @ qacc - efc_aref - if SPARSE_CONSTRAINT_JACOBIAN: + if m.is_sparse: efc_J_raw = d.efc.J.numpy()[0, 0] colind = d.efc.J_colind.numpy()[0, 0] rownnz = d.efc.J_rownnz.numpy()[0] diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index d9a70490b..c1b73beb8 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -18,18 +18,68 @@ import warp as wp from mujoco_warp._src.math import motion_cross +from mujoco_warp._src.types import MJ_MINVAL from mujoco_warp._src.types import ConeType from mujoco_warp._src.types import Data +from mujoco_warp._src.types import DynType from mujoco_warp._src.types import JointType from mujoco_warp._src.types import Model from mujoco_warp._src.types import State from mujoco_warp._src.types import vec5 +from mujoco_warp._src.types import vec10f from mujoco_warp._src.warp_util import cache_kernel from mujoco_warp._src.warp_util import event_scope wp.set_module_options({"enable_backward": False}) +# Copy kernel invisible to tape backward. Used instead of wp.copy() when a +# manual adjoint callback (e.g. record_func) already handles the backward path. +# wp.copy is a Warp built-in whose backward IS tracked regardless of +# module-level enable_backward, causing double-counting with the manual adjoint. +@wp.kernel(enable_backward=False) +def _nograd_copy( + # In: + src: wp.array2d(dtype=float), + # Out: + dst_out: wp.array2d(dtype=float), +): + worldid, idx = wp.tid() + if idx < src.shape[1]: + dst_out[worldid, idx] = src[worldid, idx] + + +# TODO(team): kernel analyzer array slice? +@wp.func +def next_act( + # Model: + opt_timestep: float, # kernel_analyzer: ignore + actuator_dyntype: int, # kernel_analyzer: ignore + actuator_dynprm: vec10f, # kernel_analyzer: ignore + actuator_actrange: wp.vec2, # kernel_analyzer: ignore + # Data In: + act_in: float, # kernel_analyzer: ignore + act_dot_in: float, # kernel_analyzer: ignore + # In: + act_dot_scale: float, + clamp: bool, +) -> float: + # advance actuation + if actuator_dyntype == DynType.FILTEREXACT: + tau = wp.max(MJ_MINVAL, actuator_dynprm[0]) + act = act_in + act_dot_scale * act_dot_in * tau * (1.0 - wp.exp(-opt_timestep / tau)) + elif actuator_dyntype == DynType.USER: + return act_in + else: + act = act_in + act_dot_scale * act_dot_in * opt_timestep + + # clamp to actrange + if clamp: + act = wp.clamp(act, actuator_actrange[0], actuator_actrange[1]) + + return act + + @cache_kernel def mul_m_sparse(check_skip: bool): @wp.kernel(module="unique") diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 21ac957bb..85f85a6f9 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -344,6 +344,7 @@ class GeomType(enum.IntEnum): BOX: box MESH: mesh SDF: sdf + FLEX: flex """ PLANE = mujoco.mjtGeom.mjGEOM_PLANE @@ -355,6 +356,7 @@ class GeomType(enum.IntEnum): BOX = mujoco.mjtGeom.mjGEOM_BOX MESH = mujoco.mjtGeom.mjGEOM_MESH SDF = mujoco.mjtGeom.mjGEOM_SDF + FLEX = mujoco.mjtGeom.mjGEOM_FLEX # unsupported: NGEOMTYPES, ARROW*, LINE, SKIN, LABEL, NONE @@ -980,7 +982,8 @@ class Model: flex_edgenum: number of edges (nflex,) flex_elemadr: first element address (nflex,) flex_elemnum: number of elements (nflex,) - flex_elemedgeadr: first element address (nflex,) + flex_elemdataadr: first element vertex id address (nflex,) + flex_elemedgeadr: first element edge id address (nflex,) flex_shellnum: number of shells (nflex,) flex_shelldataadr: first shell data address (nflex,) flex_vertbodyid: vertex body ids (nflexvert,) @@ -1366,6 +1369,7 @@ class Model: flex_edgenum: array("nflex", int) flex_elemadr: array("nflex", int) flex_elemnum: array("nflex", int) + flex_elemdataadr: array("nflex", int) flex_elemedgeadr: array("nflex", int) flex_shellnum: array("nflex", int) flex_shelldataadr: array("nflex", int) @@ -1773,6 +1777,15 @@ class Data: njmax_nnz: number of non-zeros in constraint Jacobian nacon: number of detected contacts (across all worlds) (1,) ncollision: collision count from broadphase (1,) + solver_h: solver retained Hessian for backward pass + solver_hfactor: solver retained factored Hessian for backward pass + solver_Jaref: solver retained Jacobian reference for backward pass + smooth_adjoint: enable smooth constraint adjoint (0=off, 1=on) + smooth_friction_viscosity: D value for SATISFIED friction constraints in smooth adjoint + smooth_friction_scale: D scale factor for QUADRATIC friction constraints in smooth adjoint + smooth_friction_surrogate_adjoint: replace friction-face backward projections + with damped free-body targets while keeping solver-informed normal handling + smooth_friction_surrogate_alpha: damping factor for the friction surrogate """ solver_niter: array("nworld", int) @@ -1898,21 +1911,12 @@ class RenderContext: hfield_registry: hfield BVH id to warp mesh mapping hfield_bvh_id: hfield BVH ids hfield_bounds_size: hfield bounds half-extents - flex_mesh: flex mesh + flex_mesh_registry: per-flex mesh BVH registry (prevents garbage collection) flex_rgba: flex rgba - flex_bvh_id: flex BVH id - flex_face_point: flex face points - flex_faceadr: flex face addresses - flex_nface: number of flex faces - flex_nwork: total flex work items for refit - flex_group_root: flex group roots - flex_elemdataadr: flex element data addresses - flex_shell: flex shell data - flex_shelldataadr: flex shell data addresses - flex_radius: flex radius - flex_workadr: flex work item addresses for refit - flex_worknum: flex work item counts for refit + flex_bvh_id: per-flex BVH ids + flex_group_root: per-flex group roots (nworld x n_flex_bvh) flex_render_smooth: whether to render flex meshes smoothly + flex_dim: flex dimension per flex (1D/2D/3D) bvh: scene BVH bvh_id: scene BVH id lower: lower bounds @@ -1922,10 +1926,8 @@ class RenderContext: ray: rays rgb_data: RGB data rgb_adr: RGB addresses - rgb_size: per-camera RGB buffer sizes depth_data: depth data depth_adr: depth addresses - depth_size: per-camera depth buffer sizes render_rgb: per-camera RGB render flags render_depth: per-camera depth render flags seg_data: segmentation data (per-pixel geom IDs) @@ -1955,21 +1957,15 @@ class RenderContext: hfield_registry: dict hfield_bvh_id: array("nhfield", wp.uint64) hfield_bounds_size: array("nhfield", wp.vec3) - flex_mesh: wp.Mesh + flex_mesh_registry: dict flex_rgba: array("nflex", wp.vec4) - flex_bvh_id: wp.uint64 - flex_face_point: array("*", wp.vec3) - flex_faceadr: array("nflex", int) - flex_nface: int - flex_nwork: int - flex_group_root: array("nworld", int) - flex_elemdataadr: array("nflex", int) - flex_shell: array("*", int) - flex_shelldataadr: array("nflex", int) - flex_radius: array("nflex", float) - flex_workadr: array("nflex", int) - flex_worknum: array("nflex", int) + flex_bvh_id: array("*", wp.uint64) + flex_group_root: array("nworld", "*", int) flex_render_smooth: bool + bvh_nflexgeom: int + flex_dim_np: array("nflex", int) + flex_geom_flexid: array("*", int) + flex_geom_edgeid: array("*", int) bvh: wp.Bvh bvh_id: wp.uint64 lower: array("*", wp.vec3) diff --git a/mujoco_warp/test_data/flex/floppy.xml b/mujoco_warp/test_data/flex/floppy.xml index dfdb973dc..fa5fa6d63 100644 --- a/mujoco_warp/test_data/flex/floppy.xml +++ b/mujoco_warp/test_data/flex/floppy.xml @@ -27,6 +27,7 @@ + diff --git a/mujoco_warp/test_data/flex/multiflex.xml b/mujoco_warp/test_data/flex/multiflex.xml new file mode 100644 index 000000000..bd7423d6c --- /dev/null +++ b/mujoco_warp/test_data/flex/multiflex.xml @@ -0,0 +1,42 @@ + + diff --git a/mujoco_warp/test_data/flex/rope.xml b/mujoco_warp/test_data/flex/rope.xml new file mode 100644 index 000000000..8f07a1611 --- /dev/null +++ b/mujoco_warp/test_data/flex/rope.xml @@ -0,0 +1,31 @@ + + diff --git a/mujoco_warp/testspeed.py b/mujoco_warp/testspeed.py index bae97bf99..67d98a12c 100644 --- a/mujoco_warp/testspeed.py +++ b/mujoco_warp/testspeed.py @@ -18,7 +18,7 @@ Usage: mjwarp-testspeed [flags] Example: - mjwarp-testspeed benchmark/humanoid/humanoid.xml --nworld 4096 -o "opt.solver=cg" + mjwarp-testspeed benchmarks/humanoid/humanoid.xml --nworld 4096 -o "opt.solver=cg" """ import dataclasses diff --git a/mujoco_warp/viewer.py b/mujoco_warp/viewer.py index 28797ac84..5659ad6b8 100644 --- a/mujoco_warp/viewer.py +++ b/mujoco_warp/viewer.py @@ -18,7 +18,7 @@ Usage: mjwarp-viewer [flags] Example: - mjwarp-viewer benchmark/humanoid/humanoid.xml -o "opt.solver=cg" + mjwarp-viewer benchmarks/humanoid/humanoid.xml -o "opt.solver=cg" """ import copy