Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions docs/reference/cli-commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,49 @@ NeMo Gym Server Status:

---

### `ng_stop` / `nemo_gym_stop`

Provides a clean way to stop servers without having to manually kill processes, or use Ctrl+C on multiple terminals.

**Examples**
```bash
# Stop all servers
ng_stop +all=true

# Stop specific server by name
ng_stop +name=example_single_tool_call

# Stop server on port 8001
ng_stop +port=8001

# Force stop all servers
ng_stop +all=true +force=true
```

**Parameters**

```{list-table}
:header-rows: 1
:widths: 25 10 65

* - Parameter
- Type
- Description
* - `all`
- bool
- Stop all servers.
* - `name`
- str
- Stop specific server by name.
* - `port`
- int
- Stop specific server by port.
* - `force`
- bool
- Force stop the specified server(s).
```


## Getting Help

For detailed help on any command, run it with `+help=true` or `+h=true`:
Expand Down
117 changes: 98 additions & 19 deletions nemo_gym/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
GlobalConfigDictParserConfig,
get_global_config_dict,
)
from nemo_gym.server_status import StatusCommand
from nemo_gym.server_commands import StatusCommand, StopCommand
from nemo_gym.server_utils import (
HEAD_SERVER_KEY_NAME,
HeadServer,
Expand Down Expand Up @@ -129,6 +129,13 @@ class RunConfig(BaseNeMoGymCLIConfig):
)


class StopConfig(BaseNeMoGymCLIConfig):
all: bool = Field(default=False, description="Stop all running servers")
name: Optional[str] = Field(default=None, description="Stop server by name")
port: Optional[int] = Field(default=None, description="Stop server on specific port")
force: bool = Field(default=False, description="Force stop unresponsive servers")


class TestConfig(RunConfig):
"""
Test a specific server module by running its pytest suite and optionally validating example data.
Expand Down Expand Up @@ -287,24 +294,44 @@ def poll(self) -> None:
if not self._head_server_thread.is_alive():
raise RuntimeError("Head server finished unexpectedly!")

# Clean up processes that have stopped
processes_to_delete = []
for process_name, process in self._processes.items():
if process.poll() is not None:
proc_out, proc_err = process.communicate()
print_str = f"Process `{process_name}` finished unexpectedly!"

if isinstance(proc_out, bytes):
proc_out = proc_out.decode("utf-8")
print_str = f"""{print_str}
Process `{process_name}` stdout:
{proc_out}
"""
if isinstance(proc_err, bytes):
proc_err = proc_err.decode("utf-8")
print_str = f"""{print_str}
Process `{process_name}` stderr:
{proc_err}"""
poll_result = process.poll()

if poll_result is not None:
# Assume the process exited
exit_code = poll_result

try:
proc_out, proc_err = process.communicate()
except:
proc_out, proc_err = None, None

if exit_code <= 0:
processes_to_delete.append(process_name)
else:
print_str = f"Process `{process_name}` finished unexpectedly!"

if isinstance(proc_out, bytes):
proc_out = proc_out.decode("utf-8")
print_str = f"""{print_str}
Process `{process_name}` stdout:
{proc_out}
"""
if isinstance(proc_err, bytes):
proc_err = proc_err.decode("utf-8")
print_str = f"""{print_str}
Process `{process_name}` stderr:
{proc_err}"""

raise RuntimeError(print_str)

for process_name in processes_to_delete:
del self._processes[process_name]

raise RuntimeError(print_str)
if not self._processes:
raise KeyboardInterrupt()

def wait_for_spinup(self) -> None:
sleep_interval = 3
Expand Down Expand Up @@ -358,10 +385,29 @@ def shutdown(self) -> None:

def run_forever(self) -> None:
async def sleep():
poll_interval = 60
sleep_interval = 1
secs_since_last_poll = 0

# Indefinitely
while True:
self.poll()
await asyncio.sleep(60)
if secs_since_last_poll >= poll_interval:
self.poll()
secs_since_last_poll = 0

alive_count = 0
for proc in self._processes.values():
if proc.poll() is None: # still running
alive_count += 1

if self._processes and alive_count == 0:
print(f"\n{'#' * 100}")
print("All servers stopped. Shutting down head server...")
print(f"{'#' * 100}\n")
return

await asyncio.sleep(sleep_interval)
secs_since_last_poll += sleep_interval

try:
asyncio.run(sleep())
Expand Down Expand Up @@ -947,3 +993,36 @@ def version(): # pragma: no cover
Memory: {sys_info["memory_gb"]} GB"""

print(output)


def stop(): # pragma: no cover
global_config_dict = get_global_config_dict()
config = StopConfig.model_validate(global_config_dict)

stop_cmd = StopCommand()

# Validation to prevent multiple options from being set
options_set = sum([config.all, config.name is not None, config.port is not None])

if options_set == 0:
print("Error: Must specify one of: '+all=<bool>', '+name=<name>', or '+port=<port>'")
print("\nUsage:")
print(" ng_stop +all=true # Stop all servers")
print(" ng_stop +name=example_single_tool_call # Stop specific server")
print(" ng_stop +port=8001 # Stop server on port 8001")
print(" ng_stop +all=true +force=true # Force stop all servers")
exit(1)

if options_set > 1:
print("Error: Can only specify one of: '+all=<bool>', '+name=<name>', or '+port=<port>'")
exit(1)

if config.all:
results = stop_cmd.stop_all(config.force)
elif config.name:
results = stop_cmd.stop_by_name(config.name, config.force)
elif config.port:
results = stop_cmd.stop_by_port(config.port, config.force)

stop_cmd.display_results(results)
exit()
Loading