Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .changeset/fix-rate-limiter-abort-signal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
"@smithy/util-retry": minor
"@smithy/middleware-retry": minor
"@smithy/smithy-client": minor
---

feat(util-retry): support AbortSignal in DefaultRateLimiter.getSendToken

Thread AbortSignal from command options through the middleware stack so that
retry delays (both V1 StandardRetryStrategy/AdaptiveRetryStrategy and V2
RetryStrategyV2) can be cancelled early. This enables graceful abort of
retry waits in environments like AWS Lambda.
25 changes: 25 additions & 0 deletions packages/middleware-retry/src/AdaptiveRetryStrategy.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,30 @@ describe(AdaptiveRetryStrategy.name, () => {
await mockedSuperRetry.mock.calls[0][2]!.afterRequest();
expect(mockDefaultRateLimiter.updateClientSendingRate).toHaveBeenCalledTimes(1);
});

it("passes abortSignal to rateLimiter.getSendToken", async () => {
vi.clearAllMocks();
const next = vi.fn();
const retryStrategy = new AdaptiveRetryStrategy(maxAttemptsProvider);
const abortController = new AbortController();
await retryStrategy.retry(next, { request: { headers: {} } } as any, {
abortSignal: abortController.signal,
});
expect(mockedSuperRetry).toHaveBeenCalledTimes(1);
await mockedSuperRetry.mock.calls[0][2]!.beforeRequest();
expect(mockDefaultRateLimiter.getSendToken).toHaveBeenCalledWith(abortController.signal);
});

it("passes abortSignal through to super.retry options", async () => {
vi.clearAllMocks();
const next = vi.fn();
const retryStrategy = new AdaptiveRetryStrategy(maxAttemptsProvider);
const abortController = new AbortController();
await retryStrategy.retry(next, { request: { headers: {} } } as any, {
abortSignal: abortController.signal,
});
expect(mockedSuperRetry).toHaveBeenCalledTimes(1);
expect(mockedSuperRetry.mock.calls[0][2]!.abortSignal).toBe(abortController.signal);
});
});
});
8 changes: 6 additions & 2 deletions packages/middleware-retry/src/AdaptiveRetryStrategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,19 @@ export class AdaptiveRetryStrategy extends StandardRetryStrategy {

async retry<Input extends object, Ouput extends MetadataBearer>(
next: FinalizeHandler<Input, Ouput>,
args: FinalizeHandlerArguments<Input>
args: FinalizeHandlerArguments<Input>,
options?: {
abortSignal?: AbortSignal;
}
) {
return super.retry(next, args, {
beforeRequest: async () => {
return this.rateLimiter.getSendToken();
return this.rateLimiter.getSendToken(options?.abortSignal);
},
afterRequest: (response: any) => {
this.rateLimiter.updateClientSendingRate(response);
},
abortSignal: options?.abortSignal,
});
}
}
27 changes: 26 additions & 1 deletion packages/middleware-retry/src/StandardRetryStrategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ export class StandardRetryStrategy implements RetryStrategy {
options?: {
beforeRequest: Function;
afterRequest: Function;
abortSignal?: AbortSignal;
}
) {
let retryTokenAmount;
Expand Down Expand Up @@ -116,7 +117,7 @@ export class StandardRetryStrategy implements RetryStrategy {

totalDelay += delay;

await new Promise((resolve) => setTimeout(resolve, delay));
await abortableDelay(delay, options?.abortSignal);
continue;
}

Expand Down Expand Up @@ -148,3 +149,27 @@ const getDelayFromRetryAfterHeader = (response: unknown): number | undefined =>
const retryAfterDate = new Date(retryAfter);
return retryAfterDate.getTime() - Date.now();
};

