Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mujoco_warp/_src/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,8 @@ def fwd_position(m: Model, d: Data, factorize: bool = True):
if m.opt.run_collision_detection:
collision_driver.collision(m, d)
constraint.make_constraint(m, d)
# TODO(team): remove False after island features are more complete
if False and not (m.opt.disableflags & DisableBit.ISLAND):
# TODO(team): m.opt.disableflags & DisableBit.ISLAND
if m.opt.use_islands:
island.island(m, d)
smooth.transmission(m, d)

Expand Down
15 changes: 14 additions & 1 deletion mujoco_warp/_src/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def _check_friction(name: str, id_: int, condim: int, friction, checks):
opt.contact_sensor_maxmatch = mjm.numeric_data[mjm.numeric_adr[contact_sensor_maxmatch_id]]
else:
opt.contact_sensor_maxmatch = 64
# TODO(team): remove and use opt.disableflags logic
opt.use_islands = False

# place opt on device
for f in dataclasses.fields(types.Option):
Expand Down Expand Up @@ -1145,7 +1147,7 @@ def put_data(
continue
shape = tuple(sizes[dim] if isinstance(dim, str) else dim for dim in f.type.shape)
val = np.zeros(shape, dtype=f.type.dtype)
if f.name in ("type", "id", "pos", "margin", "D", "vel", "aref", "frictionloss", "force"):
if f.name in ("type", "id", "pos", "margin", "D", "vel", "aref", "frictionloss", "force", "island"):
val[:, : mjd.nefc] = np.tile(getattr(mjd, "efc_" + f.name), (nworld, 1))
efc_kwargs[f.name] = wp.array(val, dtype=f.type.dtype)

Expand Down Expand Up @@ -1206,6 +1208,8 @@ def put_data(
# island arrays
"nisland": None,
"tree_island": None,
"dof_island": None,
"island_dofadr": None,
}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
Expand Down Expand Up @@ -1234,6 +1238,11 @@ def put_data(
# island arrays
d.nisland = wp.array(np.full(nworld, mjd.nisland), dtype=int)
d.tree_island = wp.array(np.tile(mjd.tree_island, (nworld, 1)), dtype=int)
d.dof_island = wp.array(np.tile(mjd.dof_island, (nworld, 1)), dtype=int)
val = np.zeros((nworld, mjm.ntree), dtype=int)
if mjd.nisland > 0:
val[:, : mjd.nisland] = np.tile(mjd.island_dofadr, (nworld, 1))
d.island_dofadr = wp.array(val, dtype=int)

d.nacon = wp.array([mjd.ncon * nworld], dtype=int)

Expand Down Expand Up @@ -1430,6 +1439,7 @@ def get_data_into(
result.efc_frictionloss[:] = d.efc.frictionloss.numpy()[world_id, efc_idx]
result.efc_state[:] = d.efc.state.numpy()[world_id, efc_idx]
result.efc_force[:] = d.efc.force.numpy()[world_id, efc_idx]
result.efc_island[:] = d.efc.island.numpy()[world_id, efc_idx]

# rne_postconstraint
result.cacc[:] = d.cacc.numpy()[world_id]
Expand All @@ -1453,6 +1463,8 @@ def get_data_into(
result.nisland = nisland
if nisland:
result.tree_island[:] = d.tree_island.numpy()[world_id]
result.dof_island[:] = d.dof_island.numpy()[world_id]
result.island_dofadr[:nisland] = d.island_dofadr.numpy()[world_id, :nisland]


def reset_data(m: types.Model, d: types.Data, reset: Optional[wp.array] = None):
Expand Down Expand Up @@ -2527,6 +2539,7 @@ def override_model(model: types.Model | mujoco.MjModel, overrides: dict[str, Any
"opt.ls_parallel",
"opt.graph_conditional",
"opt.contact_sensor_maxmatch",
"opt.use_islands",
}
mj_only_fields = {"opt.jacobian"}

Expand Down
Loading
Loading