Skip to content

Commit a6ebdd3

Browse files
committed
Refactor task update methods in TaskRepository
- Consolidated multiple task update methods (status, inputs, dependencies, name, priority, params, schemas) into a single method `update_task` that accepts arbitrary fields. - Updated related type hints and documentation to reflect changes. - Modified all references to the old update methods in tests and other parts of the codebase to use the new `update_task` method. - Removed unused JSON extraction functions from helpers and replaced them with a more efficient JSON finder. - Updated tests to ensure compatibility with the new task update method and removed redundant tests for the deleted JSON extraction functions.
1 parent 6527447 commit a6ebdd3

27 files changed

+180
-626
lines changed

docs/api/python.md

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,7 @@ Database operations for tasks.
306306
- `get_task_by_id(task_id)`: Get task by ID
307307
- `get_root_task(task)`: Get root task
308308
- `build_task_tree(task)`: Build task tree from task
309-
- `update_task_status(task_id, status, ...)`: Update task status and related fields (error, result, progress, timestamps)
310-
- `update_task_inputs(task_id, inputs)`: Update task inputs
311-
- `update_task_dependencies(task_id, dependencies)`: Update task dependencies (with validation)
312-
- `update_task_name(task_id, name)`: Update task name
313-
- `update_task_priority(task_id, priority)`: Update task priority
314-
- `update_task_params(task_id, params)`: Update executor parameters
315-
- `update_task_schemas(task_id, schemas)`: Update validation schemas
309+
- `update_task(task_id, **kwarg)` Update task fields
316310
- `delete_task(task_id)`: Physically delete a task from the database
317311
- `get_all_children_recursive(task_id)`: Recursively get all child tasks (including grandchildren)
318312
- `find_dependent_tasks(task_id)`: Find all tasks that depend on a given task (reverse dependencies)
@@ -513,7 +507,7 @@ async def my_hook(task):
513507
repo = get_hook_repository()
514508
if repo:
515509
# Modify task fields
516-
await repo.update_task_priority(task.id, 10)
510+
await repo.update_task(task.id, priority=10)
517511
# Query other tasks
518512
pending = await repo.get_tasks_by_status("pending")
519513
```

docs/api/quick-reference.md

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -113,33 +113,12 @@ task = await task_manager.task_repository.get_task_by_id(task_id)
113113

114114
```python
115115
# Update status and related fields
116-
await task_repository.update_task_status(
116+
await task_repository.update_task(
117117
task_id,
118118
status="completed",
119119
result={"data": "result"},
120120
progress=1.0
121121
)
122-
123-
# Update inputs
124-
await task_repository.update_task_inputs(task_id, {"key": "new_value"})
125-
126-
# Update name
127-
await task_repository.update_task_name(task_id, "New Task Name")
128-
129-
# Update priority
130-
await task_repository.update_task_priority(task_id, 2)
131-
132-
# Update params
133-
await task_repository.update_task_params(task_id, {"executor_id": "new_executor"})
134-
135-
# Update schemas
136-
await task_repository.update_task_schemas(task_id, {"input_schema": {...}})
137-
138-
# Update dependencies (only for pending tasks, with validation)
139-
await task_repository.update_task_dependencies(
140-
task_id,
141-
[{"id": "dep-task-id", "required": True}]
142-
)
143122
```
144123

145124
**Critical Field Validation:**
@@ -389,8 +368,7 @@ async def modify_task_with_db(task):
389368
repo = get_hook_repository()
390369
if repo:
391370
# Update task fields
392-
await repo.update_task_name(task.id, "New Name")
393-
await repo.update_task_priority(task.id, 10)
371+
await repo.update_task(task.id, name="New Name")
394372

395373
# Query other tasks
396374
pending_tasks = await repo.get_tasks_by_status("pending")

docs/architecture/task-tree-lifecycle.md

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ async def execute_task_tree(
184184
task = await self.task_repository.get_task_by_id(task_id)
185185

186186
# Update to in_progress
187-
await self.task_repository.update_task_status(
187+
await self.task_repository.update_task(
188188
task_id=task_id,
189189
status="in_progress",
190190
started_at=datetime.now(timezone.utc)
@@ -197,7 +197,7 @@ async def execute_task_tree(
197197
resolved_inputs = await resolve_task_dependencies(task, self.task_repository)
198198

199199
if resolved_inputs != (task.inputs or {}):
200-
await self.task_repository.update_task_inputs(task_id, resolved_inputs)
200+
await self.task_repository.update_task(task_id, inputs=resolved_inputs)
201201
```
202202

203203
3. **Pre-Hook Execution** (lines 761-794)
@@ -210,7 +210,7 @@ async def execute_task_tree(
210210

211211
# Auto-persist if inputs changed
212212
if inputs_after_pre_hooks != inputs_before_pre_hooks:
213-
await self.task_repository.update_task_inputs(task_id, inputs_to_save)
213+
await self.task_repository.update_task(task_id, inputs=inputs_to_save)
214214
task = await self.task_repository.get_task_by_id(task_id) # Refresh
215215
```
216216

@@ -223,7 +223,7 @@ async def execute_task_tree(
223223
5. **Status Update and Cleanup** (lines 832-856)
224224
```python
225225
# Update task status
226-
await self.task_repository.update_task_status(
226+
await self.task_repository.update_task(
227227
task_id=task_id,
228228
status="completed",
229229
progress=1.0,
@@ -267,11 +267,11 @@ TaskExecutor.execute_task_tree (session created)
267267
│ ├── on_tree_started hooks
268268
│ ├── _execute_task_tree_recursive
269269
│ │ ├── _execute_single_task (task 1)
270-
│ │ │ ├── update_task_status
270+
│ │ │ ├── update_task
271271
│ │ │ ├── resolve_task_dependencies
272272
│ │ │ ├── pre-hooks (can modify task.inputs)
273273
│ │ │ ├── execute task
274-
│ │ │ ├── update_task_status
274+
│ │ │ ├── update_task
275275
│ │ │ └── post-hooks
276276
│ │ ├── _execute_single_task (task 2)
277277
│ │ └── ...
@@ -285,12 +285,6 @@ TaskExecutor.execute_task_tree (session created)
285285

286286
**Per-Operation Commits**:
287287
- Each `TaskRepository` method commits its own transaction
288-
- Example in `update_task_status()`:
289-
```python
290-
await self.db.commit()
291-
flag_modified(task, "result") # For JSON fields
292-
await self.db.refresh(task) # Ensure fresh data
293-
```
294288

295289
**Benefits**:
296290
- No cascading rollbacks across the entire task tree
@@ -439,7 +433,7 @@ async def update_metadata(task, inputs, result):
439433
repo = get_hook_repository()
440434

441435
# Explicit repository call required for non-inputs fields
442-
await repo.update_task_params(
436+
await repo.update_task(
443437
task_id=task.id,
444438
params={"processed_at": datetime.now().isoformat()}
445439
)
@@ -458,7 +452,7 @@ async def aggregate_results(root_task, status):
458452

459453
# Aggregate and update
460454
total_tokens = sum(t.result.get("token_usage", 0) for t in all_tasks if t.result)
461-
await repo.update_task_result(
455+
await repo.update_task(
462456
task_id=root_task.id,
463457
result={"total_tokens": total_tokens}
464458
)
@@ -474,7 +468,7 @@ try:
474468
task_result = await self._execute_task_with_schemas(task, final_inputs)
475469
except Exception as e:
476470
# Update task status to failed
477-
await self.task_repository.update_task_status(
471+
await self.task_repository.update_task(
478472
task_id=task_id,
479473
status="failed",
480474
error=str(e),

docs/development/design/cli-design.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ Return: {
239239
```
240240
1. CLI/API: tasks cancel task-123
241241
242-
2. TaskRepository.update_task_status(task_id, status="cancelled")
242+
2. TaskRepository.update_task(task_id, status="cancelled")
243243
244244
3. TaskManager._execute_single_task() checks status at multiple points:
245245
- Before starting execution

docs/development/extending.md

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -261,17 +261,11 @@ async def modify_task_fields(task):
261261
return # Not in hook context
262262

263263
# Modify task fields explicitly
264-
await repo.update_task_name(task.id, "Modified Name")
265-
await repo.update_task_priority(task.id, 10)
264+
await repo.update_task(task.id, name="Modified Name", priority=10)
266265

267266
# Query other tasks
268267
pending_tasks = await repo.get_tasks_by_status("pending")
269268
print(f"Found {len(pending_tasks)} pending tasks")
270-
271-
# Modify dependency tasks
272-
if task.dependencies:
273-
dep_id = task.dependencies[0]["id"]
274-
await repo.update_task_priority(dep_id, 100)
275269
```
276270

277271
**Key Points:**
@@ -292,11 +286,7 @@ Hook database access is managed through a context that spans the entire task tre
292286
For detailed lifecycle information, see [Task Tree Execution Lifecycle](../architecture/task-tree-lifecycle.md).
293287

294288
**Available Hook Repository Methods:**
295-
- `update_task_name(task_id, name)` - Update task name
296-
- `update_task_priority(task_id, priority)` - Update task priority
297-
- `update_task_status(task_id, status)` - Update task status
298-
- `update_task_params(task_id, params)` - Update task params
299-
- `update_task_inputs(task_id, inputs)` - Update task inputs (usually not needed, direct modification is auto-saved)
289+
- `update_task(task_id, **kwarg)` - Update task (usually not needed, direct modification is auto-saved)
300290
- `get_task_by_id(task_id)` - Query task by ID
301291
- `get_tasks_by_status(status)` - Query tasks by status
302292
- And all other TaskRepository methods...

docs/guides/best-practices.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ async def adjust_priority_by_load(task):
384384

385385
# Adjust priority if system is overloaded
386386
if pending_count > 100:
387-
await repo.update_task_priority(task.id, task.priority + 1)
387+
await repo.update_task(task.id, priority=task.priority + 1)
388388
```
389389

390390
**Remember:**

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ dev = [
135135
"mypy>=1.5.0",
136136
"build>=1.0.0",
137137
"twine>=4.0.0",
138+
"jsonfinder>=0.4.0",
138139
"pre-commit>=3.0.0",
139140
]
140141

src/apflow/api/routes/tasks.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,7 +1288,7 @@ async def handle_task_update(self, params: dict, request: Request, request_id: s
12881288
# Update status and related fields if provided
12891289
status = params.get("status")
12901290
if status is not None:
1291-
await task_repository.update_task_status(
1291+
await task_repository.update_task(
12921292
task_id=task_id,
12931293
status=status,
12941294
error=params.get("error"),
@@ -1300,44 +1300,44 @@ async def handle_task_update(self, params: dict, request: Request, request_id: s
13001300
else:
13011301
# Update individual status-related fields if status is not provided
13021302
if "error" in params:
1303-
await task_repository.update_task_status(
1303+
await task_repository.update_task(
13041304
task_id=task_id, status=task.status, error=params.get("error")
13051305
)
13061306
if "result" in params:
1307-
await task_repository.update_task_status(
1307+
await task_repository.update_task(
13081308
task_id=task_id, status=task.status, result=params.get("result")
13091309
)
13101310
if "progress" in params:
1311-
await task_repository.update_task_status(
1311+
await task_repository.update_task(
13121312
task_id=task_id, status=task.status, progress=params.get("progress")
13131313
)
13141314
if "started_at" in params:
1315-
await task_repository.update_task_status(
1315+
await task_repository.update_task(
13161316
task_id=task_id, status=task.status, started_at=params.get("started_at")
13171317
)
13181318
if "completed_at" in params:
1319-
await task_repository.update_task_status(
1319+
await task_repository.update_task(
13201320
task_id=task_id, status=task.status, completed_at=params.get("completed_at")
13211321
)
13221322

13231323
# Update other fields
13241324
if "inputs" in params:
1325-
await task_repository.update_task_inputs(task_id, params["inputs"])
1325+
await task_repository.update_task(task_id, inputs=params["inputs"])
13261326

13271327
if "dependencies" in params:
1328-
await task_repository.update_task_dependencies(task_id, params["dependencies"])
1328+
await task_repository.update_task(task_id, dependencies=params["dependencies"])
13291329

13301330
if "name" in params:
1331-
await task_repository.update_task_name(task_id, params["name"])
1331+
await task_repository.update_task(task_id, name=params["name"])
13321332

13331333
if "priority" in params:
1334-
await task_repository.update_task_priority(task_id, params["priority"])
1334+
await task_repository.update_task(task_id, priority=params["priority"])
13351335

13361336
if "params" in params:
1337-
await task_repository.update_task_params(task_id, params["params"])
1337+
await task_repository.update_task(task_id, params=params["params"])
13381338

13391339
if "schemas" in params:
1340-
await task_repository.update_task_schemas(task_id, params["schemas"])
1340+
await task_repository.update_task(task_id, schemas=params["schemas"])
13411341

13421342
# Refresh task to get updated values
13431343
updated_task = await task_repository.get_task_by_id(task_id)

src/apflow/cli/commands/tasks.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def create(
694694
Args:
695695
file: JSON file containing task(s) definition
696696
stdin: Read from stdin instead of file
697-
697+
698698
"""
699699
try:
700700
import sys
@@ -841,7 +841,7 @@ async def update_task():
841841

842842
# Update status-related fields if status is provided
843843
if "status" in update_params:
844-
await task_repository.update_task_status(
844+
await task_repository.update_task(
845845
task_id=task_id,
846846
status=update_params["status"],
847847
error=update_params.get("error"),
@@ -851,35 +851,35 @@ async def update_task():
851851
else:
852852
# Update individual fields
853853
if "error" in update_params:
854-
await task_repository.update_task_status(
854+
await task_repository.update_task(
855855
task_id=task_id,
856856
status=task.status,
857857
error=update_params["error"]
858858
)
859859
if "result" in update_params:
860-
await task_repository.update_task_status(
860+
await task_repository.update_task(
861861
task_id=task_id,
862862
status=task.status,
863863
result=update_params["result"]
864864
)
865865
if "progress" in update_params:
866-
await task_repository.update_task_status(
866+
await task_repository.update_task(
867867
task_id=task_id,
868868
status=task.status,
869869
progress=update_params["progress"]
870870
)
871871

872872
# Update other fields
873873
if "name" in update_params:
874-
await task_repository.update_task_name(task_id, update_params["name"])
874+
await task_repository.update_task(task_id, name=update_params["name"])
875875
if "priority" in update_params:
876-
await task_repository.update_task_priority(task_id, update_params["priority"])
876+
await task_repository.update_task(task_id, priority=update_params["priority"])
877877
if "inputs" in update_params:
878-
await task_repository.update_task_inputs(task_id, update_params["inputs"])
878+
await task_repository.update_task(task_id, inputs=update_params["inputs"])
879879
if "params" in update_params:
880-
await task_repository.update_task_params(task_id, update_params["params"])
880+
await task_repository.update_task(task_id, params=update_params["params"])
881881
if "schemas" in update_params:
882-
await task_repository.update_task_schemas(task_id, update_params["schemas"])
882+
await task_repository.update_task(task_id, schemas=update_params["schemas"])
883883

884884
# Get updated task
885885
updated_task = await task_repository.get_task_by_id(task_id)

0 commit comments

Comments
 (0)