Skip to content

Commit 8912efa

Browse files
Christian Munleycmunley1
authored andcommitted
updates
Signed-off-by: Christian Munley <cmunley@> Signed-off-by: cmunley1 <cmunley@nvidia.com>
1 parent 59ebb1b commit 8912efa

File tree

2 files changed

+493
-0
lines changed

2 files changed

+493
-0
lines changed
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
Orchestrator agent: decompose > dispatch sub-agents > synthesize.
17+
18+
Asks the model to decompose a problem into sub-tasks, solves each
19+
sub-task with an independent LLM call, then synthesizes a final answer.
20+
21+
Graph: decompose -> dispatch (loop per subtask) -> synthesize -> END
22+
"""
23+
24+
import re
25+
from typing import Annotated, List, TypedDict
26+
27+
from app import LangGraphAgentAdapter, LangGraphAgentConfig
28+
from fastapi import Request
29+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
30+
from langgraph.graph import END, StateGraph
31+
from langgraph.graph.message import add_messages
32+
from pydantic import ConfigDict
33+
34+
from nemo_gym.base_resources_server import BaseRunRequest, BaseVerifyRequest, BaseVerifyResponse
35+
from nemo_gym.openai_utils import NeMoGymEasyInputMessage, NeMoGymResponse, NeMoGymResponseCreateParamsNonStreaming
36+
from nemo_gym.server_utils import get_response_json, raise_for_status
37+
38+
39+
DECOMPOSE_PROMPT = """Break the following problem into 2-4 independent sub-tasks that can each be solved separately. \
40+
For each sub-task, write it as a self-contained question that can be answered without context from the others.
41+
42+
Format your response exactly as:
43+
SUBTASK 1: <question>
44+
SUBTASK 2: <question>
45+
SUBTASK 3: <question>
46+
47+
If the problem is simple enough to solve directly, just write:
48+
SUBTASK 1: <the original problem>
49+
50+
Problem: {task}"""
51+
52+
SYNTHESIZE_PROMPT = """You decomposed a problem into sub-tasks and solved each one. \
53+
Now combine the sub-task results into a final answer to the original problem.
54+
55+
Original problem: {task}
56+
57+
{subtask_results}
58+
59+
Synthesize these results into a single final answer. Show your reasoning, then wrap your final answer \
60+
in <answer></answer> tags."""
61+
62+
SUBTASK_REGEX = r"SUBTASK\s+\d+:\s*(.+)"
63+
64+
65+
class OrchestratorAgentConfig(LangGraphAgentConfig):
66+
max_subtasks: int = 4
67+
68+
69+
class OrchestratorRunRequest(BaseRunRequest):
70+
model_config = ConfigDict(extra="allow")
71+
72+
73+
class OrchestratorVerifyRequest(BaseVerifyRequest):
74+
model_config = ConfigDict(extra="allow")
75+
76+
77+
class OrchestratorVerifyResponse(BaseVerifyResponse):
78+
model_config = ConfigDict(extra="allow")
79+
80+
81+
class OrchestratorState(TypedDict):
82+
messages: Annotated[list[BaseMessage], add_messages]
83+
nemo_outputs: list
84+
cookies: dict
85+
request_body: NeMoGymResponseCreateParamsNonStreaming
86+
last_model_response: NeMoGymResponse
87+
task: str
88+
subtasks: List[str]
89+
subtask_results: dict
90+
current_subtask: int
91+
92+
93+
def _extract_text(outputs):
94+
return "".join(c.text for o in outputs if o.type == "message" for c in o.content if c.type == "output_text")
95+
96+
97+
# TODO: Use LangGraph's Send() API for true parallel worker dispatch instead of
98+
# sequential loop. See langgraphs workflows.md "Orchestrator-Worker" pattern.
99+
class OrchestratorAgent(LangGraphAgentAdapter):
100+
config: OrchestratorAgentConfig
101+
102+
async def _call_model(self, state, prompt):
103+
input_messages = [NeMoGymEasyInputMessage(role="user", content=prompt)]
104+
request_body = state["request_body"].model_copy(update={"input": input_messages + state["nemo_outputs"]})
105+
resp = await self.server_client.post(
106+
server_name=self.config.model_server.name,
107+
url_path="/v1/responses",
108+
json=request_body,
109+
cookies=state["cookies"],
110+
)
111+
await raise_for_status(resp)
112+
return NeMoGymResponse.model_validate(await resp.json()), resp.cookies
113+
114+
def build_graph(self):
115+
graph = StateGraph(OrchestratorState)
116+
117+
async def decompose(state):
118+
task = state["task"]
119+
prompt = DECOMPOSE_PROMPT.format(task=task)
120+
prompt_msg = NeMoGymEasyInputMessage(role="user", content=prompt)
121+
122+
nemo_response, cookies = await self._call_model(state, prompt)
123+
text = _extract_text(nemo_response.output)
124+
125+
matches = re.findall(SUBTASK_REGEX, text)
126+
subtasks = [m.strip() for m in matches[: self.config.max_subtasks]]
127+
128+
# If no subtasks parsed, use the original task
129+
if not subtasks:
130+
subtasks = [task]
131+
132+
return {
133+
"messages": [HumanMessage(content=prompt), AIMessage(content=text)],
134+
"nemo_outputs": state["nemo_outputs"] + [prompt_msg] + nemo_response.output,
135+
"cookies": cookies,
136+
"last_model_response": nemo_response,
137+
"request_body": state["request_body"],
138+
"subtasks": subtasks,
139+
"subtask_results": {},
140+
"current_subtask": 0,
141+
}
142+
143+
async def dispatch(state):
144+
idx = state["current_subtask"]
145+
subtask = state["subtasks"][idx]
146+
prompt = f"Solve the following sub-task completely. Show your work.\n\nSub-task: {subtask}"
147+
prompt_msg = NeMoGymEasyInputMessage(role="user", content=prompt)
148+
149+
nemo_response, cookies = await self._call_model(state, prompt)
150+
text = _extract_text(nemo_response.output)
151+
152+
new_results = {**state["subtask_results"], f"subtask_{idx + 1}": text}
153+
154+
return {
155+
"messages": [HumanMessage(content=prompt), AIMessage(content=text)],
156+
"nemo_outputs": state["nemo_outputs"] + [prompt_msg] + nemo_response.output,
157+
"cookies": cookies,
158+
"last_model_response": nemo_response,
159+
"request_body": state["request_body"],
160+
"subtask_results": new_results,
161+
"current_subtask": idx + 1,
162+
}
163+
164+
async def synthesize(state):
165+
task = state["task"]
166+
results_text = "\n\n".join(
167+
f"--- Sub-task {i + 1}: {state['subtasks'][i]} ---\nResult: {state['subtask_results'].get(f'subtask_{i + 1}', 'N/A')}"
168+
for i in range(len(state["subtasks"]))
169+
)
170+
prompt = SYNTHESIZE_PROMPT.format(task=task, subtask_results=results_text)
171+
prompt_msg = NeMoGymEasyInputMessage(role="user", content=prompt)
172+
173+
nemo_response, cookies = await self._call_model(state, prompt)
174+
text = _extract_text(nemo_response.output)
175+
176+
return {
177+
"messages": [HumanMessage(content=prompt), AIMessage(content=text)],
178+
"nemo_outputs": state["nemo_outputs"] + [prompt_msg] + nemo_response.output,
179+
"cookies": cookies,
180+
"last_model_response": nemo_response,
181+
"request_body": state["request_body"],
182+
}
183+
184+
def route_dispatch(state):
185+
if state["current_subtask"] >= len(state["subtasks"]):
186+
return "synthesize"
187+
return "dispatch"
188+
189+
graph.add_node("decompose", decompose)
190+
graph.add_node("dispatch", dispatch)
191+
graph.add_node("synthesize", synthesize)
192+
graph.set_entry_point("decompose")
193+
graph.add_conditional_edges("decompose", route_dispatch, {"dispatch": "dispatch", "synthesize": "synthesize"})
194+
graph.add_conditional_edges("dispatch", route_dispatch, {"dispatch": "dispatch", "synthesize": "synthesize"})
195+
graph.add_edge("synthesize", END)
196+
197+
return graph.compile()
198+
199+
async def get_initial_state(self, body: NeMoGymResponseCreateParamsNonStreaming, cookies: dict) -> dict:
200+
if isinstance(body.input, str):
201+
task = body.input
202+
else:
203+
task = ""
204+
for msg in body.input:
205+
content = getattr(msg, "content", None) or (msg.get("content") if isinstance(msg, dict) else "")
206+
role = getattr(msg, "role", None) or (msg.get("role") if isinstance(msg, dict) else "user")
207+
if role in ["user", "human"] and isinstance(content, str):
208+
task = content
209+
210+
return {
211+
"messages": [HumanMessage(content=task)],
212+
"nemo_outputs": [],
213+
"cookies": cookies,
214+
"request_body": body,
215+
"last_model_response": None,
216+
"task": task,
217+
"subtasks": [],
218+
"subtask_results": {},
219+
"current_subtask": 0,
220+
}
221+
222+
def extract_outputs(self, final_state: dict) -> list:
223+
return final_state["nemo_outputs"]
224+
225+
async def run(self, request: Request, body: OrchestratorRunRequest) -> OrchestratorVerifyResponse:
226+
cookies = request.cookies
227+
228+
seed = await self.server_client.post(
229+
server_name=self.config.resources_server.name,
230+
url_path="/seed_session",
231+
json=body.model_dump(),
232+
cookies=cookies,
233+
)
234+
await raise_for_status(seed)
235+
cookies = seed.cookies
236+
237+
resp = await self.server_client.post(
238+
server_name=self.config.name, url_path="/v1/responses", json=body.responses_create_params, cookies=cookies
239+
)
240+
await raise_for_status(resp)
241+
242+
verify_request = OrchestratorVerifyRequest.model_validate(
243+
body.model_dump() | {"response": await get_response_json(resp)}
244+
)
245+
246+
verify = await self.server_client.post(
247+
server_name=self.config.resources_server.name,
248+
url_path="/verify",
249+
json=verify_request.model_dump(),
250+
cookies=resp.cookies,
251+
)
252+
await raise_for_status(verify)
253+
return OrchestratorVerifyResponse.model_validate(await get_response_json(verify))
254+
255+
256+
if __name__ == "__main__":
257+
OrchestratorAgent.run_webserver()

0 commit comments

Comments
 (0)