Skip to content

Commit 878ae73

Browse files
committed
add memray tracker
1 parent e6e3370 commit 878ae73

File tree

11 files changed

+380
-3
lines changed

11 files changed

+380
-3
lines changed

airflow-core/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ dependencies = [
171171
"requests-kerberos>=0.14.0",
172172
"thrift-sasl>=0.4.2",
173173
]
174+
"memray" = [
175+
"memray>=1.19.0",
176+
]
174177
"otel" = [
175178
"opentelemetry-exporter-prometheus>=0.47b0",
176179
]

airflow-core/src/airflow/cli/commands/api_server_command.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from airflow.exceptions import AirflowConfigException
3434
from airflow.typing_compat import ParamSpec
3535
from airflow.utils import cli as cli_utils
36+
from airflow.utils.memray_utils import MemrayTraceComponents, enable_memray_trace
3637
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
3738

3839
PS = ParamSpec("PS")
@@ -48,6 +49,7 @@
4849
# more info here: https://github.com/benoitc/gunicorn/issues/1877#issuecomment-1911136399
4950

5051

52+
@enable_memray_trace(component=MemrayTraceComponents.api)
5153
def _run_api_server(args, apps: str, num_workers: int, worker_timeout: int, proxy_headers: bool):
5254
"""Run the API server."""
5355
log.info(

airflow-core/src/airflow/cli/commands/dag_processor_command.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner
2727
from airflow.jobs.job import Job, run_job
2828
from airflow.utils import cli as cli_utils
29+
from airflow.utils.memray_utils import MemrayTraceComponents, enable_memray_trace
2930
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
3031

3132
log = logging.getLogger(__name__)
@@ -44,6 +45,7 @@ def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner:
4445
)
4546

4647

48+
@enable_memray_trace(component=MemrayTraceComponents.dag_processor)
4749
@cli_utils.action_cli
4850
@providers_configuration_loaded
4951
def dag_processor(args):

airflow-core/src/airflow/cli/commands/scheduler_command.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@
3030
from airflow.jobs.job import Job, run_job
3131
from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
3232
from airflow.utils import cli as cli_utils
33+
from airflow.utils.memray_utils import MemrayTraceComponents, enable_memray_trace
3334
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
3435
from airflow.utils.scheduler_health import serve_health_check
3536

3637
log = logging.getLogger(__name__)
3738

3839

40+
@enable_memray_trace(component=MemrayTraceComponents.scheduler)
3941
def _run_scheduler_job(args) -> None:
4042
job_runner = SchedulerJobRunner(job=Job(), num_runs=args.num_runs)
4143
enable_health_check = conf.getboolean("scheduler", "ENABLE_HEALTH_CHECK")

airflow-core/src/airflow/config_templates/config.yml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2561,3 +2561,33 @@ dag_processor:
25612561
type: boolean
25622562
example: ~
25632563
default: "True"
2564+
profiling:
2565+
description: |
2566+
Configuration for memory profiling in Airflow component.
2567+
Currently, we provide profiling using Memray and additional tools may be added in the future
2568+
Also, see the guide in Link (TBD)
2569+
options:
2570+
memray_trace_enabled:
2571+
description: |
2572+
Whether to enable memory allocation tracing by memray in the scheduler. If enabled, Airflow will
2573+
start tracing memory allocation and store the metrics in "$AIRFLOW_HOME/dag_processor_memory.bin"
2574+
To generate analyzed view, run this command in base directory where the bin file is generated
2575+
```
2576+
# see also https://bloomberg.github.io/memray/run.html#aggregated-capture-files
2577+
memray flamegraph $AIRFLOW_HOME/<component>_memory.bin
2578+
```
2579+
This is an expensive operation and generally should not be used except for debugging purposes.
2580+
version_added: 3.1.1
2581+
type: boolean
2582+
example: ~
2583+
default: "False"
2584+
memray_trace_components:
2585+
description: |
2586+
Comma-separated list of Airflow components to profile with memray.
2587+
Valid components are: scheduler, api, dag_processor
2588+
2589+
This option only takes effect when memray_trace_enabled is set to True.
2590+
version_added: 3.1.2
2591+
type: string
2592+
example: "scheduler,api,dag_processor"
2593+
default: ""