/**
* Returns a promise that resolves after the given delay, but rejects
* immediately if the optional AbortSignal is triggered.
*/
const abortableDelay = (delay: number, abortSignal?: AbortSignal): Promise<void> => {
if (!abortSignal) {
return new Promise((resolve) => setTimeout(resolve, delay));
}
if (abortSignal.aborted) {
return Promise.reject(abortSignal.reason ?? new Error("Request aborted"));
}
return new Promise<void>((resolve, reject) => {
const onAbort = () => {
clearTimeout(timer);
reject(abortSignal.reason ?? new Error("Request aborted"));
};
const timer = setTimeout(() => {
abortSignal.removeEventListener("abort", onAbort);
resolve();
}, delay);
abortSignal.addEventListener("abort", onAbort, { once: true });
});
};
78 changes: 77 additions & 1 deletion packages/middleware-retry/src/retryMiddleware.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,29 @@ describe(retryMiddleware.name, () => {
context
)(args as FinalizeHandlerArguments<any>);
expect(mockRetryStrategy.retry).toHaveBeenCalledTimes(1);
expect(mockRetryStrategy.retry).toHaveBeenCalledWith(next, args);
expect(mockRetryStrategy.retry).toHaveBeenCalledWith(next, args, { abortSignal: undefined });
expect(context.userAgent).toContainEqual(["cfg/retry-mode", mockRetryStrategy.mode]);
});

it("passes context.abortSignal to retryStrategy.retry", async () => {
const next = vi.fn();
const args = {
request: { headers: {} },
};
const abortController = new AbortController();
const context: HandlerExecutionContext = {
abortSignal: abortController.signal,
};

await retryMiddleware({
maxAttempts: () => Promise.resolve(maxAttempts),
retryStrategy: vi.fn().mockResolvedValue({ ...mockRetryStrategy, maxAttempts }),
})(
next,
context
)(args as FinalizeHandlerArguments<any>);
expect(mockRetryStrategy.retry).toHaveBeenCalledWith(next, args, { abortSignal: abortController.signal });
});
});

describe("RetryStrategyV2", () => {
Expand Down Expand Up @@ -355,6 +375,62 @@ describe(retryMiddleware.name, () => {
});
});

describe("abortSignal support", () => {
it("rejects retry delay when abortSignal is triggered", async () => {
vi.mocked(isThrottlingError).mockReturnValue(true);
const mockError = Object.assign(new Error("mockError"), {
$response: { headers: {} },
$metadata: {},
});
const next = vi.fn().mockRejectedValue(mockError);
const abortController = new AbortController();
const abortReason = new Error("Lambda timeout approaching");
const contextWithAbort: HandlerExecutionContext = {
partition_id: partitionId,
abortSignal: abortController.signal,
};

const promise = retryMiddleware({
maxAttempts: () => Promise.resolve(maxAttempts),
retryStrategy: vi.fn().mockResolvedValue({ ...mockRetryStrategy, maxAttempts }),
})(
next,
contextWithAbort
)(args as FinalizeHandlerArguments<any>);

// Abort after the first failure triggers a retry delay
abortController.abort(abortReason);

await expect(promise).rejects.toBe(abortReason);
});

it("rejects immediately when abortSignal is already aborted", async () => {
vi.mocked(isThrottlingError).mockReturnValue(true);
const mockError = Object.assign(new Error("mockError"), {
$response: { headers: {} },
$metadata: {},
});
const next = vi.fn().mockRejectedValue(mockError);
const abortController = new AbortController();
const abortReason = new Error("Already aborted");
abortController.abort(abortReason);
const contextWithAbort: HandlerExecutionContext = {
partition_id: partitionId,
abortSignal: abortController.signal,
};

const promise = retryMiddleware({
maxAttempts: () => Promise.resolve(maxAttempts),
retryStrategy: vi.fn().mockResolvedValue({ ...mockRetryStrategy, maxAttempts }),
})(
next,
contextWithAbort
)(args as FinalizeHandlerArguments<any>);

await expect(promise).rejects.toBe(abortReason);
});
});

