Skip to content

Commit 59ebb1b

Browse files
committed
add rewoo
Signed-off-by: Christian Munley <cmunley@nvidia.com> Signed-off-by: cmunley1 <cmunley@nvidia.com>
1 parent 2995019 commit 59ebb1b

File tree

6 files changed

+415
-0
lines changed

6 files changed

+415
-0
lines changed

resources_servers/reasoning_gym/configs/reasoning_gym.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,22 @@ reasoning_gym_langgraph_agent:
4343
version: 1.0.0
4444
artifact_fpath: train_knights_knaves.jsonl
4545
license: Apache 2.0
46+
reasoning_gym_rewoo_agent:
47+
responses_api_agents:
48+
langgraph_agent:
49+
entrypoint: rewoo_agent.py
50+
resources_server:
51+
type: resources_servers
52+
name: reasoning_gym
53+
model_server:
54+
type: responses_api_models
55+
name: policy_model
56+
datasets:
57+
- name: train
58+
type: train
59+
jsonl_fpath: resources_servers/reasoning_gym/data/train_knights_knaves.jsonl
60+
gitlab_identifier:
61+
dataset_name: knights_knaves_reasoning_gym
62+
version: 1.0.0
63+
artifact_fpath: train_knights_knaves.jsonl
64+
license: Apache 2.0
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
rewoo_agent:
2+
responses_api_agents:
3+
langgraph_agent:
4+
entrypoint: rewoo_agent.py
5+
resources_server:
6+
type: resources_servers
7+
name: ???
8+
model_server:
9+
type: responses_api_models
10+
name: policy_model

responses_api_agents/langgraph_agent/reflection_agent.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
"""
16+
See: https://github.com/langchain-ai/langgraph/blob/23961cff61a42b52525f3b20b4094d8d2fba1744/docs/docs/tutorials/reflection/reflection.ipynb
17+
Reflection agent: generate, critique, revise loop.
18+
19+
Generates an initial answer, critiques it, then revises. Repeats until
20+
<answer> tag found or max_reflections reached.
21+
22+
Graph: generate -> should_continue? -> reflect -> generate (revised) -> ...
23+
"""
24+
1525
from typing import Annotated, TypedDict
1626

