Skip to content

Commit f082621

Browse files
committed
fix(mcptoolset): validate cached MCP session with Ping before reuse
Add session validation using Ping health check before reusing cached MCP sessions. If the session is stale or closed, it will be automatically recreated to prevent 'connection closed' errors. Also adds TestStaleSessionIsRecreated test to verify the fix behavior.
1 parent 415e398 commit f082621

File tree

2 files changed

+147
-1
lines changed

2 files changed

+147
-1
lines changed

tool/mcptoolset/set.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,15 @@ func (s *set) getSession(ctx context.Context) (*mcp.ClientSession, error) {
139139
defer s.mu.Unlock()
140140

141141
if s.session != nil {
142-
return s.session, nil
142+
// Validate cached session with a Ping health check.
143+
// If the connection was closed, clear the stale session and reconnect.
144+
if err := s.session.Ping(ctx, nil); err != nil {
145+
// Session is stale, close and clear it.
146+
_ = s.session.Close() // Best effort close
147+
s.session = nil
148+
} else {
149+
return s.session, nil
150+
}
143151
}
144152

145153
session, err := s.client.Connect(ctx, s.transport, nil)

tool/mcptoolset/set_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"net/http"
2323
"path/filepath"
2424
"strings"
25+
"sync"
2526
"testing"
2627

2728
"github.com/google/go-cmp/cmp"
@@ -307,3 +308,140 @@ func TestToolFilter(t *testing.T) {
307308
t.Errorf("tools mismatch (-want +got):\n%s", diff)
308309
}
309310
}
311+
312+
// TestSessionValidationOnReuse verifies that the MCP toolset validates
313+
// cached sessions before reuse using Ping health check.
314+
func TestSessionValidationOnReuse(t *testing.T) {
315+
const toolDescription = "returns weather in the given city"
316+
317+
clientTransport, serverTransport := mcp.NewInMemoryTransports()
318+
319+
// Create server instance.
320+
server := mcp.NewServer(&mcp.Implementation{Name: "weather_server", Version: "v1.0.0"}, nil)
321+
mcp.AddTool(server, &mcp.Tool{Name: "get_weather", Description: toolDescription}, weatherFunc)
322+
_, err := server.Connect(t.Context(), serverTransport, nil)
323+
if err != nil {
324+
t.Fatal(err)
325+
}
326+
327+
ts, err := mcptoolset.New(mcptoolset.Config{
328+
Transport: clientTransport,
329+
})
330+
if err != nil {
331+
t.Fatalf("Failed to create MCP tool set: %v", err)
332+
}
333+
334+
readonlyCtx := icontext.NewReadonlyContext(
335+
icontext.NewInvocationContext(
336+
t.Context(),
337+
icontext.InvocationContextParams{},
338+
),
339+
)
340+
341+
// First call should establish a session and return tools.
342+
tools, err := ts.Tools(readonlyCtx)
343+
if err != nil {
344+
t.Fatalf("First Tools() call failed: %v", err)
345+
}
346+
if len(tools) != 1 || tools[0].Name() != "get_weather" {
347+
t.Fatalf("Expected 1 tool named 'get_weather', got %d tools", len(tools))
348+
}
349+
350+
// Second call should reuse the cached session (validated via Ping).
351+
tools2, err := ts.Tools(readonlyCtx)
352+
if err != nil {
353+
t.Fatalf("Second Tools() call failed: %v", err)
354+
}
355+
if len(tools2) != 1 || tools2[0].Name() != "get_weather" {
356+
t.Fatalf("Expected 1 tool named 'get_weather', got %d tools", len(tools2))
357+
}
358+
}
359+
360+
// reconnectableTransport wraps another transport and allows swapping it out
361+
// to simulate reconnection scenarios.
362+
type reconnectableTransport struct {
363+
mu sync.Mutex
364+
transport mcp.Transport
365+
}
366+
367+
func (r *reconnectableTransport) setTransport(t mcp.Transport) {
368+
r.mu.Lock()
369+
defer r.mu.Unlock()
370+
r.transport = t
371+
}
372+
373+
func (r *reconnectableTransport) Connect(ctx context.Context) (mcp.Connection, error) {
374+
r.mu.Lock()
375+
defer r.mu.Unlock()
376+
return r.transport.Connect(ctx)
377+
}
378+
379+
// TestStaleSessionIsRecreated verifies that when a cached session becomes stale
380+
// (e.g., server closes the connection), the MCP toolset detects this via Ping
381+
// and automatically creates a new session.
382+
func TestStaleSessionIsRecreated(t *testing.T) {
383+
const toolDescription = "returns weather in the given city"
384+
385+
// Create first transport pair.
386+
clientTransport1, serverTransport1 := mcp.NewInMemoryTransports()
387+
388+
// Create server instance.
389+
server := mcp.NewServer(&mcp.Implementation{Name: "weather_server", Version: "v1.0.0"}, nil)
390+
mcp.AddTool(server, &mcp.Tool{Name: "get_weather", Description: toolDescription}, weatherFunc)
391+
392+
// Capture the server-side session to be able to close it later.
393+
serverSession, err := server.Connect(t.Context(), serverTransport1, nil)
394+
if err != nil {
395+
t.Fatal(err)
396+
}
397+
398+
// Use a reconnectable transport that we can swap.
399+
reconnectable := &reconnectableTransport{transport: clientTransport1}
400+
401+
ts, err := mcptoolset.New(mcptoolset.Config{
402+
Transport: reconnectable,
403+
})
404+
if err != nil {
405+
t.Fatalf("Failed to create MCP tool set: %v", err)
406+
}
407+
408+
readonlyCtx := icontext.NewReadonlyContext(
409+
icontext.NewInvocationContext(
410+
t.Context(),
411+
icontext.InvocationContextParams{},
412+
),
413+
)
414+
415+
// First call should establish a session and return tools.
416+
tools, err := ts.Tools(readonlyCtx)
417+
if err != nil {
418+
t.Fatalf("First Tools() call failed: %v", err)
419+
}
420+
if len(tools) != 1 || tools[0].Name() != "get_weather" {
421+
t.Fatalf("Expected 1 tool named 'get_weather', got %d tools", len(tools))
422+
}
423+
424+
// Simulate connection drop by closing the server-side session.
425+
if err := serverSession.Close(); err != nil {
426+
t.Fatalf("Failed to close server session: %v", err)
427+
}
428+
429+
// Create a new transport pair for reconnection.
430+
clientTransport2, serverTransport2 := mcp.NewInMemoryTransports()
431+
reconnectable.setTransport(clientTransport2)
432+
433+
// Create a new server session for the new transport.
434+
_, err = server.Connect(t.Context(), serverTransport2, nil)
435+
if err != nil {
436+
t.Fatalf("Failed to create new server session: %v", err)
437+
}
438+
439+
// Second call should detect the stale session via Ping, reconnect, and succeed.
440+
tools2, err := ts.Tools(readonlyCtx)
441+
if err != nil {
442+
t.Fatalf("Second Tools() call failed after session drop: %v", err)
443+
}
444+
if len(tools2) != 1 || tools2[0].Name() != "get_weather" {
445+
t.Fatalf("Expected 1 tool named 'get_weather' after reconnect, got %d tools", len(tools2))
446+
}
447+
}

0 commit comments

Comments
 (0)