Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions data_rentgen/consumer/extractors/batch_extraction_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,27 @@
RunDTO,
SchemaDTO,
SQLQueryDTO,
TagDTO,
TagValueDTO,
UserDTO,
)

T = TypeVar(
"T",
LocationDTO,
DatasetDTO,
ColumnLineageDTO,
DatasetDTO,
DatasetSymlinkDTO,
InputDTO,
JobDTO,
JobTypeDTO,
RunDTO,
LocationDTO,
OperationDTO,
InputDTO,
OutputDTO,
RunDTO,
SchemaDTO,
SQLQueryDTO,
TagDTO,
TagValueDTO,
UserDTO,
)

Expand Down Expand Up @@ -70,6 +74,8 @@ def __init__(self):
self._column_lineage: dict[tuple, ColumnLineageDTO] = {}
self._schemas: dict[tuple, SchemaDTO] = {}
self._sql_queries: dict[tuple, SQLQueryDTO] = {}
self._tags: dict[tuple, TagDTO] = {}
self._tag_values: dict[tuple, TagValueDTO] = {}
self._users: dict[tuple, UserDTO] = {}

def __repr__(self):
Expand All @@ -87,6 +93,8 @@ def __repr__(self):
f"column_lineage={len(self._column_lineage)}, "
f"schemas={len(self._schemas)}, "
f"sql_queries={len(self._sql_queries)}, "
f"tag_values={len(self._tags)}, "
f"tag_values={len(self._tag_values)}, "
f"users={len(self._users)}"
")"
)
Expand Down Expand Up @@ -125,6 +133,7 @@ def add_job(self, job: JobDTO):
job.location = self.add_location(job.location)
if job.type:
job.type = self.add_job_type(job.type)
job.tag_values = {self.add_tag_value(tag_value) for tag_value in job.tag_values}
return self._add(self._jobs, job)

def add_run(self, run: RunDTO):
Expand Down Expand Up @@ -167,6 +176,13 @@ def add_schema(self, schema: SchemaDTO):
def add_sql_query(self, sql_query: SQLQueryDTO):
return self._add(self._sql_queries, sql_query)

def add_tag(self, tag: TagDTO):
return self._add(self._tags, tag)

def add_tag_value(self, tag_value: TagValueDTO):
tag_value.tag = self.add_tag(tag_value.tag)
return self._add(self._tag_values, tag_value)

def add_user(self, user: UserDTO):
return self._add(self._users, user)

Expand All @@ -182,6 +198,12 @@ def get_sql_query(self, sql_query_key: tuple) -> SQLQueryDTO:
def get_user(self, user_key: tuple) -> UserDTO:
return self._users[user_key]

def get_tag(self, tag_key: tuple) -> TagDTO:
return self._tags[tag_key]

def get_tag_value(self, tag_value_key: tuple) -> TagValueDTO:
return self._tag_values[tag_value_key]

def get_dataset(self, dataset_key: tuple) -> DatasetDTO:
dataset = self._datasets[dataset_key]
dataset.location = self.get_location(dataset.location.unique_key)
Expand All @@ -201,6 +223,7 @@ def get_job(self, job_key: tuple) -> JobDTO:
job.location = self.get_location(job.location.unique_key)
if job.type:
job.type = self.get_job_type(job.type.unique_key)
job.tag_values = {self.get_tag_value(tag_value.unique_key) for tag_value in job.tag_values}
return job

def get_run(self, run_key: tuple) -> RunDTO:
Expand Down Expand Up @@ -282,6 +305,12 @@ def schemas(self) -> list[SchemaDTO]:
def sql_queries(self) -> list[SQLQueryDTO]:
return self._resolve(self.get_sql_query, self._sql_queries)

def tags(self) -> list[TagDTO]:
return self._resolve(self.get_tag, self._tags)

def tag_values(self) -> list[TagValueDTO]:
return self._resolve(self.get_tag_value, self._tag_values)

def users(self) -> list[UserDTO]:
return self._resolve(self.get_user, self._users)

Expand Down Expand Up @@ -322,6 +351,12 @@ def merge(self, other: BatchExtractionResult) -> BatchExtractionResult: # noqa:
for sql_query in other.sql_queries():
self.add_sql_query(sql_query)

for tag in other.tags():
self.add_tag(tag)

for tag_value in other.tag_values():
self.add_tag_value(tag_value)

for user in other.users():
self.add_user(user)

Expand Down
31 changes: 31 additions & 0 deletions data_rentgen/consumer/extractors/generic/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
JobDTO,
RunDTO,
RunStatusDTO,
TagDTO,
TagValueDTO,
)
from data_rentgen.openlineage.job import OpenLineageJob
from data_rentgen.openlineage.run_event import (
Expand Down Expand Up @@ -39,6 +41,7 @@ def extract_run(self, event: OpenLineageRunEvent) -> RunDTO:
parent_run=self.extract_parent_run(event.run.facets.parent) if event.run.facets.parent else None,
)
self._enrich_run_status(run, event)
self._enrich_run_tags(run, event)
return run