airflow-core/src/airflow/configuration.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@
3636
from configparser import ConfigParser, NoOptionError, NoSectionError
3737
from contextlib import contextmanager
3838
from copy import deepcopy
39+
from enum import Enum
3940
from io import StringIO
4041
from json.decoder import JSONDecodeError
4142
from re import Pattern
42-
from typing import IO, TYPE_CHECKING, Any
43+
from typing import IO, TYPE_CHECKING, Any, TypeVar
4344
from urllib.parse import urlsplit
4445

4546
from packaging.version import parse as parse_version
@@ -1244,10 +1245,52 @@ def getlist(self, section: str, key: str, delimiter=",", **kwargs):
12441245
return [item.strip() for item in val.split(delimiter)]
12451246
except Exception:
12461247
raise AirflowConfigException(
1247-
f'Failed to parse value to a list. Please check "{key}" key in "{section}" section. '
1248+
f'Failed to parse value to list. Please check "{key}" key in "{section}" section. '
12481249
f'Current value: "{val}".'
12491250
)
12501251

1252+
E = TypeVar("E", bound=Enum)
1253+
1254+
def getenum(self, section: str, key: str, enum_class: type[E], **kwargs) -> E:
1255+
val = self.get(section, key, **kwargs)
1256+
enum_names = [enum_item.name for enum_item in enum_class]
1257+
1258+
if val is None:
1259+
raise AirflowConfigException(
1260+
f'Failed to convert value. Please check "{key}" key in "{section}" section. '
1261+
f'Current value: "{val}" and it must be one of {", ".join(enum_names)}'
1262+
)
1263+
1264+
try:
1265+
return enum_class[val]
1266+
except KeyError:
1267+
if "fallback" in kwargs and kwargs["fallback"] in enum_names:
1268+
return enum_class[kwargs["fallback"]]
1269+
raise AirflowConfigException(
1270+
f'Failed to convert value. Please check "{key}" key in "{section}" section. '
1271+
f'Current value: "{val}" and it must be one of {", ".join(enum_names)}'
1272+
)
1273+
1274+
def getenumlist(self, section: str, key: str, enum_class: type[E], delimiter=",", **kwargs) -> list[E]:
1275+
string_list = self.getlist(section, key, delimiter, **kwargs)
1276+
enum_names = [enum_item.name for enum_item in enum_class]
1277+
enum_list = []
1278+
1279+
for val in string_list:
1280+
try:
1281+
enum_list.append(enum_class[val])
1282+
except KeyError:
1283+
log.warning(
1284+
"Failed to convert value. Please check %s key in %s section. "
1285+
"Current value: %s and it must be one of %s",
1286+
key,
1287+
section,
1288+
val,
1289+
", ".join(enum_names),
1290+
)
1291+
1292+
return enum_list
1293+
12511294
def getimport(self, section: str, key: str, **kwargs) -> Any:
12521295
"""
12531296
Read options, import the full qualified name, and return the object.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
import logging
21+
from collections.abc import Callable
22+
from enum import Enum
23+
from functools import wraps
24+
from typing import ParamSpec, TypeVar
25+
26+
from airflow.configuration import AIRFLOW_HOME, conf
27+
28+
# Type variables for preserving function signatures
29+
PS = ParamSpec("PS")
30+
RT = TypeVar("RT")
31+
32+
log = logging.getLogger(__name__)
33+
34+
35+
class MemrayTraceComponents(Enum):
36+
"""Possible airflow components can apply memray trace."""
37+
38+
scheduler = "scheduler"
39+
dag_processor = "dag_processor"
40+
api = "api"
41+
42+
43+
def enable_memray_trace(component: MemrayTraceComponents) -> Callable[[Callable[PS, RT]], Callable[PS, RT]]:
44+
"""
45+
Conditionally track memory using memray based on configuration.
46+
47+
Args:
48+
component: Enum value of the component for configuration lookup
49+
"""
50+
51+
def decorator(func: Callable[PS, RT]) -> Callable[PS, RT]:
52+
@wraps(func)
53+
def wrapper(*args: PS.args, **kwargs: PS.kwargs) -> RT:
54+
_enable_memray_trace = conf.getboolean("profiling", "memray_trace_enabled")
55+
if not _enable_memray_trace:
56+
return func(*args, **kwargs)
57+
58+
_memray_trace_components = conf.getenumlist(
59+
"profiling", "memray_trace_components", MemrayTraceComponents
60+
)
61+
if component not in _memray_trace_components:
62+
return func(*args, **kwargs)
63+
64+
try:
65+
import memray
66+
67+
airflow_component_name = component.value
68+
profile_path = f"{AIRFLOW_HOME}/{airflow_component_name}_memory.bin"
69+
log.info("enable_memray_trace is on. so memory state is tracked by memray")
70+
with memray.Tracker(
71+
profile_path,
72+
):
73+
log.info(
74+
"Memray tracing enabled for %s. Output: %s", airflow_component_name, profile_path
75+
)
76+
return func(*args, **kwargs)
77+
except ImportError as error:
78+
# Silently fall back to running without tracking
79+
log.warning("ImportError memray.Tracker: %s", error.msg)
80+
return func(*args, **kwargs)
81+
82+
return wrapper
83+
84+
return decorator

airflow-core/tests/unit/core/test_configuration.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import re
2424
import textwrap
2525
import warnings
26+
from enum import Enum
2627
from io import StringIO
2728
from unittest import mock
2829
from unittest.mock import patch
@@ -529,6 +530,62 @@ def test_getjson(self, config_str, expected):
529530

530531
assert test_conf.getjson("test", "json") == expected
531532

533+
def test_getenum(self):
534+
class TestEnum(Enum):
535+
option1 = 1
536+
option2 = 2
537+
option3 = 3
538+
fallback = 4
539+
540+
config = """
541+
[test1]
542+
option = option1
543+
[test2]
544+
option = option2
545+
[test3]
546+
option = option3
547+
[test4]
548+
option = option4
549+
"""
550+
test_conf = AirflowConfigParser()
551+
test_conf.read_string(config)
552+
553+
assert test_conf.getenum("test1", "option", TestEnum) == TestEnum.option1
554+
assert test_conf.getenum("test2", "option", TestEnum) == TestEnum.option2
555+
assert test_conf.getenum("test3", "option", TestEnum) == TestEnum.option3
556+
assert test_conf.getenum("test4", "option", TestEnum, fallback="fallback") == TestEnum.fallback
557+
with pytest.raises(AirflowConfigException, match=re.escape("option1, option2, option3, fallback")):
558+
test_conf.getenum("test4", "option", TestEnum)
559+
560+
def test_getenumlist(self):
561+
class TestEnum(Enum):
562+
option1 = 1
563+
option2 = 2
564+
option3 = 3
565+
fallback = 4
566+
567+
config = """
568+
[test1]
569+
option = option1,option2,option3
570+
[test2]
571+
option = option1,option3
572+
[test3]
573+
option = option1,option4
574+
[test4]
575+
option =
576+
"""
577+
test_conf = AirflowConfigParser()
578+
test_conf.read_string(config)
579+
580+
assert test_conf.getenumlist("test1", "option", TestEnum) == [
581+
TestEnum.option1,
582+
TestEnum.option2,
583+
TestEnum.option3,
584+
]
585+
assert test_conf.getenumlist("test2", "option", TestEnum) == [TestEnum.option1, TestEnum.option3]
586+
assert test_conf.getenumlist("test3", "option", TestEnum) == [TestEnum.option1]
587+
assert test_conf.getenumlist("test4", "option", TestEnum) == []
588+
532589
def test_getjson_empty_with_fallback(self):
533590
config = textwrap.dedent(
534591
"""

0 commit comments

Comments
 (0)