Skip to content

Commit 53e4b0b

Browse files
authored
chore(refactor): Refactoring tools tests (#22)
1 parent abad95f commit 53e4b0b

File tree

6 files changed

+398
-338
lines changed

6 files changed

+398
-338
lines changed

internal/toolsets/config/tools_test.go

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ import (
77
"testing"
88

99
"github.com/modelcontextprotocol/go-sdk/mcp"
10-
v1 "github.com/stackrox/rox/generated/api/v1"
1110
"github.com/stackrox/rox/generated/storage"
1211
"github.com/stackrox/stackrox-mcp/internal/client"
1312
"github.com/stackrox/stackrox-mcp/internal/config"
13+
"github.com/stackrox/stackrox-mcp/internal/toolsets/mock"
1414
"github.com/stretchr/testify/assert"
1515
"github.com/stretchr/testify/require"
1616
"google.golang.org/grpc"
@@ -60,44 +60,6 @@ func TestListClustersTool_RegisterWith(t *testing.T) {
6060
})
6161
}
6262

63-
// Mock infrastructure for gRPC testing.
64-
65-
// mockClustersService implements v1.ClustersServiceServer for testing.
66-
type mockClustersService struct {
67-
v1.UnimplementedClustersServiceServer
68-
69-
clusters []*storage.Cluster
70-
err error
71-
}
72-
73-
func (m *mockClustersService) GetClusters(
74-
_ context.Context,
75-
_ *v1.GetClustersRequest,
76-
) (*v1.ClustersList, error) {
77-
if m.err != nil {
78-
return nil, m.err
79-
}
80-
81-
return &v1.ClustersList{
82-
Clusters: m.clusters,
83-
}, nil
84-
}
85-
86-
// setupMockServer creates an in-memory gRPC server using bufconn.
87-
func setupMockServer(mockService *mockClustersService) (*grpc.Server, *bufconn.Listener) {
88-
buffer := 1024 * 1024
89-
listener := bufconn.Listen(buffer)
90-
91-
grpcServer := grpc.NewServer()
92-
v1.RegisterClustersServiceServer(grpcServer, mockService)
93-
94-
go func() {
95-
_ = grpcServer.Serve(listener)
96-
}()
97-
98-
return grpcServer, listener
99-
}
100-
10163
// bufDialer creates a dialer function for bufconn.
10264
func bufDialer(listener *bufconn.Listener) func(context.Context, string) (net.Conn, error) {
10365
return func(_ context.Context, _ string) (net.Conn, error) {
@@ -129,17 +91,18 @@ func createTestClient(t *testing.T, listener *bufconn.Listener) *client.Client {
12991
}
13092

13193
func TestHandle_DefaultLimit(t *testing.T) {
132-
mockService := &mockClustersService{
133-
clusters: []*storage.Cluster{
94+
mockService := mock.NewClustersServiceMock(
95+
[]*storage.Cluster{
13496
{Id: "c1", Name: "Cluster 1", Type: storage.ClusterType_KUBERNETES_CLUSTER},
13597
{Id: "c2", Name: "Cluster 2", Type: storage.ClusterType_KUBERNETES_CLUSTER},
13698
{Id: "c3", Name: "Cluster 3", Type: storage.ClusterType_KUBERNETES_CLUSTER},
13799
{Id: "c4", Name: "Cluster 4", Type: storage.ClusterType_KUBERNETES_CLUSTER},
138100
{Id: "c5", Name: "Cluster 5", Type: storage.ClusterType_KUBERNETES_CLUSTER},
139101
},
140-
}
102+
nil,
103+
)
141104

142-
grpcServer, listener := setupMockServer(mockService)
105+
grpcServer, listener := mock.SetupClusterServer(mockService)
143106
defer grpcServer.Stop()
144107

145108
testClient := createTestClient(t, listener)
@@ -180,11 +143,9 @@ func TestHandle_WithPagination(t *testing.T) {
180143
}
181144
}
182145

183-
mockService := &mockClustersService{
184-
clusters: clusters,
185-
}
146+
mockService := mock.NewClustersServiceMock(clusters, nil)
186147

187-
grpcServer, listener := setupMockServer(mockService)
148+
grpcServer, listener := mock.SetupClusterServer(mockService)
188149
defer grpcServer.Stop()
189150

190151
testClient := createTestClient(t, listener)
@@ -255,11 +216,12 @@ func TestHandle_WithPagination(t *testing.T) {
255216
}
256217

257218
func TestHandle_GetClustersError(t *testing.T) {
258-
mockService := &mockClustersService{
259-
err: status.Error(codes.Internal, "test"),
260-
}
219+
mockService := mock.NewClustersServiceMock(
220+
[]*storage.Cluster{},
221+
status.Error(codes.Internal, "test"),
222+
)
261223

262-
grpcServer, listener := setupMockServer(mockService)
224+
grpcServer, listener := mock.SetupClusterServer(mockService)
263225
defer grpcServer.Stop()
264226

265227
testClient := createTestClient(t, listener)
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
package mock
2+
3+
import (
4+
"context"
5+
"strings"
6+
"sync"
7+
8+
"github.com/pkg/errors"
9+
v1 "github.com/stackrox/rox/generated/api/v1"
10+
"github.com/stackrox/rox/generated/storage"
11+
"google.golang.org/grpc"
12+
"google.golang.org/grpc/test/bufconn"
13+
)
14+
15+
const bufferSize = 1024 * 1024
16+
17+
// SetupAPIServer creates an in-memory gRPC Central server.
18+
func SetupAPIServer(
19+
deploymentService v1.DeploymentServiceServer,
20+
imageService v1.ImageServiceServer,
21+
nodeService v1.NodeServiceServer,
22+
clusterService v1.ClustersServiceServer,
23+
) (*grpc.Server, *bufconn.Listener) {
24+
buffer := bufferSize
25+
listener := bufconn.Listen(buffer)
26+
27+
grpcServer := grpc.NewServer()
28+
v1.RegisterDeploymentServiceServer(grpcServer, deploymentService)
29+
v1.RegisterImageServiceServer(grpcServer, imageService)
30+
v1.RegisterNodeServiceServer(grpcServer, nodeService)
31+
v1.RegisterClustersServiceServer(grpcServer, clusterService)
32+
33+
go func() {
34+
_ = grpcServer.Serve(listener)
35+
}()
36+
37+
return grpcServer, listener
38+
}
39+
40+
// SetupNodeServer creates an in-memory gRPC server with node services.
41+
func SetupNodeServer(nodeService v1.NodeServiceServer) (*grpc.Server, *bufconn.Listener) {
42+
return SetupAPIServer(
43+
v1.UnimplementedDeploymentServiceServer{},
44+
v1.UnimplementedImageServiceServer{},
45+
nodeService,
46+
v1.UnimplementedClustersServiceServer{},
47+
)
48+
}
49+
50+
// SetupDeploymentServer creates an in-memory gRPC server with deployment services.
51+
func SetupDeploymentServer(mockService v1.DeploymentServiceServer) (*grpc.Server, *bufconn.Listener) {
52+
return SetupAPIServer(
53+
mockService,
54+
v1.UnimplementedImageServiceServer{},
55+
v1.UnimplementedNodeServiceServer{},
56+
v1.UnimplementedClustersServiceServer{},
57+
)
58+
}
59+
60+
// SetupClusterServer creates an in-memory gRPC server with cluster services.
61+
func SetupClusterServer(mockService v1.ClustersServiceServer) (*grpc.Server, *bufconn.Listener) {
62+
return SetupAPIServer(
63+
v1.UnimplementedDeploymentServiceServer{},
64+
v1.UnimplementedImageServiceServer{},
65+
v1.UnimplementedNodeServiceServer{},
66+
mockService,
67+
)
68+
}
69+
70+
// ClustersService implements v1.ClustersServiceServer for testing.
71+
type ClustersService struct {
72+
v1.UnimplementedClustersServiceServer
73+
74+
clusters []*storage.Cluster
75+
err error
76+
77+
lastCallQuery string
78+
}
79+
80+
// NewClustersServiceMock return mock to cluster service.
81+
func NewClustersServiceMock(clusters []*storage.Cluster, err error) *ClustersService {
82+
return &ClustersService{clusters: clusters, err: err}
83+
}
84+
85+
// GetLastCallQuery returns query used for the last call.
86+
func (cs *ClustersService) GetLastCallQuery() string {
87+
return cs.lastCallQuery
88+
}
89+
90+
// GetClusters implements v1.ClustersServiceServer.GetClusters for testing.
91+
func (cs *ClustersService) GetClusters(
92+
_ context.Context,
93+
req *v1.GetClustersRequest,
94+
) (*v1.ClustersList, error) {
95+
cs.lastCallQuery = req.GetQuery()
96+
97+
if cs.err != nil {
98+
return nil, cs.err
99+
}
100+
101+
return &v1.ClustersList{
102+
Clusters: cs.clusters,
103+
}, nil
104+
}
105+
106+
// NodeService implements v1.NodeServiceServer for testing.
107+
type NodeService struct {
108+
v1.UnimplementedNodeServiceServer
109+
110+
nodes []*storage.Node
111+
err error
112+
113+
lastCallQuery string
114+
}
115+
116+
// NewNodeServiceMock return mock to node service.
117+
func NewNodeServiceMock(nodes []*storage.Node, err error) *NodeService {
118+
return &NodeService{nodes: nodes, err: err}
119+
}
120+
121+
// GetLastCallQuery returns query used for the last call.
122+
func (ns *NodeService) GetLastCallQuery() string {
123+
return ns.lastCallQuery
124+
}
125+
126+
// ExportNodes implements v1.NodeServiceServer.ExportNodes for testing.
127+
func (ns *NodeService) ExportNodes(
128+
req *v1.ExportNodeRequest,
129+
stream grpc.ServerStreamingServer[v1.ExportNodeResponse],
130+
) error {
131+
ns.lastCallQuery = req.GetQuery()
132+
133+
if ns.err != nil {
134+
return ns.err
135+
}
136+
137+
// Send all nodes through the stream.
138+
for _, node := range ns.nodes {
139+
resp := &v1.ExportNodeResponse{Node: node}
140+
if err := stream.Send(resp); err != nil {
141+
return errors.Wrap(err, "sending node over stream failed")
142+
}
143+
}
144+
145+
return nil
146+
}
147+
148+
// DeploymentService implements v1.DeploymentServiceServer for testing.
149+
type DeploymentService struct {
150+
v1.UnimplementedDeploymentServiceServer
151+
152+
deployments []*storage.ListDeployment
153+
err error
154+
155+
// Mock call information.
156+
lastCallQuery string
157+
lastCallLimit int32
158+
lastCallOffset int32
159+
}
160+
161+
// NewDeploymentServiceMock returns mock for deployment service.
162+
func NewDeploymentServiceMock(deployments []*storage.ListDeployment, err error) *DeploymentService {
163+
return &DeploymentService{
164+
deployments: deployments,
165+
err: err,
166+
}
167+
}
168+
169+
// GetLastCallQuery returns query used for the last call.
170+
func (ds *DeploymentService) GetLastCallQuery() string {
171+
return ds.lastCallQuery
172+
}
173+
174+
// GetLastCallLimit returns limit used for the last call.
175+
func (ds *DeploymentService) GetLastCallLimit() int32 {
176+
return ds.lastCallLimit
177+
}
178+
179+
// GetLastCallOffset returns offset used for the last call.
180+
func (ds *DeploymentService) GetLastCallOffset() int32 {
181+
return ds.lastCallOffset
182+
}
183+
184+
// ListDeployments implements v1.DeploymentServiceServer.ListDeployments for testing.
185+
func (ds *DeploymentService) ListDeployments(
186+
_ context.Context,
187+
query *v1.RawQuery,
188+
) (*v1.ListDeploymentsResponse, error) {
189+
ds.lastCallQuery = query.GetQuery()
190+
ds.lastCallLimit = query.GetPagination().GetLimit()
191+
ds.lastCallOffset = query.GetPagination().GetOffset()
192+
193+
if ds.err != nil {
194+
return nil, ds.err
195+
}
196+
197+
return &v1.ListDeploymentsResponse{
198+
Deployments: ds.deployments,
199+
}, nil
200+
}
201+
202+
// ImageService implements v1.ImageServiceServer for testing.
203+
type ImageService struct {
204+
v1.UnimplementedImageServiceServer
205+
206+
images map[string][]*storage.ListImage // keyed by deploymentID
207+
err error
208+
209+
// We are requesting images in parallel requests.
210+
lock sync.Mutex
211+
212+
// Mock call information.
213+
lastCallQuery string
214+
lastCallLimit int32
215+
callCount int
216+
}
217+
218+
// NewImageServiceMock returns mock for image service.
219+
func NewImageServiceMock(images map[string][]*storage.ListImage, err error) *ImageService {
220+
return &ImageService{
221+
images: images,
222+
err: err,
223+
}
224+
}
225+
226+
// GetLastCallQuery returns query used for the last call.
227+
func (is *ImageService) GetLastCallQuery() string {
228+
return is.lastCallQuery
229+
}
230+
231+
// GetLastCallLimit returns limit used for the last call.
232+
func (is *ImageService) GetLastCallLimit() int32 {
233+
return is.lastCallLimit
234+
}
235+
236+
// GetCallCount returns count off all calls.
237+
func (is *ImageService) GetCallCount() int {
238+
return is.callCount
239+
}
240+
241+
// ListImages implements v1.ImageServiceServer.ListImages for testing.
242+
func (is *ImageService) ListImages(
243+
_ context.Context,
244+
query *v1.RawQuery,
245+
) (*v1.ListImagesResponse, error) {
246+
is.lock.Lock()
247+
defer is.lock.Unlock()
248+
249+
is.callCount++
250+
is.lastCallQuery = query.GetQuery()
251+
is.lastCallLimit = query.GetPagination().GetLimit()
252+
253+
if is.err != nil {
254+
return nil, is.err
255+
}
256+
257+
// Extract deployment ID from query.
258+
// Query format: CVE:"CVE-2021-44228"+Deployment ID:"dep-1"
259+
deploymentID := extractDeploymentIDFromQuery(query.GetQuery())
260+
261+
return &v1.ListImagesResponse{
262+
Images: is.images[deploymentID],
263+
}, nil
264+
}
265+
266+
// extractDeploymentIDFromQuery extracts deployment ID from the query string.
267+
func extractDeploymentIDFromQuery(query string) string {
268+
const prefix = "Deployment ID:\""
269+
270+
start := strings.Index(query, prefix)
271+
if start == -1 {
272+
return ""
273+
}
274+
275+
start += len(prefix)
276+
277+
end := strings.Index(query[start:], "\"")
278+
if end == -1 {
279+
return ""
280+
}
281+
282+
return query[start : start+end]
283+
}

0 commit comments

Comments
 (0)