Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion client/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ func init() {
rootCmd.TraverseChildren = true
rootCmd.Flags().String(RCFlagName, "", "path to rc script file")
rootCmd.PersistentFlags().Bool(disableWGFlag, false, "connect to multiplayer directly even if the operator config includes a WireGuard wrapper")
rootCmd.PersistentFlags().Bool(requireWGFlag, false, "require the operator config's WireGuard wrapper for multiplayer connections")

// Create the console client, without any RPC or commands bound to it yet.
// This created before anything so that multiple commands can make use of
Expand Down
177 changes: 177 additions & 0 deletions client/cli/connect_spinner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package cli

import (
"fmt"
"io"
"strings"
"time"
"unicode/utf8"

bspinner "charm.land/bubbles/v2/spinner"
"github.com/bishopfox/sliver/client/transport"
"github.com/bishopfox/sliver/protobuf/rpcpb"
"google.golang.org/grpc"
)

const (
minConnectSpinnerDuration = 900 * time.Millisecond
minConnectStatusDuration = 350 * time.Millisecond
)

type connectResult struct {
rpc rpcpb.SliverRPCClient
conn *grpc.ClientConn
err error
}

func connectWithSpinner(out io.Writer, target string, connect func(transport.ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error)) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) {
statusCh := make(chan string, 8)
resultCh := make(chan connectResult, 1)

go func() {
rpc, conn, err := connect(func(status string) {
sendStatus(statusCh, status)
})
resultCh <- connectResult{rpc: rpc, conn: conn, err: err}
}()

currentStatus := ""
currentStatusSince := time.Time{}
queuedStatus := ""
lastWidth := 0
spinner := bspinner.New(bspinner.WithSpinner(bspinner.Line))
ticker := time.NewTicker(spinner.Spinner.FPS)
defer ticker.Stop()
startedAt := time.Now()
var pendingResult *connectResult

setStatus := func(status string) {
status = strings.TrimSpace(status)
if status == "" || status == currentStatus {
return
}
currentStatus = status
currentStatusSince = time.Now()
}

stageStatus := func(status string) {
status = strings.TrimSpace(status)
if status == "" {
return
}
if currentStatus == "" || time.Since(currentStatusSince) >= minConnectStatusDuration {
queuedStatus = ""
setStatus(status)
return
}
queuedStatus = status
}

flushQueuedStatus := func() {
if queuedStatus == "" || time.Since(currentStatusSince) < minConnectStatusDuration {
return
}
nextStatus := queuedStatus
queuedStatus = ""
setStatus(nextStatus)
}

canReturnPendingResult := func() bool {
if pendingResult == nil {
return false
}
if pendingResult.err != nil {
return true
}
if time.Since(startedAt) < minConnectSpinnerDuration {
return false
}
if queuedStatus != "" {
return false
}
if !currentStatusSince.IsZero() && time.Since(currentStatusSince) < minConnectStatusDuration {
return false
}
return true
}

render := func() {
line := fmt.Sprintf("%s %s", spinner.View(), formatConnectSpinnerMessage(target, currentStatus))
lastWidth = writeSpinnerLine(out, line, lastWidth)
}

render()
for {
select {
case status := <-statusCh:
stageStatus(status)
spinner, _ = spinner.Update(spinner.Tick())
render()

case result := <-resultCh:
pendingResult = &result

case <-ticker.C:
spinner, _ = spinner.Update(spinner.Tick())
flushQueuedStatus()
render()
if canReturnPendingResult() {
result := *pendingResult
pendingResult = nil
clearSpinnerLine(out, lastWidth)
return result.rpc, result.conn, result.err
}
}

if pendingResult != nil && pendingResult.err != nil {
result := *pendingResult
pendingResult = nil
clearSpinnerLine(out, lastWidth)
return result.rpc, result.conn, result.err
}
}
}

func sendStatus(statusCh chan string, status string) {
status = strings.TrimSpace(status)
if status == "" {
return
}
statusCh <- status
}