def extract_parent_run(self, facet: OpenLineageParentRunFacet | OpenLineageRunEvent) -> RunDTO:
Expand Down Expand Up @@ -70,3 +73,31 @@ def _enrich_run_status(self, run: RunDTO, event: OpenLineageRunEvent) -> RunDTO:
# OTHER is used only to update run statistics
pass
return run

def _enrich_run_tags(self, run: RunDTO, event: OpenLineageRunEvent) -> RunDTO:
if event.run.facets.processing_engine:
client_tag_value = TagValueDTO(
tag=TagDTO(name=f"{event.run.facets.processing_engine.name.lower()}.version"),
value=str(event.run.facets.processing_engine.version),
)
adapter_tag_value = TagValueDTO(
tag=TagDTO(name="openlineage_adapter.version"),
value=str(event.run.facets.processing_engine.openlineageAdapterVersion),
)
# we don't store run tags, everything is merged into job tags
run.job.tag_values.add(client_tag_value)
run.job.tag_values.add(adapter_tag_value)

if not event.run.facets.tags:
return run

for raw_tag in event.run.facets.tags.tags:
key = raw_tag.key
if key == "openlineage_client_version":
# https://github.com/OpenLineage/OpenLineage/blob/1.42.1/client/python/src/openlineage/client/client.py#L460
tag_value = TagValueDTO(
tag=TagDTO(name="openlineage_client.version"),
value=raw_tag.value,
)
run.job.tag_values.add(tag_value)
return run
4 changes: 3 additions & 1 deletion data_rentgen/consumer/extractors/impl/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def extract_run(self, event: OpenLineageRunEvent) -> RunDTO:
if hive_session.username not in ("anonymous", "hive"):
user = UserDTO(name=hive_session.username)

return RunDTO(
result = RunDTO(
id=run_id,
job=JobDTO(
name=job_name,
Expand All @@ -68,6 +68,8 @@ def extract_run(self, event: OpenLineageRunEvent) -> RunDTO:
external_id=hive_session.sessionId,
user=user,
)
self._enrich_run_tags(result, event)
return result

def extract_operation(self, event: OpenLineageRunEvent) -> OperationDTO:
run = self.extract_run(event)
Expand Down
1 change: 1 addition & 0 deletions data_rentgen/consumer/extractors/impl/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def extract_operation(self, event: OpenLineageRunEvent) -> OperationDTO:
run = self.extract_parent_run(event.run.facets.parent) # type: ignore[arg-type]
# Workaround for https://github.com/OpenLineage/OpenLineage/issues/3846
self._enrich_run_identifiers(run, event)
self._enrich_run_tags(run, event)
operation = super()._extract_operation(event, run)

# in some cases, operation name may contain raw SELECT query with newlines. use spaces instead.
Expand Down
24 changes: 22 additions & 2 deletions data_rentgen/consumer/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ async def save(self, data: BatchExtractionResult):
self.logger.info("Saving to database")

await self.create_locations(data)
await self.create_tags(data)
await self.create_tag_values(data)
await self.create_datasets(data)
await self.create_dataset_symlinks(data)
await self.create_job_types(data)
Expand Down Expand Up @@ -51,7 +53,7 @@ async def create_locations(self, data: BatchExtractionResult):
for location_dto in data.locations():
async with self.unit_of_work:
location = await self.unit_of_work.location.create_or_update(location_dto)
location_dto.id = location.id
location_dto.id = location.id

# To avoid deadlocks when parallel consumer instances insert/update the same row,
# commit changes for each row instead of committing the whole batch. Yes, this cloud be slow.
Expand Down Expand Up @@ -92,7 +94,7 @@ async def create_jobs(self, data: BatchExtractionResult):
job = await self.unit_of_work.job.create_or_update(job_dto) # noqa: PLW2901
else:
job = await self.unit_of_work.job.update(job, job_dto) # noqa: PLW2901
job_dto.id = job.id
job_dto.id = job.id

async def create_users(self, data: BatchExtractionResult):
self.logger.debug("Creating users")
Expand Down Expand Up @@ -125,6 +127,24 @@ async def create_schemas(self, data: BatchExtractionResult):
else:
schema_dto.id = schema_id

async def create_tags(self, data: BatchExtractionResult):
self.logger.debug("Creating tags")
tag_pairs = await self.unit_of_work.tag.fetch_bulk(data.tags())
for tag_dto, tag in tag_pairs:
if not tag:
async with self.unit_of_work:
tag = await self.unit_of_work.tag.create(tag_dto) # noqa: PLW2901
tag_dto.id = tag.id

async def create_tag_values(self, data: BatchExtractionResult):
self.logger.debug("Creating tag values")
tag_value_pairs = await self.unit_of_work.tag_value.fetch_bulk(data.tag_values())
for tag_value_dto, tag_value in tag_value_pairs:
if not tag_value:
async with self.unit_of_work:
tag_value = await self.unit_of_work.tag_value.create(tag_value_dto) # noqa: PLW2901
tag_value_dto.id = tag_value.id

# In most cases, all the run tree created by some parent is send into one
# Kafka partition, and thus handled by just one worker.
# Cross fingers and create all runs in one transaction.
Expand Down
53 changes: 47 additions & 6 deletions data_rentgen/db/repositories/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
select,
tuple_,
union,
update,
)
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload

from data_rentgen.db.models import Address, Job, Location, TagValue
from data_rentgen.db.models import Address, Job, JobTagValue, Location, TagValue
from data_rentgen.db.repositories.base import Repository
from data_rentgen.db.utils.search import make_tsquery, ts_match, ts_rank
from data_rentgen.dto import JobDTO, PaginationDTO
Expand Down Expand Up @@ -69,6 +71,19 @@
.group_by(Job.location_id)
)

