This project requires Python 3.10 or later and a working JAX installation. To install JAX, refer to the instructions.
pip install --upgrade pip
pip install -r requirements.txtThere are three main scripts. Each have a number of command line arguments that can be obtained by running: python <script_name>.py --help.
To run a training, use the train.py script. This will create a folder in the directory results/ which contains a config file. By the end of the training a tasks.png visualization should also automatically be created. See --help for more information on the hyperparameters.
To optimize in the latent space with human feedback, run the humanfeedback.py script. You can precise the run folder with --run_path or the environment with --env. See --help for more information.
At the end, a pathhf.npy should be created, as well as a plot representing the path inside the latent space.
To linearly interpolate between behaviors, run the interpolation.py script. This will directly fetch the successes.npz file created after training the agent, calculate the barycenters of each task in the latent space and start the visualization. You can move the slider to move between behaviors.

