This is a package for training control policies through motion imitation using deep reinforcement learning. Part of MIMIC-MJX, along with STAC-MJX (a tool for performing inverse kinematics on markerless motion tracking data).
Please use the latest stable version of track-mjx (v0.0.1) for notebook demos and running rodent training example.
track-mjx v1 will soon include all body models, related notebooks and training logic. track-mjx v1 and on will rely on vnl-playground for the environment and task logic. vnl-playground will be installed during the following installation steps along with other needed libraries. For more information regarding the environment and task logic, please visit vnl-playground.
- Python 3.11 or 3.12
- uv package manager (recommended) or pip
- CUDA 12.x or 13.x (for GPU support, optional)
If you don't have uv installed:
# Linux/macOS
curl -LsSf https://astral.sh/uv/install.sh | sh
# Or using pip
pip install uv- Clone the repository:
git clone https://github.com/talmolab/track-mjx.git
cd track-mjx- Create and activate a virtual environment:
uv venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate- Install the package with optional dependencies based on your hardware. CUDA 12, CUDA 13, and CPU-only configurations are supported. This will take a few minutes:
For CUDA 12.x:
uv pip install -e ".[cuda12]"For CUDA 13.x:
uv pip install -e ".[cuda13]"For CPU-only:
uv pip install -e .For development, include the [dev] extras in addition to the hardware optional dependencies:
uv pip install -e ".[cuda13,dev]"- Verify the installation:
python -c "import jax; print(f'JAX version: {jax.__version__}'); print(f'Available devices: {jax.devices()}')"- Register the environment as a Jupyter kernel:
python -m ipykernel install --user --name=track-mjx --display-name="Python (track-mjx)"- Test the environment:
Execute the tests in
notebooks/test_setup.ipynb. This will check if MuJoCo, GPU support and Jax appear to be working.
If you prefer using pip instead of uv:
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
pip install -e ".[cuda13]" # or cuda12/no optional depsCUDA version mismatch:
- Check your CUDA version:
nvcc --versionornvidia-smi - Ensure you install the matching JAX CUDA version (cuda12 or cuda13)
Import errors:
- Verify the virtual environment is activated
- Try reinstalling:
uv pip install --force-reinstall -e ".[cuda13]"
GPU not detected:
- Verify CUDA installation:
nvidia-smi - Check that JAX can see GPUs:
python -c "import jax; print(jax.devices())"
Expected output:
- GPU: Should show
cudaorgpudevices - CPU: Should show
cpudevice
- Clone the repository:
git clone https://github.com/talmolab/track-mjx.git && cd track-mjx
- Create a new development environment via
conda(this will create the necessary base environment):conda env create -f environment.yml
- Activate the environment:
conda activate track-mjx
- Install the package with desired CUDA version:
If your machine supports up to CUDA 13:
If your machine supports up to CUDA 12:
pip install -e ".[cuda12]"If your machine only has a CPU:pip install -e ".[cuda13]"pip install -e . - Test the environment:
Execute the tests in
notebooks/test_setup.ipynb. This will check if MuJoCo, GPU support and Jax appear to be working.
The main training entrypoint is defined in scripts/train.py and relies on the config in track_mjx/config/rodent-full-clips.yaml.
To download data, run notebooks/rodent_demo.ipynb
Execute the following command in terminal
python -c "from huggingface_hub import hf_hub_download; hf_hub_download(repo_id='talmolab/MIMIC-MJX', repo_type='dataset', filename='data/rodent/rodent_reference_clips.h5', local_dir='.')"Using uv:
uv run python -m track_mjx.train --config-name rodent-full-clips.yamlUsing conda:
conda activate track_mjx
python -m track_mjx.train --config-name rodent-full-clips.yamlWe provide generic scripts to train policies on any task registered in vnl-playground.
Trains an end-to-end MLP policy using Brax PPO. Supports both the default JAX/MJX backend and the Warp backend (for full-collision body models).
# Any registered task
python scripts/train_task.py --task RodentBowlEscape
# With PPO overrides
python scripts/train_task.py --task RodentRearing --num_timesteps 1e8 --entropy_cost 0.1
# With env config overrides
python scripts/train_task.py --task RodentBowlEscape --env "target_speed=1.5"
# Warp backend (full-collision body model)
python scripts/train_task.py --task RodentBowlEscape --env "mujoco_impl=warp"Trains a high-level policy that outputs latent intentions to a frozen pretrained mimic decoder. The decoder converts intentions into naturalistic motor commands.
# Any registered task
python scripts/train_highlvl.py --task RodentBowlEscape --mimic_checkpoint <checkpoint_id>
# With PPO overrides
python scripts/train_highlvl.py --task RodentRearing \
--mimic_checkpoint <checkpoint_id> --num_timesteps 1e8 --entropy_cost 0.1Both scripts support --policy_hidden_sizes, --value_hidden_sizes, --env (for env config overrides), and standard PPO hyperparameter flags. Run with --help for full usage.
If you use track-mjx in your research, please cite our paper:
@misc{zhang2025mimicmjxneuromechanicalemulationanimal,
title={MIMIC-MJX: Neuromechanical Emulation of Animal Behavior},
author={Charles Y. Zhang and Yuanjia Yang and Aidan Sirbu and Elliott T. T. Abe and Emil Wärnberg and Eric J. Leonardis and Diego E. Aldarondo and Adam Lee and Aaditya Prasad and Jason Foat and Kaiwen Bian and Joshua Park and Rusham Bhatt and Hutton Saunders and Akira Nagamori and Ayesha R. Thanawalla and Kee Wui Huang and Fabian Plum and Hendrik K. Beck and Steven W. Flavell and David Labonte and Blake A. Richards and Bingni W. Brunton and Eiman Azim and Bence P. Ölveczky and Talmo D. Pereira},
year={2025},
eprint={2511.20532},
archivePrefix={arXiv},
primaryClass={q-bio.NC},
url={https://arxiv.org/abs/2511.20532},
}This package is distributed under a BSD 3-Clause License and can be used without
restrictions. See LICENSE for details.