Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions ans/transparency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1225,3 +1225,104 @@ func TestTransparencyClient_ParameterValidation(t *testing.T) {
func startsWith(s, prefix string) bool {
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
}

func TestDoRequestWithSchemaVersion_ErrorPaths(t *testing.T) {
tests := []struct {
name string
handler http.HandlerFunc
wantErr string
}{
{
name: "invalid JSON response",
handler: func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte("not-json"))
},
wantErr: "failed to parse response",
},
{
name: "schema version from header when missing in body",
handler: func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Schema-Version", "V1")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]any{
"status": "ACTIVE",
"payload": map[string]any{
"ansId": "agent-1",
"ansName": "ans://v1.0.0.host.com",
},
})
},
},
{
name: "response with empty payload",
handler: func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]any{
"status": "ACTIVE",
"schemaVersion": "V1",
})
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(tt.handler)
defer server.Close()

client, err := NewTransparencyClient(WithBaseURL(server.URL))
if err != nil {
t.Fatalf("NewTransparencyClient() error = %v", err)
}

ctx := context.Background()
result, err := client.GetAgentTransparencyLog(ctx, "test-agent")
if tt.wantErr != "" {
if err == nil {
t.Fatal("expected error")
}
} else {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result == nil {
t.Fatal("result is nil")
}
}
})
}
}

func TestDoRequestWithSchemaVersion_HTTPClientFailure(t *testing.T) {
tests := []struct {
name string
}{
{name: "connection refused"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use a server that's been shut down to simulate connection failure
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
serverURL := server.URL
server.Close() // Close immediately so connections fail

client, err := NewTransparencyClient(WithBaseURL(serverURL))
if err != nil {
t.Fatalf("NewTransparencyClient() error = %v", err)
}

ctx := context.Background()
_, err = client.GetAgentTransparencyLog(ctx, "test-agent")
if err == nil {
t.Fatal("expected error for connection failure")
}
})
}
}
82 changes: 73 additions & 9 deletions cmd/ans-cli/cmd/badge.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"os"
"sort"
"strings"

"github.com/godaddy/ans-sdk-go/ans"
Expand All @@ -13,7 +14,10 @@ import (
"github.com/spf13/cobra"
)

const defaultBadgeAuditLimit = 10
const (
defaultBadgeAuditLimit = 10
maxMetadataHashesDisplay = 20
)

