Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import itertools
import logging
from collections import OrderedDict, defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -479,33 +480,57 @@ def online_read(
online_config.endpoint_url,
online_config.session_based_auth,
)
table_instance = dynamodb_resource.Table(
_get_table_name(online_config, config, table)
)
table_name = _get_table_name(online_config, config, table)
table_instance = dynamodb_resource.Table(table_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

table_instance is only used in the single-batch path, and seems unnecessary since table_name can be used instead of table_instance.name


batch_size = online_config.batch_size
entity_ids = self._to_entity_ids(config, entity_keys)
entity_ids_iter = iter(entity_ids)
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []

# Split entity_ids into batches upfront
batches: List[List[str]] = []
entity_ids_iter = iter(entity_ids)
while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))

# No more items to insert
if len(batch) == 0:
if not batch:
break
batches.append(batch)

if not batches:
return []

# For single batch, no parallelization overhead needed
if len(batches) == 1:
batch_entity_ids = self._to_resource_batch_get_payload(
online_config, table_instance.name, batch
online_config, table_instance.name, batches[0]
)
response = dynamodb_resource.batch_get_item(
RequestItems=batch_entity_ids,
response = dynamodb_resource.batch_get_item(RequestItems=batch_entity_ids)
return self._process_batch_get_response(table_name, response, batches[0])

# Execute batch requests in parallel for multiple batches
# Note: boto3 resources are NOT thread-safe, so we create a new resource per thread
def fetch_batch(batch: List[str]) -> Dict[str, Any]:
thread_resource = _initialize_dynamodb_resource(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think creating here we are creating new session on each thread. Instead we can share a single _dynamodb_client across all threads.

def fetch_batch(batch: List[str]) -> Dict[str, Any]:
    batch_entity_ids = self._to_client_batch_get_payload(
        online_config, table_name, batch
    )
    return dynamodb_client.batch_get_item(RequestItems=batch_entity_ids)

https://docs.aws.amazon.com/boto3/latest/guide/clients.html#multithreading-or-multiprocessing-with-clients

online_config.region,
online_config.endpoint_url,
online_config.session_based_auth,
)
batch_result = self._process_batch_get_response(
table_instance.name,
response,
batch,
batch_entity_ids = self._to_resource_batch_get_payload(
online_config, table_name, batch
)
return thread_resource.batch_get_item(RequestItems=batch_entity_ids)

# Use ThreadPoolExecutor for parallel I/O
# Cap at 10 workers to avoid excessive thread creation
max_workers = min(len(batches), 10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be configurable (max_read_workers) in online store configs

with ThreadPoolExecutor(max_workers=max_workers) as executor:
responses = list(executor.map(fetch_batch, batches))

# Process responses and merge results in order
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
for batch, response in zip(batches, responses):
batch_result = self._process_batch_get_response(table_name, response, batch)
result.extend(batch_result)

return result

async def online_read_async(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,3 +780,280 @@ def test_dynamodb_update_online_store_int_list(repo_config, dynamodb_online_stor
assert len(result) == 1
scores = result[0][1]["scores"]
assert _extract_int32_list(scores) == [10, 20, 30]


@mock_dynamodb
def test_dynamodb_online_store_online_read_empty_entities(
repo_config, dynamodb_online_store
):
"""Test DynamoDBOnlineStore online_read with empty entity list."""
db_table_name = f"{TABLE_NAME}_empty_entities"
create_test_table(PROJECT, db_table_name, REGION)

returned_items = dynamodb_online_store.online_read(
config=repo_config,
table=MockFeatureView(name=db_table_name),
entity_keys=[],
)
assert returned_items == []


@mock_dynamodb
def test_dynamodb_online_store_online_read_parallel_batches(
repo_config, dynamodb_online_store
):
"""Test DynamoDBOnlineStore online_read with multiple batches (parallel execution).

With batch_size=100 (default), 250 entities should create 3 batches
that are executed in parallel via ThreadPoolExecutor.
"""
n_samples = 250
db_table_name = f"{TABLE_NAME}_parallel_batches"
create_test_table(PROJECT, db_table_name, REGION)
data = create_n_customer_test_samples(n=n_samples)
insert_data_test_table(data, PROJECT, db_table_name, REGION)

entity_keys, features, *rest = zip(*data)
returned_items = dynamodb_online_store.online_read(
config=repo_config,
table=MockFeatureView(name=db_table_name),
entity_keys=entity_keys,
)

# Verify all items returned
assert len(returned_items) == n_samples
# Verify order is preserved
assert [item[1] for item in returned_items] == list(features)


@mock_dynamodb
def test_dynamodb_online_store_online_read_single_batch_no_parallel(
repo_config, dynamodb_online_store
):
"""Test DynamoDBOnlineStore online_read with single batch (no parallelization).

With batch_size=100, 50 entities should use single batch path
without ThreadPoolExecutor overhead.
"""
n_samples = 50
db_table_name = f"{TABLE_NAME}_single_batch"
create_test_table(PROJECT, db_table_name, REGION)
data = create_n_customer_test_samples(n=n_samples)
insert_data_test_table(data, PROJECT, db_table_name, REGION)

entity_keys, features, *rest = zip(*data)
returned_items = dynamodb_online_store.online_read(
config=repo_config,
table=MockFeatureView(name=db_table_name),
entity_keys=entity_keys,
)

assert len(returned_items) == n_samples
assert [item[1] for item in returned_items] == list(features)


@mock_dynamodb
def test_dynamodb_online_store_online_read_order_preservation_across_batches(
repo_config, dynamodb_online_store
):
"""Test that entity order is preserved across parallel batch reads.

This is critical: parallel execution must not change the order of results.
"""
n_samples = 150 # 2 batches with batch_size=100
db_table_name = f"{TABLE_NAME}_order_preservation"
create_test_table(PROJECT, db_table_name, REGION)
data = create_n_customer_test_samples(n=n_samples)
insert_data_test_table(data, PROJECT, db_table_name, REGION)

entity_keys, features, *rest = zip(*data)

# Read multiple times to verify consistent ordering
for _ in range(3):
returned_items = dynamodb_online_store.online_read(
config=repo_config,
table=MockFeatureView(name=db_table_name),
entity_keys=entity_keys,
)
assert len(returned_items) == n_samples
# Verify exact order matches
for i, (returned, expected) in enumerate(zip(returned_items, features)):
assert returned[1] == expected, f"Mismatch at index {i}"


@mock_dynamodb
def test_dynamodb_online_store_online_read_small_batch_size(dynamodb_online_store):
"""Test parallel reads with small batch_size.

Verifies correctness with small batch sizes that create multiple batches.
"""
small_batch_config = RepoConfig(
registry=REGISTRY,
project=PROJECT,
provider=PROVIDER,
online_store=DynamoDBOnlineStoreConfig(region=REGION, batch_size=5),
offline_store=DaskOfflineStoreConfig(),
entity_key_serialization_version=3,
)

n_samples = 25 # 5 batches with batch_size=5
db_table_name = f"{TABLE_NAME}_small_batch"
create_test_table(PROJECT, db_table_name, REGION)
data = create_n_customer_test_samples(n=n_samples)
insert_data_test_table(data, PROJECT, db_table_name, REGION)

entity_keys, features, *rest = zip(*data)
returned_items = dynamodb_online_store.online_read(
config=small_batch_config,
table=MockFeatureView(name=db_table_name),
entity_keys=entity_keys,
)

assert len(returned_items) == n_samples
assert [item[1] for item in returned_items] == list(features)


@mock_dynamodb
def test_dynamodb_online_store_online_read_many_batches(dynamodb_online_store):
"""Test parallel reads with many batches (>10).

Verifies correctness when number of batches exceeds max_workers cap.
"""
many_batch_config = RepoConfig(
registry=REGISTRY,
project=PROJECT,
provider=PROVIDER,
online_store=DynamoDBOnlineStoreConfig(region=REGION, batch_size=10),
offline_store=DaskOfflineStoreConfig(),
entity_key_serialization_version=3,
)

n_samples = 150 # 15 batches with batch_size=10
db_table_name = f"{TABLE_NAME}_many_batches"
create_test_table(PROJECT, db_table_name, REGION)
data = create_n_customer_test_samples(n=n_samples)
insert_data_test_table(data, PROJECT, db_table_name, REGION)

entity_keys, features, *rest = zip(*data)
returned_items = dynamodb_online_store.online_read(
config=many_batch_config,
table=MockFeatureView(name=db_table_name),
entity_keys=entity_keys,
)

assert len(returned_items) == n_samples
assert [item[1] for item in returned_items] == list(features)


@mock_dynamodb
def test_dynamodb_online_store_max_workers_capped_at_10(dynamodb_online_store):
"""Verify ThreadPoolExecutor max_workers is capped at 10, not batch_size.

Bug: Old code used min(len(batches), batch_size) which fails with small batch_size.
Fix: New code uses min(len(batches), 10) to ensure proper parallelization.

This test uses batch_size=5 with 15 batches to expose the bug:
- OLD (buggy): max_workers = min(15, 5) = 5 (insufficient parallelism)
- NEW (fixed): max_workers = min(15, 10) = 10 (correct cap)
"""
# Use small batch_size to expose the bug
small_batch_config = RepoConfig(
registry=REGISTRY,
project=PROJECT,
provider=PROVIDER,
online_store=DynamoDBOnlineStoreConfig(region=REGION, batch_size=5),
offline_store=DaskOfflineStoreConfig(),
entity_key_serialization_version=3,
)

n_samples = 75 # 15 batches with batch_size=5
db_table_name = f"{TABLE_NAME}_max_workers_cap"
create_test_table(PROJECT, db_table_name, REGION)
data = create_n_customer_test_samples(n=n_samples)
insert_data_test_table(data, PROJECT, db_table_name, REGION)

entity_keys, features, *rest = zip(*data)

with patch(
"feast.infra.online_stores.dynamodb.ThreadPoolExecutor"
) as mock_executor:
# Configure mock to work like real ThreadPoolExecutor
mock_executor.return_value.__enter__.return_value.map.return_value = iter(
[{"Responses": {}} for _ in range(15)]
)

dynamodb_online_store.online_read(
config=small_batch_config,
table=MockFeatureView(name=db_table_name),
entity_keys=entity_keys,
)

# Verify ThreadPoolExecutor was called with max_workers=10 (capped at 10, NOT batch_size=5)
mock_executor.assert_called_once()
call_kwargs = mock_executor.call_args
assert call_kwargs[1]["max_workers"] == 10, (
f"Expected max_workers=10 (capped), got {call_kwargs[1]['max_workers']}. "
f"If got 5, the bug is using batch_size instead of 10 as cap."
)


@mock_dynamodb
def test_dynamodb_online_store_thread_safety_new_resource_per_thread(
dynamodb_online_store,
):
"""Verify each thread creates its own boto3 resource for thread-safety.

boto3 resources are NOT thread-safe, so we must create a new resource
per thread when using ThreadPoolExecutor.
"""
config = RepoConfig(
registry=REGISTRY,
project=PROJECT,
provider=PROVIDER,
online_store=DynamoDBOnlineStoreConfig(region=REGION, batch_size=50),
offline_store=DaskOfflineStoreConfig(),
entity_key_serialization_version=3,
)

n_samples = 150 # 3 batches
db_table_name = f"{TABLE_NAME}_thread_safety"
create_test_table(PROJECT, db_table_name, REGION)
data = create_n_customer_test_samples(n=n_samples)
insert_data_test_table(data, PROJECT, db_table_name, REGION)

entity_keys, features, *rest = zip(*data)

# Track resources created to verify thread-safety
resources_created = []
original_init = boto3.resource

def tracking_resource(*args, **kwargs):
resource = original_init(*args, **kwargs)
resources_created.append(id(resource))
return resource

with patch.object(boto3, "resource", side_effect=tracking_resource):
returned_items = dynamodb_online_store.online_read(
config=config,
table=MockFeatureView(name=db_table_name),
entity_keys=entity_keys,
)

# Verify results are correct (functional correctness)
assert len(returned_items) == n_samples

# Verify multiple resources were created (thread-safety)
# Each of the 3 batches should create its own resource
# (plus potentially 1 for _get_dynamodb_resource cache initialization)
assert len(resources_created) >= 3, (
f"Expected at least 3 unique resources for 3 batches, "
f"got {len(resources_created)}"
)

# Verify the resources are actually different objects (not reused)
# At least the batch resources should be unique
unique_resources = set(resources_created)
assert len(unique_resources) >= 3, (
f"Expected at least 3 unique resource IDs, "
f"got {len(unique_resources)} unique out of {len(resources_created)}"
)
Loading