Skip to content

Commit 997aba7

Browse files
authored
Merge pull request #5 from juelg/feat/openpi
feat: support for the openpi model family
2 parents 6726fe3 + 5e06661 commit 997aba7

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,33 @@ pip install git+https://github.com/juelg/agents.git
100100

101101
For more details, see the [OpenVLA github page](https://github.com/openvla/openvla).
102102

103+
### OpenPi / Pi0
104+
To use OpenPi, create a new conda environment:
105+
```shell
106+
conda create -n openpi python=3.11 -y
107+
conda activate openpi
108+
```
109+
Clone the repo and install it.
110+
```shell
111+
git clone --recurse-submodules [email protected]:Physical-Intelligence/openpi.git
112+
# Or if you already cloned the repo:
113+
git submodule update --init --recursive
114+
# install dependencies
115+
GIT_LFS_SKIP_SMUDGE=1 uv sync
116+
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
117+
```
118+
For more details see [openpi's github](https://github.com/Physical-Intelligence/openpi).
119+
120+
103121
## Usage
104122
To start an agents server use the `start-server` command where `kwargs` is a dictionary of the constructor arguments of the policy you want to start e.g.
105123
```shell
106124
# octo
107125
python -m agents start-server octo --host localhost --port 8080 --kwargs '{"checkpoint_path": "hf://Juelg/octo-base-1.5-finetuned-maniskill", "checkpoint_step": None, "horizon": 1, "unnorm_key": []}'
108126
# openvla
109127
python -m agents start-server openvla --host localhost --port 8080 --kwargs '{"checkpoint_path": "Juelg/openvla-7b-finetuned-maniskill", "device": "cuda:0", "attn_implementation": "flash_attention_2", "unnorm_key": "maniskill_human:7.0.0", "checkpoint_step": 40000}'
128+
# openpi
129+
python -m agents start-server openpi --port=8080 --host=localhost --kwargs='{"checkpoint_path": "<path to checkpoint>/{checkpoint_step}", "train_config_name": "pi0_rcs", "checkpoint_step": <checkpoint_step>}' # leave "{checkpoint_step}" it will be replaced, "train_config_name" is the key for the training config
110130
```
111131

112132
There is also the `run-eval-during-training` command to evaluate a model during training, so a single checkpoint.

src/agents/policies.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,66 @@ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]:
124124
return info
125125

126126

127+
class OpenPiModel(Agent):
128+
129+
def __init__(
130+
self,
131+
train_config_name: str = "pi0_droid",
132+
default_checkpoint_path: str = "gs://openpi-assets/checkpoints/pi0_droid",
133+
execution_horizon=20,
134+
**kwargs,
135+
) -> None:
136+
super().__init__(default_checkpoint_path=default_checkpoint_path, **kwargs)
137+
from openpi.training import config
138+
139+
logging.info(f"checkpoint_path: {self.checkpoint_path}, checkpoint_step: {self.checkpoint_step}")
140+
self.openpi_path = self.checkpoint_path.format(checkpoint_step=self.checkpoint_step)
141+
142+
self.cfg = config.get_config(train_config_name)
143+
self.execution_horizon = execution_horizon
144+
145+
self.chunk_counter = self.execution_horizon
146+
self._cached_action_chunk = None
147+
148+
def initialize(self):
149+
from openpi.policies import policy_config
150+
from openpi.shared import download
151+
152+
checkpoint_dir = download.maybe_download(self.openpi_path)
153+
154+
# Create a trained policy.
155+
self.policy = policy_config.create_trained_policy(self.cfg, checkpoint_dir)
156+
157+
def act(self, obs: Obs) -> Act:
158+
if self.chunk_counter < self.execution_horizon:
159+
self.chunk_counter += 1
160+
return Act(action=self._cached_action_chunk[self.chunk_counter])
161+
162+
else:
163+
self.chunk_counter = 0
164+
observation = {f"observation/{k}": np.copy(v).transpose(2, 0, 1) for k, v in obs.cameras.items()}
165+
observation.update(
166+
{
167+
# openpi expects 0 as gripper open and 1 as closed
168+
"observation/state": np.concatenate([obs.info["joints"], [1 - obs.gripper]]),
169+
"prompt": self.instruction,
170+
}
171+
)
172+
action_chunk = self.policy.infer(observation)["actions"]
173+
174+
# convert gripper action into agents format
175+
action_chunk[:, -1] = 1 - action_chunk[:, -1]
176+
self._cached_action_chunk = action_chunk
177+
178+
return Act(action=action_chunk[0])
179+
180+
def reset(self, obs: Obs, instruction: Any):
181+
super().reset(obs, instruction)
182+
self.chunk_counter = self.execution_horizon
183+
self._cached_action_chunk = None
184+
return {}
185+
186+
127187
class OpenVLAModel(Agent):
128188
# === Utilities ===
129189
SYSTEM_PROMPT = (
@@ -457,4 +517,5 @@ def act(self, obs: Obs) -> Act:
457517
openvla=OpenVLAModel,
458518
octodist=OctoActionDistribution,
459519
openvladist=OpenVLADistribution,
520+
openpi=OpenPiModel,
460521
)

0 commit comments

Comments
 (0)