Skip to content

Commit 0c77778

Browse files
refactor: make progress token dependent on request ID
1 parent 536c6c0 commit 0c77778

File tree

2 files changed

+17
-27
lines changed

2 files changed

+17
-27
lines changed

src/Client/Client.php

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class Client
5454
private Protocol $protocol;
5555
private ClientSessionInterface $session;
5656
private ?ClientTransportInterface $transport = null;
57-
private int $progressTokenCounter = 0;
5857

5958
/**
6059
* @param NotificationHandlerInterface[] $notificationHandlers
@@ -143,11 +142,8 @@ public function listTools(?string $cursor = null): ListToolsResult
143142
* Optional callback for progress updates. If provided, a progress token
144143
* is automatically generated and attached to the request.
145144
*/
146-
public function callTool(
147-
string $name,
148-
array $arguments = [],
149-
?callable $onProgress = null,
150-
): CallToolResult {
145+
public function callTool(string $name, array $arguments = [], ?callable $onProgress = null): CallToolResult
146+
{
151147
$this->ensureConnected();
152148

153149
$request = new CallToolRequest($name, $arguments);
@@ -209,11 +205,8 @@ public function listPrompts(?string $cursor = null): ListPromptsResult
209205
* @param (callable(float $progress, ?float $total, ?string $message): void)|null $onProgress
210206
* Optional callback for progress updates.
211207
*/
212-
public function getPrompt(
213-
string $name,
214-
array $arguments = [],
215-
?callable $onProgress = null,
216-
): GetPromptResult {
208+
public function getPrompt(string $name, array $arguments = [], ?callable $onProgress = null): GetPromptResult
209+
{
217210
$this->ensureConnected();
218211

219212
$request = new GetPromptRequest($name, $arguments);
@@ -261,10 +254,13 @@ public function disconnect(): void
261254
*
262255
* @throws RequestException
263256
*/
264-
private function doRequest(object $request, ?string $resultClass = null, ?callable $onProgress = null): mixed
257+
private function doRequest(Request $request, ?string $resultClass = null, ?callable $onProgress = null): mixed
265258
{
266-
if (null !== $onProgress && $request instanceof Request) {
267-
$progressToken = $this->generateProgressToken();
259+
$requestId = $this->session->nextRequestId();
260+
$request = $request->withId($requestId);
261+
262+
if (null !== $onProgress) {
263+
$progressToken = 'prog-' . $requestId;
268264
$request = $request->withMeta(['progressToken' => $progressToken]);
269265
}
270266

@@ -283,14 +279,6 @@ private function doRequest(object $request, ?string $resultClass = null, ?callab
283279
return $resultClass::fromArray($response->result);
284280
}
285281

286-
/**
287-
* Generate a unique progress token for a request.
288-
*/
289-
private function generateProgressToken(): string
290-
{
291-
return 'prog-' . (++$this->progressTokenCounter);
292-
}
293-
294282
private function ensureConnected(): void
295283
{
296284
if (!$this->isConnected()) {

src/Client/Protocol.php

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public function connect(ClientTransportInterface $transport): void
6666
{
6767
$this->transport = $transport;
6868
$transport->setSession($this->session);
69-
$transport->onInitialize(fn() => $this->performInitialize());
69+
$transport->onInitialize($this->initialize(...));
7070
$transport->onMessage($this->processMessage(...));
7171
$transport->onError(fn(\Throwable $e) => $this->logger->error('Transport error', ['exception' => $e]));
7272

@@ -80,14 +80,17 @@ public function connect(ClientTransportInterface $transport): void
8080
*
8181
* @return Response<array<string, mixed>>|Error
8282
*/
83-
public function performInitialize(): Response|Error
83+
public function initialize(): Response|Error
8484
{
8585
$request = new InitializeRequest(
8686
$this->config->protocolVersion,
8787
$this->config->capabilities,
8888
$this->config->clientInfo,
8989
);
9090

91+
$requestId = $this->session->nextRequestId();
92+
$request = $request->withId($requestId);
93+
9194
$response = $this->request($request, $this->config->initTimeout);
9295

9396
if ($response instanceof Response) {
@@ -114,15 +117,14 @@ public function performInitialize(): Response|Error
114117
*/
115118
public function request(Request $request, int $timeout): Response|Error
116119
{
117-
$requestId = $this->session->nextRequestId();
118-
$requestWithId = $request->withId($requestId);
120+
$requestId = $request->getId();
119121

120122
$this->logger->debug('Sending request', [
121123
'id' => $requestId,
122124
'method' => $request::getMethod(),
123125
]);
124126

125-
$encoded = json_encode($requestWithId, \JSON_THROW_ON_ERROR);
127+
$encoded = json_encode($request, \JSON_THROW_ON_ERROR);
126128
$this->session->queueOutgoing($encoded, ['type' => 'request']);
127129
$this->session->addPendingRequest($requestId, $timeout);
128130

0 commit comments

Comments
 (0)