Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions pubsub/gcloud/aio/pubsub/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import time
from collections.abc import Awaitable
from collections.abc import Callable
from typing import Optional
from typing import TYPE_CHECKING
from typing import Optional
from typing import TypeVar

from . import metrics
Expand Down Expand Up @@ -286,11 +286,26 @@ async def maybe_nack(ack_id: str) -> None:

ack_ids = []

async def ack_or_nack(
message: SubscriberMessage,
ack_queue: 'asyncio.Queue[str]',
nack_queue: Optional['asyncio.Queue[str]'],
ack: bool = False,
) -> None:
if message.force_ack_nack is None:
# if we've not forced the ack status, set it here
message.force_ack_nack = ack

if message.force_ack_nack:
await ack_queue.put(message.ack_id)
elif nack_queue:
await nack_queue.put(message.ack_id)

async def _execute_callback(
message: SubscriberMessage,
callback: ApplicationHandler,
ack_queue: 'asyncio.Queue[str]',
nack_queue: 'Optional[asyncio.Queue[str]]',
nack_queue: Optional['asyncio.Queue[str]'],
insertion_time: float,
) -> None:
try:
Expand All @@ -300,18 +315,15 @@ async def _execute_callback(
)
with metrics.CONSUME_LATENCY.labels(phase='runtime').time():
await callback(message)
await ack_queue.put(message.ack_id)
await ack_or_nack(message, ack_queue, nack_queue, ack=True)
metrics.CONSUME.labels(outcome='succeeded').inc()

except asyncio.CancelledError:
if nack_queue:
await nack_queue.put(message.ack_id)
await ack_or_nack(message, ack_queue, nack_queue, ack=False)

log.warning('application callback was cancelled')
metrics.CONSUME.labels(outcome='cancelled').inc()
except Exception as e:
if nack_queue:
await nack_queue.put(message.ack_id)
await ack_or_nack(message, ack_queue, nack_queue, ack=False)

log.warning(
'application callback raised an exception',
Expand All @@ -326,7 +338,7 @@ async def consumer( # pylint: disable=too-many-locals
ack_queue: 'asyncio.Queue[str]',
ack_deadline_cache: AckDeadlineCache,
max_tasks: int,
nack_queue: 'Optional[asyncio.Queue[str]]',
nack_queue: Optional['asyncio.Queue[str]'],
) -> None:
try:
semaphore = asyncio.Semaphore(max_tasks)
Expand Down Expand Up @@ -450,7 +462,7 @@ async def subscribe(
ack_queue: 'asyncio.Queue[str]' = asyncio.Queue(
maxsize=(max_messages_per_producer * num_producers),
)
nack_queue: 'Optional[asyncio.Queue[str]]' = None
nack_queue: Optional['asyncio.Queue[str]'] = None
ack_deadline_cache = AckDeadlineCache(
subscriber_client,
subscription,
Expand Down
23 changes: 23 additions & 0 deletions pubsub/gcloud/aio/pubsub/subscriber_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def __init__(
self.attributes = attributes
self.delivery_attempt = delivery_attempt

self.force_ack_nack: bool | None = None

@staticmethod
def from_repr(
received_message: dict[str, Any],
Expand Down Expand Up @@ -66,3 +68,24 @@ def to_repr(self) -> dict[str, Any]:
if self.delivery_attempt is not None:
r['deliveryAttempt'] = self.delivery_attempt
return r

def ack(self) -> None:
"""
Forcibly mark a message as acked.

By default, we only ack a message if the callback returns without
raising an exception. If this method has been called on the Message, we
will instead ack it regardless of exception status.
"""
self.force_ack_nack = True

def nack(self) -> None:
"""
Forcibly mark a message as nacked.

By default, we only nack a message if the callback raises an exception.
If this method has been called on the Message, we will instead nack it
regardless of exception status, ie. including if it completes
successfully.
"""
self.force_ack_nack = False
1 change: 1 addition & 0 deletions pubsub/tests/unit/subscriber_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def make_message_mock():
mock = MagicMock()
mock.ack_id = 'ack_id'
mock.publish_time.timestamp = MagicMock(return_value=time.time())
mock.force_ack_nack = None
return mock

@pytest.fixture(scope='function')
Expand Down