diff --git a/beeflow/common/cwl/cwl.py b/beeflow/common/cwl/cwl.py
index b504e3a41..e6f5f84c9 100644
--- a/beeflow/common/cwl/cwl.py
+++ b/beeflow/common/cwl/cwl.py
@@ -382,6 +382,7 @@ class SlurmRequirement:
qos: str = None
reservation: str = None
load_from_file: str = None
+ sbatch: str = None
def dump(self):
"""Dump MPI requirement to dictionary."""
@@ -400,6 +401,8 @@ def dump(self):
sched_dump['beeflow:SlurmRequirement']['signal'] = self.signal
if self.load_from_file:
sched_dump['beeflow:SlurmRequirement']['load_from_file'] = self.load_from_file
+ if self.sbatch:
+ sched_dump['beeflow:SlurmRequirement']['sbatch'] = self.sbatch
return sched_dump
def __repr__(self):
diff --git a/beeflow/common/cwl/examples/comd_sbatch.py b/beeflow/common/cwl/examples/comd_sbatch.py
new file mode 100644
index 000000000..f37963bde
--- /dev/null
+++ b/beeflow/common/cwl/examples/comd_sbatch.py
@@ -0,0 +1,39 @@
+"""COMD driver for CWL generator."""
+import pathlib
+from beeflow.common.cwl.workflow import (Task, Input, Output, MPI, Charliecloud,
+ Workflow, Slurm, Script)
+
+
+def main():
+ """Recreate the COMD workflow."""
+ # Specifies the comd task
+ comd_task = Task(name="comd",
+ base_command="/CoMD/bin/CoMD-mpi -e",
+ stdout="comd.txt",
+ stderr="comd.err",
+ # list of Input objects
+ # The 2s and 40s are the actual value we want these to be
+ # this is how one sets input parameters. Prefix is just the
+ inputs=[Input("i", "int", 2, prefix="-i"),
+ Input("j", "int", 2, prefix="-j"),
+ Input("k", "int", 2, prefix="-k"),
+ Input("x", "int", 40, prefix="-x"),
+ Input("y", "int", 40, prefix="-y"),
+ Input("z", "int", 40, prefix="-z"),
+ Input("pot_dir", "string", "/CoMD/pots", prefix="--potDir")],
+ # List of Output objects.
+ # In this case we just have a file that represents stdout.
+ # The important part here is the source field that states
+ # this output comes from this task
+ outputs=[Output("comd_stdout", "File", source="comd/comd_stdout")],
+ hints=[
+ Script(pre_script="comd_pre.sh"),
+ # Pass an sbatch script
+ Slurm(sbatch="run.sh"),
+ ])
+ workflow = Workflow("comd", [comd_task])
+ workflow.dump_wf("comd")
+ workflow.dump_yaml("comd")
+
+if __name__ == "__main__":
+ main()
diff --git a/beeflow/common/cwl/workflow.py b/beeflow/common/cwl/workflow.py
index a18fc3789..5f99ec9b2 100644
--- a/beeflow/common/cwl/workflow.py
+++ b/beeflow/common/cwl/workflow.py
@@ -91,7 +91,7 @@ def requirement(self):
return SlurmRequirement(time_limit=self.time_limit, account=self.account,
partition=self.partition, qos=self.qos, reservation=self.reservation,
signal=self.signal,
- load_from_file=self.load_from_file)
+ load_from_file=self.load_from_file, sbatch=self.sbatch)
@dataclass
diff --git a/beeflow/common/parser/parser.py b/beeflow/common/parser/parser.py
index e0f993928..75d809153 100644
--- a/beeflow/common/parser/parser.py
+++ b/beeflow/common/parser/parser.py
@@ -446,6 +446,9 @@ def parse_requirements(self, requirements, as_hints=False):
# Load in the dockerfile at parse time
if "dockerFile" in items:
self._read_requirement_file("dockerFile", items)
+ # Load the sbatch if added at parse time. This might not work well?
+ if "sbatch" in items:
+ self._read_requirement_file("sbatch", items)
# Load in pre/post scripts and make sure shell option is defined in cwl file
if "pre_script" in items and items["enabled"]:
if "shell" in items:
diff --git a/beeflow/common/worker/slurm_worker.py b/beeflow/common/worker/slurm_worker.py
index 168c04816..784baec5a 100644
--- a/beeflow/common/worker/slurm_worker.py
+++ b/beeflow/common/worker/slurm_worker.py
@@ -11,6 +11,7 @@
import getpass
import requests_unixsocket
import requests
+import string
from beeflow.common import log as bee_logging
import beeflow.common.worker.utils as worker_utils
@@ -172,6 +173,16 @@ def build_text(self, task):
# Get task requirements
requirements = self.get_task_requirements(task)
+ sbatch_script = task.get_requirement('beeflow:SlurmRequirement', 'sbatch')
+
+ # If we have an sbatch script defined just return that
+ if sbatch_script:
+ stdout_path, stderr_path = self.resolve_stdout_stderr(task)
+ sbatch_template = string.Template(sbatch_script)
+ changes = {"output":stdout_path, "error":stderr_path}
+ sbatch_script = sbatch_template.safe_substitute(changes)
+ #log.info("SBATCH")
+ return sbatch_script
pre_script, post_script = None, None
if requirements['scripts_enabled']:
# We use StringIO here to properly break the script up into lines with readlines
diff --git a/beeflow/common/worker/worker.py b/beeflow/common/worker/worker.py
index 40888d9ea..d23455165 100644
--- a/beeflow/common/worker/worker.py
+++ b/beeflow/common/worker/worker.py
@@ -75,6 +75,8 @@ def task_save_path(self, task):
def write_script(self, task):
"""Build task script; returns filename of script."""
+ # If the user has provided an sbatch script this will just return
+ # the contents of that script as a string
task_text = self.build_text(task)
task_archive_dir = self.task_save_path(task)
os.makedirs(task_archive_dir,exist_ok=True)
diff --git a/coverage.svg b/coverage.svg
index 5b4170033..9e299d059 100644
--- a/coverage.svg
+++ b/coverage.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file