Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ share/
shell/
ssl/
lib64

# VS Code
.vscode/
119 changes: 87 additions & 32 deletions modules/coact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
"""

from loguru import logger
from typing import Any, Generator, Iterator, List, Optional, TYPE_CHECKING
from typing import Any, Iterator, Optional, Sequence, TypedDict
from functools import wraps
from string import Template
import re
import math
import sys
import os

import click
import json
Expand All @@ -29,9 +28,6 @@
from .base import GraphQlMixin, common_options, graphql_options, configure_logging_from_verbose
from .utils.graphql import GraphQlClient

if TYPE_CHECKING:
from typing import IO

# get local timezone
_now = pdl.now()

Expand All @@ -41,6 +37,24 @@
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])


class OveragePoint(TypedDict):
facility: str
cluster: str
qos: str
window_mins: int
percentages: Sequence[float]
percent_used: float
held: bool | None
over: bool
change: bool
purchased_nodes: int | None

class FacilityNodeUsage(TypedDict):
facility: str
cluster: str
nodes: int


def parse_datetime(value: Any, timezone=_now.timezone, force_tz: bool = False):
"""Parse various datetime formats into pendulum DateTime objects."""
dt = None
Expand Down Expand Up @@ -817,7 +831,20 @@ def slurm_recalculate(ctx, date, verbose, username, password_file):
@click.option('--influxdb-password', default=None, help='InfluxDB password')
@click.option('--influxdb-database', default='coact', help='InfluxDB database name (default: coact)')
@click.pass_context
def overage(ctx, date, verbose, username, password_file, windows, threshold, dry_run, influxdb_url, influxdb_username, influxdb_password, influxdb_database):
def overage(
ctx,
date: str,
verbose: int,
username: str,
password_file: str,
windows: Sequence[int],
threshold: float,
dry_run: bool,
influxdb_url: str,
influxdb_username: str | None,
influxdb_password: str | None,
influxdb_database: str
):
"""Recalculate the usage numbers from slurm jobs in Coact."""
configure_logging_from_verbose(verbose)
ctx.obj['verbose'] = verbose
Expand All @@ -837,7 +864,7 @@ def overage(ctx, date, verbose, username, password_file, windows, threshold, dry
data.append(point)
# Toggle job blocking only if held state needs to change
if point['held'] is not None and point['change']:
toggle_job_blocking(execute=not dry_run, **point)
toggle_job_blocking(execute=not dry_run, point=point)

# Bulk send all points to InfluxDB using raw requests
if influxdb_url is not None and len(data) > 0:
Expand Down Expand Up @@ -869,18 +896,42 @@ def overage(ctx, date, verbose, username, password_file, windows, threshold, dry
logger.error(f"Failed to send data to InfluxDB: {e}")


def toggle_job_blocking(execute: bool = False, **xargs) -> bool:
def toggle_job_blocking(point: OveragePoint, execute: bool = False) -> bool:
"""Enable/disable job blocking for overaged allocations."""
template = Template("sacctmgr modify -i account name=$facility:_regular_@$cluster set GrpTRES=node=$nodes")
xargs["nodes"] = 0 if xargs["over"] else -1
logger.trace(f"{xargs['facility']} job holding must be toggled... execute={execute}")
cmd = template.safe_substitute(**xargs)

# Determine node count based on blocking state
if point['over']:
# Blocking: set to 0
nodes = 0
else:
# Unblocking: use purchased nodes or fallback to unlimited
nodes = point.get('purchased_nodes', -1)
if nodes is None:
nodes = -1
logger.warning(f"No purchased node count available for {point['facility']}@{point['cluster']}, using unlimited")
elif nodes > 0:
logger.info(f"Restoring {nodes} nodes for {point['facility']}@{point['cluster']}")
else:
logger.warning(f"Invalid node count {nodes} for {point['facility']}@{point['cluster']}, using unlimited")
nodes = -1

facility_usage = FacilityNodeUsage(
facility=point['facility'],
cluster=point['cluster'],
nodes=nodes
)

logger.info(f"Job blocking toggle for {facility_usage['facility']}@{facility_usage['cluster']}: nodes={nodes} (over={point['over']}, execute={execute})")
cmd = template.safe_substitute(**facility_usage)
logger.info(f"Command: {cmd}")

if execute:
try:
for l in subprocess.check_output(cmd.split()).split(b"\n"):
logger.trace(f"{l}")
result = subprocess.check_output(cmd.split())
for line in result.split(b"\n"):
if line.strip():
logger.debug(f"sacctmgr output: {line.decode().strip()}")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to toggle job blocking: {e}")
return False
Expand All @@ -898,7 +949,7 @@ def __init__(self, username: str, password_file: str, windows: list, threshold:
self.threshold = threshold
self.dry_run = dry_run

def get(self, date: str) -> Iterator[dict]:
def get(self, date: str) -> Iterator[OveragePoint]:
"""Run the overage calculation process."""
self.back_channel = self.connect_graph_ql(
username=self.username,
Expand All @@ -913,7 +964,7 @@ def get(self, date: str) -> Iterator[dict]:
def get_data(self) -> dict:
"""Fetch usage data from GraphQL."""
per_window_template = Template(
"""_$key: facilityRecentComputeUsage(pastMinutes:$minutes) { cluster: clustername, facility, percentUsed }"""
"""_$key: facilityRecentComputeUsage(pastMinutes:$minutes) { cluster: clustername, facility, percentUsed, purchasedNodes }"""
)
logger.trace(f"Fetching windows {self.windows}")
all_windows = []
Expand All @@ -940,16 +991,19 @@ def format_data(self, result: dict) -> dict:
current[f] = {}
for item in k["allocs"]:
c = item["cluster"].lower()
current[f][c] = {"held": None, "percentUsed": []}
current[f][c] = {"held": None, "percentUsed": [], "purchasedNodes": None}
del result["repos"]

for time, array in result.items():
logger.trace(f"Looking at time {time} with {array}")
for a in array:
f = a["facility"].lower()
c = a["cluster"].lower()
logger.trace(f"Setting {f} {c} to {a['percentUsed']}")
logger.trace(f"Setting {f} {c} to {a['percentUsed']} (nodes: {a.get('purchasedNodes')})")
current[f][c]["percentUsed"].append(int(a["percentUsed"]))
# Store purchased nodes (use the value from any time window since it's constant)
if a.get("purchasedNodes") is not None and current[f][c]["purchasedNodes"] is None:
current[f][c]["purchasedNodes"] = a["purchasedNodes"]

logger.trace(f"Overages: {current}")

Expand Down Expand Up @@ -980,14 +1034,15 @@ def format_data(self, result: dict) -> dict:

return current

def overaged(self, data: dict, threshold: float = 100.0) -> Iterator[dict]:
def overaged(self, data: dict, threshold: float = 100.0) -> Iterator[OveragePoint]:
"""Check which allocations are over threshold and yield point objects."""
logger.trace(f"Determining overages with threshold {threshold}%...")
for fac, d in data.items():
logger.trace(f"Looping facility {fac}...")
for clust, m in d.items():
percentages = m["percentUsed"]
logger.trace(f"Sublooping {clust}, {percentages}")
purchased_nodes = m.get("purchasedNodes")
logger.trace(f"Sublooping {clust}, {percentages}, purchased_nodes: {purchased_nodes}")
over = False
for p in percentages:
if p >= threshold:
Expand All @@ -998,23 +1053,23 @@ def overaged(self, data: dict, threshold: float = 100.0) -> Iterator[dict]:
if m["held"] is None:
change = False
if len(percentages) > 0:
logger.info(f"{fac:16} {clust:12} qos=regular held={m['held'] if m['held'] is not None else '-':1} over={over:1} change={change:1} {values}")
logger.info(f"{fac:16} {clust:12} qos=regular held={m['held'] if m['held'] is not None else '-':1} over={over:1} change={change:1} nodes={purchased_nodes or 'N/A':>5} {values}")

# Yield a point for each window
for idx, pct in enumerate(percentages):
window_duration = self.windows[idx] if idx < len(self.windows) else idx
yield {
"facility": fac.lower(),
"cluster": clust.lower(),
"qos": "regular",
"window_mins": window_duration,
"percentages": percentages,
"percent_used": pct,
"held": bool(m["held"]) if m["held"] is not None else None,
"over": bool(over),
"change": bool(change),
}

yield OveragePoint(
facility=fac.lower(),
cluster=clust.lower(),
qos="regular",
window_mins=window_duration,
percentages=percentages,
percent_used=pct,
held=bool(m["held"]) if m["held"] is not None else None,
over=bool(over),
change=bool(change),
purchased_nodes=purchased_nodes
)



Expand Down