Skip to content

Commit 35d2bfc

Browse files
Add Tau^2 Bench training environment
1 parent 12f94df commit 35d2bfc

File tree

16 files changed

+1274
-0
lines changed

16 files changed

+1274
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Description
2+
3+
Please note that this is a dummy resource environment, as the tau^2 agent needs to be run under a different response_api_agents, and the implementations and instructions can be found at https://github.com/NVIDIA-NeMo/Gym/tree/main/responses_api_agents/tau2_agent.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from pydantic import BaseModel
16+
17+
from fastapi import FastAPI
18+
19+
from nemo_gym.base_resources_server import (
20+
SimpleResourcesServer,
21+
BaseResourcesServerConfig,
22+
BaseVerifyRequest,
23+
BaseVerifyResponse,
24+
)
25+
26+
27+
class Tau2BenchResourcesServerConfig(BaseResourcesServerConfig):
28+
pass
29+
30+
31+
class Tau2BenchResourcesServer(SimpleResourcesServer):
32+
config: Tau2BenchResourcesServerConfig
33+
34+
def setup_webserver(self) -> FastAPI:
35+
app = super().setup_webserver()
36+
37+
# Additional server routes go here! e.g.:
38+
# app.post("/get_weather")(self.get_weather)
39+
40+
return app
41+
42+
async def verify(self, body: BaseVerifyRequest) -> BaseVerifyResponse:
43+
return BaseVerifyResponse(**body.model_dump(), reward=1.0)
44+
45+
46+
if __name__ == "__main__":
47+
Tau2BenchResourcesServer.run_webserver()
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
tau2_bench_resources_server:
2+
resources_servers:
3+
tau2_bench:
4+
entrypoint: app.py
5+
domain: agent
6+
tau2_agent:
7+
responses_api_agents:
8+
tau2_agent:
9+
entrypoint: app.py
10+
resources_server:
11+
type: resources_servers
12+
name: tau2_bench_resources_server
13+
model_server:
14+
type: responses_api_models
15+
name: policy_model
16+
# user_model_server:
17+
# type: responses_api_models
18+
# name: policy_model
19+
concurrency: 16
20+
tau2_domain: airline
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
*train.jsonl
2+
*validation.jsonl
3+
*train_prepare.jsonl
4+
*validation_prepare.jsonl
5+
*example_prepare.jsonl
6+
tau2/
7+
simulations/
8+
*rollouts*.jsonl
9+
*rollouts*.json
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{"id": 0, "task_domain": "retail", "task_id": 0, "responses_create_params": {"input": []}, "agent_ref": {"type": "responses_api_agents", "name": "tau2_agent"}}
2+
{"id": 1, "task_domain": "retail", "task_id": 1, "responses_create_params": {"input": []}, "agent_ref": {"type": "responses_api_agents", "name": "tau2_agent"}}
3+
{"id": 2, "task_domain": "retail", "task_id": 2, "responses_create_params": {"input": []}, "agent_ref": {"type": "responses_api_agents", "name": "tau2_agent"}}
4+
{"id": 3, "task_domain": "retail", "task_id": 3, "responses_create_params": {"input": []}, "agent_ref": {"type": "responses_api_agents", "name": "tau2_agent"}}
5+
{"id": 4, "task_domain": "retail", "task_id": 4, "responses_create_params": {"input": []}, "agent_ref": {"type": "responses_api_agents", "name": "tau2_agent"}}
6+
{"id": 5, "task_domain": "retail", "task_id": 5, "responses_create_params": {"input": []}, "agent_ref": {"type": "responses_api_agents", "name": "tau2_agent"}}
7+
{"id": 6, "task_domain": "retail", "task_id": 6, "responses_create_params": {"input": []}, "agent_ref": {"type": "responses_api_agents", "name": "tau2_agent"}}
8+
{"id": 7, "task_domain": "retail", "task_id": 7, "responses_create_params": {"input": []}, "agent_ref": {"type": "responses_api_agents", "name": "tau2_agent"}}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
-e nemo-gym[dev] @ ../../
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from unittest.mock import MagicMock
16+
17+
from nemo_gym.server_utils import ServerClient
18+
from resources_servers.tau2_bench.app import (
19+
Tau2BenchResourcesServer,
20+
Tau2BenchResourcesServerConfig,
21+
)
22+
23+
24+
class TestApp:
25+
def test_sanity(self) -> None:
26+
config = Tau2BenchResourcesServerConfig(
27+
name="tau2_bench_agent",
28+
host="0.0.0.0",
29+
port=8080,
30+
entrypoint="",
31+
)
32+
Tau2BenchResourcesServer(config=config, server_client=MagicMock(spec=ServerClient))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
results/
2+
data/
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
Tau2 agent — how to run experiments
2+
=================================
3+
4+
This document shows the minimal steps to run tau2 experiments locally.
5+
6+
*Steps*
7+
1) Configure your API Key
8+
```bash
9+
echo "policy_base_url: https://api.openai.com/v1
10+
policy_api_key: your-openai-api-key
11+
policy_model_name: gpt-4.1-2025-04-14" > env.yaml
12+
```
13+
14+
2) Setup Tau^2 data
15+
16+
- Download the `tau2` folder (https://github.com/sierra-research/tau2-bench/tree/main/data/tau2).
17+
- Save it to `resources_servers/tau2_bench/data/`.
18+
- Configure data path (*don't forget* to modify the path accordingly):
19+
```bash
20+
export TAU2_DATA_DIR="/your_path/to/resources_servers/tau2_bench/data/"
21+
```
22+
23+
3) Launch the NemoGym server
24+
- In the *first terminal*, launch the server.
25+
26+
Example server for `openai_model`:
27+
```bash
28+
config_paths="responses_api_agents/tau2_agent/configs/tau2_agent.yaml,\
29+
responses_api_models/openai_model/configs/openai_model.yaml,\
30+
resources_servers/tau2_bench/configs/tau2_bench.yaml"
31+
32+
ng_run "+config_paths=[$config_paths]" \
33+
+tau2_agent.responses_api_agents.tau2_agent.resources_server.name=tau2_bench_resources_server
34+
```
35+
36+
Example server for `vllm_model`:
37+
```bash
38+
config_paths="responses_api_agents/tau2_agent/configs/tau2_agent.yaml,\
39+
responses_api_models/vllm_model/configs/vllm_model.yaml,\
40+
resources_servers/tau2_bench/configs/tau2_bench.yaml"
41+
42+
ng_run "+config_paths=[$config_paths]" \
43+
+tau2_agent.responses_api_agents.tau2_agent.resources_server.name=tau2_bench_resources_server \
44+
+policy_model.responses_api_models.vllm_model.return_token_id_information=true
45+
```
46+
47+
4) Prepare experiment input
48+
- Prepare an input JSONL file describing which domain/task(s) to run. Set the path in the `input_jsonl_fpath`. An example is in `resources_servers/tau2_bench/data/example_retail_demo.jsonl`
49+
50+
5) Collect rollouts from Tau^2 Bench (separate terminal)
51+
- In the *second (separate) terminal*, launch the rollout script to kick off the experiment:
52+
53+
```bash
54+
ng_collect_rollouts +agent_name=tau2_agent \
55+
+input_jsonl_fpath=resources_servers/tau2_bench/data/example_retail_demo.jsonl \
56+
+output_jsonl_fpath=resources_servers/tau2_bench/data/example_retail_demo_rollouts.jsonl \
57+
+limit=1 \
58+
+num_repeats=1 \
59+
+num_samples_in_parallel=null
60+
```

responses_api_agents/tau2_agent/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)