describe("retry headers", () => {
describe("not added if HttpRequest.isInstance returns false", () => {
it(`retry informational header: ${INVOCATION_ID_HEADER}`, async () => {
Expand Down
28 changes: 26 additions & 2 deletions packages/middleware-retry/src/retryMiddleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ export const retryMiddleware =
attempts = retryToken.getRetryCount();
const delay = retryToken.getRetryDelay();
totalRetryDelay += delay;
await new Promise((resolve) => setTimeout(resolve, delay));
await abortableDelay(delay, context.abortSignal as AbortSignal | undefined);
}
}
} else {
retryStrategy = retryStrategy as RetryStrategy;
if (retryStrategy?.mode)
context.userAgent = [...(context.userAgent || []), ["cfg/retry-mode", retryStrategy.mode]];

return retryStrategy.retry(next, args);
return retryStrategy.retry(next, args, { abortSignal: context.abortSignal as AbortSignal | undefined });
}
};

Expand All @@ -100,6 +100,30 @@ const isRetryStrategyV2 = (retryStrategy: RetryStrategy | RetryStrategyV2) =>
typeof (retryStrategy as RetryStrategyV2).refreshRetryTokenForRetry !== "undefined" &&
typeof (retryStrategy as RetryStrategyV2).recordSuccess !== "undefined";

/**
* Returns a promise that resolves after the given delay, but rejects
* immediately if the optional AbortSignal is triggered.
*/
const abortableDelay = (delay: number, abortSignal?: AbortSignal): Promise<void> => {
if (!abortSignal) {
return new Promise((resolve) => setTimeout(resolve, delay));
}
if (abortSignal.aborted) {
return Promise.reject(abortSignal.reason ?? new Error("Request aborted"));
}
return new Promise<void>((resolve, reject) => {
const onAbort = () => {
clearTimeout(timer);
reject(abortSignal.reason ?? new Error("Request aborted"));
};
const timer = setTimeout(() => {
abortSignal.removeEventListener("abort", onAbort);
resolve();
}, delay);
abortSignal.addEventListener("abort", onAbort, { once: true });
});
};

const getRetryErrorInfo = (error: SdkError): RetryErrorInfo => {
const errorInfo: RetryErrorInfo = {
error,
Expand Down
4 changes: 3 additions & 1 deletion packages/middleware-retry/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ export interface RateLimiter {
* If there is not sufficient capacity, it will either sleep a certain amount
* of time until the rate limiter can retrieve a token from its token bucket
* or raise an exception indicating there is insufficient capacity.
*
* @param abortSignal - optional signal to abort the token wait early.
*/
getSendToken: () => Promise<void>;
getSendToken: (abortSignal?: AbortSignal) => Promise<void>;

/**
* Updates the client sending rate based on response.
Expand Down
3 changes: 3 additions & 0 deletions packages/smithy-client/src/command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ export abstract class Command<
},
...additionalContext,
};
if (options?.abortSignal) {
handlerExecutionContext.abortSignal = options.abortSignal;
}
const { requestHandler } = configuration;
return stack.resolve(
(request: FinalizeHandlerArguments<any>) => requestHandler.handle(request.request as HttpRequest, options || {}),
Expand Down
69 changes: 69 additions & 0 deletions packages/util-retry/src/DefaultRateLimiter.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,75 @@ describe(DefaultRateLimiter.name, () => {
vi.runAllTimers();
expect(spy).toHaveBeenLastCalledWith(expect.any(Function), delay);
});

it("rejects when abortSignal is already aborted", async () => {
vi.spyOn(Date, "now").mockImplementation(() => 0);
const rateLimiter = new DefaultRateLimiter();

vi.mocked(isThrottlingError).mockReturnValueOnce(true);
vi.spyOn(Date, "now").mockImplementation(() => 500);
rateLimiter.updateClientSendingRate({});

const abortController = new AbortController();
const reason = new Error("Lambda timeout approaching");
abortController.abort(reason);

await expect(rateLimiter.getSendToken(abortController.signal)).rejects.toBe(reason);
});

it("rejects when abortSignal fires during wait", async () => {
vi.spyOn(Date, "now").mockImplementation(() => 0);
const rateLimiter = new DefaultRateLimiter();

vi.mocked(isThrottlingError).mockReturnValueOnce(true);
vi.spyOn(Date, "now").mockImplementation(() => 500);
rateLimiter.updateClientSendingRate({});

const abortController = new AbortController();
const reason = new Error("Lambda timeout approaching");

const promise = rateLimiter.getSendToken(abortController.signal);
abortController.abort(reason);

await expect(promise).rejects.toBe(reason);
});

it("resolves normally when abortSignal is not aborted", async () => {
vi.spyOn(Date, "now").mockImplementation(() => 0);
const rateLimiter = new DefaultRateLimiter();

// Use a spy to immediately resolve the setTimeout callback
vi.spyOn(DefaultRateLimiter as any, "setTimeoutFn").mockImplementation((cb: () => void) => {
cb();
return 0;
});

vi.mocked(isThrottlingError).mockReturnValueOnce(true);
vi.spyOn(Date, "now").mockImplementation(() => 500);
rateLimiter.updateClientSendingRate({});

const abortController = new AbortController();
await expect(rateLimiter.getSendToken(abortController.signal)).resolves.toBeUndefined();
});

it("removes abort listener after successful delay", async () => {
vi.spyOn(Date, "now").mockImplementation(() => 0);
const rateLimiter = new DefaultRateLimiter();

vi.spyOn(DefaultRateLimiter as any, "setTimeoutFn").mockImplementation((cb: () => void) => {
cb();
return 0;
});

vi.mocked(isThrottlingError).mockReturnValueOnce(true);
vi.spyOn(Date, "now").mockImplementation(() => 500);
rateLimiter.updateClientSendingRate({});

const abortController = new AbortController();
const removeEventListenerSpy = vi.spyOn(abortController.signal, "removeEventListener");
await rateLimiter.getSendToken(abortController.signal);
expect(removeEventListenerSpy).toHaveBeenCalledWith("abort", expect.any(Function));
});
});

describe("cubicSuccess", () => {
Expand Down
28 changes: 24 additions & 4 deletions packages/util-retry/src/DefaultRateLimiter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ export class DefaultRateLimiter implements RateLimiter {
return Date.now() / 1000;
}

public async getSendToken() {
return this.acquireTokenBucket(1);
public async getSendToken(abortSignal?: AbortSignal) {
return this.acquireTokenBucket(1, abortSignal);
}

private async acquireTokenBucket(amount: number) {
private async acquireTokenBucket(amount: number, abortSignal?: AbortSignal) {
// Client side throttling is not enabled until we see a throttling error.
if (!this.enabled) {
return;
Expand All @@ -76,7 +76,27 @@ export class DefaultRateLimiter implements RateLimiter {
this.refillTokenBucket();
if (amount > this.currentCapacity) {
const delay = ((amount - this.currentCapacity) / this.fillRate) * 1000;
await new Promise((resolve) => DefaultRateLimiter.setTimeoutFn(resolve, delay));
await new Promise<void>((resolve, reject) => {
const onAbort = () => {
clearTimeout(timer);
reject(abortSignal?.reason ?? new Error("Request aborted"));
};
const timer = DefaultRateLimiter.setTimeoutFn(() => {
if (abortSignal) {
abortSignal.removeEventListener("abort", onAbort);
}
resolve();
}, delay);

if (abortSignal) {
if (abortSignal.aborted) {
clearTimeout(timer);
reject(abortSignal.reason ?? new Error("Request aborted"));
return;
}
abortSignal.addEventListener("abort", onAbort, { once: true });
}
});
}
this.currentCapacity = this.currentCapacity - amount;
}
Expand Down
Loading