1727
from app import LangGraphAgentAdapter, LangGraphAgentConfig
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
See https://github.com/langchain-ai/langgraph/blob/23961cff61a42b52525f3b20b4094d8d2fba1744/docs/docs/tutorials/rewoo/rewoo.ipynb
17+
ReWOO (Reasoning Without Observation) agent.
18+
19+
Generates a full plan with variable substitution in a single LLM call,
20+
then executes steps sequentially, substituting prior results. Last,
21+
a solver synthesizes all results into a final answer.
22+
23+
Graph: plan -> worker -> (loop for each step) -> solve -> END
24+
"""
25+
26+
import re
27+
from typing import Annotated, List, TypedDict
28+
29+
from app import LangGraphAgentAdapter, LangGraphAgentConfig
30+
from fastapi import Request
31+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
32+
from langgraph.graph import END, StateGraph
33+
from langgraph.graph.message import add_messages
34+
from pydantic import ConfigDict
35+
36+
from nemo_gym.base_resources_server import BaseRunRequest, BaseVerifyRequest, BaseVerifyResponse
37+
from nemo_gym.openai_utils import NeMoGymEasyInputMessage, NeMoGymResponse, NeMoGymResponseCreateParamsNonStreaming
38+
from nemo_gym.server_utils import get_response_json, raise_for_status
39+
40+
41+
ROLE_MAP = {"human": "user", "ai": "assistant", "system": "system"}
42+
43+
PLAN_PROMPT = """For the following task, make plans that can solve the problem step by step. For each plan, indicate \
44+
which external tool together with tool input to retrieve evidence. You can store the evidence into a \
45+
variable #E that can be called by later tools. (Plan, #E1, Plan, #E2, Plan, ...)
46+
47+
Tools can be one of the following:
48+
(1) LLM[input]: A pretrained LLM. Useful when you need to act with general world knowledge, \
49+
reasoning, and common sense. Input can be any instruction.
50+
51+
For example,
52+
Task: Thomas, Toby, and Rebecca worked a total of 157 hours in one week. Thomas worked x \
53+
hours. Toby worked 10 hours less than twice what Thomas worked, and Rebecca worked 8 hours \
54+
less than Toby. How many hours did Rebecca work?
55+
Plan: Translate the problem into algebraic expressions and solve. #E1 = LLM[Solve x + (2x - 10) + ((2x - 10) - 8) = 157]
56+
Plan: Find out the number of hours Thomas worked. #E2 = LLM[What is x, given #E1]
57+
Plan: Calculate the number of hours Rebecca worked. #E3 = LLM[Calculate (2 * #E2 - 10) - 8]
58+
59+
Begin!
60+
Describe your plans with rich details. Each Plan should be followed by only one #E.
61+
62+
Task: {task}"""
63+
64+
SOLVE_PROMPT = """Solve the following task or problem. To solve the problem, we have made step-by-step Plan and \
65+
retrieved corresponding Evidence to each Plan. Use them with caution since long evidence might \
66+
contain irrelevant information.
67+
68+
{plan}
69+
70+
Now solve the question or task according to provided Evidence above. Respond with the answer \
71+
directly. Wrap your final answer in <answer></answer> tags.
72+
73+
Task: {task}
74+
Response:"""
75+
76+
# Regex to match: Plan: <reasoning> #E1 = Tool[argument]
77+
STEP_REGEX = r"Plan:\s*(.+)\s*(#E\d+)\s*=\s*(\w+)\s*\[([^\]]+)\]"
78+
79+
80+
class ReWOOAgentConfig(LangGraphAgentConfig):
81+
pass
82+
83+
84+
class ReWOORunRequest(BaseRunRequest):
85+
model_config = ConfigDict(extra="allow")
86+
87+
88+
class ReWOOVerifyRequest(BaseVerifyRequest):
89+
model_config = ConfigDict(extra="allow")
90+
91+
92+
class ReWOOVerifyResponse(BaseVerifyResponse):
93+
model_config = ConfigDict(extra="allow")
94+
95+
96+
class ReWOOState(TypedDict):
97+
messages: Annotated[list[BaseMessage], add_messages]
98+
nemo_outputs: list
99+
cookies: dict
100+
request_body: NeMoGymResponseCreateParamsNonStreaming
101+
last_model_response: NeMoGymResponse
102+
task: str
103+
plan_string: str
104+
steps: List
105+
results: dict
106+
current_step: int
107+
108+
109+
def _extract_text(outputs):
110+
return "".join(c.text for o in outputs if o.type == "message" for c in o.content if c.type == "output_text")
111+
112+
113+
class ReWOOAgent(LangGraphAgentAdapter):
114+
config: ReWOOAgentConfig
115+
116+
async def _call_model(self, state, prompt):
117+
input_messages = [NeMoGymEasyInputMessage(role="user", content=prompt)]
118+
request_body = state["request_body"].model_copy(update={"input": input_messages + state["nemo_outputs"]})
119+
resp = await self.server_client.post(
120+
server_name=self.config.model_server.name,
121+
url_path="/v1/responses",
122+
json=request_body,
123+
cookies=state["cookies"],
124+
)
125+
await raise_for_status(resp)
126+
return NeMoGymResponse.model_validate(await resp.json()), resp.cookies
127+
128+
def build_graph(self):
129+
graph = StateGraph(ReWOOState)
130+
131+
async def plan(state):
132+
task = state["task"]
133+
prompt = PLAN_PROMPT.format(task=task)
134+
prompt_msg = NeMoGymEasyInputMessage(role="user", content=prompt)
135+
136+
nemo_response, cookies = await self._call_model(state, prompt)
137+
text = _extract_text(nemo_response.output)
138+
139+
matches = re.findall(STEP_REGEX, text)
140+
141+
return {
142+
"messages": [HumanMessage(content=prompt), AIMessage(content=text)],
143+
"nemo_outputs": state["nemo_outputs"] + [prompt_msg] + nemo_response.output,
144+
"cookies": cookies,
145+
"last_model_response": nemo_response,
146+
"request_body": state["request_body"],
147+
"plan_string": text,
148+
"steps": matches,
149+
"results": {},
150+
"current_step": 0,
151+
}
152+
153+
async def worker(state):
154+
step_idx = state["current_step"]
155+
_, step_name, tool, tool_input = state["steps"][step_idx]
156+
157+
# Variable substitution: replace #E1, #E2, etc. with prior results
158+
for k, v in state["results"].items():
159+
tool_input = tool_input.replace(k, v)
160+
161+
prompt = tool_input
162+
prompt_msg = NeMoGymEasyInputMessage(role="user", content=f"Step {step_name}: {prompt}")
163+
164+
nemo_response, cookies = await self._call_model(state, prompt)
165+
text = _extract_text(nemo_response.output)
166+
167+
new_results = {**state["results"], step_name: text}
168+
169+
return {
170+
"messages": [
171+
HumanMessage(content=f"Step {step_name}: {prompt}"),
172+
AIMessage(content=text),
173+
],
174+
"nemo_outputs": state["nemo_outputs"] + [prompt_msg] + nemo_response.output,
175+
"cookies": cookies,
176+
"last_model_response": nemo_response,
177+
"request_body": state["request_body"],
178+
"results": new_results,
179+
"current_step": step_idx + 1,
180+
}
181+
182+
async def solve(state):
183+
# Build plan string with evidence substituted
184+
plan_with_evidence = ""
185+
for _plan, step_name, tool, tool_input in state["steps"]:
186+
for k, v in state["results"].items():
187+
tool_input = tool_input.replace(k, v)
188+
plan_with_evidence += f"Plan: {_plan}\n{step_name} = {tool}[{tool_input}]\nEvidence: {state['results'].get(step_name, 'N/A')}\n\n"
189+
190+
prompt = SOLVE_PROMPT.format(plan=plan_with_evidence, task=state["task"])
191+
prompt_msg = NeMoGymEasyInputMessage(role="user", content=prompt)
192+
193+
nemo_response, cookies = await self._call_model(state, prompt)
194+
text = _extract_text(nemo_response.output)
195+
196+
return {
197+
"messages": [HumanMessage(content=prompt), AIMessage(content=text)],
198+
"nemo_outputs": state["nemo_outputs"] + [prompt_msg] + nemo_response.output,
199+
"cookies": cookies,
200+
"last_model_response": nemo_response,
201+
"request_body": state["request_body"],
202+
}
203+
204+
def route_worker(state):
205+
if state["current_step"] >= len(state["steps"]):
206+
return "solve"
207+
return "worker"
208+
209+
graph.add_node("plan", plan)
210+
graph.add_node("worker", worker)
211+
graph.add_node("solve", solve)
212+
graph.set_entry_point("plan")
213+
graph.add_edge("plan", "worker")
214+
graph.add_conditional_edges("worker", route_worker, {"worker": "worker", "solve": "solve"})
215+
graph.add_edge("solve", END)
216+
217+
return graph.compile()
218+
219+
async def get_initial_state(self, body: NeMoGymResponseCreateParamsNonStreaming, cookies: dict) -> dict:
220+
# Extract task text from input
221+
if isinstance(body.input, str):
222+
task = body.input
223+
else:
224+
task = ""
225+
for msg in body.input:
226+
content = getattr(msg, "content", None) or (msg.get("content") if isinstance(msg, dict) else "")
227+
role = getattr(msg, "role", None) or (msg.get("role") if isinstance(msg, dict) else "user")
228+
if role in ["user", "human"] and isinstance(content, str):
229+
task = content
230+
231+
return {
232+
"messages": [HumanMessage(content=task)],
233+
"nemo_outputs": [],
234+
"cookies": cookies,
235+
"request_body": body,
236+
"last_model_response": None,
237+
"task": task,
238+
"plan_string": "",
239+
"steps": [],
240+
"results": {},
241+
"current_step": 0,
242+
}
243+
244+
def extract_outputs(self, final_state: dict) -> list:
245+
return final_state["nemo_outputs"]
246+
247+
async def run(self, request: Request, body: ReWOORunRequest) -> ReWOOVerifyResponse:
248+
cookies = request.cookies
249+
250+
seed = await self.server_client.post(
251+
server_name=self.config.resources_server.name,
252+
url_path="/seed_session",
253+
json=body.model_dump(),
254+
cookies=cookies,
255+
)
256+
await raise_for_status(seed)
257+
cookies = seed.cookies
258+
259+
resp = await self.server_client.post(
260+
server_name=self.config.name, url_path="/v1/responses", json=body.responses_create_params, cookies=cookies
261+
)
262+
await raise_for_status(resp)
263+
264+
verify_request = ReWOOVerifyRequest.model_validate(
265+
body.model_dump() | {"response": await get_response_json(resp)}
266+
)
267+
268+
verify = await self.server_client.post(
269+
server_name=self.config.resources_server.name,
270+
url_path="/verify",
271+
json=verify_request.model_dump(),
272+
cookies=resp.cookies,
273+
)
274+
await raise_for_status(verify)
275+
return ReWOOVerifyResponse.model_validate(await get_response_json(verify))
276+
277+
278+
if __name__ == "__main__":
279+
ReWOOAgent.run_webserver()

responses_api_agents/langgraph_agent/tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)