Skip to content

Commit ff72b46

Browse files
committed
fix logic
1 parent bea4faa commit ff72b46

File tree

10 files changed

+193
-14
lines changed

10 files changed

+193
-14
lines changed

providers/edge3/src/airflow/providers/edge3/cli/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@
5151
status_file_path,
5252
write_pid_to_pidfile,
5353
)
54-
from airflow.providers.edge3.utils import is_callback_execute
5554
from airflow.providers.edge3.models.edge_worker import (
5655
EdgeWorkerDuplicateException,
5756
EdgeWorkerState,
5857
EdgeWorkerVersionException,
5958
)
59+
from airflow.providers.edge3.utils.types import is_callback_execute
6060
from airflow.utils.net import getfqdn
6161
from airflow.utils.state import TaskInstanceState
6262

providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828
from airflow.executors.base_executor import BaseExecutor
2929
from airflow.models.taskinstance import TaskInstance
3030
from airflow.providers.common.compat.sdk import Stats, conf, timezone
31-
from airflow.providers.edge3.utils import is_callback_execute
3231
from airflow.providers.edge3.models.db import EdgeDBManager, check_db_manager_config
3332
from airflow.providers.edge3.models.edge_job import EdgeJobModel
3433
from airflow.providers.edge3.models.edge_logs import EdgeLogsModel
3534
from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, EdgeWorkerState, reset_metrics
35+
from airflow.providers.edge3.utils.types import is_callback_execute
3636
from airflow.utils.db import DBLocks, create_global_lock
3737
from airflow.utils.session import NEW_SESSION, provide_session
3838
from airflow.utils.state import TaskInstanceState
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.

providers/edge3/src/airflow/providers/edge3/utils.py renamed to providers/edge3/src/airflow/providers/edge3/utils/types.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,24 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from typing import Annotated, TypeAlias, TypeGuard, Union
19+
from typing import TYPE_CHECKING, Annotated, TypeAlias, TypeGuard
2020

2121
from pydantic import Discriminator, Tag
2222

23-
from airflow.executors import workloads
2423
from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS
2524

25+
if TYPE_CHECKING:
26+
from airflow.executors import workloads
27+
2628
if not AIRFLOW_V_3_2_PLUS:
2729
from airflow.executors.workloads import ExecuteTask
2830

2931
ExecuteTypeBody: TypeAlias = ExecuteTask
3032
else:
3133
from airflow.executors.workloads import ExecuteCallback, ExecuteTask
3234

33-
ExecuteTypeBody: TypeAlias = Annotated[
34-
Union[Annotated[ExecuteTask, Tag("ExecuteTask")], Annotated[ExecuteCallback, Tag("ExecuteCallback")]],
35+
ExecuteTypeBody: TypeAlias = Annotated[ # type: ignore[no-redef,misc]
36+
Annotated[ExecuteTask, Tag("ExecuteTask")] | Annotated[ExecuteCallback, Tag("ExecuteCallback")],
3537
Discriminator("type"),
3638
]
3739

providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from airflow.providers.common.compat.sdk import TaskInstanceKey
2929
from airflow.providers.edge3.models.edge_worker import EdgeWorkerState # noqa: TCH001
30-
from airflow.providers.edge3.utils import ExecuteTypeBody # noqa: TCH001
30+
from airflow.providers.edge3.utils.types import ExecuteTypeBody # noqa: TCH001
3131

3232

3333
class WorkerApiDocs:

providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from __future__ import annotations
1919

20-
from typing import Annotated
20+
from typing import TYPE_CHECKING, Annotated
2121

2222
from fastapi import Body, Depends, status
2323
from sqlalchemy import select, update
@@ -27,7 +27,6 @@
2727
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
2828
from airflow.executors.workloads import ExecuteTask
2929
from airflow.providers.common.compat.sdk import Stats, timezone
30-
from airflow.providers.edge3.utils import ExecuteTypeBody
3130
from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS
3231

3332
try:
@@ -43,6 +42,9 @@
4342
)
4443
from airflow.utils.state import TaskInstanceState
4544

45+
if TYPE_CHECKING:
46+
from airflow.providers.edge3.utils.types import ExecuteTypeBody
47+
4648
jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")
4749

4850

@@ -51,7 +53,7 @@ def parse_command(command: str, dag_id: str, run_id: str) -> ExecuteTypeBody:
5153
from airflow.executors.workloads import ExecuteCallback
5254

5355
if dag_id == ExecuteCallback.TYPE and run_id.startswith(ExecuteCallback.TYPE):
54-
return ExecuteCallback.model_validate_json(command)
56+
return ExecuteCallback.model_validate_json(command) # type: ignore[return-value]
5557

5658
return ExecuteTask.model_validate_json(command)
5759

providers/edge3/tests/unit/edge3/executors/test_edge_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sqlalchemy import delete, select
2727

2828
from airflow.configuration import conf
29-
from airflow.executors.workloads import ExecuteTask, TaskInstanceDTO
29+
from airflow.executors.workloads import ExecuteTask
3030
from airflow.executors.workloads.base import BundleInfo
3131
from airflow.models.taskinstancekey import TaskInstanceKey
3232
from airflow.providers.common.compat.sdk import Stats, timezone
@@ -40,6 +40,7 @@
4040
from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS
4141

4242
if AIRFLOW_V_3_2_PLUS:
43+
from airflow.executors.workloads import TaskInstanceDTO
4344
from airflow.executors.workloads.callback import CallbackDTO, CallbackFetchMethod, ExecuteCallback
4445

