diff --git a/beeflow/common/cwl/cwl.py b/beeflow/common/cwl/cwl.py index b504e3a41..e6f5f84c9 100644 --- a/beeflow/common/cwl/cwl.py +++ b/beeflow/common/cwl/cwl.py @@ -382,6 +382,7 @@ class SlurmRequirement: qos: str = None reservation: str = None load_from_file: str = None + sbatch: str = None def dump(self): """Dump MPI requirement to dictionary.""" @@ -400,6 +401,8 @@ def dump(self): sched_dump['beeflow:SlurmRequirement']['signal'] = self.signal if self.load_from_file: sched_dump['beeflow:SlurmRequirement']['load_from_file'] = self.load_from_file + if self.sbatch: + sched_dump['beeflow:SlurmRequirement']['sbatch'] = self.sbatch return sched_dump def __repr__(self): diff --git a/beeflow/common/cwl/examples/comd_sbatch.py b/beeflow/common/cwl/examples/comd_sbatch.py new file mode 100644 index 000000000..f37963bde --- /dev/null +++ b/beeflow/common/cwl/examples/comd_sbatch.py @@ -0,0 +1,39 @@ +"""COMD driver for CWL generator.""" +import pathlib +from beeflow.common.cwl.workflow import (Task, Input, Output, MPI, Charliecloud, + Workflow, Slurm, Script) + + +def main(): + """Recreate the COMD workflow.""" + # Specifies the comd task + comd_task = Task(name="comd", + base_command="/CoMD/bin/CoMD-mpi -e", + stdout="comd.txt", + stderr="comd.err", + # list of Input objects + # The 2s and 40s are the actual value we want these to be + # this is how one sets input parameters. Prefix is just the + inputs=[Input("i", "int", 2, prefix="-i"), + Input("j", "int", 2, prefix="-j"), + Input("k", "int", 2, prefix="-k"), + Input("x", "int", 40, prefix="-x"), + Input("y", "int", 40, prefix="-y"), + Input("z", "int", 40, prefix="-z"), + Input("pot_dir", "string", "/CoMD/pots", prefix="--potDir")], + # List of Output objects. + # In this case we just have a file that represents stdout. + # The important part here is the source field that states + # this output comes from this task + outputs=[Output("comd_stdout", "File", source="comd/comd_stdout")], + hints=[ + Script(pre_script="comd_pre.sh"), + # Pass an sbatch script + Slurm(sbatch="run.sh"), + ]) + workflow = Workflow("comd", [comd_task]) + workflow.dump_wf("comd") + workflow.dump_yaml("comd") + +if __name__ == "__main__": + main() diff --git a/beeflow/common/cwl/workflow.py b/beeflow/common/cwl/workflow.py index a18fc3789..5f99ec9b2 100644 --- a/beeflow/common/cwl/workflow.py +++ b/beeflow/common/cwl/workflow.py @@ -91,7 +91,7 @@ def requirement(self): return SlurmRequirement(time_limit=self.time_limit, account=self.account, partition=self.partition, qos=self.qos, reservation=self.reservation, signal=self.signal, - load_from_file=self.load_from_file) + load_from_file=self.load_from_file, sbatch=self.sbatch) @dataclass diff --git a/beeflow/common/parser/parser.py b/beeflow/common/parser/parser.py index e0f993928..75d809153 100644 --- a/beeflow/common/parser/parser.py +++ b/beeflow/common/parser/parser.py @@ -446,6 +446,9 @@ def parse_requirements(self, requirements, as_hints=False): # Load in the dockerfile at parse time if "dockerFile" in items: self._read_requirement_file("dockerFile", items) + # Load the sbatch if added at parse time. This might not work well? + if "sbatch" in items: + self._read_requirement_file("sbatch", items) # Load in pre/post scripts and make sure shell option is defined in cwl file if "pre_script" in items and items["enabled"]: if "shell" in items: diff --git a/beeflow/common/worker/slurm_worker.py b/beeflow/common/worker/slurm_worker.py index 168c04816..784baec5a 100644 --- a/beeflow/common/worker/slurm_worker.py +++ b/beeflow/common/worker/slurm_worker.py @@ -11,6 +11,7 @@ import getpass import requests_unixsocket import requests +import string from beeflow.common import log as bee_logging import beeflow.common.worker.utils as worker_utils @@ -172,6 +173,16 @@ def build_text(self, task): # Get task requirements requirements = self.get_task_requirements(task) + sbatch_script = task.get_requirement('beeflow:SlurmRequirement', 'sbatch') + + # If we have an sbatch script defined just return that + if sbatch_script: + stdout_path, stderr_path = self.resolve_stdout_stderr(task) + sbatch_template = string.Template(sbatch_script) + changes = {"output":stdout_path, "error":stderr_path} + sbatch_script = sbatch_template.safe_substitute(changes) + #log.info("SBATCH") + return sbatch_script pre_script, post_script = None, None if requirements['scripts_enabled']: # We use StringIO here to properly break the script up into lines with readlines diff --git a/beeflow/common/worker/worker.py b/beeflow/common/worker/worker.py index 40888d9ea..d23455165 100644 --- a/beeflow/common/worker/worker.py +++ b/beeflow/common/worker/worker.py @@ -75,6 +75,8 @@ def task_save_path(self, task): def write_script(self, task): """Build task script; returns filename of script.""" + # If the user has provided an sbatch script this will just return + # the contents of that script as a string task_text = self.build_text(task) task_archive_dir = self.task_save_path(task) os.makedirs(task_archive_dir,exist_ok=True) diff --git a/coverage.svg b/coverage.svg index 5b4170033..9e299d059 100644 --- a/coverage.svg +++ b/coverage.svg @@ -1 +1 @@ -coverage: 71.51%coverage71.51% \ No newline at end of file +coverage: 71.47%coverage71.47% \ No newline at end of file