func buildBadgeCmd() *cobra.Command {
var (
Expand Down Expand Up @@ -87,24 +91,33 @@ func runBadgeWithParams(agentID string, auditTrail, checkpoint bool, transparenc

func outputBadgeJSON(ctx context.Context, c *ans.TransparencyClient, agentID string, logEntry *models.TransparencyLog, auditTrail, checkpoint bool) {
result := map[string]any{"transparencyLog": logEntry}
var warnings []string

if auditTrail {
params := &models.AgentAuditParams{Limit: defaultBadgeAuditLimit, Offset: 0}
if audit, auditErr := c.GetAgentTransparencyLogAudit(ctx, agentID, params); auditErr != nil {
fmt.Fprintf(os.Stdout, "Warning: failed to retrieve audit trail: %v\n", auditErr)
warning := fmt.Sprintf("failed to retrieve audit trail: %v", auditErr)
fmt.Fprintln(os.Stderr, "Warning: "+warning)
warnings = append(warnings, warning)
} else {
result["audit"] = audit
}
}

if checkpoint {
if checkpointData, checkpointErr := c.GetCheckpoint(ctx); checkpointErr != nil {
fmt.Fprintf(os.Stdout, "Warning: failed to retrieve checkpoint: %v\n", checkpointErr)
warning := fmt.Sprintf("failed to retrieve checkpoint: %v", checkpointErr)
fmt.Fprintln(os.Stderr, "Warning: "+warning)
warnings = append(warnings, warning)
} else {
result["checkpoint"] = checkpointData
}
}

if len(warnings) > 0 {
result["warnings"] = warnings
}

jsonData, _ := json.MarshalIndent(result, "", " ")
fmt.Fprintln(os.Stdout, string(jsonData))
}
Expand Down Expand Up @@ -324,8 +337,49 @@ func printV1Attestations(att *models.AttestationsV1) {

if len(att.DNSRecordsProvisioned) > 0 {
fmt.Fprintln(os.Stdout, " DNS Records Provisioned:")
for key, value := range att.DNSRecordsProvisioned {
fmt.Fprintf(os.Stdout, " %s: %s\n", key, value)
keys := make([]string, 0, len(att.DNSRecordsProvisioned))
for k := range att.DNSRecordsProvisioned {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
fmt.Fprintf(os.Stdout, " %s: %s\n", k, att.DNSRecordsProvisioned[k])
}
}

if len(att.MetadataHashes) > 0 {
fmt.Fprintln(os.Stdout, " Metadata Hashes:")
keys := make([]string, 0, len(att.MetadataHashes))
for k := range att.MetadataHashes {
keys = append(keys, k)
}
sort.Strings(keys)
for i, k := range keys {
if i >= maxMetadataHashesDisplay {
fmt.Fprintf(os.Stdout, " ... and %d more\n", len(keys)-maxMetadataHashesDisplay)
break
}
fmt.Fprintf(os.Stdout, " %s: %s\n", k, att.MetadataHashes[k])
}
}

if len(att.ValidIdentityCerts) > 0 {
fmt.Fprintf(os.Stdout, " Valid Identity Certs: %d\n", len(att.ValidIdentityCerts))
for _, cert := range att.ValidIdentityCerts {
fmt.Fprintf(os.Stdout, " %s (%s)\n", cert.Fingerprint, cert.Type)
if cert.NotAfter != nil {
fmt.Fprintf(os.Stdout, " Not After: %s\n", cert.NotAfter.Format("2006-01-02 15:04:05 MST"))
}
}
}

if len(att.ValidServerCerts) > 0 {
fmt.Fprintf(os.Stdout, " Valid Server Certs: %d\n", len(att.ValidServerCerts))
for _, cert := range att.ValidServerCerts {
fmt.Fprintf(os.Stdout, " %s (%s)\n", cert.Fingerprint, cert.Type)
if cert.NotAfter != nil {
fmt.Fprintf(os.Stdout, " Not After: %s\n", cert.NotAfter.Format("2006-01-02 15:04:05 MST"))
}
}
}
}
Expand Down Expand Up @@ -363,8 +417,13 @@ func printV0Attestations(att *models.AttestationsV0) {

if len(att.DNSRecordsProvisioned) > 0 {
fmt.Fprintln(os.Stdout, " DNS Records Provisioned:")
for key, value := range att.DNSRecordsProvisioned {
fmt.Fprintf(os.Stdout, " %s: %s\n", key, value)
keys := make([]string, 0, len(att.DNSRecordsProvisioned))
for k := range att.DNSRecordsProvisioned {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
fmt.Fprintf(os.Stdout, " %s: %s\n", k, att.DNSRecordsProvisioned[k])
}
}
}
Expand Down Expand Up @@ -440,8 +499,13 @@ func printAttestationsFromPayload(att map[string]any) {

if dnsRecords, dnsOk := att["dnsRecordsProvisioned"].(map[string]any); dnsOk && len(dnsRecords) > 0 {
fmt.Fprintln(os.Stdout, " DNS Records Provisioned:")
for key, value := range dnsRecords {
fmt.Fprintf(os.Stdout, " %s: %v\n", key, value)
keys := make([]string, 0, len(dnsRecords))
for k := range dnsRecords {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
fmt.Fprintf(os.Stdout, " %s: %v\n", k, dnsRecords[k])
}
}
}
Expand Down
23 changes: 23 additions & 0 deletions cmd/ans-cli/cmd/badge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,26 @@ func TestPrintV1Payload(t *testing.T) {
DNSRecordsProvisioned: map[string]string{
"_ans-badge": "v=ans-badge1",
},
MetadataHashes: map[string]string{
"sha256": "deadbeef1234",
},
ValidIdentityCerts: []models.CertificateV1Extended{
{
CertificateV1: models.CertificateV1{
Fingerprint: "SHA256:idcert1",
Type: models.CertTypeX509OVClient,
},
NotAfter: &expiresAt,
},
},
ValidServerCerts: []models.CertificateV1Extended{
{
CertificateV1: models.CertificateV1{
Fingerprint: "SHA256:srvcert1",
Type: models.CertTypeX509DVServer,
},
},
},
},
},
},
Expand All @@ -346,6 +366,9 @@ func TestPrintV1Payload(t *testing.T) {
"Agent Info", "example.com", "v1.0.0", "Test Agent", "provider-123",
"Attestations", "ACME-DNS-01", "Identity Certificate", "Server Certificate",
"DNS Records", "Signature Info", "key-123",
"Metadata Hashes", "deadbeef1234",
"Valid Identity Certs", "SHA256:idcert1", "Not After",
"Valid Server Certs", "SHA256:srvcert1",
},
},
}
Expand Down
52 changes: 52 additions & 0 deletions cmd/ans-cli/cmd/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -671,6 +672,57 @@ func TestRunRegisterWithParams(t *testing.T) {
}
}

func TestRunBadgeWithParams_JSONOutput(t *testing.T) {
tests := []struct {
name string
audit bool
checkpoint bool
}{
{
name: "JSON with audit and checkpoint success",
audit: true,
checkpoint: true,
},
{
name: "JSON no extras",
audit: false,
checkpoint: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.URL.Path == "/v1/agents/agent-123" && r.Method == http.MethodGet && r.URL.RawQuery == "":
json.NewEncoder(w).Encode(models.TransparencyLog{
Status: "ACTIVE",
SchemaVersion: "V1",
Payload: map[string]any{"logId": "test"},
})
case strings.Contains(r.URL.Path, "/audit"):
json.NewEncoder(w).Encode(models.TransparencyLogAudit{Records: []models.TransparencyLog{}})
case strings.Contains(r.URL.Path, "/checkpoint"):
json.NewEncoder(w).Encode(models.CheckpointResponse{LogSize: 100})
default:
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]any{})
}
}))
defer server.Close()

setupViperForTest(t, server.URL)
viper.Set("json", true)

err := runBadgeWithParams("agent-123", tt.audit, tt.checkpoint, server.URL)
if err != nil {
t.Errorf("runBadgeWithParams() error = %v", err)
}
})
}
}

func TestRunBadgeWithParams_ServerErrors(t *testing.T) {
tests := []struct {
name string
Expand Down
11 changes: 8 additions & 3 deletions cmd/ans-cli/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@ import (
"github.com/spf13/viper"
)

// Execute runs the root command
func Execute() {
// Run builds and executes the root command, returning any error.
func Run() error {
rootCmd := buildRootCmd()
if err := rootCmd.Execute(); err != nil {
return rootCmd.Execute()
}

// Execute runs the root command and exits on error.
func Execute() {
if err := Run(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
Expand Down
Loading