-
Notifications
You must be signed in to change notification settings - Fork 131
Description
Note: I am using warp through the mjx interface, but still wanted to open the issue here as I believe it is related to warp side.
Me and @bmabsout are encountering a hash collision issue that causes objects to fall through the ground. Unfortunately, I can not share the model that causes the issue as it includes things I am not allowed to share (companies property). I have tried to replicate the issue with different models and failed but still wanted to let you know about it (perhaps it could be easier for you to replicate). I apologize for the inconvenience. Let me explain the issue in detail.
Process: First I set warp data and model via:
model = mjx.put_model(mj_model, impl=mjx.Impl.WARP)
data = mjx.make_data(model, impl=mjx.Impl.WARP, nconmax=..., njmax=..., device=...)
and then simply step the environment.
Experiment steps:
- Clean warp cache.
- Start a model with a sphere (and a robot I am not allowed the share) and compile the model. You will see that the primitive narrowphase is compiled:
Module _primitive_narrowphase__locals__primitive_narrowphase_112e6761 fa4a16c load on device 'cuda:0' took 506.52 ms (compiled)
- Step the environment, everything works well.
- DO NOT clean the warp cache.
- Start a model with a cube (instead of a sphere) and compile the model. You will see that primitive narrowphase is cached, even though we switched to a cube from a sphere. I see the following:
Module _primitive_narrowphase__locals__primitive_narrowphase_112e6761 fa4a16c load on device 'cuda:0' took 0.51 ms (cached)
- Step the environment and cube falls through the ground.
- Clean ONLY the module (
_primitive_narrowphase__locals__primitive_narrowphase_) mentioned above from the cache. - Start a model with a cube and compile the model. It does compile the kernel and everything works well.
I am quite confident issue is with this line. We have a local fix by switching that line with:
def order_independent_hash(items):
"""Collision-free hash for any iterable of ints, order-independent."""
hash_val = 0
for item in sorted(items): # sorting ensures order-independence
# upper sudzik as a collision free hash
hash_val = max(hash_val, item)**2 + min(hash_val, item)
return hash_val
types_hash = order_independent_hash(map(order_independent_hash, primitive_collisions_types))
unique_kernel_name = f"primitive_narrowphase_{types_hash}"
@wp.kernel(module=unique_kernel_name, enable_backward=False)
Relevant versions:
mujoco 3.5.0
mujoco-mjx 3.5.0
warp-lang 1.11.0