diff --git a/src/control/drpc/drpc_server_test.go b/src/control/drpc/drpc_server_test.go index f79b2035852..5d9ecae2e04 100644 --- a/src/control/drpc/drpc_server_test.go +++ b/src/control/drpc/drpc_server_test.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" @@ -283,38 +284,51 @@ func TestServer_RegisterModule(t *testing.T) { } } -func TestServer_Listen_AcceptError(t *testing.T) { - log, buf := logging.NewTestLogger(t.Name()) - defer test.ShowBufferOnFailure(t, buf) - - lis := newMockListener() - lis.acceptErr = errors.New("mock accept error") - dss, _ := NewDomainSocketServer(log, "dontcare.sock", testFileMode) - dss.listener = lis +func TestDrpc_DomainSocketServer_Listen(t *testing.T) { + for name, tc := range map[string]struct { + newMockListener func(t *testing.T) *mockListener + expAcceptCallCount int + expNumSessions int + }{ + "accept error": { + newMockListener: func(t *testing.T) *mockListener { + lis := newMockListener(t) + lis.acceptErr = errors.New("mock accept error") + return lis + }, + expAcceptCallCount: 1, // return after first error + }, + "accept multiple": { + newMockListener: func(t *testing.T) *mockListener { + lis := newMockListener(t) + lis.setNumConnsToAccept(3) + return lis + }, + expAcceptCallCount: 4, // accepted connections + final error to exit listener + expNumSessions: 3, + }, + } { + t.Run(name, func(t *testing.T) { + ctx := test.MustLogContext(t) + log := logging.FromContext(ctx) - dss.Listen(test.Context(t)) // should return instantly + dss, err := NewDomainSocketServer(log, "dontcare.sock", testFileMode) + if err != nil { + t.Fatal(err) + } + lis := tc.newMockListener(t) + dss.listener = lis - test.AssertEqual(t, lis.acceptCallCount, 1, "should have returned after first error") -} + dss.Listen(ctx) -func TestServer_Listen_AcceptConnection(t *testing.T) { - log, buf := logging.NewTestLogger(t.Name()) - defer test.ShowBufferOnFailure(t, buf) + // Time for session goroutines to settle + time.Sleep(500 * time.Microsecond) - lis := newMockListener() - lis.setNumConnsToAccept(3) - dss, err := NewDomainSocketServer(log, "dontcare.sock", testFileMode) - if err != nil { - t.Fatal(err) + test.AssertEqual(t, tc.expAcceptCallCount, lis.acceptCallCount, "") + test.AssertEqual(t, tc.expNumSessions, dss.GetNumSessions(), + "server should have made connections into sessions") + }) } - dss.listener = lis - - dss.Listen(test.Context(t)) // will return when error is sent - - test.AssertEqual(t, lis.acceptCallCount, lis.acceptNumConns+1, - "should have returned after listener errored") - test.AssertEqual(t, dss.GetNumSessions(), lis.acceptNumConns, - "server should have made connections into sessions") } func TestServer_ListenSession_Error(t *testing.T) { diff --git a/src/control/drpc/mocks_test.go b/src/control/drpc/mocks_test.go index a2c04c6ea10..1d65de5f13f 100644 --- a/src/control/drpc/mocks_test.go +++ b/src/control/drpc/mocks_test.go @@ -1,6 +1,6 @@ // // (C) Copyright 2019-2022 Intel Corporation. -// (C) Copyright 2025 Hewlett Packard Enterprise Development LP +// (C) Copyright 2025-2026 Hewlett Packard Enterprise Development LP // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -167,6 +167,7 @@ func (m *mockConn) SetReadOutputBytesToResponse(t *testing.T, resp *Response) { // mockListener is a mock of the net.Listener interface type mockListener struct { + t *testing.T acceptNumConns int // accept a certain number of connections before failing acceptErr error acceptCallCount int @@ -179,7 +180,9 @@ func (l *mockListener) Accept() (net.Conn, error) { if l.acceptCallCount > l.acceptNumConns { return nil, l.acceptErr } - return newMockConn(), nil + c := newMockConn() + c.SetReadOutputBytesToResponse(l.t, &Response{}) + return c, nil } func (l *mockListener) Close() error { @@ -197,10 +200,8 @@ func (l *mockListener) setNumConnsToAccept(n int) { l.acceptErr = errors.New("mock done accepting connections") } -func newMockListener() *mockListener { - return &mockListener{ - acceptNumConns: -1, - } +func newMockListener(t *testing.T) *mockListener { + return &mockListener{t: t} } // ctxMockListener is a mock of the net.Listener interface that blocks