func formatConnectSpinnerMessage(target string, status string) string {
target = strings.TrimSpace(target)
status = strings.TrimSpace(status)

if target == "" {
if status == "" {
return "Connecting ..."
}
return fmt.Sprintf("Connecting (%s) ...", status)
}
if status == "" {
return fmt.Sprintf("Connecting to %s ...", target)
}
return fmt.Sprintf("Connecting to %s (%s) ...", target, status)
}

func writeSpinnerLine(out io.Writer, line string, lastWidth int) int {
width := utf8.RuneCountInString(line)
padding := ""
if lastWidth > width {
padding = strings.Repeat(" ", lastWidth-width)
}
fmt.Fprintf(out, "\r%s%s", line, padding)
if width > lastWidth {
return width
}
return lastWidth
}

func clearSpinnerLine(out io.Writer, lastWidth int) {
if lastWidth <= 0 {
return
}
fmt.Fprintf(out, "\r%s\r", strings.Repeat(" ", lastWidth))
}
132 changes: 132 additions & 0 deletions client/cli/connect_spinner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package cli

import (
"bytes"
"strings"
"testing"
"time"

"github.com/bishopfox/sliver/client/transport"
"github.com/bishopfox/sliver/protobuf/rpcpb"
"google.golang.org/grpc"
)

func TestFormatConnectSpinnerMessage(t *testing.T) {
tests := []struct {
name string
target string
status string
want string
}{
{
name: "target only",
target: "127.0.0.1:31337",
want: "Connecting to 127.0.0.1:31337 ...",
},
{
name: "target with status",
target: "127.0.0.1:31337",
status: "wireguard",
want: "Connecting to 127.0.0.1:31337 (wireguard) ...",
},
{
name: "status without target",
status: "grpc/mtls",
want: "Connecting (grpc/mtls) ...",
},
{
name: "trim whitespace",
target: " 127.0.0.1:31337 ",
status: " grpc/mtls over wireguard ",
want: "Connecting to 127.0.0.1:31337 (grpc/mtls over wireguard) ...",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if got := formatConnectSpinnerMessage(test.target, test.status); got != test.want {
t.Fatalf("expected %q, got %q", test.want, got)
}
})
}
}

func TestConnectWithSpinnerOutput(t *testing.T) {
var out bytes.Buffer
spinnerDelay := 2 * (100 * time.Millisecond)

_, _, err := connectWithSpinner(&out, "127.0.0.1:31337", func(statusFn transport.ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) {
statusFn("wireguard")
time.Sleep(spinnerDelay)
statusFn("grpc/mtls over wireguard")
time.Sleep(spinnerDelay)
return nil, nil, nil
})
if err != nil {
t.Fatalf("connectWithSpinner returned error: %v", err)
}

got := out.String()
if !strings.Contains(got, "Connecting to 127.0.0.1:31337 (wireguard) ...") {
t.Fatalf("expected wireguard status in output, got %q", got)
}
if !strings.Contains(got, "Connecting to 127.0.0.1:31337 (grpc/mtls over wireguard) ...") {
t.Fatalf("expected grpc over wireguard status in output, got %q", got)
}
frameCount := 0
for _, frame := range []string{"|", "/", "-", "\\"} {
if strings.Contains(got, frame+" Connecting to 127.0.0.1:31337") {
frameCount++
}
}
if frameCount < 2 {
t.Fatalf("expected multiple spinner frames in output, got %q", got)
}
}

func TestConnectWithSpinnerFastSuccessStillShowsMultipleFrames(t *testing.T) {
var out bytes.Buffer

_, _, err := connectWithSpinner(&out, "127.0.0.1:31337", func(statusFn transport.ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) {
statusFn("grpc/mtls")
return nil, nil, nil
})
if err != nil {
t.Fatalf("connectWithSpinner returned error: %v", err)
}

got := out.String()
frameCount := 0
for _, frame := range []string{"|", "/", "-", "\\"} {
if strings.Contains(got, frame+" Connecting to 127.0.0.1:31337 (grpc/mtls) ...") {
frameCount++
}
}
if frameCount < 2 {
t.Fatalf("expected multiple spinner frames for fast success, got %q", got)
}
}

