diff --git a/client/cli/cli.go b/client/cli/cli.go index 1dd9761553..67e20c6fea 100644 --- a/client/cli/cli.go +++ b/client/cli/cli.go @@ -53,7 +53,6 @@ func init() { rootCmd.TraverseChildren = true rootCmd.Flags().String(RCFlagName, "", "path to rc script file") rootCmd.PersistentFlags().Bool(disableWGFlag, false, "connect to multiplayer directly even if the operator config includes a WireGuard wrapper") - rootCmd.PersistentFlags().Bool(requireWGFlag, false, "require the operator config's WireGuard wrapper for multiplayer connections") // Create the console client, without any RPC or commands bound to it yet. // This created before anything so that multiple commands can make use of diff --git a/client/cli/connect_spinner.go b/client/cli/connect_spinner.go new file mode 100644 index 0000000000..83e4cc51e8 --- /dev/null +++ b/client/cli/connect_spinner.go @@ -0,0 +1,177 @@ +package cli + +import ( + "fmt" + "io" + "strings" + "time" + "unicode/utf8" + + bspinner "charm.land/bubbles/v2/spinner" + "github.com/bishopfox/sliver/client/transport" + "github.com/bishopfox/sliver/protobuf/rpcpb" + "google.golang.org/grpc" +) + +const ( + minConnectSpinnerDuration = 900 * time.Millisecond + minConnectStatusDuration = 350 * time.Millisecond +) + +type connectResult struct { + rpc rpcpb.SliverRPCClient + conn *grpc.ClientConn + err error +} + +func connectWithSpinner(out io.Writer, target string, connect func(transport.ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error)) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) { + statusCh := make(chan string, 8) + resultCh := make(chan connectResult, 1) + + go func() { + rpc, conn, err := connect(func(status string) { + sendStatus(statusCh, status) + }) + resultCh <- connectResult{rpc: rpc, conn: conn, err: err} + }() + + currentStatus := "" + currentStatusSince := time.Time{} + queuedStatus := "" + lastWidth := 0 + spinner := bspinner.New(bspinner.WithSpinner(bspinner.Line)) + ticker := time.NewTicker(spinner.Spinner.FPS) + defer ticker.Stop() + startedAt := time.Now() + var pendingResult *connectResult + + setStatus := func(status string) { + status = strings.TrimSpace(status) + if status == "" || status == currentStatus { + return + } + currentStatus = status + currentStatusSince = time.Now() + } + + stageStatus := func(status string) { + status = strings.TrimSpace(status) + if status == "" { + return + } + if currentStatus == "" || time.Since(currentStatusSince) >= minConnectStatusDuration { + queuedStatus = "" + setStatus(status) + return + } + queuedStatus = status + } + + flushQueuedStatus := func() { + if queuedStatus == "" || time.Since(currentStatusSince) < minConnectStatusDuration { + return + } + nextStatus := queuedStatus + queuedStatus = "" + setStatus(nextStatus) + } + + canReturnPendingResult := func() bool { + if pendingResult == nil { + return false + } + if pendingResult.err != nil { + return true + } + if time.Since(startedAt) < minConnectSpinnerDuration { + return false + } + if queuedStatus != "" { + return false + } + if !currentStatusSince.IsZero() && time.Since(currentStatusSince) < minConnectStatusDuration { + return false + } + return true + } + + render := func() { + line := fmt.Sprintf("%s %s", spinner.View(), formatConnectSpinnerMessage(target, currentStatus)) + lastWidth = writeSpinnerLine(out, line, lastWidth) + } + + render() + for { + select { + case status := <-statusCh: + stageStatus(status) + spinner, _ = spinner.Update(spinner.Tick()) + render() + + case result := <-resultCh: + pendingResult = &result + + case <-ticker.C: + spinner, _ = spinner.Update(spinner.Tick()) + flushQueuedStatus() + render() + if canReturnPendingResult() { + result := *pendingResult + pendingResult = nil + clearSpinnerLine(out, lastWidth) + return result.rpc, result.conn, result.err + } + } + + if pendingResult != nil && pendingResult.err != nil { + result := *pendingResult + pendingResult = nil + clearSpinnerLine(out, lastWidth) + return result.rpc, result.conn, result.err + } + } +} + +func sendStatus(statusCh chan string, status string) { + status = strings.TrimSpace(status) + if status == "" { + return + } + statusCh <- status +} + +func formatConnectSpinnerMessage(target string, status string) string { + target = strings.TrimSpace(target) + status = strings.TrimSpace(status) + + if target == "" { + if status == "" { + return "Connecting ..." + } + return fmt.Sprintf("Connecting (%s) ...", status) + } + if status == "" { + return fmt.Sprintf("Connecting to %s ...", target) + } + return fmt.Sprintf("Connecting to %s (%s) ...", target, status) +} + +func writeSpinnerLine(out io.Writer, line string, lastWidth int) int { + width := utf8.RuneCountInString(line) + padding := "" + if lastWidth > width { + padding = strings.Repeat(" ", lastWidth-width) + } + fmt.Fprintf(out, "\r%s%s", line, padding) + if width > lastWidth { + return width + } + return lastWidth +} + +func clearSpinnerLine(out io.Writer, lastWidth int) { + if lastWidth <= 0 { + return + } + fmt.Fprintf(out, "\r%s\r", strings.Repeat(" ", lastWidth)) +} diff --git a/client/cli/connect_spinner_test.go b/client/cli/connect_spinner_test.go new file mode 100644 index 0000000000..381434a248 --- /dev/null +++ b/client/cli/connect_spinner_test.go @@ -0,0 +1,132 @@ +package cli + +import ( + "bytes" + "strings" + "testing" + "time" + + "github.com/bishopfox/sliver/client/transport" + "github.com/bishopfox/sliver/protobuf/rpcpb" + "google.golang.org/grpc" +) + +func TestFormatConnectSpinnerMessage(t *testing.T) { + tests := []struct { + name string + target string + status string + want string + }{ + { + name: "target only", + target: "127.0.0.1:31337", + want: "Connecting to 127.0.0.1:31337 ...", + }, + { + name: "target with status", + target: "127.0.0.1:31337", + status: "wireguard", + want: "Connecting to 127.0.0.1:31337 (wireguard) ...", + }, + { + name: "status without target", + status: "grpc/mtls", + want: "Connecting (grpc/mtls) ...", + }, + { + name: "trim whitespace", + target: " 127.0.0.1:31337 ", + status: " grpc/mtls over wireguard ", + want: "Connecting to 127.0.0.1:31337 (grpc/mtls over wireguard) ...", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := formatConnectSpinnerMessage(test.target, test.status); got != test.want { + t.Fatalf("expected %q, got %q", test.want, got) + } + }) + } +} + +func TestConnectWithSpinnerOutput(t *testing.T) { + var out bytes.Buffer + spinnerDelay := 2 * (100 * time.Millisecond) + + _, _, err := connectWithSpinner(&out, "127.0.0.1:31337", func(statusFn transport.ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) { + statusFn("wireguard") + time.Sleep(spinnerDelay) + statusFn("grpc/mtls over wireguard") + time.Sleep(spinnerDelay) + return nil, nil, nil + }) + if err != nil { + t.Fatalf("connectWithSpinner returned error: %v", err) + } + + got := out.String() + if !strings.Contains(got, "Connecting to 127.0.0.1:31337 (wireguard) ...") { + t.Fatalf("expected wireguard status in output, got %q", got) + } + if !strings.Contains(got, "Connecting to 127.0.0.1:31337 (grpc/mtls over wireguard) ...") { + t.Fatalf("expected grpc over wireguard status in output, got %q", got) + } + frameCount := 0 + for _, frame := range []string{"|", "/", "-", "\\"} { + if strings.Contains(got, frame+" Connecting to 127.0.0.1:31337") { + frameCount++ + } + } + if frameCount < 2 { + t.Fatalf("expected multiple spinner frames in output, got %q", got) + } +} + +func TestConnectWithSpinnerFastSuccessStillShowsMultipleFrames(t *testing.T) { + var out bytes.Buffer + + _, _, err := connectWithSpinner(&out, "127.0.0.1:31337", func(statusFn transport.ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) { + statusFn("grpc/mtls") + return nil, nil, nil + }) + if err != nil { + t.Fatalf("connectWithSpinner returned error: %v", err) + } + + got := out.String() + frameCount := 0 + for _, frame := range []string{"|", "/", "-", "\\"} { + if strings.Contains(got, frame+" Connecting to 127.0.0.1:31337 (grpc/mtls) ...") { + frameCount++ + } + } + if frameCount < 2 { + t.Fatalf("expected multiple spinner frames for fast success, got %q", got) + } +} + +func TestConnectWithSpinnerFastSuccessShowsEachStatus(t *testing.T) { + var out bytes.Buffer + + _, _, err := connectWithSpinner(&out, "127.0.0.1:31337", func(statusFn transport.ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) { + statusFn("wireguard") + statusFn("grpc/mtls over wireguard") + return nil, nil, nil + }) + if err != nil { + t.Fatalf("connectWithSpinner returned error: %v", err) + } + + got := out.String() + if !strings.Contains(got, "Connecting to 127.0.0.1:31337 (wireguard) ...") { + t.Fatalf("expected fast success output to include wireguard status, got %q", got) + } + if !strings.Contains(got, "Connecting to 127.0.0.1:31337 (grpc/mtls over wireguard) ...") { + t.Fatalf("expected fast success output to include grpc over wireguard status, got %q", got) + } + if strings.Contains(got, "Connected to 127.0.0.1:31337") { + t.Fatalf("did not expect past-tense success output, got %q", got) + } +} diff --git a/client/cli/console.go b/client/cli/console.go index 614783944b..50ef9bd2c9 100644 --- a/client/cli/console.go +++ b/client/cli/console.go @@ -20,6 +20,7 @@ package cli import ( "fmt" + "os" "github.com/bishopfox/sliver/client/assets" "github.com/bishopfox/sliver/client/command" @@ -64,17 +65,20 @@ func consoleRunnerCmd(con *console.SliverClient, run bool) (pre, post func(cmd * return nil } - // Don't clobber output when simply running an implant command from system shell. - if run { - fmt.Printf("Connecting to %s:%d ...\n", config.LHost, config.LPort) - } - + target := fmt.Sprintf("%s:%d", config.LHost, config.LPort) var rpc rpcpb.SliverRPCClient var ln *grpc.ClientConn - rpc, ln, err = transport.MTLSConnect(config) + // Don't clobber output when simply running an implant command from system shell. + if run { + rpc, ln, err = connectWithSpinner(os.Stdout, target, func(statusFn transport.ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) { + return transport.MTLSConnectWithStatus(config, statusFn) + }) + } else { + rpc, ln, err = transport.MTLSConnect(config) + } if err != nil { - fmt.Printf("Connection to server failed %s", err) + fmt.Printf("Connection to server failed %s\n", err) return nil } return console.StartClient(con, rpc, ln, &console.ConnectionDetails{ConfigKey: configKey, Config: config}, command.ServerCommands(con, nil), command.SliverCommands(con), run, rcScript) diff --git a/client/cli/transport_mode.go b/client/cli/transport_mode.go index 7b917e5f30..bef1649528 100644 --- a/client/cli/transport_mode.go +++ b/client/cli/transport_mode.go @@ -1,14 +1,11 @@ package cli import ( - "fmt" - "github.com/bishopfox/sliver/client/transport" "github.com/spf13/cobra" ) const ( - requireWGFlag = "require-wg" disableWGFlag = "disable-wg" ) @@ -18,23 +15,13 @@ func applyMultiplayerConnectMode(cmd *cobra.Command) error { return nil } - requireWG, err := cmd.Flags().GetBool(requireWGFlag) - if err != nil { - return err - } disableWG, err := cmd.Flags().GetBool(disableWGFlag) if err != nil { return err } - if requireWG && disableWG { - return fmt.Errorf("--%s and --%s cannot be used together", requireWGFlag, disableWGFlag) - } mode := transport.MultiplayerConnectAuto - switch { - case requireWG: - mode = transport.MultiplayerConnectRequireWG - case disableWG: + if disableWG { mode = transport.MultiplayerConnectDisableWG } transport.SetMultiplayerConnectMode(mode) diff --git a/client/console/connection.go b/client/console/connection.go index 7b5b93aa4b..b17d9981d4 100644 --- a/client/console/connection.go +++ b/client/console/connection.go @@ -46,6 +46,9 @@ func (con *SliverClient) SetConnection(rpc rpcpb.SliverRPCClient, grpcConn *grpc con.Rpc = rpc con.grpcConn = grpcConn con.connDetails = details + con.backgroundRPC = nil + con.backgroundConn = nil + con.backgroundDedicated = false if con.Rpc == nil || con.grpcConn == nil { return nil @@ -58,31 +61,49 @@ func (con *SliverClient) SetConnection(rpc rpcpb.SliverRPCClient, grpcConn *grpc con.BeaconTaskCallbacksMutex.Unlock() con.ActiveTarget.Set(nil, nil) + if details != nil && details.Config != nil && details.Config.WG != nil { + con.backgroundDedicated = true + commandRPC, commandConn, err := transport.MTLSConnect(details.Config) + if err != nil { + log.Printf("Dedicated command WireGuard connection unavailable, async console streams disabled to preserve command reliability: %v", err) + } else { + con.backgroundRPC = con.Rpc + con.backgroundConn = con.grpcConn + con.Rpc = commandRPC + con.grpcConn = commandConn + } + } + ctx, cancel := context.WithCancel(context.Background()) con.connCancel = cancel wg := &sync.WaitGroup{} con.connWg = wg - wg.Add(1) - go func(wg *sync.WaitGroup) { - defer wg.Done() - con.startEventLoop(ctx) - }(wg) - - wg.Add(1) - go func(wg *sync.WaitGroup) { - defer wg.Done() - if err := core.TunnelLoop(ctx, con.Rpc); err != nil && !errors.Is(err, context.Canceled) { - log.Printf("TunnelLoop error: %v", err) - } - }(wg) + backgroundRPC := con.backgroundRPCClientLocked() + if backgroundRPC != nil { + wg.Add(1) + go func(wg *sync.WaitGroup, rpc rpcpb.SliverRPCClient) { + defer wg.Done() + con.startEventLoop(ctx, rpc) + }(wg, backgroundRPC) + + wg.Add(1) + go func(wg *sync.WaitGroup, rpc rpcpb.SliverRPCClient) { + defer wg.Done() + if err := core.TunnelLoop(ctx, rpc); err != nil && !errors.Is(err, context.Canceled) { + log.Printf("TunnelLoop error: %v", err) + } + }(wg, backgroundRPC) + } - wg.Add(1) - go func(wg *sync.WaitGroup, conn *grpc.ClientConn) { - defer wg.Done() - con.monitorConnectionLost(ctx, conn) - }(wg, con.grpcConn) + if !con.backgroundDedicated || con.backgroundConn == nil { + wg.Add(1) + go func(wg *sync.WaitGroup, conn *grpc.ClientConn) { + defer wg.Done() + con.monitorConnectionLost(ctx, conn) + }(wg, con.grpcConn) + } con.refreshRemoteLogStreamsLocked() @@ -127,13 +148,59 @@ func (con *SliverClient) detachConnectionLocked() { _ = transport.CloseGRPCConnection(con.grpcConn) con.grpcConn = nil } + if con.backgroundConn != nil { + _ = transport.CloseGRPCConnection(con.backgroundConn) + con.backgroundConn = nil + } con.Rpc = nil + con.backgroundRPC = nil + con.backgroundDedicated = false con.connDetails = nil // Tear down any singleton network tooling that was tied to the previous server. core.ResetClientState() } +func (con *SliverClient) backgroundRPCClientLocked() rpcpb.SliverRPCClient { + if con.backgroundRPC != nil { + return con.backgroundRPC + } + if con.backgroundDedicated { + return nil + } + return con.Rpc +} + +func (con *SliverClient) refreshDedicatedCommandConnectionHook(args []string) ([]string, error) { + if err := con.refreshDedicatedCommandConnection(); err != nil { + return args, err + } + return args, nil +} + +func (con *SliverClient) refreshDedicatedCommandConnection() error { + con.connMu.Lock() + defer con.connMu.Unlock() + + if !con.backgroundDedicated || con.backgroundConn == nil || con.connDetails == nil || con.connDetails.Config == nil { + return nil + } + + rpc, conn, err := transport.MTLSConnect(con.connDetails.Config) + if err != nil { + return fmt.Errorf("refresh command connection: %w", err) + } + + oldConn := con.grpcConn + con.Rpc = rpc + con.grpcConn = conn + + if oldConn != nil { + _ = transport.CloseGRPCConnection(oldConn) + } + return nil +} + func (con *SliverClient) monitorConnectionLost(ctx context.Context, conn *grpc.ClientConn) { if conn == nil { return diff --git a/client/console/console.go b/client/console/console.go index fd088d1309..280f29fbe4 100644 --- a/client/console/console.go +++ b/client/console/console.go @@ -95,6 +95,9 @@ type SliverClient struct { connMu sync.Mutex grpcConn *grpc.ClientConn + backgroundRPC rpcpb.SliverRPCClient + backgroundConn *grpc.ClientConn + backgroundDedicated bool connDetails *ConnectionDetails connCancel context.CancelFunc connWg *sync.WaitGroup @@ -326,6 +329,7 @@ func (con *SliverClient) applyConnectionHooksOnce() { con.App.PreReadlineHooks = append(con.App.PreReadlineHooks, con.syncOutputHook) con.App.PostCmdRunHooks = append(con.App.PostCmdRunHooks, con.syncOutputHook) + con.App.PreCmdRunLineHooks = append(con.App.PreCmdRunLineHooks, con.refreshDedicatedCommandConnectionHook) con.App.PreCmdRunLineHooks = append(con.App.PreCmdRunLineHooks, con.allowServerRootCommands) if shell := con.App.Shell(); shell != nil && shell.Completer != nil { baseCompleter := shell.Completer @@ -336,8 +340,12 @@ func (con *SliverClient) applyConnectionHooksOnce() { } } -func (con *SliverClient) startEventLoop(ctx context.Context) { - eventStream, err := con.Rpc.Events(ctx, &commonpb.Empty{}) +func (con *SliverClient) startEventLoop(ctx context.Context, rpc rpcpb.SliverRPCClient) { + if rpc == nil { + return + } + + eventStream, err := rpc.Events(ctx, &commonpb.Empty{}) if err != nil { fmt.Printf("%s%s\n", Warn, err) return diff --git a/client/console/log.go b/client/console/log.go index 76ca719b7a..6873d84001 100644 --- a/client/console/log.go +++ b/client/console/log.go @@ -107,7 +107,8 @@ func (con *SliverClient) refreshRemoteLogStreamsLocked() { if con.jsonRemoteWriter == nil && con.asciicastRemoteWriter == nil { return } - if con.Rpc == nil { + rpc := con.backgroundRPCClientLocked() + if rpc == nil { con.setRemoteLogStreamsLocked(nil, nil) return } @@ -116,7 +117,7 @@ func (con *SliverClient) refreshRemoteLogStreamsLocked() { var asciicastStream *ConsoleClientLogger if con.jsonRemoteWriter != nil { - s, err := con.ClientLogStream("json") + s, err := con.clientLogStream(rpc, "json") if err != nil { log.Printf("Could not get client json log stream: %s", err) } else { @@ -124,7 +125,7 @@ func (con *SliverClient) refreshRemoteLogStreamsLocked() { } } if con.asciicastRemoteWriter != nil { - s, err := con.ClientLogStream("asciicast") + s, err := con.clientLogStream(rpc, "asciicast") if err != nil { log.Printf("Could not get client asciicast log stream: %s", err) } else { @@ -166,7 +167,17 @@ func (con *SliverClient) setRemoteLogStreamsLocked(jsonStream, asciicastStream * // ClientLogStream requires a log stream name, used to save the logs // going through this stream in a specific log subdirectory/file. func (con *SliverClient) ClientLogStream(name string) (*ConsoleClientLogger, error) { - stream, err := con.Rpc.ClientLog(context.Background()) + con.connMu.Lock() + rpc := con.backgroundRPCClientLocked() + con.connMu.Unlock() + if rpc == nil { + return nil, fmt.Errorf("no RPC connection available for client log stream %q", name) + } + return con.clientLogStream(rpc, name) +} + +func (con *SliverClient) clientLogStream(rpc rpcpb.SliverRPCClient, name string) (*ConsoleClientLogger, error) { + stream, err := rpc.ClientLog(context.Background()) if err != nil { return nil, err } diff --git a/client/transport/connect_mode.go b/client/transport/connect_mode.go index da8a87a2ca..2b94ea071f 100644 --- a/client/transport/connect_mode.go +++ b/client/transport/connect_mode.go @@ -13,7 +13,6 @@ type MultiplayerConnectMode int const ( MultiplayerConnectAuto MultiplayerConnectMode = iota MultiplayerConnectDisableWG - MultiplayerConnectRequireWG ) type connectionCloser interface { @@ -65,9 +64,9 @@ func CloseGRPCConnection(conn *grpc.ClientConn) error { } var errs []error + errs = append(errs, conn.Close()) if closer := unregisterConnCloser(conn); closer != nil { errs = append(errs, closer.Close()) } - errs = append(errs, conn.Close()) return errors.Join(errs...) } diff --git a/client/transport/mtls.go b/client/transport/mtls.go index ca9174c791..2e43c02454 100644 --- a/client/transport/mtls.go +++ b/client/transport/mtls.go @@ -58,6 +58,15 @@ const ( multiplayerDialWireGuard ) +const ( + connectStatusGRPCMTLS = "grpc/mtls" + connectStatusWireGuard = "wireguard" + connectStatusGRPCMTLSOverWireGuard = "grpc/mtls over wireguard" +) + +// ConnectStatusFn receives human-readable connection phase updates. +type ConnectStatusFn func(string) + // Return value is mapped to request headers. func (t TokenAuth) GetRequestMetadata(ctx context.Context, in ...string) (map[string]string, error) { return map[string]string{ @@ -71,6 +80,12 @@ func (TokenAuth) RequireTransportSecurity() bool { // MTLSConnect - Connect to the sliver server func MTLSConnect(config *assets.ClientConfig) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) { + return MTLSConnectWithStatus(config, nil) +} + +// MTLSConnectWithStatus connects to the sliver server and optionally reports +// transport phase updates, such as WireGuard setup and gRPC dialing. +func MTLSConnectWithStatus(config *assets.ClientConfig, statusFn ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) { strategy, err := selectMultiplayerDialStrategy(config) if err != nil { return nil, nil, err @@ -78,9 +93,9 @@ func MTLSConnect(config *assets.ClientConfig) (rpcpb.SliverRPCClient, *grpc.Clie switch strategy { case multiplayerDialWireGuard: - return wireGuardMTLSConnect(config) + return wireGuardMTLSConnect(config, statusFn) default: - return directMTLSConnect(config) + return directMTLSConnect(config, statusFn) } } @@ -92,11 +107,6 @@ func selectMultiplayerDialStrategy(config *assets.ClientConfig) (multiplayerDial switch getMultiplayerConnectMode() { case MultiplayerConnectDisableWG: return multiplayerDialDirect, nil - case MultiplayerConnectRequireWG: - if err := validateWireGuardConfig(config); err != nil { - return multiplayerDialDirect, err - } - return multiplayerDialWireGuard, nil default: if config.WG == nil { return multiplayerDialDirect, nil @@ -108,7 +118,8 @@ func selectMultiplayerDialStrategy(config *assets.ClientConfig) (multiplayerDial } } -func directMTLSConnect(config *assets.ClientConfig) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) { +func directMTLSConnect(config *assets.ClientConfig, statusFn ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) { + notifyConnectStatus(statusFn, connectStatusGRPCMTLS) options, err := newMTLSDialOptions(config) if err != nil { return nil, nil, err @@ -116,6 +127,12 @@ func directMTLSConnect(config *assets.ClientConfig) (rpcpb.SliverRPCClient, *grp return dialRPCClient(fmt.Sprintf("%s:%d", config.LHost, config.LPort), options, nil) } +func notifyConnectStatus(statusFn ConnectStatusFn, status string) { + if statusFn != nil { + statusFn(status) + } +} + func newMTLSDialOptions(config *assets.ClientConfig) ([]grpc.DialOption, error) { tlsConfig, err := getTLSConfig(config.CACertificate, config.Certificate, config.PrivateKey) if err != nil { diff --git a/client/transport/mtls_test.go b/client/transport/mtls_test.go index c5b0c46d1b..d5bea04968 100644 --- a/client/transport/mtls_test.go +++ b/client/transport/mtls_test.go @@ -5,6 +5,9 @@ import ( "testing" "github.com/bishopfox/sliver/client/assets" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" ) func TestSelectMultiplayerDialStrategyLegacyConfigUsesDirectMTLS(t *testing.T) { @@ -19,15 +22,6 @@ func TestSelectMultiplayerDialStrategyLegacyConfigUsesDirectMTLS(t *testing.T) { } } -func TestSelectMultiplayerDialStrategyRequireWGRejectsMissingWGConfig(t *testing.T) { - setTestMultiplayerConnectMode(t, MultiplayerConnectRequireWG) - - _, err := selectMultiplayerDialStrategy(&assets.ClientConfig{}) - if !errors.Is(err, ErrMissingWireGuardConfig) { - t.Fatalf("expected missing WG config error, got %v", err) - } -} - func TestSelectMultiplayerDialStrategyRejectsIncompleteWGConfig(t *testing.T) { setTestMultiplayerConnectMode(t, MultiplayerConnectAuto) @@ -87,3 +81,29 @@ func setTestMultiplayerConnectMode(t *testing.T, mode MultiplayerConnectMode) { SetMultiplayerConnectMode(previous) }) } + +func TestCloseGRPCConnectionClosesConnBeforeTransportCloser(t *testing.T) { + conn, err := grpc.NewClient("passthrough:///sliver-test", grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("create grpc client: %v", err) + } + + var stateSeen connectivity.State + registerConnCloser(conn, testConnectionCloser(func() error { + stateSeen = conn.GetState() + return nil + })) + + if err := CloseGRPCConnection(conn); err != nil { + t.Fatalf("close grpc connection: %v", err) + } + if stateSeen != connectivity.Shutdown { + t.Fatalf("expected transport closer to run after grpc shutdown, got state %v", stateSeen) + } +} + +type testConnectionCloser func() error + +func (fn testConnectionCloser) Close() error { + return fn() +} diff --git a/client/transport/wireguard.go b/client/transport/wireguard.go index 1e72d4c652..5662c4d39a 100644 --- a/client/transport/wireguard.go +++ b/client/transport/wireguard.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" "sync" + "time" "github.com/bishopfox/sliver/client/assets" "github.com/bishopfox/sliver/protobuf/rpcpb" @@ -23,11 +24,17 @@ const ( multiplayerWireGuardDefaultServerIP = "100.65.0.1" multiplayerWireGuardMTU = 1420 multiplayerWireGuardKeepalive = 25 + multiplayerWireGuardDialTimeout = 30 * time.Second + multiplayerWireGuardRetryDelay = 250 * time.Millisecond ) var ( ErrMissingWireGuardConfig = errors.New("operator config has no wg block") ErrIncompleteWireGuardConfig = errors.New("operator config has incomplete wg block") + + multiplayerWireGuardIdleTimeout = 5 * time.Second + wireGuardTunnelCacheMu sync.Mutex + wireGuardTunnelCache = map[string]*cachedWireGuardTunnel{} ) type wireGuardTunnel struct { @@ -37,6 +44,19 @@ type wireGuardTunnel struct { closeOnce sync.Once } +type cachedWireGuardTunnel struct { + tunnel *wireGuardTunnel + target string + timer *time.Timer +} + +type idleWireGuardTunnelCloser struct { + key string + + tunnel *wireGuardTunnel + target string +} + func (t *wireGuardTunnel) Close() error { if t == nil { return nil @@ -44,6 +64,7 @@ func (t *wireGuardTunnel) Close() error { t.closeOnce.Do(func() { if t.dev != nil { t.dev.Close() + <-t.dev.Wait() } }) return nil @@ -56,22 +77,136 @@ func (t *wireGuardTunnel) DialContext(ctx context.Context, address string) (net. return t.net.DialContext(ctx, "tcp", address) } -func wireGuardMTLSConnect(config *assets.ClientConfig) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) { - tunnel, target, err := newWireGuardTunnel(config) +func (c *idleWireGuardTunnelCloser) Close() error { + if c == nil || c.key == "" || c.tunnel == nil { + return nil + } + cacheIdleWireGuardTunnel(c.key, c.tunnel, c.target) + return nil +} + +func wireGuardMTLSConnect(config *assets.ClientConfig, statusFn ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) { + deadline := time.Now().Add(multiplayerWireGuardDialTimeout) + var lastErr error + attempts := 0 + + for { + attempts++ + + notifyConnectStatus(statusFn, connectStatusWireGuard) + cacheKey, tunnel, target, err := acquireWireGuardTunnel(config) + if err != nil { + return nil, nil, err + } + + options, err := newMTLSDialOptions(config) + if err != nil { + _ = tunnel.Close() + return nil, nil, err + } + options = append(options, grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return tunnel.DialContext(ctx, addr) + })) + + notifyConnectStatus(statusFn, connectStatusGRPCMTLSOverWireGuard) + rpcClient, conn, err := dialRPCClient(target, options, nil) + if err == nil { + registerConnCloser(conn, &idleWireGuardTunnelCloser{ + key: cacheKey, + tunnel: tunnel, + target: target, + }) + return rpcClient, conn, nil + } + + _ = tunnel.Close() + lastErr = err + if !errors.Is(err, context.DeadlineExceeded) || !time.Now().Before(deadline) { + break + } + time.Sleep(multiplayerWireGuardRetryDelay) + } + + if attempts > 1 { + return nil, nil, fmt.Errorf("wireguard multiplayer connect failed after %d attempts: %w", attempts, lastErr) + } + return nil, nil, lastErr +} + +func acquireWireGuardTunnel(config *assets.ClientConfig) (string, *wireGuardTunnel, string, error) { + key, err := wireGuardTunnelCacheKey(config) if err != nil { - return nil, nil, err + return "", nil, "", err } - options, err := newMTLSDialOptions(config) + wireGuardTunnelCacheMu.Lock() + if cached := wireGuardTunnelCache[key]; cached != nil && cached.tunnel != nil { + delete(wireGuardTunnelCache, key) + if cached.timer != nil { + cached.timer.Stop() + cached.timer = nil + } + tunnel := cached.tunnel + target := cached.target + wireGuardTunnelCacheMu.Unlock() + return key, tunnel, target, nil + } + wireGuardTunnelCacheMu.Unlock() + + tunnel, target, err := newWireGuardTunnel(config) if err != nil { - _ = tunnel.Close() - return nil, nil, err + return "", nil, "", err + } + return key, tunnel, target, nil +} + +func cacheIdleWireGuardTunnel(key string, tunnel *wireGuardTunnel, target string) { + if key == "" || tunnel == nil { + return + } + wireGuardTunnelCacheMu.Lock() + previous := wireGuardTunnelCache[key] + cached := &cachedWireGuardTunnel{ + tunnel: tunnel, + target: target, } - options = append(options, grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return tunnel.DialContext(ctx, addr) - })) + wireGuardTunnelCache[key] = cached + if previous != nil && previous.timer != nil { + previous.timer.Stop() + } + cached.timer = time.AfterFunc(multiplayerWireGuardIdleTimeout, func() { + wireGuardTunnelCacheMu.Lock() + current := wireGuardTunnelCache[key] + if current != cached || current == nil { + wireGuardTunnelCacheMu.Unlock() + return + } + delete(wireGuardTunnelCache, key) + current.timer = nil + tunnel := current.tunnel + wireGuardTunnelCacheMu.Unlock() + if tunnel != nil { + _ = tunnel.Close() + } + }) + wireGuardTunnelCacheMu.Unlock() + if previous != nil && previous.tunnel != nil { + _ = previous.tunnel.Close() + } +} - return dialRPCClient(target, options, tunnel) +func wireGuardTunnelCacheKey(config *assets.ClientConfig) (string, error) { + if err := validateWireGuardConfig(config); err != nil { + return "", err + } + return strings.Join([]string{ + strings.TrimSpace(config.LHost), + strconv.Itoa(config.LPort), + strings.TrimSpace(config.WG.ServerPubKey), + strings.TrimSpace(config.WG.ClientPrivateKey), + strings.TrimSpace(config.WG.ClientIP), + strings.TrimSpace(config.WG.ServerIP), + }, "\x00"), nil } func newWireGuardTunnel(config *assets.ClientConfig) (*wireGuardTunnel, string, error) { diff --git a/client/transport/wireguard_netstack.go b/client/transport/wireguard_netstack.go index babec20c3f..5c0afe9c17 100644 --- a/client/transport/wireguard_netstack.go +++ b/client/transport/wireguard_netstack.go @@ -38,6 +38,11 @@ type transportTun struct { type transportNet transportTun +const ( + transportTCPReceiveBufferMax = 8 << 20 + transportTCPSendBufferMax = 6 << 20 +) + func createTransportNetTUN(localAddresses []netip.Addr, mtu int) (tun.Device, *transportNet, error) { n, err := rand.Int(rand.Reader, big.NewInt(0xFFFFFFFF)) if err != nil { @@ -60,9 +65,8 @@ func createTransportNetTUN(localAddresses []netip.Addr, mtu int) (tun.Device, *t mtu: mtu, } - sackEnabledOpt := tcpip.TCPSACKEnabled(true) - if tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt); tcpipErr != nil { - return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) + if err := configureTransportTCPStack(dev.stack); err != nil { + return nil, nil, err } dev.notifyHandle = dev.ep.AddNotify(dev) @@ -102,6 +106,43 @@ func createTransportNetTUN(localAddresses []netip.Addr, mtu int) (tun.Device, *t return dev, (*transportNet)(dev), nil } +func configureTransportTCPStack(ipstack *stack.Stack) error { + sackEnabledOpt := tcpip.TCPSACKEnabled(true) + if tcpipErr := ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt); tcpipErr != nil { + return fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) + } + + tcpRecoveryOpt := tcpip.TCPRecovery(0) + if tcpipErr := ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRecoveryOpt); tcpipErr != nil { + return fmt.Errorf("could not disable TCP RACK: %v", tcpipErr) + } + + renoOpt := tcpip.CongestionControlOption("reno") + if tcpipErr := ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &renoOpt); tcpipErr != nil { + return fmt.Errorf("could not set TCP congestion control to reno: %v", tcpipErr) + } + + tcpRXBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: tcp.MinBufferSize, + Default: tcp.DefaultSendBufferSize, + Max: transportTCPReceiveBufferMax, + } + if tcpipErr := ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRXBufOpt); tcpipErr != nil { + return fmt.Errorf("could not set TCP RX buffer size: %v", tcpipErr) + } + + tcpTXBufOpt := tcpip.TCPSendBufferSizeRangeOption{ + Min: tcp.MinBufferSize, + Default: tcp.DefaultReceiveBufferSize, + Max: transportTCPSendBufferMax, + } + if tcpipErr := ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTXBufOpt); tcpipErr != nil { + return fmt.Errorf("could not set TCP TX buffer size: %v", tcpipErr) + } + + return nil +} + func (tun *transportTun) Name() (string, error) { return "go", nil } diff --git a/client/transport/wireguard_test.go b/client/transport/wireguard_test.go new file mode 100644 index 0000000000..e85400570a --- /dev/null +++ b/client/transport/wireguard_test.go @@ -0,0 +1,86 @@ +package transport + +import ( + "testing" + "time" + + "github.com/bishopfox/sliver/client/assets" +) + +func TestWireGuardTunnelCacheKeyIncludesConfigMaterial(t *testing.T) { + config := &assets.ClientConfig{ + LHost: "127.0.0.1", + LPort: 31337, + WG: &assets.ClientWGConfig{ + ServerPubKey: "server-pub", + ClientPrivateKey: "client-priv", + ClientIP: "100.65.0.2", + ServerIP: "100.65.0.1", + }, + } + + key1, err := wireGuardTunnelCacheKey(config) + if err != nil { + t.Fatalf("cache key: %v", err) + } + + config.WG.ClientPrivateKey = "other-client-priv" + key2, err := wireGuardTunnelCacheKey(config) + if err != nil { + t.Fatalf("cache key after config change: %v", err) + } + + if key1 == key2 { + t.Fatal("expected cache key to change when wireguard config changes") + } +} + +func TestCacheIdleWireGuardTunnelRemovesIdleTunnel(t *testing.T) { + resetWireGuardTunnelCacheForTest(t) + + previousIdleTimeout := multiplayerWireGuardIdleTimeout + multiplayerWireGuardIdleTimeout = 10 * time.Millisecond + t.Cleanup(func() { + multiplayerWireGuardIdleTimeout = previousIdleTimeout + }) + + const cacheKey = "test-cache-key" + cacheIdleWireGuardTunnel(cacheKey, &wireGuardTunnel{}, "100.65.0.1:31337") + + deadline := time.Now().Add(250 * time.Millisecond) + for time.Now().Before(deadline) { + wireGuardTunnelCacheMu.Lock() + _, exists := wireGuardTunnelCache[cacheKey] + wireGuardTunnelCacheMu.Unlock() + if !exists { + return + } + time.Sleep(5 * time.Millisecond) + } + + t.Fatal("expected idle tunnel cache entry to be removed") +} + +func resetWireGuardTunnelCacheForTest(t *testing.T) { + t.Helper() + + wireGuardTunnelCacheMu.Lock() + for key, shared := range wireGuardTunnelCache { + if shared != nil && shared.timer != nil { + shared.timer.Stop() + } + delete(wireGuardTunnelCache, key) + } + wireGuardTunnelCacheMu.Unlock() + + t.Cleanup(func() { + wireGuardTunnelCacheMu.Lock() + for key, shared := range wireGuardTunnelCache { + if shared != nil && shared.timer != nil { + shared.timer.Stop() + } + delete(wireGuardTunnelCache, key) + } + wireGuardTunnelCacheMu.Unlock() + }) +} diff --git a/docs/sliver-docs/pages/docs/md/Multi-player Mode.md b/docs/sliver-docs/pages/docs/md/Multi-player Mode.md index 3a1194d3ce..ffddd04063 100644 --- a/docs/sliver-docs/pages/docs/md/Multi-player Mode.md +++ b/docs/sliver-docs/pages/docs/md/Multi-player Mode.md @@ -86,7 +86,6 @@ In direct mode: - Multiplayer is exposed directly over TCP on `--lport` (default `31337`). - Generated operator configs omit the `wg` block. - `sliver-client --disable-wg` forces a direct connection even if the config includes a `wg` block. -- `sliver-client --require-wg` fails fast if the config does not contain a valid `wg` block. The listener mode and the operator config need to match. A WireGuard-enabled config cannot talk to a direct listener, and a direct-only client cannot talk to the default WireGuard-wrapped listener. diff --git a/docs/sliver-docs/public/install b/docs/sliver-docs/public/install index 1f8747abb1..f0efe2edda 100644 --- a/docs/sliver-docs/public/install +++ b/docs/sliver-docs/public/install @@ -42,6 +42,34 @@ echo "Running from $(pwd)" echo "Using Minisign public key..." +stop_existing_sliver_daemon() { + local existing_sliver_server="" + + if test -x /root/sliver-server; then + existing_sliver_server="/root/sliver-server" + elif command -v sliver-server &> /dev/null; then + existing_sliver_server="$(command -v sliver-server)" + fi + + if [[ -z "$existing_sliver_server" ]]; then + return 0 + fi + + echo "Existing Sliver server install detected at $existing_sliver_server" + + if command -v systemctl &> /dev/null && test -f /etc/systemd/system/sliver.service; then + if systemctl is-active --quiet sliver; then + echo "Stopping the Sliver systemd service before upgrade..." + systemctl stop sliver + fi + fi + + if pgrep -f "${existing_sliver_server} daemon" > /dev/null 2>&1; then + echo "Stopping the running Sliver daemon before upgrade..." + pkill -f "${existing_sliver_server} daemon" + fi +} + # Download and Unpack Sliver Server ARCH="$(uname -m)" case "$ARCH" in @@ -79,6 +107,8 @@ echo "Verifying signatures ..." minisign -Vm "/root/$SLIVER_SERVER" -x "/root/$SLIVER_SERVER.minisig" -P "$SLIVER_MINISIGN_PUB_KEY" minisign -Vm "/root/$SLIVER_CLIENT" -x "/root/$SLIVER_CLIENT.minisig" -P "$SLIVER_MINISIGN_PUB_KEY" +stop_existing_sliver_daemon + if test -f "/root/$SLIVER_SERVER"; then echo "Moving the Sliver server executable to /root/sliver-server..." mv "/root/$SLIVER_SERVER" /root/sliver-server @@ -135,7 +165,7 @@ echo "Generating local configs ..." # Generate local configs echo "Generating operator configs ..." mkdir -p /root/.sliver-client/configs -/root/sliver-server operator --name root --lhost localhost --permissions all --save /root/.sliver-client/configs +/root/sliver-server operator --name root --lhost 127.0.0.1 --permissions all --save /root/.sliver-client/configs chown -R root:root /root/.sliver-client/ USER_DIRS=(/home/*) @@ -144,7 +174,7 @@ for USER_DIR in "${USER_DIRS[@]}"; do if id -u "$USER" >/dev/null 2>&1; then echo "Generating operator configs for user $USER..." mkdir -p "$USER_DIR/.sliver-client/configs" - /root/sliver-server operator --name "$USER" --lhost localhost --permissions all --save "$USER_DIR/.sliver-client/configs" + /root/sliver-server operator --name "$USER" --lhost 127.0.0.1 --permissions all --save "$USER_DIR/.sliver-client/configs" chown -R "$USER":"$(id -gn "$USER")" "$USER_DIR/.sliver-client/" fi done diff --git a/server/console/console-admin_wireguard_test.go b/server/console/console-admin_wireguard_test.go index b96f1ab3bf..9f1be77fd8 100644 --- a/server/console/console-admin_wireguard_test.go +++ b/server/console/console-admin_wireguard_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "io" "net" "os" "os/exec" @@ -11,14 +12,17 @@ import ( "strconv" "strings" "testing" + "time" clientassets "github.com/bishopfox/sliver/client/assets" clienttransport "github.com/bishopfox/sliver/client/transport" "github.com/bishopfox/sliver/protobuf/commonpb" + "github.com/bishopfox/sliver/protobuf/rpcpb" "github.com/bishopfox/sliver/server/certs" "github.com/bishopfox/sliver/server/db" "github.com/bishopfox/sliver/server/db/models" servertransport "github.com/bishopfox/sliver/server/transport" + "google.golang.org/grpc" "gorm.io/gorm" ) @@ -120,6 +124,160 @@ func TestNewOperatorConfigWithWireGuardConnectsToWrappedMultiplayer(t *testing.T } } +func TestNewOperatorConfigWithWireGuardConnectsToWrappedMultiplayerRepeatedly(t *testing.T) { + certs.SetupCAs() + certs.SetupWGKeys() + certs.SetupMultiplayerWGKeys() + clienttransport.SetMultiplayerConnectMode(clienttransport.MultiplayerConnectAuto) + + operatorName := uniqueKickOperatorName(t) + t.Cleanup(func() { + _ = removeOperator(operatorName) + _ = revokeOperatorClientCertificate(operatorName) + closeOperatorStreams(operatorName) + }) + + port := freeUDPPort(t) + configJSON, err := NewOperatorConfig(operatorName, "127.0.0.1", uint16(port), []string{"all"}, true) + if err != nil { + t.Fatalf("generate wireguard operator config: %v", err) + } + + config := &clientassets.ClientConfig{} + if err := json.Unmarshal(configJSON, config); err != nil { + t.Fatalf("parse operator config: %v", err) + } + + grpcServer, ln, err := servertransport.StartWGWrappedMtlsClientListener("127.0.0.1", uint16(port)) + if err != nil { + if strings.Contains(err.Error(), "operation not permitted") { + t.Skipf("wireguard listener bind not permitted in this environment: %v", err) + } + t.Fatalf("start wrapped multiplayer listener: %v", err) + } + defer grpcServer.Stop() + defer ln.Close() + + for attempt := 1; attempt <= 5; attempt++ { + rpcClient, conn, err := clienttransport.MTLSConnect(config) + if err != nil { + t.Fatalf("attempt %d: connect operator through wireguard wrapper: %v", attempt, err) + } + if _, err := rpcClient.GetVersion(context.Background(), &commonpb.Empty{}); err != nil { + _ = clienttransport.CloseGRPCConnection(conn) + t.Fatalf("attempt %d: GetVersion over wrapped multiplayer failed: %v", attempt, err) + } + if _, err := rpcClient.GetOperators(context.Background(), &commonpb.Empty{}); err != nil { + _ = clienttransport.CloseGRPCConnection(conn) + t.Fatalf("attempt %d: GetOperators over wrapped multiplayer failed: %v", attempt, err) + } + if err := clienttransport.CloseGRPCConnection(conn); err != nil { + t.Fatalf("attempt %d: close wrapped multiplayer connection: %v", attempt, err) + } + } +} + +func TestWrappedMultiplayerWireGuardSupportsUnaryRPCsWithDedicatedCommandConnection(t *testing.T) { + t.Run("events-only", func(t *testing.T) { + runWrappedMultiplayerWireGuardUnaryWithBackgroundStreams(t, true, false) + }) + t.Run("tunnel-only", func(t *testing.T) { + runWrappedMultiplayerWireGuardUnaryWithBackgroundStreams(t, false, true) + }) + t.Run("events-and-tunnel", func(t *testing.T) { + runWrappedMultiplayerWireGuardUnaryWithBackgroundStreams(t, true, true) + }) +} + +func runWrappedMultiplayerWireGuardUnaryWithBackgroundStreams(t *testing.T, useEvents bool, useTunnel bool) { + certs.SetupCAs() + certs.SetupWGKeys() + certs.SetupMultiplayerWGKeys() + clienttransport.SetMultiplayerConnectMode(clienttransport.MultiplayerConnectAuto) + + operatorName := uniqueKickOperatorName(t) + t.Cleanup(func() { + _ = removeOperator(operatorName) + _ = revokeOperatorClientCertificate(operatorName) + closeOperatorStreams(operatorName) + }) + + port := freeUDPPort(t) + configJSON, err := NewOperatorConfig(operatorName, "127.0.0.1", uint16(port), []string{"all"}, true) + if err != nil { + t.Fatalf("generate wireguard operator config: %v", err) + } + + config := &clientassets.ClientConfig{} + if err := json.Unmarshal(configJSON, config); err != nil { + t.Fatalf("parse operator config: %v", err) + } + + grpcServer, ln, err := servertransport.StartWGWrappedMtlsClientListener("127.0.0.1", uint16(port)) + if err != nil { + if strings.Contains(err.Error(), "operation not permitted") { + t.Skipf("wireguard listener bind not permitted in this environment: %v", err) + } + t.Fatalf("start wrapped multiplayer listener: %v", err) + } + defer grpcServer.Stop() + defer ln.Close() + + var streamClient rpcpb.SliverRPCClient + var streamConn *grpc.ClientConn + if useEvents || useTunnel { + streamClient, streamConn, err = clienttransport.MTLSConnect(config) + if err != nil { + t.Fatalf("connect dedicated background operator through wireguard wrapper: %v", err) + } + defer clienttransport.CloseGRPCConnection(streamConn) + } + + streamCtx, streamCancel := context.WithCancel(context.Background()) + defer streamCancel() + + if useEvents { + events, err := streamClient.Events(streamCtx, &commonpb.Empty{}) + if err != nil { + t.Fatalf("open events stream: %v", err) + } + go drainWGTestStream(t, "events", func() error { + _, err := events.Recv() + return err + }) + } + + if useTunnel { + tunnelData, err := streamClient.TunnelData(streamCtx) + if err != nil { + t.Fatalf("open tunnel data stream: %v", err) + } + go drainWGTestStream(t, "tunnel-data", func() error { + _, err := tunnelData.Recv() + return err + }) + } + + // Let the background streams settle before issuing the unary call. The + // interactive console opens these streams first and only later runs + // unary RPC-backed commands like "operators". + if useEvents || useTunnel { + time.Sleep(2 * time.Second) + } + + rpcClient, conn, err := clienttransport.MTLSConnect(config) + if err != nil { + t.Fatalf("connect operator command channel through wireguard wrapper: %v", err) + } + defer clienttransport.CloseGRPCConnection(conn) + + callCtx, callCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer callCancel() + if _, err := rpcClient.GetOperators(callCtx, &commonpb.Empty{}); err != nil { + t.Fatalf("GetOperators with dedicated command connection over wrapped multiplayer failed (events=%t tunnel=%t): %v", useEvents, useTunnel, err) + } +} + func TestOperatorCLIWireGuardConfigConnectsToWrappedMultiplayer(t *testing.T) { certs.SetupCAs() certs.SetupWGKeys() @@ -238,3 +396,16 @@ func freeUDPPort(t *testing.T) int { defer ln.Close() return ln.LocalAddr().(*net.UDPAddr).Port } + +func drainWGTestStream(t *testing.T, name string, recv func() error) { + t.Helper() + + err := recv() + if err == nil || errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { + return + } + if strings.Contains(err.Error(), "context canceled") { + return + } + t.Errorf("%s stream recv failed: %v", name, err) +} diff --git a/server/rpc/rpc-ai_test.go b/server/rpc/rpc-ai_test.go index 4e0f1d5a77..2d1f0b1199 100644 --- a/server/rpc/rpc-ai_test.go +++ b/server/rpc/rpc-ai_test.go @@ -17,6 +17,7 @@ import ( "github.com/bishopfox/sliver/protobuf/rpcpb" serverai "github.com/bishopfox/sliver/server/ai" "github.com/bishopfox/sliver/server/configs" + "github.com/bishopfox/sliver/server/core" "github.com/bishopfox/sliver/server/db" "github.com/bishopfox/sliver/server/db/models" "google.golang.org/grpc/codes" @@ -26,6 +27,8 @@ import ( "gorm.io/gorm" ) +const aiConversationEventStreamTimeout = 30 * time.Second + func TestSaveAIConversationMessageCompletesConversationAndPublishesEvents(t *testing.T) { setupAIRPCTestEnv(t) @@ -69,9 +72,9 @@ func TestSaveAIConversationMessageCompletesConversationAndPublishesEvents(t *tes client, cleanup := newBufnetRPCClient(t) defer cleanup() - streamCtx, cancelStream := context.WithTimeout(context.Background(), 10*time.Second) + streamCtx, cancelStream := context.WithTimeout(context.Background(), aiConversationEventStreamTimeout) defer cancelStream() - eventStream, err := client.Events(streamCtx, &commonpb.Empty{}) + eventStream, err := startAIEventStream(t, client, streamCtx) if err != nil { t.Fatalf("start events stream: %v", err) } @@ -242,9 +245,9 @@ func TestSaveAIConversationMessageCompletesOpenAIWithoutExplicitBaseURL(t *testi client, cleanup := newBufnetRPCClient(t) defer cleanup() - streamCtx, cancelStream := context.WithTimeout(context.Background(), 10*time.Second) + streamCtx, cancelStream := context.WithTimeout(context.Background(), aiConversationEventStreamTimeout) defer cancelStream() - eventStream, err := client.Events(streamCtx, &commonpb.Empty{}) + eventStream, err := startAIEventStream(t, client, streamCtx) if err != nil { t.Fatalf("start events stream: %v", err) } @@ -332,9 +335,9 @@ func TestSaveAIConversationMessagePublishesFailureMessageWhenProviderErrors(t *t client, cleanup := newBufnetRPCClient(t) defer cleanup() - streamCtx, cancelStream := context.WithTimeout(context.Background(), 10*time.Second) + streamCtx, cancelStream := context.WithTimeout(context.Background(), aiConversationEventStreamTimeout) defer cancelStream() - eventStream, err := client.Events(streamCtx, &commonpb.Empty{}) + eventStream, err := startAIEventStream(t, client, streamCtx) if err != nil { t.Fatalf("start events stream: %v", err) } @@ -622,9 +625,9 @@ func TestSaveAIConversationMessagePersistsReasoningAndToolBlocks(t *testing.T) { client, cleanup := newBufnetRPCClient(t) defer cleanup() - streamCtx, cancelStream := context.WithTimeout(context.Background(), 10*time.Second) + streamCtx, cancelStream := context.WithTimeout(context.Background(), aiConversationEventStreamTimeout) defer cancelStream() - eventStream, err := client.Events(streamCtx, &commonpb.Empty{}) + eventStream, err := startAIEventStream(t, client, streamCtx) if err != nil { t.Fatalf("start events stream: %v", err) } @@ -933,6 +936,27 @@ func saveOpenAICompletionConfig(t *testing.T, model string, thinkingLevel string } } +func startAIEventStream(t *testing.T, client rpcpb.SliverRPCClient, streamCtx context.Context) (rpcpb.SliverRPC_EventsClient, error) { + t.Helper() + + before := len(core.Clients.ActiveOperators()) + eventStream, err := client.Events(streamCtx, &commonpb.Empty{}) + if err != nil { + return nil, err + } + + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if len(core.Clients.ActiveOperators()) > before { + return eventStream, nil + } + time.Sleep(10 * time.Millisecond) + } + + t.Fatalf("event stream did not subscribe before deadline") + return nil, nil +} + func waitForAIConversationEvent(t *testing.T, eventStream rpcpb.SliverRPC_EventsClient, conversationID string) { t.Helper() diff --git a/server/transport/mtls.go b/server/transport/mtls.go index d5bd387c7c..49567b04eb 100644 --- a/server/transport/mtls.go +++ b/server/transport/mtls.go @@ -25,6 +25,7 @@ import ( "fmt" "net" "runtime/debug" + "strings" "github.com/bishopfox/sliver/protobuf/rpcpb" "github.com/bishopfox/sliver/server/certs" @@ -89,21 +90,33 @@ func StartMtlsClientServer(ln net.Listener) (*grpc.Server, error) { grpcServer := grpc.NewServer(options...) rpcpb.RegisterSliverRPCServer(grpcServer, rpc.NewServer()) go func() { - panicked := true defer func() { - if panicked { - mtlsLog.Errorf("stacktrace from panic: %s", string(debug.Stack())) + if r := recover(); r != nil { + mtlsLog.Errorf("gRPC server panic: %v\n%s", r, string(debug.Stack())) } }() - if err := grpcServer.Serve(ln); err != nil { + if err := grpcServer.Serve(ln); err != nil && !isExpectedGRPCServerExit(err) { mtlsLog.Warnf("gRPC server exited with error: %v", err) - } else { - panicked = false } }() return grpcServer, nil } +func isExpectedGRPCServerExit(err error) bool { + if err == nil { + return true + } + if errors.Is(err, grpc.ErrServerStopped) || errors.Is(err, net.ErrClosed) { + return true + } + + // The gVisor-backed listener used by the WireGuard multiplayer transport can + // surface this accept error during normal shutdown on Windows. + errString := err.Error() + return strings.Contains(errString, "use of closed network connection") || + strings.Contains(errString, "endpoint is in invalid state") +} + // getOperatorServerTLSConfig - Generate the TLS configuration, we do now allow the end user // to specify any TLS parameters, we choose sensible defaults instead func getOperatorServerTLSConfig(host string) *tls.Config {