4546
pytestmark = pytest.mark.db_test
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
from pathlib import Path
20+
from uuid import uuid4
21+
22+
import pytest
23+
from pydantic import TypeAdapter
24+
25+
from airflow.executors.workloads import ExecuteTask
26+
from airflow.executors.workloads.base import BundleInfo
27+
from airflow.providers.edge3.utils.types import ExecuteTypeBody, is_callback_execute
28+
29+
from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS
30+
31+
if AIRFLOW_V_3_2_PLUS:
32+
from airflow.executors.workloads import TaskInstanceDTO
33+
from airflow.executors.workloads.callback import CallbackDTO, CallbackFetchMethod, ExecuteCallback
34+
35+
36+
def _make_execute_task() -> ExecuteTask:
37+
ti = TaskInstanceDTO(
38+
id=uuid4(),
39+
dag_version_id=uuid4(),
40+
task_id="test_task",
41+
dag_id="test_dag",
42+
run_id="test_run",
43+
try_number=1,
44+
map_index=-1,
45+
pool_slots=1,
46+
queue="default",
47+
priority_weight=1,
48+
)
49+
return ExecuteTask(
50+
ti=ti,
51+
dag_rel_path=Path("test_dag.py"),
52+
token="test_token",
53+
bundle_info=BundleInfo(name="test_bundle", version="1.0"),
54+
log_path="test.log",
55+
)
56+
57+
58+
def _make_execute_callback() -> ExecuteCallback:
59+
callback_data = CallbackDTO(
60+
id=str(uuid4()),
61+
fetch_method=CallbackFetchMethod.IMPORT_PATH,
62+
data={
63+
"path": "builtins.dict",
64+
"kwargs": {"a": 1, "b": 2, "c": 3},
65+
},
66+
)
67+
return ExecuteCallback(
68+
callback=callback_data,
69+
dag_rel_path=Path("test.py"),
70+
bundle_info=BundleInfo(name="test_bundle", version="1.0"),
71+
token="test_token",
72+
log_path="test.log",
73+
)
74+
75+
76+
class TestIsCallbackExecute:
77+
def test_returns_false_for_execute_task(self):
78+
workload = _make_execute_task()
79+
assert is_callback_execute(workload) is False
80+
81+
@pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecuteCallback requires Airflow 3.2+")
82+
def test_returns_true_for_execute_callback(self):
83+
workload = _make_execute_callback()
84+
assert is_callback_execute(workload) is True
85+
86+
87+
@pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecuteTypeBody union requires Airflow 3.2+")
88+
class TestExecuteTypeBody:
89+
def setup_method(self):
90+
self.adapter: TypeAdapter = TypeAdapter(ExecuteTypeBody)
91+
92+
def test_validate_execute_task_json(self):
93+
workload = _make_execute_task()
94+
json_str = workload.model_dump_json()
95+
96+
result = self.adapter.validate_json(json_str)
97+
98+
assert isinstance(result, ExecuteTask)
99+
assert result.ti.dag_id == "test_dag"
100+
101+
def test_validate_execute_callback_json(self):
102+
workload = _make_execute_callback()
103+
json_str = workload.model_dump_json()
104+
105+
result = self.adapter.validate_json(json_str)
106+
107+
assert isinstance(result, ExecuteCallback)
108+
assert result.callback.fetch_method == CallbackFetchMethod.IMPORT_PATH
109+
110+
def test_validate_execute_task_dict(self):
111+
workload = _make_execute_task()
112+
data = workload.model_dump()
113+
114+
result = self.adapter.validate_python(data)
115+
116+
assert isinstance(result, ExecuteTask)
117+
118+
def test_validate_execute_callback_dict(self):
119+
workload = _make_execute_callback()
120+
data = workload.model_dump()
121+
122+
result = self.adapter.validate_python(data)
123+
124+
assert isinstance(result, ExecuteCallback)
125+
126+
def test_roundtrip_execute_task(self):
127+
original = _make_execute_task()
128+
json_str = self.adapter.dump_json(original)
129+
restored = self.adapter.validate_json(json_str)
130+
131+
assert isinstance(restored, ExecuteTask)
132+
assert restored.ti.task_id == original.ti.task_id
133+
assert restored.ti.dag_id == original.ti.dag_id
134+
135+
def test_roundtrip_execute_callback(self):
136+
original = _make_execute_callback()
137+
json_str = self.adapter.dump_json(original)
138+
restored = self.adapter.validate_json(json_str)
139+
140+
assert isinstance(restored, ExecuteCallback)
141+
assert restored.callback.id == original.callback.id

providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import pytest
2525
from sqlalchemy import delete, select
2626

27-
from airflow.executors.workloads import ExecuteTask, TaskInstanceDTO
27+
from airflow.executors.workloads import ExecuteTask
2828
from airflow.executors.workloads.base import BundleInfo
2929
from airflow.executors.workloads.callback import CallbackDTO
3030
from airflow.providers.edge3.models.edge_job import EdgeJobModel
@@ -37,10 +37,11 @@
3737
if TYPE_CHECKING:
3838
from sqlalchemy.orm import Session
3939

40-
from airflow.executors.workloads import CallbackFetchMethod, ExecuteCallback
40+
from airflow.executors.workloads import CallbackFetchMethod, ExecuteCallback, TaskInstanceDTO
4141

4242
if AIRFLOW_V_3_2_PLUS:
43-
from airflow.executors.workloads import CallbackFetchMethod, ExecuteCallback
43+
from airflow.executors.workloads import CallbackFetchMethod, ExecuteCallback, TaskInstanceDTO
44+
4445

4546
try:
4647
from airflow.sdk._shared.observability.metrics.dual_stats_manager import DualStatsManager # noqa: F401

0 commit comments

Comments
 (0)