func TestConnectWithSpinnerFastSuccessShowsEachStatus(t *testing.T) {
var out bytes.Buffer

_, _, err := connectWithSpinner(&out, "127.0.0.1:31337", func(statusFn transport.ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) {
statusFn("wireguard")
statusFn("grpc/mtls over wireguard")
return nil, nil, nil
})
if err != nil {
t.Fatalf("connectWithSpinner returned error: %v", err)
}

got := out.String()
if !strings.Contains(got, "Connecting to 127.0.0.1:31337 (wireguard) ...") {
t.Fatalf("expected fast success output to include wireguard status, got %q", got)
}
if !strings.Contains(got, "Connecting to 127.0.0.1:31337 (grpc/mtls over wireguard) ...") {
t.Fatalf("expected fast success output to include grpc over wireguard status, got %q", got)
}
if strings.Contains(got, "Connected to 127.0.0.1:31337") {
t.Fatalf("did not expect past-tense success output, got %q", got)
}
}
18 changes: 11 additions & 7 deletions client/cli/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package cli

import (
"fmt"
"os"

"github.com/bishopfox/sliver/client/assets"
"github.com/bishopfox/sliver/client/command"
Expand Down Expand Up @@ -64,17 +65,20 @@ func consoleRunnerCmd(con *console.SliverClient, run bool) (pre, post func(cmd *
return nil
}

// Don't clobber output when simply running an implant command from system shell.
if run {
fmt.Printf("Connecting to %s:%d ...\n", config.LHost, config.LPort)
}

target := fmt.Sprintf("%s:%d", config.LHost, config.LPort)
var rpc rpcpb.SliverRPCClient
var ln *grpc.ClientConn

rpc, ln, err = transport.MTLSConnect(config)
// Don't clobber output when simply running an implant command from system shell.
if run {
rpc, ln, err = connectWithSpinner(os.Stdout, target, func(statusFn transport.ConnectStatusFn) (rpcpb.SliverRPCClient, *grpc.ClientConn, error) {
return transport.MTLSConnectWithStatus(config, statusFn)
})
} else {
rpc, ln, err = transport.MTLSConnect(config)
}
if err != nil {
fmt.Printf("Connection to server failed %s", err)
fmt.Printf("Connection to server failed %s\n", err)
return nil
}
return console.StartClient(con, rpc, ln, &console.ConnectionDetails{ConfigKey: configKey, Config: config}, command.ServerCommands(con, nil), command.SliverCommands(con), run, rcScript)
Expand Down
15 changes: 1 addition & 14 deletions client/cli/transport_mode.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package cli

import (
"fmt"

"github.com/bishopfox/sliver/client/transport"
"github.com/spf13/cobra"
)

const (
requireWGFlag = "require-wg"
disableWGFlag = "disable-wg"
)

Expand All @@ -18,23 +15,13 @@ func applyMultiplayerConnectMode(cmd *cobra.Command) error {
return nil
}

requireWG, err := cmd.Flags().GetBool(requireWGFlag)
if err != nil {
return err
}
disableWG, err := cmd.Flags().GetBool(disableWGFlag)
if err != nil {
return err
}
if requireWG && disableWG {
return fmt.Errorf("--%s and --%s cannot be used together", requireWGFlag, disableWGFlag)
}

mode := transport.MultiplayerConnectAuto
switch {
case requireWG:
mode = transport.MultiplayerConnectRequireWG
case disableWG:
if disableWG {
mode = transport.MultiplayerConnectDisableWG
}
transport.SetMultiplayerConnectMode(mode)
Expand Down
Loading
Loading