update_job_type_query = update(Job).where(Job.id == bindparam("job_id")).values(type_id=bindparam("type_id"))

insert_tag_value_query = (
insert(JobTagValue)
.values(
{
"job_id": bindparam("job_id"),
"tag_value_id": bindparam("tag_value_id"),
}
)
.on_conflict_do_nothing(index_elements=["job_id", "tag_value_id"])
)


class JobRepository(Repository[Job]):
async def paginate(
Expand Down Expand Up @@ -175,11 +190,15 @@ async def fetch_bulk(self, jobs_dto: list[JobDTO]) -> list[tuple[JobDTO, Job | N
]

async def create_or_update(self, job: JobDTO) -> Job:
# if another worker already created the same row, just use it. if not - create with holding the lock.
await self._lock(job.location.id, job.name.lower())
result = await self._get(job)
if not result:
return await self._create(job)
# try one more time, but with lock acquired.
# if another worker already created the same row, just use it. if not - create with holding the lock.
await self._lock(job.location.id, job.name.lower())
result = await self._get(job)

if not result:
result = await self._create(job)
return await self.update(result, job)

async def _get(self, job: JobDTO) -> Job | None:
Expand All @@ -204,8 +223,30 @@ async def _create(self, job: JobDTO) -> Job:
async def update(self, existing: Job, new: JobDTO) -> Job:
# almost of fields are immutable, so we can avoid UPDATE statements if row is unchanged
if new.type and new.type.id and existing.type_id != new.type.id:
existing.type_id = new.type.id
await self._session.flush([existing])
await self._session.execute(
update_job_type_query,
{
"job_id": existing.id,
"type_id": new.type.id,
},
)

if not new.tag_values:
# in cases when jobs have no tag values we can avoid INSERT statements
return existing

# Lock to prevent inserting the same rows from multiple workers
await self._lock(existing.location_id, existing.name)
await self._session.execute(
insert_tag_value_query,
[
{
"job_id": existing.id,
"tag_value_id": tag_value_dto.id,
}
for tag_value_dto in new.tag_values
],
)
return existing

async def list_by_ids(self, job_ids: Collection[int]) -> list[Job]:
Expand Down
39 changes: 24 additions & 15 deletions data_rentgen/db/repositories/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
select,
union,
)
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload

from data_rentgen.db.models import Address, Location
Expand Down Expand Up @@ -44,6 +45,17 @@
)
get_distinct_query = select(Location.type).distinct(Location.type).order_by(Location.type)

insert_address_query = (
insert(Address)
.values(
{
"location_id": bindparam("location_id"),
"url": bindparam("url"),
}
)
.on_conflict_do_nothing(index_elements=["location_id", "url"])
)


class LocationRepository(Repository[Location]):
async def paginate(
Expand Down Expand Up @@ -141,25 +153,22 @@ async def _create(self, location: LocationDTO) -> Location:
async def _update_addresses(self, existing: Location, new: LocationDTO) -> Location:
existing_urls = {address.url for address in existing.addresses}
new_urls = new.addresses - existing_urls
# in most cases, Location is unchanged, so we can avoid UPDATE statements
# in most cases, Location is unchanged, so we can avoid INSERT statements
if not new_urls:
return existing

# take a lock, to avoid race conditions, and then
# get fresh state of the object, because it already could be updated by another worker
# take a lock to avoid creating the same address from multiple workers
await self._lock(existing.type, existing.name)
await self._session.refresh(existing, ["addresses"])

# already has all required addresses - nothing to update
existing_urls = {address.url for address in existing.addresses}
new_urls = new.addresses - existing_urls
if not new_urls:
return existing

# add new addresses while holding the lock
addresses = [Address(url=url, location_id=existing.id) for url in new_urls]
existing.addresses.extend(addresses)
await self._session.flush([existing])
await self._session.execute(
insert_address_query,
[
{
"location_id": existing.id,
"url": url,
}
for url in new_urls
],
)
return existing

async def get_location_types(self):
Expand Down
Loading