Skip to content

Commit 0716dab

Browse files
committed
fix(go-sdk): implement quote hex decoding and env var encryption
1 parent dc69453 commit 0716dab

File tree

3 files changed

+270
-39
lines changed

3 files changed

+270
-39
lines changed

sdk/go/dstack/client.go

Lines changed: 96 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@ package dstack
99
import (
1010
"bytes"
1111
"context"
12+
"crypto/ecdsa"
13+
"crypto/ed25519"
1214
"crypto/sha512"
15+
"crypto/x509"
1316
"encoding/hex"
1417
"encoding/json"
18+
"encoding/pem"
1519
"fmt"
1620
"io"
1721
"log/slog"
@@ -30,25 +34,30 @@ type GetTlsKeyResponse struct {
3034

3135
// AsUint8Array converts the private key to bytes, optionally limiting the length
3236
func (r *GetTlsKeyResponse) AsUint8Array(maxLength ...int) ([]byte, error) {
33-
content := r.Key
34-
content = strings.Replace(content, "-----BEGIN PRIVATE KEY-----", "", 1)
35-
content = strings.Replace(content, "-----END PRIVATE KEY-----", "", 1)
36-
content = strings.Replace(content, "\n", "", -1)
37-
content = strings.Replace(content, " ", "", -1)
38-
39-
// For now, assume base64 encoding - would need actual implementation
40-
// This is a placeholder that matches the JavaScript version behavior
41-
if len(maxLength) > 0 && maxLength[0] > 0 {
42-
result := make([]byte, maxLength[0])
43-
// For testing, return a fixed pattern
44-
for i := 0; i < maxLength[0] && i < len(content); i++ {
45-
result[i] = byte(i % 256)
46-
}
47-
return result, nil
37+
block, _ := pem.Decode([]byte(r.Key))
38+
if block == nil {
39+
return nil, fmt.Errorf("failed to decode pem private key")
4840
}
4941

50-
// Return content as bytes for testing
51-
return []byte(content), nil
42+
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
43+
if err != nil {
44+
return nil, fmt.Errorf("failed to parse private key: %w", err)
45+
}
46+
47+
var keyBytes []byte
48+
switch k := key.(type) {
49+
case *ecdsa.PrivateKey:
50+
keyBytes = k.D.FillBytes(make([]byte, (k.Curve.Params().N.BitLen()+7)/8))
51+
case ed25519.PrivateKey:
52+
keyBytes = k.Seed()
53+
default:
54+
return nil, fmt.Errorf("unsupported key type: %T", key)
55+
}
56+
57+
if len(maxLength) > 0 && maxLength[0] > 0 && maxLength[0] < len(keyBytes) {
58+
return keyBytes[:maxLength[0]], nil
59+
}
60+
return keyBytes, nil
5261
}
5362

5463
// Represents the response from a key derivation request.
@@ -77,12 +86,22 @@ func (r *GetKeyResponse) DecodeSignatureChain() ([][]byte, error) {
7786

7887
// Represents the response from a quote request.
7988
type GetQuoteResponse struct {
80-
Quote []byte `json:"quote"`
89+
Quote string `json:"quote"`
8190
EventLog string `json:"event_log"`
82-
ReportData []byte `json:"report_data"`
91+
ReportData string `json:"report_data"`
8392
VmConfig string `json:"vm_config"`
8493
}
8594

95+
// DecodeQuote returns the quote bytes
96+
func (r *GetQuoteResponse) DecodeQuote() ([]byte, error) {
97+
return hex.DecodeString(r.Quote)
98+
}
99+
100+
// DecodeReportData returns the report data bytes
101+
func (r *GetQuoteResponse) DecodeReportData() ([]byte, error) {
102+
return hex.DecodeString(r.ReportData)
103+
}
104+
86105
// DecodeEventLog returns the event log as structured data
87106
func (r *GetQuoteResponse) DecodeEventLog() ([]EventLog, error) {
88107
var events []EventLog
@@ -106,17 +125,20 @@ type EventLog struct {
106125

107126
// Represents the TCB information
108127
type TcbInfo struct {
109-
Mrtd string `json:"mrtd"`
110-
Rtmr0 string `json:"rtmr0"`
111-
Rtmr1 string `json:"rtmr1"`
112-
Rtmr2 string `json:"rtmr2"`
113-
Rtmr3 string `json:"rtmr3"`
114-
// The hash of the OS image. This is empty if the OS image is not measured by KMS.
115-
OsImageHash string `json:"os_image_hash,omitempty"`
116-
ComposeHash string `json:"compose_hash"`
117-
DeviceID string `json:"device_id"`
118-
AppCompose string `json:"app_compose"`
119-
EventLog []EventLog `json:"event_log"`
128+
Mrtd string `json:"mrtd"`
129+
Rtmr0 string `json:"rtmr0"`
130+
Rtmr1 string `json:"rtmr1"`
131+
Rtmr2 string `json:"rtmr2"`
132+
Rtmr3 string `json:"rtmr3"`
133+
AppCompose string `json:"app_compose"`
134+
EventLog []EventLog `json:"event_log"`
135+
// V0.3.x fields
136+
RootfsHash string `json:"rootfs_hash,omitempty"`
137+
// V0.5.x fields
138+
MrAggregated string `json:"mr_aggregated,omitempty"`
139+
OsImageHash string `json:"os_image_hash,omitempty"`
140+
ComposeHash string `json:"compose_hash,omitempty"`
141+
DeviceID string `json:"device_id,omitempty"`
120142
}
121143

122144
// Represents the response from an info request
@@ -130,9 +152,11 @@ type InfoResponse struct {
130152
MrAggregated string `json:"mr_aggregated,omitempty"`
131153
KeyProviderInfo string `json:"key_provider_info"`
132154
// Optional: empty if OS image is not measured by KMS
133-
OsImageHash string `json:"os_image_hash,omitempty"`
134-
ComposeHash string `json:"compose_hash"`
135-
VmConfig string `json:"vm_config,omitempty"`
155+
OsImageHash string `json:"os_image_hash,omitempty"`
156+
ComposeHash string `json:"compose_hash"`
157+
VmConfig string `json:"vm_config,omitempty"`
158+
CloudVendor string `json:"cloud_vendor,omitempty"`
159+
CloudProduct string `json:"cloud_product,omitempty"`
136160
}
137161

138162
// DecodeTcbInfo decodes the TcbInfo string into a TcbInfo struct
@@ -347,6 +371,9 @@ type tlsKeyOptions struct {
347371
usageRaTls bool
348372
usageServerAuth bool
349373
usageClientAuth bool
374+
notBefore *uint64
375+
notAfter *uint64
376+
withAppInfo *bool
350377
}
351378

352379
// WithSubject sets the subject for the TLS key
@@ -384,6 +411,27 @@ func WithUsageClientAuth(usage bool) TlsKeyOption {
384411
}
385412
}
386413

414+
// WithNotBefore sets the not_before timestamp for the certificate
415+
func WithNotBefore(t uint64) TlsKeyOption {
416+
return func(opts *tlsKeyOptions) {
417+
opts.notBefore = &t
418+
}
419+
}
420+
421+
// WithNotAfter sets the not_after timestamp for the certificate
422+
func WithNotAfter(t uint64) TlsKeyOption {
423+
return func(opts *tlsKeyOptions) {
424+
opts.notAfter = &t
425+
}
426+
}
427+
428+
// WithAppInfo sets the with_app_info flag for the certificate
429+
func WithAppInfo(enabled bool) TlsKeyOption {
430+
return func(opts *tlsKeyOptions) {
431+
opts.withAppInfo = &enabled
432+
}
433+
}
434+
387435
// Gets a TLS key from the dstack service with optional parameters.
388436
func (c *DstackClient) GetTlsKey(
389437
ctx context.Context,
@@ -406,6 +454,15 @@ func (c *DstackClient) GetTlsKey(
406454
if len(opts.altNames) > 0 {
407455
payload["alt_names"] = opts.altNames
408456
}
457+
if opts.notBefore != nil {
458+
payload["not_before"] = *opts.notBefore
459+
}
460+
if opts.notAfter != nil {
461+
payload["not_after"] = *opts.notAfter
462+
}
463+
if opts.withAppInfo != nil {
464+
payload["with_app_info"] = *opts.withAppInfo
465+
}
409466

410467
data, err := c.sendRPCRequest(ctx, "/GetTlsKey", payload)
411468
if err != nil {
@@ -684,7 +741,7 @@ type TappdClient struct {
684741
func NewTappdClient(opts ...DstackClientOption) *TappdClient {
685742
// Create a modified option to use TAPPD_SIMULATOR_ENDPOINT
686743
tappdOpts := make([]DstackClientOption, 0, len(opts)+1)
687-
744+
688745
// Add default endpoint option that checks TAPPD_SIMULATOR_ENDPOINT
689746
tappdOpts = append(tappdOpts, func(c *DstackClient) {
690747
if c.endpoint == "" {
@@ -696,13 +753,13 @@ func NewTappdClient(opts ...DstackClientOption) *TappdClient {
696753
}
697754
}
698755
})
699-
756+
700757
// Add user-provided options
701758
tappdOpts = append(tappdOpts, opts...)
702-
759+
703760
client := NewDstackClient(tappdOpts...)
704761
client.logger.Warn("TappdClient is deprecated, please use DstackClient instead")
705-
762+
706763
return &TappdClient{
707764
DstackClient: client,
708765
}
@@ -714,7 +771,7 @@ func NewTappdClient(opts ...DstackClientOption) *TappdClient {
714771
// Deprecated: Use GetKey instead.
715772
func (tc *TappdClient) DeriveKey(ctx context.Context, path string, subject string, altNames []string) (*GetTlsKeyResponse, error) {
716773
tc.logger.Warn("deriveKey is deprecated, please use GetKey instead")
717-
774+
718775
if subject == "" {
719776
subject = path
720777
}
@@ -743,7 +800,7 @@ func (tc *TappdClient) DeriveKey(ctx context.Context, path string, subject strin
743800
// Deprecated: Use GetQuote instead.
744801
func (tc *TappdClient) TdxQuote(ctx context.Context, reportData []byte, hashAlgorithm string) (*GetQuoteResponse, error) {
745802
tc.logger.Warn("tdxQuote is deprecated, please use GetQuote instead")
746-
803+
747804
if hashAlgorithm == "raw" {
748805
if len(reportData) > 64 {
749806
return nil, fmt.Errorf("report data is too large, it should be at most 64 bytes when hashAlgorithm is raw")

sdk/go/dstack/encrypt_env_vars.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// SPDX-FileCopyrightText: © 2025 Phala Network <dstack@phala.network>
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
package dstack
6+
7+
import (
8+
"crypto/aes"
9+
"crypto/cipher"
10+
"crypto/rand"
11+
"encoding/hex"
12+
"encoding/json"
13+
"fmt"
14+
"strings"
15+
16+
"golang.org/x/crypto/curve25519"
17+
)
18+
19+
// EnvVar represents an environment variable key-value pair.
20+
type EnvVar struct {
21+
Key string `json:"key"`
22+
Value string `json:"value"`
23+
}
24+
25+
// EncryptEnvVars encrypts environment variables using X25519 ECDH + AES-256-GCM.
26+
//
27+
// publicKeyHex is the remote X25519 public key (hex-encoded, with or without 0x prefix).
28+
// Returns hex(ephemeral_pubkey || iv || ciphertext).
29+
func EncryptEnvVars(envs []EnvVar, publicKeyHex string) (string, error) {
30+
cleanHex := strings.TrimPrefix(publicKeyHex, "0x")
31+
remotePubKey, err := hex.DecodeString(cleanHex)
32+
if err != nil {
33+
return "", fmt.Errorf("failed to decode public key: %w", err)
34+
}
35+
if len(remotePubKey) != 32 {
36+
return "", fmt.Errorf("invalid public key length: expected 32 bytes, got %d", len(remotePubKey))
37+
}
38+
39+
envJSON, err := json.Marshal(struct {
40+
Env []EnvVar `json:"env"`
41+
}{Env: envs})
42+
if err != nil {
43+
return "", fmt.Errorf("failed to marshal env vars: %w", err)
44+
}
45+
46+
ephemeralPrivKey := make([]byte, 32)
47+
if _, err := rand.Read(ephemeralPrivKey); err != nil {
48+
return "", fmt.Errorf("failed to generate ephemeral private key: %w", err)
49+
}
50+
51+
ephemeralPubKey, err := curve25519.X25519(ephemeralPrivKey, curve25519.Basepoint)
52+
if err != nil {
53+
return "", fmt.Errorf("failed to derive ephemeral public key: %w", err)
54+
}
55+
56+
sharedSecret, err := curve25519.X25519(ephemeralPrivKey, remotePubKey)
57+
if err != nil {
58+
return "", fmt.Errorf("failed to derive shared secret: %w", err)
59+
}
60+
61+
block, err := aes.NewCipher(sharedSecret)
62+
if err != nil {
63+
return "", fmt.Errorf("failed to create aes cipher: %w", err)
64+
}
65+
gcm, err := cipher.NewGCM(block)
66+
if err != nil {
67+
return "", fmt.Errorf("failed to create aes-gcm: %w", err)
68+
}
69+
70+
iv := make([]byte, 12)
71+
if _, err := rand.Read(iv); err != nil {
72+
return "", fmt.Errorf("failed to generate iv: %w", err)
73+
}
74+
75+
ciphertext := gcm.Seal(nil, iv, envJSON, nil)
76+
result := make([]byte, 0, len(ephemeralPubKey)+len(iv)+len(ciphertext))
77+
result = append(result, ephemeralPubKey...)
78+
result = append(result, iv...)
79+
result = append(result, ciphertext...)
80+
81+
return hex.EncodeToString(result), nil
82+
}

0 commit comments

Comments
 (0)