Skip to content

Commit 76c621f

Browse files
Enhance task decorator: Add support for pre-script and post-script in ProActive Scheduler (#62)
1 parent ba7b850 commit 76c621f

File tree

1 file changed

+43
-4
lines changed

1 file changed

+43
-4
lines changed

proactive/decorators.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ class TaskDecorator:
88
def __init__(self, language):
99
self.language = language
1010

11-
def __call__(self, name=None, depends_on=None, runtime_env=None, virtual_env=None, input_files=None, output_files=None):
11+
def __call__(self, name=None, depends_on=None, runtime_env=None, virtual_env=None, input_files=None, output_files=None, prescript=None, postscript=None):
1212
def decorator(func):
13-
return task(name=name, depends_on=depends_on, language=self.language, runtime_env=runtime_env, virtual_env=virtual_env, input_files=input_files, output_files=output_files)(func)
13+
return task(name=name, depends_on=depends_on, language=self.language, runtime_env=runtime_env, virtual_env=virtual_env, input_files=input_files, output_files=output_files, prescript=prescript, postscript=postscript)(func)
1414
return decorator
1515

1616
class LoopDecorator:
@@ -80,7 +80,26 @@ def decorator(func):
8080

8181
branch = BranchDecorator()
8282

83-
def task(name=None, depends_on=None, language='Python', runtime_env=None, virtual_env=None, input_files=None, output_files=None):
83+
class ScriptDecorator:
84+
def __init__(self):
85+
self.languages = [
86+
'python', 'groovy', 'bash', 'shell', 'r', 'powershell', 'perl', 'ruby',
87+
'windows_cmd', 'javascript', 'scalaw', 'docker_compose', 'dockerfile',
88+
'kubernetes', 'php', 'vbscript', 'jython'
89+
]
90+
for lang in self.languages:
91+
setattr(self, lang, self.create_decorator(lang))
92+
93+
def create_decorator(self, language):
94+
def decorator(func):
95+
@wraps(func)
96+
def wrapper():
97+
return func()
98+
wrapper.language = language
99+
return wrapper
100+
return decorator
101+
102+
def task(name=None, depends_on=None, language='Python', runtime_env=None, virtual_env=None, input_files=None, output_files=None, prescript=None, postscript=None):
84103
"""
85104
Decorator to define a ProActive task.
86105
@@ -98,6 +117,8 @@ def task(name=None, depends_on=None, language='Python', runtime_env=None, virtua
98117
- requirements_file (str): File containing the list of requirements.
99118
:param input_files: Optional list of files to transfer to the task environment.
100119
:param output_files: Optional list of output files to be retrieved after task execution.
120+
:param prescript: Optional pre-script function to execute before the task.
121+
:param postscript: Optional post-script function to execute after the task.
101122
"""
102123
def decorator(func):
103124
@wraps(func)
@@ -114,6 +135,8 @@ def wrapper(*args, **kwargs):
114135
'VirtualEnv': virtual_env,
115136
'InputFiles': input_files,
116137
'OutputFiles': output_files,
138+
'Prescript': prescript,
139+
'Postscript': postscript,
117140
'IsLoopStart': getattr(func, '_is_loop_start', False),
118141
'IsLoopEnd': getattr(func, '_is_loop_end', False),
119142
'LoopCriteria': getattr(func, '_loop_criteria', None),
@@ -150,6 +173,10 @@ def wrapper(*args, **kwargs):
150173
task.vbscript = TaskDecorator(language=ProactiveScriptLanguage().vbscript())
151174
task.jython = TaskDecorator(language=ProactiveScriptLanguage().jython())
152175

176+
# Define pre-script and post-script as part of the task module
177+
task.prescript = ScriptDecorator()
178+
task.postscript = ScriptDecorator()
179+
153180
def job(name, print_job_output=True):
154181
"""
155182
Decorator to define a ProActive job.
@@ -205,7 +232,7 @@ def wrapper(*args, **kwargs):
205232
print(f"Exception details: {e}")
206233
continue
207234

208-
# Set the runtime environment if provided
235+
# Set the runtime environment if provided
209236
# Parameters:
210237
# - type (str): Specifies the type of container technology to use for running the task.
211238
# Options include "docker", "podman", "singularity", or any other value to indicate a non-containerized execution.
@@ -264,6 +291,18 @@ def wrapper(*args, **kwargs):
264291
for file in task_def['OutputFiles']:
265292
task.addOutputFile(file)
266293

294+
# Set pre-script if provided
295+
if task_def['Prescript']:
296+
pre_script = gateway.createPreScript(getattr(ProactiveScriptLanguage(), task_def['Prescript'].language)())
297+
pre_script.setImplementation(task_def['Prescript']())
298+
task.setPreScript(pre_script)
299+
300+
# Set post-script if provided
301+
if task_def['Postscript']:
302+
post_script = gateway.createPostScript(getattr(ProactiveScriptLanguage(), task_def['Postscript'].language)())
303+
post_script.setImplementation(task_def['Postscript']())
304+
task.setPostScript(post_script)
305+
267306
job.addTask(task)
268307
task_objects[task_def['Name']] = task
269308

0 commit comments

Comments
 (0)