Skip to content

Commit e5fe9b5

Browse files
authored
update test_get_data_into_io_test_models (#1189)
* update test_get_data_into_io_test_models * fix tendon_armature indexing * fix more indexing bugs * fix formatting
1 parent 324fddc commit e5fe9b5

File tree

3 files changed

+60
-53
lines changed

3 files changed

+60
-53
lines changed

mujoco_warp/_src/forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ def _actuator_force(
647647
act = act_in[worldid, act_last]
648648
act_dot = (ctrl - act) / wp.max(dynprm[0], MJ_MINVAL)
649649
elif dyntype == DynType.MUSCLE:
650-
dynprm = actuator_dynprm[worldid, uid]
650+
dynprm = actuator_dynprm[worldid % actuator_dynprm.shape[0], uid]
651651
act = act_in[worldid, act_last]
652652
act_dot = util_misc.muscle_dynamics(ctrl, act, dynprm)
653653
else: # DynType.NONE

mujoco_warp/_src/io_test.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from absl.testing import parameterized
2626

2727
import mujoco_warp as mjwarp
28+
from mujoco_warp import ConeType
29+
from mujoco_warp import IntegratorType
2830
from mujoco_warp import test_data
2931
from mujoco_warp._src import warp_util
3032
from mujoco_warp._src.io import set_length_range
@@ -272,53 +274,58 @@ def test_get_data_into(self, nworld, world_id):
272274
field,
273275
)
274276

275-
@parameterized.parameters(*_IO_TEST_MODELS)
276-
def test_get_data_into_io_test_models(self, xml):
277+
@parameterized.product(
278+
xml=_IO_TEST_MODELS,
279+
cone=list(ConeType),
280+
integrator=list(IntegratorType),
281+
)
282+
def test_get_data_into_io_test_models(self, xml, cone, integrator):
277283
"""Tests get_data_into for field coverage across diverse model types."""
278-
mjm, mjd, _, d = test_data.fixture(xml)
279-
280-
# Create fresh MjData to verify get_data_into populates it correctly
281-
mjd_result = mujoco.MjData(mjm)
282-
283-
mjwarp.get_data_into(mjd_result, mjm, d)
284-
285-
# Compare key fields, including flex/tendon data not covered by humanoid.xml
286-
for field in [
287-
"qpos",
288-
"qvel",
289-
"qacc",
290-
"ctrl",
291-
"act",
292-
"flexvert_xpos",
293-
"flexedge_length",
294-
"flexedge_velocity",
295-
"ten_length",
296-
"ten_velocity",
297-
"actuator_length",
298-
"actuator_velocity",
299-
"actuator_force",
300-
"xpos",
301-
"xquat",
302-
"geom_xpos",
303-
"tree_island",
304-
]:
305-
if field == "tree_island" and d.nisland.numpy()[0] == 0:
306-
continue
307-
if getattr(mjd, field).size > 0:
284+
mjm, _, m, d = test_data.fixture(xml, nworld=2, overrides={"opt.cone": cone, "opt.integrator": integrator})
285+
mjwarp.step(m, d)
286+
287+
for world_id in range(2):
288+
# Create reference MjData from warp data (resizes contact/efc fields internally)
289+
mjd = mujoco.MjData(mjm)
290+
mjwarp.get_data_into(mjd, mjm, d, world_id=world_id)
291+
292+
# Compare key fields, including flex/tendon data not covered by humanoid.xml
293+
for field in [
294+
"qpos",
295+
"qvel",
296+
"qacc",
297+
"ctrl",
298+
"act",
299+
"flexvert_xpos",
300+
"flexedge_length",
301+
"flexedge_velocity",
302+
"ten_length",
303+
"ten_velocity",
304+
"actuator_length",
305+
"actuator_velocity",
306+
"actuator_force",
307+
"xpos",
308+
"xquat",
309+
"geom_xpos",
310+
"tree_island",
311+
]:
312+
if field == "tree_island" and d.nisland.numpy()[0] == 0:
313+
continue
314+
if getattr(mjd, field).size > 0:
315+
_assert_eq(
316+
getattr(mjd, field).reshape(-1),
317+
getattr(d, field).numpy()[world_id].reshape(-1),
318+
f"{field} (model: {xml}, world: {world_id})",
319+
)
320+
321+
# flexedge_J
322+
if xml == "flex/floppy.xml":
308323
_assert_eq(
309-
getattr(mjd_result, field).reshape(-1),
310-
getattr(mjd, field).reshape(-1),
311-
f"{field} (model: {xml})",
324+
mjd.flexedge_J.reshape(-1),
325+
d.flexedge_J.numpy()[world_id].reshape(-1),
326+
f"flexedge_J (world: {world_id})",
312327
)
313328

314-
# flexedge_J
315-
if xml == "flex/floppy.xml":
316-
_assert_eq(
317-
mjd_result.flexedge_J.reshape(-1),
318-
d.flexedge_J.numpy()[0].reshape(-1),
319-
"flexedge_J",
320-
)
321-
322329
def test_ellipsoid_fluid_model(self):
323330
mjm = mujoco.MjModel.from_xml_string(
324331
"""

mujoco_warp/_src/smooth.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,7 @@ def _tendon_armature(
908908
if is_sparse: # is_sparse is not batched
909909
madr_ij = dof_Madr[dofid]
910910

911-
armature = tendon_armature[worldid, tenid]
911+
armature = tendon_armature[worldid % tendon_armature.shape[0], tenid]
912912

913913
if armature == 0.0:
914914
return
@@ -1321,7 +1321,7 @@ def _cfrc_ext_equality(
13211321
)
13221322

13231323
id = efc_id_in[worldid, efcid]
1324-
eq_data_ = eq_data[worldid, id]
1324+
eq_data_ = eq_data[worldid % eq_data.shape[0], id]
13251325
body_semantic = eq_objtype[id] == ObjType.BODY
13261326

13271327
obj1 = eq_obj1id[id]
@@ -1342,7 +1342,7 @@ def _cfrc_ext_equality(
13421342
else:
13431343
offset = wp.vec3(eq_data_[3], eq_data_[4], eq_data_[5])
13441344
else:
1345-
offset = site_pos[worldid, obj1]
1345+
offset = site_pos[worldid % site_pos.shape[0], obj1]
13461346

13471347
# transform point on body1: local -> global
13481348
pos = xmat_in[worldid, bodyid1] @ offset + xpos_in[worldid, bodyid1]
@@ -1364,7 +1364,7 @@ def _cfrc_ext_equality(
13641364
else:
13651365
offset = wp.vec3(eq_data_[0], eq_data_[1], eq_data_[2])
13661366
else:
1367-
offset = site_pos[worldid, obj2]
1367+
offset = site_pos[worldid % site_pos.shape[0], obj2]
13681368

13691369
# transform point on body2: local -> global
13701370
pos = xmat_in[worldid, bodyid2] @ offset + xpos_in[worldid, bodyid2]
@@ -1559,7 +1559,7 @@ def _tendon_dot(
15591559
):
15601560
worldid, tenid = wp.tid()
15611561

1562-
armature = tendon_armature[worldid, tenid]
1562+
armature = tendon_armature[worldid % tendon_armature.shape[0], tenid]
15631563
if armature == 0.0:
15641564
return
15651565

@@ -1717,7 +1717,7 @@ def _tendon_bias_coef(
17171717
):
17181718
worldid, tenid, dofid = wp.tid()
17191719

1720-
armature = tendon_armature[worldid, tenid]
1720+
armature = tendon_armature[worldid % tendon_armature.shape[0], tenid]
17211721
if armature == 0.0:
17221722
return
17231723

@@ -1741,7 +1741,7 @@ def _tendon_bias_qfrc(
17411741
):
17421742
worldid, tenid, dofid = wp.tid()
17431743

1744-
armature = tendon_armature[worldid, tenid]
1744+
armature = tendon_armature[worldid % tendon_armature.shape[0], tenid]
17451745
if armature == 0.0:
17461746
return
17471747

@@ -2078,7 +2078,7 @@ def _transmission(
20782078
siteid = trnid[0]
20792079
refid = trnid[1]
20802080

2081-
gear = actuator_gear[worldid, actid]
2081+
gear = actuator_gear[actuator_gear_id, actid]
20822082
site_quat_id = worldid % site_quat.shape[0]
20832083
gear_translation = wp.spatial_top(gear)
20842084
gear_rotational = wp.spatial_bottom(gear)
@@ -2943,7 +2943,7 @@ def _spatial_geom_tendon(
29432943
# get geom information
29442944
geom_xpos = geom_xpos_in[worldid, wrap_objid_geom]
29452945
geom_xmat = geom_xmat_in[worldid, wrap_objid_geom]
2946-
geomsize = geom_size[worldid, wrap_objid_geom][0]
2946+
geomsize = geom_size[worldid % geom_size.shape[0], wrap_objid_geom][0]
29472947
geom_type = wrap_type[wrap_adr]
29482948

29492949
# get body ids for site-geom-site instances

0 commit comments

Comments
 (0)