Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions beeflow/common/cwl/cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ class SlurmRequirement:
qos: str = None
reservation: str = None
load_from_file: str = None
sbatch: str = None

def dump(self):
"""Dump MPI requirement to dictionary."""
Expand All @@ -400,6 +401,8 @@ def dump(self):
sched_dump['beeflow:SlurmRequirement']['signal'] = self.signal
if self.load_from_file:
sched_dump['beeflow:SlurmRequirement']['load_from_file'] = self.load_from_file
if self.sbatch:
sched_dump['beeflow:SlurmRequirement']['sbatch'] = self.sbatch
return sched_dump

def __repr__(self):
Expand Down
39 changes: 39 additions & 0 deletions beeflow/common/cwl/examples/comd_sbatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""COMD driver for CWL generator."""
import pathlib
from beeflow.common.cwl.workflow import (Task, Input, Output, MPI, Charliecloud,
Workflow, Slurm, Script)


def main():
"""Recreate the COMD workflow."""
# Specifies the comd task
comd_task = Task(name="comd",
base_command="/CoMD/bin/CoMD-mpi -e",
stdout="comd.txt",
stderr="comd.err",
# list of Input objects
# The 2s and 40s are the actual value we want these to be
# this is how one sets input parameters. Prefix is just the
inputs=[Input("i", "int", 2, prefix="-i"),
Input("j", "int", 2, prefix="-j"),
Input("k", "int", 2, prefix="-k"),
Input("x", "int", 40, prefix="-x"),
Input("y", "int", 40, prefix="-y"),
Input("z", "int", 40, prefix="-z"),
Input("pot_dir", "string", "/CoMD/pots", prefix="--potDir")],
# List of Output objects.
# In this case we just have a file that represents stdout.
# The important part here is the source field that states
# this output comes from this task
outputs=[Output("comd_stdout", "File", source="comd/comd_stdout")],
hints=[
Script(pre_script="comd_pre.sh"),
# Pass an sbatch script
Slurm(sbatch="run.sh"),
])
workflow = Workflow("comd", [comd_task])
workflow.dump_wf("comd")
workflow.dump_yaml("comd")

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion beeflow/common/cwl/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def requirement(self):
return SlurmRequirement(time_limit=self.time_limit, account=self.account,
partition=self.partition, qos=self.qos, reservation=self.reservation,
signal=self.signal,
load_from_file=self.load_from_file)
load_from_file=self.load_from_file, sbatch=self.sbatch)


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions beeflow/common/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,9 @@ def parse_requirements(self, requirements, as_hints=False):
# Load in the dockerfile at parse time
if "dockerFile" in items:
self._read_requirement_file("dockerFile", items)
# Load the sbatch if added at parse time. This might not work well?
if "sbatch" in items:
self._read_requirement_file("sbatch", items)
# Load in pre/post scripts and make sure shell option is defined in cwl file
if "pre_script" in items and items["enabled"]:
if "shell" in items:
Expand Down
11 changes: 11 additions & 0 deletions beeflow/common/worker/slurm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import getpass
import requests_unixsocket
import requests
import string

from beeflow.common import log as bee_logging
import beeflow.common.worker.utils as worker_utils
Expand Down Expand Up @@ -172,6 +173,16 @@ def build_text(self, task):
# Get task requirements
requirements = self.get_task_requirements(task)

sbatch_script = task.get_requirement('beeflow:SlurmRequirement', 'sbatch')

# If we have an sbatch script defined just return that
if sbatch_script:
stdout_path, stderr_path = self.resolve_stdout_stderr(task)
sbatch_template = string.Template(sbatch_script)
changes = {"output":stdout_path, "error":stderr_path}
sbatch_script = sbatch_template.safe_substitute(changes)
#log.info("SBATCH")
return sbatch_script
pre_script, post_script = None, None
if requirements['scripts_enabled']:
# We use StringIO here to properly break the script up into lines with readlines
Expand Down
2 changes: 2 additions & 0 deletions beeflow/common/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def task_save_path(self, task):

def write_script(self, task):
"""Build task script; returns filename of script."""
# If the user has provided an sbatch script this will just return
# the contents of that script as a string
task_text = self.build_text(task)
task_archive_dir = self.task_save_path(task)
os.makedirs(task_archive_dir,exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.