Skip to content

Commit 0b6b71c

Browse files
committed
feat(validation): add firewall validation for iptables/ufw rules
- Add networking_validation.go with ValidateFirewallBlocksPort and ValidateDockerFirewallBlocksPort to verify servers on 0.0.0.0 are not accessible from outside - Extract instance validation functions to instance_validation.go - Add RunFirewallValidation to validation suite - Integrate ValidateFirewallBlocksPort into RunInstanceLifecycleValidation
1 parent 4570d27 commit 0b6b71c

File tree

4 files changed

+588
-175
lines changed

4 files changed

+588
-175
lines changed

internal/validation/suite.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,16 @@ func RunInstanceLifecycleValidation(t *testing.T, config ProviderConfig) {
119119
require.NoError(t, err, "ValidateInstanceImage should pass")
120120
})
121121

122+
t.Run("ValidateFirewallBlocksPort", func(t *testing.T) {
123+
err := v1.ValidateFirewallBlocksPort(ctx, client, instance, ssh.GetTestPrivateKey(), v1.DefaultFirewallTestPort)
124+
require.NoError(t, err, "ValidateFirewallBlocksPort should pass - non-allowed port should be blocked")
125+
})
126+
127+
t.Run("ValidateDockerFirewallBlocksPort", func(t *testing.T) {
128+
err := v1.ValidateDockerFirewallBlocksPort(ctx, client, instance, ssh.GetTestPrivateKey(), v1.DefaultFirewallTestPort)
129+
require.NoError(t, err, "ValidateDockerFirewallBlocksPort should pass - docker port should be blocked by iptables")
130+
})
131+
122132
if capabilities.IsCapable(v1.CapabilityStopStartInstance) && instance.Stoppable {
123133
t.Run("ValidateStopStartInstance", func(t *testing.T) {
124134
err := v1.ValidateStopStartInstance(ctx, client, instance)
@@ -235,6 +245,101 @@ func RunNetworkValidation(t *testing.T, config ProviderConfig, opts NetworkValid
235245
})
236246
}
237247

248+
type FirewallValidationOpts struct {
249+
// TestPort is the port to test firewall blocking on (should NOT be in allowed ingress)
250+
TestPort int
251+
// TestDockerFirewall enables docker firewall validation (requires Docker on instance)
252+
TestDockerFirewall bool
253+
}
254+
255+
func RunFirewallValidation(t *testing.T, config ProviderConfig, opts FirewallValidationOpts) {
256+
if testing.Short() {
257+
t.Skip("Skipping validation tests in short mode")
258+
}
259+
260+
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute)
261+
defer cancel()
262+
263+
client, err := config.Credential.MakeClient(ctx, config.Location)
264+
if err != nil {
265+
t.Fatalf("Failed to create client for %s: %v", config.Credential.GetCloudProviderID(), err)
266+
}
267+
268+
types, err := client.GetInstanceTypes(ctx, v1.GetInstanceTypeArgs{
269+
ArchitectureFilter: &v1.ArchitectureFilter{
270+
IncludeArchitectures: []v1.Architecture{v1.ArchitectureX86_64},
271+
},
272+
})
273+
require.NoError(t, err)
274+
require.NotEmpty(t, types, "Should have instance types")
275+
276+
// Find an available instance type
277+
attrs := v1.CreateInstanceAttrs{}
278+
selectedType := v1.InstanceType{}
279+
for _, typ := range types {
280+
if typ.IsAvailable {
281+
attrs.InstanceType = typ.Type
282+
attrs.Location = typ.Location
283+
attrs.PublicKey = ssh.GetTestPublicKey()
284+
selectedType = typ
285+
break
286+
}
287+
}
288+
require.NotEmpty(t, attrs.InstanceType, "Should find available instance type")
289+
290+
// Create instance for firewall testing
291+
instance, err := v1.ValidateCreateInstance(ctx, client, attrs, selectedType)
292+
require.NoError(t, err, "ValidateCreateInstance should pass")
293+
require.NotNil(t, instance)
294+
295+
defer func() {
296+
if instance != nil {
297+
_ = client.TerminateInstance(ctx, instance.CloudID)
298+
}
299+
}()
300+
301+
// Wait for instance to be running and SSH accessible
302+
t.Run("ValidateSSHAccessible", func(t *testing.T) {
303+
err := v1.ValidateInstanceSSHAccessible(ctx, client, instance, ssh.GetTestPrivateKey())
304+
require.NoError(t, err, "ValidateSSHAccessible should pass")
305+
})
306+
307+
// Refresh instance data
308+
instance, err = client.GetInstance(ctx, instance.CloudID)
309+
require.NoError(t, err)
310+
311+
testPort := opts.TestPort
312+
if testPort == 0 {
313+
testPort = v1.DefaultFirewallTestPort
314+
}
315+
316+
// Test that regular server on 0.0.0.0 is blocked
317+
t.Run("ValidateFirewallBlocksPort", func(t *testing.T) {
318+
err := v1.ValidateFirewallBlocksPort(ctx, client, instance, ssh.GetTestPrivateKey(), testPort)
319+
require.NoError(t, err, "ValidateFirewallBlocksPort should pass - port should be blocked")
320+
})
321+
322+
// Test that Docker container on 0.0.0.0 is blocked (if enabled)
323+
if opts.TestDockerFirewall {
324+
t.Run("ValidateDockerFirewallBlocksPort", func(t *testing.T) {
325+
err := v1.ValidateDockerFirewallBlocksPort(ctx, client, instance, ssh.GetTestPrivateKey(), testPort)
326+
require.NoError(t, err, "ValidateDockerFirewallBlocksPort should pass - docker port should be blocked")
327+
})
328+
}
329+
330+
// Test that SSH port is accessible (sanity check)
331+
t.Run("ValidateSSHPortAccessible", func(t *testing.T) {
332+
err := v1.ValidateFirewallAllowsPort(ctx, client, instance, ssh.GetTestPrivateKey(), instance.SSHPort)
333+
require.NoError(t, err, "ValidateFirewallAllowsPort should pass for SSH port")
334+
})
335+
336+
// Terminate instance
337+
t.Run("ValidateTerminateInstance", func(t *testing.T) {
338+
err := v1.ValidateTerminateInstance(ctx, client, instance)
339+
require.NoError(t, err, "ValidateTerminateInstance should pass")
340+
})
341+
}
342+
238343
type KubernetesValidationOpts struct {
239344
Name string
240345
RefID string

v1/instance.go

Lines changed: 0 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,10 @@ package v1
22

33
import (
44
"context"
5-
"errors"
65
"fmt"
76
"time"
87

98
"github.com/alecthomas/units"
10-
"github.com/brevdev/cloud/internal/collections"
11-
"github.com/brevdev/cloud/internal/ssh"
12-
"github.com/google/uuid"
139
)
1410

1511
type CloudInstanceReader interface {
@@ -28,112 +24,11 @@ type CloudCreateTerminateInstance interface {
2824
CloudInstanceReader
2925
}
3026

31-
func ValidateCreateInstance(ctx context.Context, client CloudCreateTerminateInstance, attrs CreateInstanceAttrs, selectedType InstanceType) (*Instance, error) { //nolint:gocyclo // ok
32-
t0 := time.Now().Add(-time.Minute)
33-
attrs.RefID = uuid.New().String()
34-
name, err := makeDebuggableName(attrs.Name)
35-
if err != nil {
36-
return nil, err
37-
}
38-
attrs.Name = name
39-
i, err := client.CreateInstance(ctx, attrs)
40-
if err != nil {
41-
return nil, err
42-
}
43-
var validationErr error
44-
t1 := time.Now().Add(1 * time.Minute)
45-
diff := t1.Sub(t0)
46-
if diff > 3*time.Minute {
47-
validationErr = errors.Join(validationErr, fmt.Errorf("create instance took too long: %s", diff))
48-
}
49-
if i.CreatedAt.Before(t0) {
50-
validationErr = errors.Join(validationErr, fmt.Errorf("createdAt is before t0: %s", i.CreatedAt))
51-
}
52-
if i.CreatedAt.After(t1) {
53-
validationErr = errors.Join(validationErr, fmt.Errorf("createdAt is after t1: %s", i.CreatedAt))
54-
}
55-
if i.Name != name {
56-
fmt.Printf("name mismatch: %s != %s, input name does not mean return name will be stable\n", i.Name, name)
57-
}
58-
if i.RefID != attrs.RefID {
59-
validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", i.RefID, attrs.RefID))
60-
}
61-
if attrs.Location != "" && attrs.Location != i.Location {
62-
validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", attrs.Location, i.Location))
63-
}
64-
if attrs.SubLocation != "" && attrs.SubLocation != i.SubLocation {
65-
validationErr = errors.Join(validationErr, fmt.Errorf("subLocation mismatch: %s != %s", attrs.SubLocation, i.SubLocation))
66-
}
67-
if attrs.InstanceType != "" && attrs.InstanceType != i.InstanceType {
68-
validationErr = errors.Join(validationErr, fmt.Errorf("instanceType mismatch: %s != %s", attrs.InstanceType, i.InstanceType))
69-
}
70-
if selectedType.ID != "" && selectedType.ID != i.InstanceTypeID {
71-
validationErr = errors.Join(validationErr, fmt.Errorf("instanceTypeID mismatch: %s != %s", selectedType.ID, i.InstanceTypeID))
72-
}
73-
74-
return i, validationErr
75-
}
76-
77-
func ValidateListCreatedInstance(ctx context.Context, client CloudCreateTerminateInstance, i *Instance) error {
78-
ins, err := client.ListInstances(ctx, ListInstancesArgs{
79-
Locations: []string{i.Location},
80-
})
81-
if err != nil {
82-
return err
83-
}
84-
var validationErr error
85-
if len(ins) == 0 {
86-
validationErr = errors.Join(validationErr, fmt.Errorf("no instances found"))
87-
}
88-
foundInstance := collections.Find(ins, func(inst Instance) bool {
89-
return inst.CloudID == i.CloudID
90-
})
91-
if foundInstance == nil {
92-
validationErr = errors.Join(validationErr, fmt.Errorf("instance not found: %s", i.CloudID))
93-
return validationErr
94-
}
95-
if foundInstance.Location != i.Location { //nolint:gocritic // fine
96-
validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", foundInstance.Location, i.Location))
97-
} else if foundInstance.RefID == "" {
98-
validationErr = errors.Join(validationErr, fmt.Errorf("refID is empty"))
99-
} else if foundInstance.RefID != i.RefID {
100-
validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", foundInstance.RefID, i.RefID))
101-
} else if foundInstance.CloudCredRefID == "" {
102-
validationErr = errors.Join(validationErr, fmt.Errorf("cloudCredRefID is empty"))
103-
} else if foundInstance.CloudCredRefID != i.CloudCredRefID {
104-
validationErr = errors.Join(validationErr, fmt.Errorf("cloudCredRefID mismatch: %s != %s", foundInstance.CloudCredRefID, i.CloudCredRefID))
105-
}
106-
return validationErr
107-
}
108-
109-
func ValidateTerminateInstance(ctx context.Context, client CloudCreateTerminateInstance, instance *Instance) error {
110-
err := client.TerminateInstance(ctx, instance.CloudID)
111-
if err != nil {
112-
return err
113-
}
114-
// TODO wait for instance to go into terminating state
115-
return nil
116-
}
117-
11827
type CloudStopStartInstance interface {
11928
StopInstance(ctx context.Context, instanceID CloudProviderInstanceID) error
12029
StartInstance(ctx context.Context, instanceID CloudProviderInstanceID) error
12130
}
12231

123-
func ValidateStopStartInstance(ctx context.Context, client CloudStopStartInstance, instance *Instance) error {
124-
err := client.StopInstance(ctx, instance.CloudID)
125-
if err != nil {
126-
return err
127-
}
128-
// TODO wait for stopped
129-
err = client.StartInstance(ctx, instance.CloudID)
130-
if err != nil {
131-
return err
132-
}
133-
// TODO wait for running
134-
return nil
135-
}
136-
13732
type CloudRebootInstance interface {
13833
RebootInstance(ctx context.Context, instanceID CloudProviderInstanceID) error
13934
}
@@ -152,40 +47,6 @@ type UpdateHandler interface {
15247
MergeInstanceTypeForUpdate(currIt InstanceType, newIt InstanceType) InstanceType
15348
}
15449

155-
func ValidateMergeInstanceForUpdate(client UpdateHandler, currInst Instance, newInst Instance) error {
156-
mergedInst := client.MergeInstanceForUpdate(currInst, newInst)
157-
158-
var validationErr error
159-
if currInst.Name != mergedInst.Name {
160-
validationErr = errors.Join(validationErr, fmt.Errorf("name mismatch: %s != %s", currInst.Name, mergedInst.Name))
161-
}
162-
if currInst.RefID != mergedInst.RefID {
163-
validationErr = errors.Join(validationErr, fmt.Errorf("refID mismatch: %s != %s", currInst.RefID, mergedInst.RefID))
164-
}
165-
if currInst.Location != mergedInst.Location {
166-
validationErr = errors.Join(validationErr, fmt.Errorf("location mismatch: %s != %s", currInst.Location, newInst.Location))
167-
}
168-
if currInst.SubLocation != mergedInst.SubLocation {
169-
validationErr = errors.Join(validationErr, fmt.Errorf("subLocation mismatch: %s != %s", currInst.SubLocation, mergedInst.SubLocation))
170-
}
171-
if currInst.InstanceType != "" && currInst.InstanceType != mergedInst.InstanceType {
172-
validationErr = errors.Join(validationErr, fmt.Errorf("instanceType mismatch: %s != %s", currInst.InstanceType, mergedInst.InstanceType))
173-
}
174-
if currInst.InstanceTypeID != "" && currInst.InstanceTypeID != mergedInst.InstanceTypeID {
175-
validationErr = errors.Join(validationErr, fmt.Errorf("instanceTypeID mismatch: %s != %s", currInst.InstanceTypeID, mergedInst.InstanceTypeID))
176-
}
177-
if currInst.CloudCredRefID != mergedInst.CloudCredRefID {
178-
validationErr = errors.Join(validationErr, fmt.Errorf("cloudCredRefID mismatch: %s != %s", currInst.CloudCredRefID, mergedInst.CloudCredRefID))
179-
}
180-
if currInst.VolumeType != "" && currInst.VolumeType != mergedInst.VolumeType {
181-
validationErr = errors.Join(validationErr, fmt.Errorf("volumeType mismatch: %s != %s", currInst.VolumeType, mergedInst.VolumeType))
182-
}
183-
if currInst.Spot != mergedInst.Spot {
184-
validationErr = errors.Join(validationErr, fmt.Errorf("spot mismatch: %v != %v", currInst.Spot, mergedInst.Spot))
185-
}
186-
return validationErr
187-
}
188-
18950
type Instance struct {
19051
Name string
19152
RefID string
@@ -308,39 +169,3 @@ func makeDebuggableName(name string) (string, error) {
308169
}
309170

310171
const RunningSSHTimeout = 10 * time.Minute
311-
312-
func ValidateInstanceSSHAccessible(ctx context.Context, client CloudInstanceReader, instance *Instance, privateKey string) error {
313-
var err error
314-
instance, err = WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, PendingToRunningTimeout)
315-
if err != nil {
316-
return err
317-
}
318-
sshUser := instance.SSHUser
319-
sshPort := instance.SSHPort
320-
publicIP := instance.PublicIP
321-
// Validate that we have the required SSH connection details
322-
if sshUser == "" {
323-
return fmt.Errorf("SSH user is not set for instance %s", instance.CloudID)
324-
}
325-
if sshPort == 0 {
326-
return fmt.Errorf("SSH port is not set for instance %s", instance.CloudID)
327-
}
328-
if publicIP == "" {
329-
return fmt.Errorf("public IP is not available for instance %s", instance.CloudID)
330-
}
331-
332-
err = ssh.WaitForSSH(ctx, ssh.ConnectionConfig{
333-
User: sshUser,
334-
HostPort: fmt.Sprintf("%s:%d", publicIP, sshPort),
335-
PrivKey: privateKey,
336-
}, ssh.WaitForSSHOptions{
337-
Timeout: RunningSSHTimeout,
338-
})
339-
if err != nil {
340-
return err
341-
}
342-
343-
fmt.Printf("SSH connection validated successfully for %s@%s:%d\n", sshUser, publicIP, sshPort)
344-
345-
return nil
346-
}

0 commit comments

Comments
 (0)