diff --git a/Cargo.lock b/Cargo.lock index e9083251..d58bf02b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -544,6 +544,15 @@ version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +[[package]] +name = "arc-swap" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d03449bb8ca2cc2ef70869af31463d1ae5ccc8fa3e334b307203fbf815207e" +dependencies = [ + "rustversion", +] + [[package]] name = "ark-ff" version = "0.3.0" @@ -984,6 +993,26 @@ dependencies = [ "serde", ] +[[package]] +name = "bincode" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" +dependencies = [ + "bincode_derive", + "serde", + "unty", +] + +[[package]] +name = "bincode_derive" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" +dependencies = [ + "virtue", +] + [[package]] name = "bindgen" version = "0.71.1" @@ -1320,9 +1349,12 @@ version = "0.5.5" dependencies = [ "anyhow", "bon", + "bytes", "enum_dispatch", "fs-err", "hickory-resolver 0.24.4", + "http", + "http-body-util", "instant-acme", "path-absolutize", "rand 0.8.5", @@ -2196,6 +2228,7 @@ name = "dstack-gateway" version = "0.5.5" dependencies = [ "anyhow", + "arc-swap", "bytes", "certbot", "clap", @@ -2204,13 +2237,16 @@ dependencies = [ "dstack-guest-agent-rpc", "dstack-kms-rpc", "dstack-types", + "flate2", "fs-err", "futures", "git-version", "hex", "hickory-resolver 0.24.4", + "http-body-util", "http-client", "hyper", + "hyper-rustls", "hyper-util", "insta", "ipnet", @@ -2225,6 +2261,7 @@ dependencies = [ "rand 0.8.5", "reqwest", "rinja", + "rmp-serde", "rocket", "rustls", "safe-write", @@ -2234,10 +2271,15 @@ dependencies = [ "sha2 0.10.9", "shared_child", "smallvec", + "tdx-attest", + "tempfile", "tokio", "tokio-rustls", "tracing", "tracing-subscriber", + "uuid", + "wavekv", + "x509-parser", ] [[package]] @@ -4292,7 +4334,7 @@ checksum = "2044d8bd5489b199890c3dbf38d4c8f50f3a5a38833986808b14e2367fe267fa" dependencies = [ "aes 0.7.5", "base64 0.13.1", - "bincode", + "bincode 1.3.3", "crossterm", "hmac 0.11.0", "pbkdf2", @@ -5942,6 +5984,28 @@ dependencies = [ "rustc-hex", ] +[[package]] +name = "rmp" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "rocket" version = "0.6.0-dev" @@ -7807,6 +7871,12 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "unty" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" + [[package]] name = "url" version = "2.5.7" @@ -7860,6 +7930,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "virtue" +version = "0.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" + [[package]] name = "void" version = "1.0.2" @@ -7992,6 +8068,29 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wavekv" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf9b73bc556dfdb7ef33617a9d477b803198db43ea3df25463efaf43d4986fe8" +dependencies = [ + "anyhow", + "bincode 2.0.1", + "chrono", + "crc32fast", + "dashmap", + "fs-err", + "futures", + "hex", + "rmp-serde", + "serde", + "serde-human-bytes", + "serde_json", + "sha2 0.10.9", + "tokio", + "tracing", +] + [[package]] name = "web-sys" version = "0.3.83" diff --git a/Cargo.toml b/Cargo.toml index d0c07582..a8579b84 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,9 +86,11 @@ serde-duration = { path = "serde-duration" } dstack-mr = { path = "dstack-mr" } dstack-verifier = { path = "verifier", default-features = false } size-parser = { path = "size-parser" } +wavekv = "1.0.0" # Core dependencies anyhow = { version = "1.0.97", default-features = false } +arc-swap = "1" or-panic = { version = "1.0", default-features = false } chrono = "0.4.40" clap = { version = "4.5.32", features = ["derive", "string"] } @@ -109,6 +111,7 @@ sd-notify = "0.4.5" jemallocator = "0.5.4" # Serialization/Parsing +flate2 = "1.1" borsh = { version = "1.5.7", default-features = false, features = ["derive"] } bon = { version = "3.4.0", default-features = false } base64 = "0.22.1" @@ -122,6 +125,7 @@ scale = { version = "3.7.4", package = "parity-scale-codec", features = [ ] } serde = { version = "1.0.228", features = ["derive"], default-features = false } serde-human-bytes = "0.1.2" +rmp-serde = "1.3.0" serde_json = { version = "1.0.140", default-features = false } serde_ini = "0.2.0" toml = "0.8.20" @@ -145,6 +149,11 @@ hyper-util = { version = "0.1.10", features = [ "client-legacy", "http1", ] } +hyper-rustls = { version = "0.27", default-features = false, features = [ + "ring", + "http1", + "tls12", +] } hyperlocal = "0.9.1" ipnet = { version = "2.11.0", features = ["serde"] } reqwest = { version = "0.12.14", default-features = false, features = [ @@ -234,7 +243,6 @@ yaml-rust2 = "0.10.4" luks2 = "0.5.0" scopeguard = "1.2.0" -flate2 = "1.1" tar = "0.4" [profile.release] diff --git a/REUSE.toml b/REUSE.toml index 18aced85..f2c711f1 100644 --- a/REUSE.toml +++ b/REUSE.toml @@ -191,3 +191,12 @@ SPDX-License-Identifier = "CC0-1.0" path = "guest-agent/fixtures/*" SPDX-FileCopyrightText = "NONE" SPDX-License-Identifier = "CC0-1.0" + +[[annotations]] +path = [ + "gateway/test-run/e2e/certs/*", + "gateway/test-run/e2e/configs/*", + "gateway/test-run/e2e/pebble-config.json", +] +SPDX-FileCopyrightText = "NONE" +SPDX-License-Identifier = "CC0-1.0" diff --git a/certbot/Cargo.toml b/certbot/Cargo.toml index 52c49bea..eaf594cf 100644 --- a/certbot/Cargo.toml +++ b/certbot/Cargo.toml @@ -12,9 +12,12 @@ license.workspace = true [dependencies] anyhow.workspace = true bon.workspace = true +bytes.workspace = true enum_dispatch.workspace = true fs-err.workspace = true hickory-resolver.workspace = true +http.workspace = true +http-body-util.workspace = true instant-acme.workspace = true path-absolutize.workspace = true rcgen.workspace = true diff --git a/certbot/cli/src/main.rs b/certbot/cli/src/main.rs index b22d3246..b30eade3 100644 --- a/certbot/cli/src/main.rs +++ b/certbot/cli/src/main.rs @@ -164,7 +164,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } rustls::crypto::ring::default_provider() .install_default() diff --git a/certbot/src/acme_client.rs b/certbot/src/acme_client.rs index 50ec4589..fd32c2ad 100644 --- a/certbot/src/acme_client.rs +++ b/certbot/src/acme_client.rs @@ -21,6 +21,7 @@ use tracing::{debug, error, info}; use x509_parser::prelude::{GeneralName, Pem}; use super::dns01_client::{Dns01Api, Dns01Client}; +use super::http_client::ReqwestHttpClient; /// A AcmeClient instance. pub struct AcmeClient { @@ -63,7 +64,9 @@ impl AcmeClient { dns_txt_ttl: u32, ) -> Result { let credentials: Credentials = serde_json::from_str(encoded_credentials)?; - let account = Account::from_credentials(credentials.credentials).await?; + let http_client = Box::new(ReqwestHttpClient::new()?); + let account = + Account::from_credentials_and_http(credentials.credentials, http_client).await?; let credentials: Credentials = serde_json::from_str(encoded_credentials)?; Ok(Self { account, @@ -81,7 +84,8 @@ impl AcmeClient { max_dns_wait: Duration, dns_txt_ttl: u32, ) -> Result { - let (account, credentials) = Account::create( + let http_client = Box::new(ReqwestHttpClient::new()?); + let (account, credentials) = Account::create_with_http( &NewAccount { contact: &[], terms_of_service_agreed: true, @@ -89,6 +93,7 @@ impl AcmeClient { }, acme_url, None, + http_client, ) .await .with_context(|| format!("failed to create ACME account for {acme_url}"))?; diff --git a/certbot/src/acme_client/tests.rs b/certbot/src/acme_client/tests.rs index d77504a6..1481538c 100644 --- a/certbot/src/acme_client/tests.rs +++ b/certbot/src/acme_client/tests.rs @@ -10,6 +10,7 @@ async fn new_acme_client() -> Result { let dns01_client = Dns01Client::new_cloudflare( std::env::var("CLOUDFLARE_ZONE_ID").expect("CLOUDFLARE_ZONE_ID not set"), std::env::var("CLOUDFLARE_API_TOKEN").expect("CLOUDFLARE_API_TOKEN not set"), + std::env::var("CLOUDFLARE_API_URL").ok(), ); let credentials = std::env::var("LETSENCRYPT_CREDENTIAL").expect("LETSENCRYPT_CREDENTIAL not set"); diff --git a/certbot/src/bot.rs b/certbot/src/bot.rs index 1b59b767..76cd8491 100644 --- a/certbot/src/bot.rs +++ b/certbot/src/bot.rs @@ -28,6 +28,7 @@ pub struct CertBotConfig { credentials_file: PathBuf, auto_create_account: bool, cf_api_token: String, + cf_api_url: Option, cert_file: PathBuf, key_file: PathBuf, cert_dir: PathBuf, @@ -94,8 +95,12 @@ impl CertBot { .trim_start_matches("*.") .trim_end_matches('.') .to_string(); - let dns01_client = - Dns01Client::new_cloudflare(config.cf_api_token.clone(), base_domain).await?; + let dns01_client = Dns01Client::new_cloudflare( + base_domain, + config.cf_api_token.clone(), + config.cf_api_url.clone(), + ) + .await?; let acme_client = match fs::read_to_string(&config.credentials_file) { Ok(credentials) => { if acme_matches(&credentials, &config.acme_url) { diff --git a/certbot/src/dns01_client.rs b/certbot/src/dns01_client.rs index b4d4aeaa..88fbf91a 100644 --- a/certbot/src/dns01_client.rs +++ b/certbot/src/dns01_client.rs @@ -72,9 +72,12 @@ pub enum Dns01Client { } impl Dns01Client { - pub async fn new_cloudflare(api_token: String, base_domain: String) -> Result { - Ok(Self::Cloudflare( - CloudflareClient::new(api_token, base_domain).await?, - )) + pub async fn new_cloudflare( + base_domain: String, + api_token: String, + api_url: Option, + ) -> Result { + let client = CloudflareClient::new(base_domain, api_token, api_url).await?; + Ok(Self::Cloudflare(client)) } } diff --git a/certbot/src/dns01_client/cloudflare.rs b/certbot/src/dns01_client/cloudflare.rs index d7a6b1f5..620defb0 100644 --- a/certbot/src/dns01_client/cloudflare.rs +++ b/certbot/src/dns01_client/cloudflare.rs @@ -14,12 +14,18 @@ use crate::dns01_client::Record; use super::Dns01Api; -const CLOUDFLARE_API_URL: &str = "https://api.cloudflare.com/client/v4"; +const DEFAULT_CLOUDFLARE_API_URL: &str = "https://api.cloudflare.com/client/v4"; #[derive(Debug, Serialize, Deserialize)] pub struct CloudflareClient { zone_id: String, api_token: String, + #[serde(default = "default_api_url")] + api_url: String, +} + +fn default_api_url() -> String { + DEFAULT_CLOUDFLARE_API_URL.to_string() } #[derive(Deserialize)] @@ -59,12 +65,21 @@ struct ZonesResultInfo { } impl CloudflareClient { - pub async fn new(api_token: String, base_domain: String) -> Result { - let zone_id = Self::resolve_zone_id(&api_token, &base_domain).await?; - Ok(Self { api_token, zone_id }) + pub async fn new( + base_domain: String, + api_token: String, + api_url: Option, + ) -> Result { + let api_url = api_url.unwrap_or_else(|| DEFAULT_CLOUDFLARE_API_URL.to_string()); + let zone_id = Self::resolve_zone_id(&api_token, &base_domain, &api_url).await?; + Ok(Self { + zone_id, + api_token, + api_url, + }) } - async fn resolve_zone_id(api_token: &str, base_domain: &str) -> Result { + async fn resolve_zone_id(api_token: &str, base_domain: &str, api_url: &str) -> Result { let base = base_domain .trim() .trim_start_matches("*.") @@ -72,7 +87,7 @@ impl CloudflareClient { .to_lowercase(); let client = Client::new(); - let url = format!("{CLOUDFLARE_API_URL}/zones"); + let url = format!("{api_url}/zones"); let per_page = 50u32; let mut page = 1u32; @@ -150,8 +165,7 @@ impl CloudflareClient { async fn add_record(&self, record: &impl Serialize) -> Result { let client = Client::new(); - let url = format!("{CLOUDFLARE_API_URL}/zones/{}/dns_records", self.zone_id); - + let url = format!("{}/zones/{}/dns_records", self.api_url, self.zone_id); let response = client .post(&url) .header("Authorization", format!("Bearer {}", self.api_token)) @@ -176,8 +190,8 @@ impl CloudflareClient { async fn remove_record_inner(&self, record_id: &str) -> Result<()> { let client = Client::new(); let url = format!( - "{CLOUDFLARE_API_URL}/zones/{zone_id}/dns_records/{record_id}", - zone_id = self.zone_id + "{}/zones/{}/dns_records/{}", + self.api_url, self.zone_id, record_id ); debug!(url = %url, "cloudflare remove_record request"); @@ -201,7 +215,7 @@ impl CloudflareClient { async fn get_records_inner(&self, domain: &str) -> Result> { let client = Client::new(); - let url = format!("{CLOUDFLARE_API_URL}/zones/{}/dns_records", self.zone_id); + let url = format!("{}/zones/{}/dns_records", self.api_url, self.zone_id); let per_page = 100u32; let mut records = Vec::new(); @@ -338,8 +352,9 @@ mod tests { async fn create_client() -> CloudflareClient { CloudflareClient::new( - std::env::var("CLOUDFLARE_API_TOKEN").expect("CLOUDFLARE_API_TOKEN not set"), std::env::var("TEST_DOMAIN").expect("TEST_DOMAIN not set"), + std::env::var("CLOUDFLARE_API_TOKEN").expect("CLOUDFLARE_API_TOKEN not set"), + std::env::var("CLOUDFLARE_API_URL").ok(), ) .await .unwrap() diff --git a/certbot/src/http_client.rs b/certbot/src/http_client.rs new file mode 100644 index 00000000..2de8f823 --- /dev/null +++ b/certbot/src/http_client.rs @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Custom HTTP client for instant_acme that supports both HTTP and HTTPS. + +use anyhow::{Context, Result}; +use bytes::Bytes; +use http::Request; +use http_body_util::{BodyExt, Full}; +use instant_acme::{BytesResponse, HttpClient}; +use reqwest::Client; +use std::error::Error as StdError; +use std::future::Future; +use std::pin::Pin; + +/// A HTTP client that supports both HTTP and HTTPS connections. +/// This is needed because the default instant_acme client only supports HTTPS. +#[derive(Clone)] +pub struct ReqwestHttpClient { + client: Client, +} + +impl ReqwestHttpClient { + /// Create a new HTTP client. + pub fn new() -> Result { + let client = Client::builder() + .user_agent("dstack-certbot/0.1") + .build() + .context("failed to build reqwest client")?; + Ok(Self { client }) + } +} + +impl HttpClient for ReqwestHttpClient { + fn request( + &self, + req: Request>, + ) -> Pin> + Send>> { + let client = self.client.clone(); + Box::pin(async move { + let (parts, body) = req.into_parts(); + let uri = parts.uri.to_string(); + let method = parts.method.clone(); + let body_bytes = body + .collect() + .await + .map_err(|e| { + instant_acme::Error::Other(Box::new(e) as Box) + })? + .to_bytes(); + + tracing::debug!( + target: "certbot::http_client", + %uri, + %method, + request_body_len = body_bytes.len(), + "sending ACME request" + ); + + let mut builder = client.request(parts.method, uri.clone()); + for (name, value) in &parts.headers { + builder = builder.header(name, value); + } + + let response = builder + .body(body_bytes.to_vec()) + .send() + .await + .map_err(|e| { + instant_acme::Error::Other(Box::new(e) as Box) + })?; + + let status = response.status(); + let headers = response.headers().clone(); + let body = response.bytes().await.map_err(|e| { + instant_acme::Error::Other(Box::new(e) as Box) + })?; + + tracing::debug!( + target: "certbot::http_client", + %uri, + %status, + response_body = %String::from_utf8_lossy(&body), + "received ACME response" + ); + + let mut http_response = http::Response::builder().status(status); + for (name, value) in headers { + if let Some(name) = name { + http_response = http_response.header(name, value); + } + } + let http_response = http_response + .body(Full::new(body)) + .map_err(|e| instant_acme::Error::Other(Box::new(e)))?; + + Ok(BytesResponse::from(http_response)) + }) + } +} diff --git a/certbot/src/lib.rs b/certbot/src/lib.rs index 20cf8ed1..df71b993 100644 --- a/certbot/src/lib.rs +++ b/certbot/src/lib.rs @@ -24,4 +24,5 @@ pub use workdir::WorkDir; mod acme_client; mod bot; mod dns01_client; +mod http_client; mod workdir; diff --git a/ct_monitor/src/main.rs b/ct_monitor/src/main.rs index bfa0565c..dd5d9550 100644 --- a/ct_monitor/src/main.rs +++ b/ct_monitor/src/main.rs @@ -413,7 +413,7 @@ async fn main() -> anyhow::Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let args = Args::parse(); let mut monitor = Monitor::new(args.gateway, args.verifier_url, args.pccs_url)?; diff --git a/dstack-util/src/main.rs b/dstack-util/src/main.rs index a5cd6d8d..1c552936 100644 --- a/dstack-util/src/main.rs +++ b/dstack-util/src/main.rs @@ -453,7 +453,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let cli = Cli::parse(); diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index 8a30c6da..87b2b1e0 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -10,7 +10,8 @@ edition.workspace = true license.workspace = true [dependencies] -rocket = { workspace = true, features = ["mtls"] } +arc-swap.workspace = true +rocket = { workspace = true, features = ["mtls", "json"] } tracing.workspace = true tracing-subscriber.workspace = true anyhow.workspace = true @@ -48,8 +49,16 @@ dstack-types.workspace = true serde-duration.workspace = true reqwest = { workspace = true, features = ["json"] } hyper = { workspace = true, features = ["server", "http1"] } -hyper-util = { version = "0.1", features = ["tokio"] } +hyper-util = { workspace = true, features = ["tokio"] } +hyper-rustls.workspace = true +http-body-util.workspace = true +x509-parser.workspace = true jemallocator.workspace = true +wavekv.workspace = true +tdx-attest.workspace = true +flate2.workspace = true +uuid = { workspace = true, features = ["v4"] } +rmp-serde.workspace = true or-panic.workspace = true [target.'cfg(unix)'.dependencies] @@ -57,3 +66,4 @@ nix = { workspace = true, features = ["resource"] } [dev-dependencies] insta.workspace = true +tempfile.workspace = true diff --git a/gateway/dstack-app/builder/entrypoint.sh b/gateway/dstack-app/builder/entrypoint.sh index 9cd46755..61a662fa 100755 --- a/gateway/dstack-app/builder/entrypoint.sh +++ b/gateway/dstack-app/builder/entrypoint.sh @@ -50,6 +50,13 @@ validate_env "$BOOTNODE_URL" validate_env "$CF_API_TOKEN" validate_env "$SRV_DOMAIN" validate_env "$WG_ENDPOINT" +validate_env "$NODE_ID" + +# Validate $NODE_ID, must be a number +if [[ ! "$NODE_ID" =~ ^[0-9]+$ ]]; then + echo "Invalid NODE_ID: $NODE_ID" + exit 1 +fi # Validate $SUBNET_INDEX, valid range is 0-15 if [[ ! "$SUBNET_INDEX" =~ ^[0-9]+$ ]] || [ "$SUBNET_INDEX" -lt 0 ] || [ "$SUBNET_INDEX" -gt 15 ]; then @@ -79,8 +86,7 @@ echo "RPC_DOMAIN: $RPC_DOMAIN" cat >$CONFIG_PATH < node_connections = 2; +} + +// Node status entry +message NodeStatusEntry { + uint32 node_id = 1; + string status = 2; // "up" or "down" +} + +// Get node statuses response +message GetNodeStatusesResponse { + repeated NodeStatusEntry statuses = 1; +} + service Admin { // Get the status of the gateway. rpc Status(google.protobuf.Empty) returns (StatusResponse) {} @@ -192,4 +362,253 @@ service Admin { rpc SetCaa(google.protobuf.Empty) returns (google.protobuf.Empty) {} // Summary API for inspect. rpc GetMeta(google.protobuf.Empty) returns (GetMetaResponse) {} + // Set a node's sync URL - used for dynamic peer management + rpc SetNodeUrl(SetNodeUrlRequest) returns (google.protobuf.Empty) {} + // Set a node's status (up/down) + rpc SetNodeStatus(SetNodeStatusRequest) returns (google.protobuf.Empty) {} + // Get WaveKV sync status + rpc WaveKvStatus(google.protobuf.Empty) returns (WaveKvStatusResponse) {} + // Get instance handshakes from all nodes + rpc GetInstanceHandshakes(GetInstanceHandshakesRequest) returns (GetInstanceHandshakesResponse) {} + // Get global connections statistics + rpc GetGlobalConnections(google.protobuf.Empty) returns (GlobalConnectionsStats) {} + // Get all node statuses + rpc GetNodeStatuses(google.protobuf.Empty) returns (GetNodeStatusesResponse) {} + + // ==================== DNS Credential Management ==================== + // List all DNS credentials + rpc ListDnsCredentials(google.protobuf.Empty) returns (ListDnsCredentialsResponse) {} + // Get a DNS credential by ID + rpc GetDnsCredential(GetDnsCredentialRequest) returns (DnsCredentialInfo) {} + // Create a new DNS credential + rpc CreateDnsCredential(CreateDnsCredentialRequest) returns (DnsCredentialInfo) {} + // Update a DNS credential + rpc UpdateDnsCredential(UpdateDnsCredentialRequest) returns (DnsCredentialInfo) {} + // Delete a DNS credential + rpc DeleteDnsCredential(DeleteDnsCredentialRequest) returns (google.protobuf.Empty) {} + // Get the default DNS credential ID + rpc GetDefaultDnsCredential(google.protobuf.Empty) returns (GetDefaultDnsCredentialResponse) {} + // Set the default DNS credential ID + rpc SetDefaultDnsCredential(SetDefaultDnsCredentialRequest) returns (google.protobuf.Empty) {} + + // ==================== ZT-Domain Management ==================== + // List all ZT-Domain configurations + rpc ListZtDomains(google.protobuf.Empty) returns (ListZtDomainsResponse) {} + // Get a ZT-Domain configuration and status + rpc GetZtDomain(GetZtDomainRequest) returns (ZtDomainInfo) {} + // Add a new ZT-Domain (config.domain must not exist) + rpc AddZtDomain(ZtDomainConfig) returns (ZtDomainInfo) {} + // Update a ZT-Domain configuration (config.domain must exist) + rpc UpdateZtDomain(ZtDomainConfig) returns (ZtDomainInfo) {} + // Delete a ZT-Domain configuration + rpc DeleteZtDomain(DeleteZtDomainRequest) returns (google.protobuf.Empty) {} + // Manually trigger certificate renewal for a ZT-Domain + rpc RenewZtDomainCert(RenewZtDomainCertRequest) returns (RenewZtDomainCertResponse) {} + // List certificate attestations for a domain + rpc ListCertAttestations(ListCertAttestationsRequest) returns (ListCertAttestationsResponse) {} + + // ==================== Global Certbot Configuration ==================== + // Get global certbot configuration (includes ACME URL) + rpc GetCertbotConfig(google.protobuf.Empty) returns (CertbotConfigResponse) {} + // Set global certbot configuration (includes ACME URL) + rpc SetCertbotConfig(SetCertbotConfigRequest) returns (google.protobuf.Empty) {} +} + +// ==================== DNS Credential Messages ==================== + +// DNS credential information +message DnsCredentialInfo { + string id = 1; + string name = 2; + // Provider type: "cloudflare" + string provider_type = 3; + // Cloudflare-specific fields (when provider_type = "cloudflare") + string cf_api_token = 4; + // Cloudflare API URL (empty means default) + string cf_api_url = 5; + // Timestamps + uint64 created_at = 6; + uint64 updated_at = 7; +} + +// List DNS credentials response +message ListDnsCredentialsResponse { + repeated DnsCredentialInfo credentials = 1; + // The default credential ID (if set) + optional string default_id = 2; +} + +// Get DNS credential request +message GetDnsCredentialRequest { + string id = 1; +} + +// Create DNS credential request +message CreateDnsCredentialRequest { + string name = 1; + // Provider type: "cloudflare" + string provider_type = 2; + // Cloudflare-specific fields (when provider_type = "cloudflare") + string cf_api_token = 3; + string cf_zone_id = 4; + // If true, set this as the default credential + bool set_as_default = 5; + // Optional Cloudflare API URL (defaults to https://api.cloudflare.com/client/v4) + optional string cf_api_url = 6; + // Optional Cloudflare DNS TXT record TTL (defaults to 60) + optional uint32 dns_txt_ttl = 7; + // Optional Cloudflare maximum DNS wait time (defaults to 60) + optional uint32 max_dns_wait = 8; +} + +// Update DNS credential request +message UpdateDnsCredentialRequest { + string id = 1; + // Optional new name + optional string name = 2; + // Optional new Cloudflare api token + optional string cf_api_token = 3; + // Optional new Cloudflare zone id + optional string cf_zone_id = 4; + // Optional new Cloudflare API URL + optional string cf_api_url = 5; +} + +// Delete DNS credential request +message DeleteDnsCredentialRequest { + string id = 1; +} + +// Get default DNS credential response +message GetDefaultDnsCredentialResponse { + // The default credential ID (empty if not set) + string default_id = 1; + // The default credential info (if exists) + optional DnsCredentialInfo credential = 2; +} + +// Set default DNS credential request +message SetDefaultDnsCredentialRequest { + string id = 1; +} + +// ==================== ZT-Domain Messages ==================== + +// ZT-Domain configuration (shared by Add/Update/Info) +message ZtDomainConfig { + // Base domain name (e.g., "example.com", certificate will be issued for "*.example.com") + string domain = 1; + // DNS credential ID (None = use default) + optional string dns_cred_id = 2; + // Port this domain serves on (e.g., 443) + uint32 port = 3; + // Node binding (None = any node can serve this domain) + optional uint32 node = 4; + // Priority for default base_domain selection (higher = preferred) + int32 priority = 5; +} + +// ZT-Domain information (config + certificate status) +message ZtDomainInfo { + // Domain configuration + ZtDomainConfig config = 1; + // Certificate status + ZtDomainCertStatus cert_status = 2; +} + +// ZT-Domain certificate status +message ZtDomainCertStatus { + // Whether a certificate is currently loaded + bool has_cert = 1; + // Certificate expiry timestamp (0 if no cert) + uint64 not_after = 2; + // Node that issued the current certificate + uint32 issued_by = 3; + // When the certificate was issued + uint64 issued_at = 4; + // Whether the certificate is loaded in memory + bool loaded_in_memory = 5; +} + +// List ZT-Domains response +message ListZtDomainsResponse { + repeated ZtDomainInfo domains = 1; +} + +// Get ZT-Domain request +message GetZtDomainRequest { + string domain = 1; +} + +// Delete ZT-Domain request +message DeleteZtDomainRequest { + string domain = 1; +} + +// Renew ZT-Domain certificate request +message RenewZtDomainCertRequest { + string domain = 1; + // Force renewal even if not near expiry + bool force = 2; +} + +// Renew ZT-Domain certificate response +message RenewZtDomainCertResponse { + // True if renewal was performed + bool renewed = 1; + // New certificate expiry (if renewed) + uint64 not_after = 2; +} + +// Certificate attestation info +message CertAttestationInfo { + // Certificate public key (DER encoded) + bytes public_key = 1; + // TDX Quote (JSON serialized) + string quote = 2; + // Node that generated this attestation + uint32 generated_by = 3; + // Timestamp when this attestation was generated + uint64 generated_at = 4; +} + +// List certificate attestations request +message ListCertAttestationsRequest { + string domain = 1; + // Maximum number of attestations to return (0 = all) + uint32 limit = 2; +} + +// List certificate attestations response +message ListCertAttestationsResponse { + // Latest attestation (if exists) + optional CertAttestationInfo latest = 1; + // Historical attestations (sorted by generated_at descending) + repeated CertAttestationInfo history = 2; +} + +// ==================== Global Certbot Configuration Messages ==================== + +// Certbot configuration response +message CertbotConfigResponse { + // Interval between renewal checks (in seconds) + uint64 renew_interval_secs = 1; + // Time before expiration to trigger renewal (in seconds) + uint64 renew_before_expiration_secs = 2; + // Timeout for certificate renewal operations (in seconds) + uint64 renew_timeout_secs = 3; + // ACME server URL (empty means default Let's Encrypt production) + string acme_url = 4; +} + +// Set certbot configuration request +message SetCertbotConfigRequest { + // Interval between renewal checks (in seconds) + optional uint64 renew_interval_secs = 1; + // Time before expiration to trigger renewal (in seconds) + optional uint64 renew_before_expiration_secs = 2; + // Timeout for certificate renewal operations (in seconds) + optional uint64 renew_timeout_secs = 3; + // ACME server URL (empty means use default Let's Encrypt production) + optional string acme_url = 4; } diff --git a/gateway/src/admin_service.rs b/gateway/src/admin_service.rs index 541dee0d..e9467f63 100644 --- a/gateway/src/admin_service.rs +++ b/gateway/src/admin_service.rs @@ -3,17 +3,31 @@ // SPDX-License-Identifier: Apache-2.0 use std::sync::atomic::Ordering; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use anyhow::{Context, Result}; +use anyhow::{bail, Context, Result}; use dstack_gateway_rpc::{ admin_server::{AdminRpc, AdminServer}, - GetInfoRequest, GetInfoResponse, GetMetaResponse, HostInfo, RenewCertResponse, StatusResponse, + CertAttestationInfo, CertbotConfigResponse, CreateDnsCredentialRequest, + DeleteDnsCredentialRequest, DeleteZtDomainRequest, DnsCredentialInfo, + GetDefaultDnsCredentialResponse, GetDnsCredentialRequest, GetInfoRequest, GetInfoResponse, + GetInstanceHandshakesRequest, GetInstanceHandshakesResponse, GetMetaResponse, + GetNodeStatusesResponse, GetZtDomainRequest, GlobalConnectionsStats, HandshakeEntry, HostInfo, + LastSeenEntry, ListCertAttestationsRequest, ListCertAttestationsResponse, + ListDnsCredentialsResponse, ListZtDomainsResponse, NodeStatusEntry, + PeerSyncStatus as ProtoPeerSyncStatus, RenewCertResponse, RenewZtDomainCertRequest, + RenewZtDomainCertResponse, SetCertbotConfigRequest, SetDefaultDnsCredentialRequest, + SetNodeStatusRequest, SetNodeUrlRequest, StatusResponse, StoreSyncStatus, + UpdateDnsCredentialRequest, WaveKvStatusResponse, ZtDomainCertStatus, + ZtDomainConfig as ProtoZtDomainConfig, ZtDomainInfo, }; use ra_rpc::{CallContext, RpcCall}; +use tracing::info; +use wavekv::node::NodeStatus as WaveKvNodeStatus; use crate::{ - main_service::{encode_ts, Proxy}, + kv::{DnsCredential, DnsProvider, NodeStatus, ZtDomainConfig}, + main_service::Proxy, proxy::NUM_CONNECTIONS, }; @@ -23,35 +37,39 @@ pub struct AdminRpcHandler { impl AdminRpcHandler { pub(crate) async fn status(self) -> Result { + let (base_domain, port) = self + .state + .kv_store() + .get_best_zt_domain() + .unwrap_or_default(); let mut state = self.state.lock(); state.refresh_state()?; - let base_domain = &state.config.proxy.base_domain; let hosts = state .state .instances .values() - .map(|instance| HostInfo { - instance_id: instance.id.clone(), - ip: instance.ip.to_string(), - app_id: instance.app_id.clone(), - base_domain: base_domain.clone(), - port: state.config.proxy.listen_port as u32, - latest_handshake: encode_ts(instance.last_seen), - num_connections: instance.num_connections(), + .map(|instance| { + // Get global latest_handshake from KvStore (max across all nodes) + let latest_handshake = state + .get_instance_latest_handshake(&instance.id) + .unwrap_or(0); + HostInfo { + instance_id: instance.id.clone(), + ip: instance.ip.to_string(), + app_id: instance.app_id.clone(), + base_domain: base_domain.clone(), + port: port.into(), + latest_handshake, + num_connections: instance.num_connections(), + } }) .collect::>(); - let nodes = state - .state - .nodes - .values() - .cloned() - .map(Into::into) - .collect::>(); Ok(StatusResponse { + id: state.config.sync.node_id, url: state.config.sync.my_url.clone(), - id: state.config.id(), + uuid: state.config.uuid(), bootnode_url: state.config.sync.bootnode.clone(), - nodes, + nodes: state.get_all_nodes(), hosts, num_connections: NUM_CONNECTIONS.load(Ordering::Relaxed), }) @@ -64,22 +82,19 @@ impl AdminRpc for AdminRpcHandler { } async fn renew_cert(self) -> Result { - let renewed = self.state.renew_cert(true).await?; + // Renew all domains with force=true + let renewed = self.state.renew_cert(None, true).await?; Ok(RenewCertResponse { renewed }) } async fn set_caa(self) -> Result<()> { - self.state - .certbot - .as_ref() - .context("Certbot is not enabled")? - .set_caa() - .await?; - Ok(()) + // TODO: Implement CAA setting for multi-domain certificates + // This requires iterating over all domain configurations and setting CAA records + bail!("set_caa is not implemented for multi-domain certificates yet"); } async fn reload_cert(self) -> Result<()> { - self.state.reload_certificates() + self.state.reload_all_certs_from_kvstore() } async fn status(self) -> Result { @@ -87,8 +102,12 @@ impl AdminRpc for AdminRpcHandler { } async fn get_info(self, request: GetInfoRequest) -> Result { + let (base_domain, port) = self + .state + .kv_store() + .get_best_zt_domain() + .unwrap_or_default(); let state = self.state.lock(); - let base_domain = &state.config.proxy.base_domain; let handshakes = state.latest_handshakes(None)?; if let Some(instance) = state.state.instances.get(&request.id) { @@ -96,8 +115,8 @@ impl AdminRpc for AdminRpcHandler { instance_id: instance.id.clone(), ip: instance.ip.to_string(), app_id: instance.app_id.clone(), - base_domain: base_domain.clone(), - port: state.config.proxy.listen_port as u32, + base_domain, + port: port.into(), latest_handshake: { let (ts, _) = handshakes .get(&instance.public_key) @@ -146,6 +165,490 @@ impl AdminRpc for AdminRpcHandler { online: online as u32, }) } + + async fn set_node_url(self, request: SetNodeUrlRequest) -> Result<()> { + let kv_store = self.state.kv_store(); + kv_store.register_peer_url(request.id, &request.url)?; + info!("Updated peer URL: node {} -> {}", request.id, request.url); + Ok(()) + } + + async fn set_node_status(self, request: SetNodeStatusRequest) -> Result<()> { + let kv_store = self.state.kv_store(); + let status = match request.status.as_str() { + "up" => NodeStatus::Up, + "down" => NodeStatus::Down, + _ => anyhow::bail!("invalid status: expected 'up' or 'down'"), + }; + kv_store.set_node_status(request.id, status)?; + info!("Updated node status: node {} -> {:?}", request.id, status); + Ok(()) + } + + async fn wave_kv_status(self) -> Result { + let kv_store = self.state.kv_store(); + + let persistent_status = kv_store.persistent().read().status(); + let ephemeral_status = kv_store.ephemeral().read().status(); + + let get_peer_last_seen = |peer_id: u32| -> Vec<(u32, u64)> { + kv_store + .get_node_last_seen_by_all(peer_id) + .into_iter() + .collect() + }; + + Ok(WaveKvStatusResponse { + enabled: self.state.config.sync.enabled, + persistent: Some(build_store_status( + "persistent", + persistent_status, + &get_peer_last_seen, + )), + ephemeral: Some(build_store_status( + "ephemeral", + ephemeral_status, + &get_peer_last_seen, + )), + }) + } + + async fn get_instance_handshakes( + self, + request: GetInstanceHandshakesRequest, + ) -> Result { + let kv_store = self.state.kv_store(); + let handshakes = kv_store.get_instance_handshakes(&request.instance_id); + + let entries = handshakes + .into_iter() + .map(|(observer_node_id, timestamp)| HandshakeEntry { + observer_node_id, + timestamp, + }) + .collect(); + + Ok(GetInstanceHandshakesResponse { + handshakes: entries, + }) + } + + async fn get_global_connections(self) -> Result { + let state = self.state.lock(); + let kv_store = self.state.kv_store(); + + let mut node_connections = std::collections::HashMap::new(); + let mut total_connections = 0u64; + + // Iterate through all instances and sum up connections per node + for instance_id in state.state.instances.keys() { + // Get connection counts from ephemeral KV for this instance + let conn_prefix = format!("conn/{}/", instance_id); + for (key, count) in kv_store + .ephemeral() + .read() + .iter_by_prefix(&conn_prefix) + .filter_map(|(k, entry)| { + let value = entry.value.as_ref()?; + let count: u64 = rmp_serde::decode::from_slice(value).ok()?; + Some((k.to_string(), count)) + }) + { + // Parse node_id from key: "conn/{instance_id}/{node_id}" + if let Some(node_id_str) = key.strip_prefix(&conn_prefix) { + if let Ok(node_id) = node_id_str.parse::() { + *node_connections.entry(node_id).or_insert(0) += count; + total_connections += count; + } + } + } + } + + Ok(GlobalConnectionsStats { + total_connections, + node_connections, + }) + } + + async fn get_node_statuses(self) -> Result { + let kv_store = self.state.kv_store(); + let statuses = kv_store.load_all_node_statuses(); + + let entries = statuses + .into_iter() + .map(|(node_id, status)| { + let status_str = match status { + NodeStatus::Up => "up", + NodeStatus::Down => "down", + }; + NodeStatusEntry { + node_id, + status: status_str.to_string(), + } + }) + .collect(); + + Ok(GetNodeStatusesResponse { statuses: entries }) + } + + // ==================== DNS Credential Management ==================== + + async fn list_dns_credentials(self) -> Result { + let kv_store = self.state.kv_store(); + let credentials = kv_store + .list_dns_credentials() + .into_iter() + .map(dns_cred_to_proto) + .collect(); + let default_id = kv_store.get_default_dns_credential_id(); + Ok(ListDnsCredentialsResponse { + credentials, + default_id, + }) + } + + async fn get_dns_credential( + self, + request: GetDnsCredentialRequest, + ) -> Result { + let kv_store = self.state.kv_store(); + let cred = kv_store + .get_dns_credential(&request.id) + .context("dns credential not found")?; + Ok(dns_cred_to_proto(cred)) + } + + async fn create_dns_credential( + self, + request: CreateDnsCredentialRequest, + ) -> Result { + let kv_store = self.state.kv_store(); + + // Validate provider type + let provider = match request.provider_type.as_str() { + "cloudflare" => DnsProvider::Cloudflare { + api_token: request.cf_api_token, + api_url: request.cf_api_url, + }, + _ => bail!("unsupported provider type: {}", request.provider_type), + }; + + let now = now_secs(); + let id = generate_cred_id(); + let dns_txt_ttl = request.dns_txt_ttl.unwrap_or(60); + let max_dns_wait = Duration::from_secs(request.max_dns_wait.unwrap_or(60 * 5).into()); + let cred = DnsCredential { + id: id.clone(), + name: request.name, + provider, + created_at: now, + updated_at: now, + dns_txt_ttl, + max_dns_wait, + }; + + kv_store.save_dns_credential(&cred)?; + info!("Created DNS credential: {} ({})", cred.name, cred.id); + + // Set as default if requested + if request.set_as_default { + kv_store.set_default_dns_credential_id(&id)?; + info!("Set DNS credential {} as default", id); + } + + Ok(dns_cred_to_proto(cred)) + } + + async fn update_dns_credential( + self, + request: UpdateDnsCredentialRequest, + ) -> Result { + let kv_store = self.state.kv_store(); + + let mut cred = kv_store + .get_dns_credential(&request.id) + .context("dns credential not found")?; + + // Update name if provided + if let Some(name) = request.name { + cred.name = name; + } + + // Update provider fields if provided + match &mut cred.provider { + DnsProvider::Cloudflare { api_token, api_url } => { + if let Some(new_token) = request.cf_api_token { + *api_token = new_token; + } + if let Some(new_url) = request.cf_api_url { + *api_url = Some(new_url); + } + } + } + + cred.updated_at = now_secs(); + kv_store.save_dns_credential(&cred)?; + info!("Updated DNS credential: {} ({})", cred.name, cred.id); + + Ok(dns_cred_to_proto(cred)) + } + + async fn delete_dns_credential(self, request: DeleteDnsCredentialRequest) -> Result<()> { + let kv_store = self.state.kv_store(); + + // Check if this is the default credential + if let Some(default_id) = kv_store.get_default_dns_credential_id() { + if default_id == request.id { + bail!("cannot delete the default DNS credential; set a different default first"); + } + } + + // Check if any ZT-Domain configs reference this credential + let configs = kv_store.list_zt_domain_configs(); + for config in configs { + if config.dns_cred_id.as_deref() == Some(&request.id) { + bail!( + "cannot delete DNS credential: domain {} uses it", + config.domain + ); + } + } + + kv_store.delete_dns_credential(&request.id)?; + info!("Deleted DNS credential: {}", request.id); + Ok(()) + } + + async fn get_default_dns_credential(self) -> Result { + let kv_store = self.state.kv_store(); + let default_id = kv_store.get_default_dns_credential_id().unwrap_or_default(); + let credential = kv_store.get_default_dns_credential().map(dns_cred_to_proto); + Ok(GetDefaultDnsCredentialResponse { + default_id, + credential, + }) + } + + async fn set_default_dns_credential( + self, + request: SetDefaultDnsCredentialRequest, + ) -> Result<()> { + let kv_store = self.state.kv_store(); + + // Verify the credential exists + kv_store + .get_dns_credential(&request.id) + .context("dns credential not found")?; + + kv_store.set_default_dns_credential_id(&request.id)?; + info!("Set default DNS credential: {}", request.id); + Ok(()) + } + + // ==================== ZT-Domain Management ==================== + + async fn list_zt_domains(self) -> Result { + let kv_store = self.state.kv_store(); + let cert_resolver = &self.state.cert_resolver; + + let domains = kv_store + .list_zt_domain_configs() + .into_iter() + .map(|config| zt_domain_to_proto(config, kv_store, cert_resolver)) + .collect(); + + Ok(ListZtDomainsResponse { domains }) + } + + async fn get_zt_domain(self, request: GetZtDomainRequest) -> Result { + let kv_store = self.state.kv_store(); + let cert_resolver = &self.state.cert_resolver; + + let config = kv_store + .get_zt_domain_config(&request.domain) + .context("ZT-Domain config not found")?; + + Ok(zt_domain_to_proto(config, kv_store, cert_resolver)) + } + + async fn add_zt_domain(self, request: ProtoZtDomainConfig) -> Result { + let kv_store = self.state.kv_store(); + let cert_resolver = &self.state.cert_resolver; + + // Check if domain already exists + if kv_store.get_zt_domain_config(&request.domain).is_some() { + bail!("ZT-Domain config already exists: {}", request.domain); + } + + let config = proto_to_zt_domain_config(&request, kv_store)?; + + kv_store.save_zt_domain_config(&config)?; + info!("Added ZT-Domain config: {}", config.domain); + + Ok(zt_domain_to_proto(config, kv_store, cert_resolver)) + } + + async fn update_zt_domain(self, request: ProtoZtDomainConfig) -> Result { + let kv_store = self.state.kv_store(); + let cert_resolver = &self.state.cert_resolver; + + // Check if config exists + kv_store + .get_zt_domain_config(&request.domain) + .context("ZT-Domain config not found")?; + + let config = proto_to_zt_domain_config(&request, kv_store)?; + + kv_store.save_zt_domain_config(&config)?; + info!("Updated ZT-Domain config: {}", config.domain); + + Ok(zt_domain_to_proto(config, kv_store, cert_resolver)) + } + + async fn delete_zt_domain(self, request: DeleteZtDomainRequest) -> Result<()> { + let kv_store = self.state.kv_store(); + + // Check if config exists + kv_store + .get_zt_domain_config(&request.domain) + .context("ZT-Domain config not found")?; + + // Delete config (cert data, acme, attestations are kept for historical purposes) + kv_store.delete_zt_domain_config(&request.domain)?; + info!("Deleted ZT-Domain config: {}", request.domain); + Ok(()) + } + + async fn renew_zt_domain_cert( + self, + request: RenewZtDomainCertRequest, + ) -> Result { + let certbot = &self.state.certbot; + let renewed = certbot + .try_renew(&request.domain, request.force) + .await + .context("certificate renewal failed")?; + + if renewed { + // Get the new certificate data for response + let kv_store = self.state.kv_store(); + let cert_data = kv_store.get_cert_data(&request.domain); + let not_after = cert_data.map(|d| d.not_after).unwrap_or(0); + Ok(RenewZtDomainCertResponse { renewed, not_after }) + } else { + Ok(RenewZtDomainCertResponse { + renewed: false, + not_after: 0, + }) + } + } + + async fn list_cert_attestations( + self, + request: ListCertAttestationsRequest, + ) -> Result { + let kv_store = self.state.kv_store(); + + let latest = kv_store + .get_cert_attestation_latest(&request.domain) + .map(|att| CertAttestationInfo { + public_key: att.public_key, + quote: att.quote, + generated_by: att.generated_by, + generated_at: att.generated_at, + }); + + let mut history: Vec = kv_store + .list_cert_attestations(&request.domain) + .into_iter() + .map(|att| CertAttestationInfo { + public_key: att.public_key, + quote: att.quote, + generated_by: att.generated_by, + generated_at: att.generated_at, + }) + .collect(); + + // Apply limit if specified + if request.limit > 0 { + history.truncate(request.limit as usize); + } + + Ok(ListCertAttestationsResponse { latest, history }) + } + + // ==================== Global Certbot Configuration ==================== + + async fn get_certbot_config(self) -> Result { + let config = self.state.kv_store().get_certbot_config(); + Ok(CertbotConfigResponse { + renew_interval_secs: config.renew_interval.as_secs(), + renew_before_expiration_secs: config.renew_before_expiration.as_secs(), + renew_timeout_secs: config.renew_timeout.as_secs(), + acme_url: config.acme_url, + }) + } + + async fn set_certbot_config(self, request: SetCertbotConfigRequest) -> Result<()> { + let kv_store = self.state.kv_store(); + let mut config = kv_store.get_certbot_config(); + + // Update only the fields that are specified + if let Some(secs) = request.renew_interval_secs { + config.renew_interval = Duration::from_secs(secs); + } + if let Some(secs) = request.renew_before_expiration_secs { + config.renew_before_expiration = Duration::from_secs(secs); + } + if let Some(secs) = request.renew_timeout_secs { + config.renew_timeout = Duration::from_secs(secs); + } + if let Some(url) = request.acme_url { + config.acme_url = url; + } + + kv_store.set_certbot_config(&config)?; + info!( + "Updated certbot config: renew_interval={:?}, renew_before_expiration={:?}, renew_timeout={:?}, acme_url={:?}", + config.renew_interval, + config.renew_before_expiration, + config.renew_timeout, + config.acme_url + ); + Ok(()) + } +} + +fn build_store_status( + name: &str, + status: WaveKvNodeStatus, + get_peer_last_seen: &impl Fn(u32) -> Vec<(u32, u64)>, +) -> StoreSyncStatus { + StoreSyncStatus { + name: name.to_string(), + node_id: status.id, + n_keys: status.n_kvs as u64, + next_seq: status.next_seq, + dirty: status.dirty, + wal_enabled: status.wal, + peers: status + .peers + .into_iter() + .map(|p| { + let last_seen = get_peer_last_seen(p.id) + .into_iter() + .map(|(node_id, timestamp)| LastSeenEntry { node_id, timestamp }) + .collect(); + ProtoPeerSyncStatus { + id: p.id, + local_ack: p.ack, + peer_ack: p.pack, + buffered_logs: p.logs as u64, + last_seen, + } + }) + .collect(), + } } impl RpcCall for AdminRpcHandler { @@ -157,3 +660,93 @@ impl RpcCall for AdminRpcHandler { }) } } + +// ==================== Helper Functions ==================== + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +fn generate_cred_id() -> String { + use std::time::SystemTime; + let ts = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + // Simple ID: timestamp + random suffix + let random: u32 = rand::random(); + format!("{:x}{:08x}", ts, random) +} + +fn dns_cred_to_proto(cred: DnsCredential) -> DnsCredentialInfo { + let (provider_type, cf_api_token, cf_api_url) = match &cred.provider { + DnsProvider::Cloudflare { api_token, api_url } => ( + "cloudflare".to_string(), + api_token.clone(), + api_url.clone().unwrap_or_default(), + ), + }; + DnsCredentialInfo { + id: cred.id, + name: cred.name, + provider_type, + cf_api_token, + cf_api_url, + created_at: cred.created_at, + updated_at: cred.updated_at, + } +} + +/// Convert proto ZtDomainConfig to internal ZtDomainConfig +fn proto_to_zt_domain_config( + proto: &ProtoZtDomainConfig, + kv_store: &crate::kv::KvStore, +) -> Result { + // Validate DNS credential if specified + if let Some(ref cred_id) = proto.dns_cred_id { + kv_store + .get_dns_credential(cred_id) + .context("specified dns credential not found")?; + } + + Ok(ZtDomainConfig { + domain: proto.domain.clone(), + dns_cred_id: proto.dns_cred_id.clone(), + port: proto.port.try_into().context("port out of range")?, + node: proto.node, + priority: proto.priority, + }) +} + +/// Convert internal ZtDomainConfig to proto ZtDomainInfo (with cert status) +fn zt_domain_to_proto( + config: ZtDomainConfig, + kv_store: &crate::kv::KvStore, + cert_resolver: &crate::cert_store::CertResolver, +) -> ZtDomainInfo { + // Get certificate data for status + let cert_data = kv_store.get_cert_data(&config.domain); + let loaded_in_memory = cert_resolver.has_cert(&config.domain); + + let cert_status = Some(ZtDomainCertStatus { + has_cert: cert_data.is_some(), + not_after: cert_data.as_ref().map(|d| d.not_after).unwrap_or(0), + issued_by: cert_data.as_ref().map(|d| d.issued_by).unwrap_or(0), + issued_at: cert_data.as_ref().map(|d| d.issued_at).unwrap_or(0), + loaded_in_memory, + }); + + ZtDomainInfo { + config: Some(ProtoZtDomainConfig { + domain: config.domain, + dns_cred_id: config.dns_cred_id, + port: config.port.into(), + node: config.node, + priority: config.priority, + }), + cert_status, + } +} diff --git a/gateway/src/cert_store.rs b/gateway/src/cert_store.rs new file mode 100644 index 00000000..e65dd2de --- /dev/null +++ b/gateway/src/cert_store.rs @@ -0,0 +1,443 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! In-memory certificate store with SNI-based certificate resolution. +//! +//! This module provides a lock-free certificate store that supports: +//! - Multiple certificates for different domains +//! - Wildcard certificate matching +//! - Dynamic certificate updates via atomic replacement +//! - SNI-based certificate selection for TLS connections +//! +//! Architecture: `CertStore` is immutable after construction for lock-free reads. +//! Updates are done by building a new `CertStore` and atomically swapping the `Arc` +//! in the outer `RwLock>`. + +use std::collections::HashMap; +use std::fmt; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use arc_swap::{ArcSwap, Guard}; +use or_panic::ResultOrPanic; +use rustls::pki_types::pem::PemObject; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::server::{ClientHello, ResolvesServerCert}; +use rustls::sign::CertifiedKey; +use tracing::{debug, info}; + +use crate::kv::CertData; + +/// Immutable, lock-free certificate store. +/// +/// This struct is designed for maximum read performance - no locks required for lookups. +/// Updates are done by creating a new instance and atomically swapping via outer RwLock>. +pub struct CertStore { + /// Exact domain -> CertifiedKey + exact_certs: HashMap>, + /// Parent domain -> CertifiedKey (for wildcard certs) + /// e.g., "example.com" -> cert for "*.example.com" + wildcard_certs: HashMap>, + /// Domain -> CertData (for metadata like expiry) + cert_data: HashMap, +} + +impl CertStore { + /// Create a new empty certificate store + pub fn new() -> Self { + Self { + exact_certs: HashMap::new(), + wildcard_certs: HashMap::new(), + cert_data: HashMap::new(), + } + } + + /// Resolve certificate for a given SNI hostname (lock-free) + fn resolve_cert(&self, sni: &str) -> Option> { + // 1. Try exact match first + if let Some(cert) = self.exact_certs.get(sni) { + debug!("exact match for {sni}"); + return Some(cert.clone()); + } + + // 2. Try wildcard match (only one level deep per TLS spec) + // For "foo.bar.example.com", only try "bar.example.com" + if let Some((_, parent)) = sni.split_once('.') { + if let Some(cert) = self.wildcard_certs.get(parent) { + debug!("wildcard match *.{parent} for {sni}"); + return Some(cert.clone()); + } + } + + debug!("no certificate found for {sni}"); + None + } + + /// Check if a certificate exists for a domain + pub fn has_cert(&self, domain: &str) -> bool { + self.cert_data.contains_key(domain) + } + + /// Get certificate data for a domain + pub fn get_cert_data(&self, domain: &str) -> Option<&CertData> { + self.cert_data.get(domain) + } + + /// List all loaded domains + pub fn list_domains(&self) -> Vec { + self.cert_data.keys().cloned().collect() + } + + /// Check if a wildcard certificate exists for a domain + pub fn contains_wildcard(&self, base_domain: &str) -> bool { + self.wildcard_certs.contains_key(base_domain) + } +} + +impl fmt::Debug for CertStore { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let exact_domains: Vec<_> = self.exact_certs.keys().cloned().collect(); + let wildcard_domains: Vec<_> = self + .wildcard_certs + .keys() + .map(|k| format!("*.{}", k)) + .collect(); + + f.debug_struct("CertStore") + .field("exact_domains", &exact_domains) + .field("wildcard_domains", &wildcard_domains) + .finish() + } +} + +impl Default for CertStore { + fn default() -> Self { + Self::new() + } +} + +impl ResolvesServerCert for CertStore { + fn resolve(&self, client_hello: ClientHello) -> Option> { + let sni = client_hello.server_name()?; + self.resolve_cert(sni) + } +} + +/// Certificate resolver that wraps `ArcSwap` for lock-free reads. +/// +/// This allows TLS acceptors to be created once and certificates to be updated +/// without recreating the acceptor. The read path (TLS handshake) is completely +/// lock-free via `ArcSwap`. Write operations are serialized via a `Mutex` to +/// prevent lost updates during concurrent certificate changes. +pub struct CertResolver { + store: ArcSwap, + /// Mutex to serialize write operations (reads are still lock-free) + write_lock: std::sync::Mutex<()>, +} + +impl CertResolver { + /// Create a new resolver with an empty CertStore + pub fn new() -> Self { + Self { + store: ArcSwap::from_pointee(CertStore::new()), + write_lock: std::sync::Mutex::new(()), + } + } + + /// Get the current CertStore (lock-free) + pub fn get(&self) -> Guard> { + self.store.load() + } + + /// Replace the CertStore atomically (lock-free) + pub fn set(&self, new_store: Arc) { + self.store.store(new_store); + } + + /// List all domains + pub fn list_domains(&self) -> Vec { + self.get().list_domains() + } + + /// Check if a certificate exists for a domain + pub fn has_cert(&self, domain: &str) -> bool { + self.get().has_cert(domain) + } + + /// Update a single certificate (creates new store with updated cert) + /// + /// This is an incremental update that preserves all existing certificates. + /// Write operations are serialized to prevent lost updates. + pub fn update_cert(&self, domain: &str, data: &CertData) -> Result<()> { + let _guard = self + .write_lock + .lock() + .or_panic("failed to acquire write lock"); + + let old_store = self.get(); + + // Build new store with all existing certs plus the new/updated one + let mut builder = CertStoreBuilder::new(); + + // Copy existing certs (except the one we're replacing) + for existing_domain in old_store.list_domains() { + if existing_domain != domain { + if let Some(existing_data) = old_store.get_cert_data(&existing_domain) { + builder.add_cert(&existing_domain, existing_data)?; + } + } + } + + // Add the new/updated cert + builder.add_cert(domain, data)?; + + // Atomically swap + self.set(Arc::new(builder.build())); + Ok(()) + } +} + +impl Default for CertResolver { + fn default() -> Self { + Self::new() + } +} + +impl fmt::Debug for CertResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl ResolvesServerCert for CertResolver { + fn resolve(&self, client_hello: ClientHello) -> Option> { + // Lock-free load via ArcSwap + let store = self.store.load(); + let sni = client_hello.server_name()?; + store.resolve_cert(sni) + } +} + +/// Builder for constructing a new CertStore. +/// +/// Use this to build a complete certificate store, then call `build()` to get the immutable CertStore. +pub struct CertStoreBuilder { + exact_certs: HashMap>, + wildcard_certs: HashMap>, + cert_data: HashMap, +} + +impl CertStoreBuilder { + /// Create a new empty builder + pub fn new() -> Self { + Self { + exact_certs: HashMap::new(), + wildcard_certs: HashMap::new(), + cert_data: HashMap::new(), + } + } + + /// Add a certificate to the builder + /// + /// The domain is the base domain (e.g., "example.com"). + /// All gateway certificates are wildcard certs for "*.{domain}". + pub fn add_cert(&mut self, domain: &str, data: &CertData) -> Result<()> { + let certified_key = parse_certified_key(&data.cert_pem, &data.key_pem) + .with_context(|| format!("failed to parse certificate for {}", domain))?; + + let certified_key = Arc::new(certified_key); + + // Gateway certificates are always wildcard certs + // domain is the base domain (e.g., "example.com"), cert is for "*.example.com" + self.wildcard_certs + .insert(domain.to_string(), certified_key); + info!( + "cert_store: prepared wildcard certificate for *.{} (expires: {})", + domain, + format_expiry(data.not_after) + ); + + // Store metadata + self.cert_data.insert(domain.to_string(), data.clone()); + + Ok(()) + } + + /// Build the immutable CertStore + pub fn build(self) -> CertStore { + CertStore { + exact_certs: self.exact_certs, + wildcard_certs: self.wildcard_certs, + cert_data: self.cert_data, + } + } +} + +impl Default for CertStoreBuilder { + fn default() -> Self { + Self::new() + } +} + +/// Parse certificate and private key PEM strings into a CertifiedKey +fn parse_certified_key(cert_pem: &str, key_pem: &str) -> Result { + let certs = CertificateDer::pem_slice_iter(cert_pem.as_bytes()) + .collect::, _>>() + .context("failed to parse certificate chain")?; + + if certs.is_empty() { + anyhow::bail!("no certificates found in PEM"); + } + + let key = + PrivateKeyDer::from_pem_slice(key_pem.as_bytes()).context("failed to parse private key")?; + + let signing_key = rustls::crypto::aws_lc_rs::sign::any_supported_type(&key) + .map_err(|e| anyhow::anyhow!("failed to create signing key: {:?}", e))?; + + Ok(CertifiedKey::new(certs, signing_key)) +} + +/// Format expiry timestamp as human-readable string +fn format_expiry(not_after: u64) -> String { + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + let expiry = UNIX_EPOCH + Duration::from_secs(not_after); + let now = SystemTime::now(); + + match expiry.duration_since(now) { + Ok(remaining) => { + let days = remaining.as_secs() / 86400; + format!("{} days remaining", days) + } + Err(_) => "expired".to_string(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + impl CertStore { + /// Check if a certificate can be resolved for a given SNI hostname + pub fn has_cert_for_sni(&self, sni: &str) -> bool { + self.resolve_cert(sni).is_some() + } + } + + fn make_test_cert_data() -> CertData { + // Generate a self-signed test certificate using rcgen + use ra_tls::rcgen::{self, CertificateParams, KeyPair}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + let key_pair = KeyPair::generate().expect("failed to generate key pair"); + let mut params = CertificateParams::new(vec!["test.example.com".to_string()]) + .expect("failed to create cert params"); + params.not_after = rcgen::date_time_ymd(2030, 1, 1); + let cert = params + .self_signed(&key_pair) + .expect("failed to generate self-signed cert"); + + let not_after = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + + Duration::from_secs(365 * 24 * 3600).as_secs(); + + CertData { + cert_pem: cert.pem(), + key_pem: key_pair.serialize_pem(), + not_after, + issued_by: 1, + issued_at: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + } + } + + #[test] + fn test_cert_store_basic() { + let store = CertStore::new(); + assert!(store.list_domains().is_empty()); + } + + #[test] + fn test_cert_store_builder() { + let data = make_test_cert_data(); + + // Use builder - domain is base domain (e.g., "example.com") + // All gateway certs are wildcard certs + let mut builder = CertStoreBuilder::new(); + builder + .add_cert("example.com", &data) + .expect("failed to add cert"); + + let store = builder.build(); + + // Check it's loaded (stored by base domain) + assert!(store.has_cert("example.com")); + assert_eq!(store.list_domains().len(), 1); + + // Should resolve any subdomain via wildcard matching + assert!(store.has_cert_for_sni("test.example.com")); + assert!(store.has_cert_for_sni("foo.example.com")); + + // Should not resolve exact base domain (wildcard doesn't match base) + assert!(!store.has_cert_for_sni("example.com")); + + // Should not resolve different domain + assert!(!store.has_cert_for_sni("example.org")); + } + + #[test] + fn test_cert_store_wildcard() { + // Generate wildcard cert + use ra_tls::rcgen::{self, CertificateParams, KeyPair}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + let key_pair = KeyPair::generate().expect("failed to generate key pair"); + let mut params = CertificateParams::new(vec!["*.example.com".to_string()]) + .expect("failed to create cert params"); + params.not_after = rcgen::date_time_ymd(2030, 1, 1); + let cert = params + .self_signed(&key_pair) + .expect("failed to generate self-signed cert"); + + let not_after = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + + Duration::from_secs(365 * 24 * 3600).as_secs(); + + let data = CertData { + cert_pem: cert.pem(), + key_pem: key_pair.serialize_pem(), + not_after, + issued_by: 1, + issued_at: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }; + + let mut builder = CertStoreBuilder::new(); + // Now we use base domain format (without *. prefix) + builder + .add_cert("example.com", &data) + .expect("failed to add wildcard cert"); + + let store = builder.build(); + + // Should resolve any subdomain + assert!(store.has_cert_for_sni("foo.example.com")); + assert!(store.has_cert_for_sni("bar.example.com")); + + // Wildcard certs do not match nested subdomains + assert!(!store.has_cert_for_sni("sub.foo.example.com")); + + // Should not resolve different domain + assert!(!store.has_cert_for_sni("example.org")); + } +} diff --git a/gateway/src/config.rs b/gateway/src/config.rs index 3b990795..e7d1918e 100644 --- a/gateway/src/config.rs +++ b/gateway/src/config.rs @@ -69,15 +69,10 @@ pub enum TlsVersion { #[derive(Debug, Clone, Deserialize)] pub struct ProxyConfig { - pub cert_chain: String, - pub cert_key: String, pub tls_crypto_provider: CryptoProvider, pub tls_versions: Vec, - pub base_domain: String, - pub external_port: u16, pub listen_addr: Ipv4Addr, pub listen_port: u16, - pub agent_port: u16, pub timeouts: Timeouts, pub buffer_size: usize, pub connect_top_n: usize, @@ -125,30 +120,49 @@ pub struct SyncConfig { #[serde(with = "serde_duration")] pub interval: Duration, #[serde(with = "serde_duration")] - pub broadcast_interval: Duration, - #[serde(with = "serde_duration")] pub timeout: Duration, pub my_url: String, + /// The URL of the bootnode used to fetch initial peer list when joining the network pub bootnode: String, + /// WaveKV node ID for this gateway (must be unique across cluster) + pub node_id: u32, + /// Data directory for WaveKV persistence + pub data_dir: String, + /// Interval for periodic WAL persistence (default: 10s) + #[serde(with = "serde_duration")] + pub persist_interval: Duration, + /// Enable periodic sync of instance connections to KV store + pub sync_connections_enabled: bool, + /// Interval for syncing instance connections to KV store + #[serde(with = "serde_duration")] + pub sync_connections_interval: Duration, } #[derive(Debug, Clone, Deserialize)] pub struct Config { pub wg: WgConfig, pub proxy: ProxyConfig, - pub certbot: CertbotConfig, pub pccs_url: Option, pub recycle: RecycleConfig, - pub state_path: String, pub set_ulimit: bool, pub rpc_domain: String, pub kms_url: String, pub admin: AdminConfig, - pub run_in_dstack: bool, + /// Debug server configuration (separate port for debug RPCs) + pub debug: DebugConfig, pub sync: SyncConfig, pub auth: AuthConfig, } +#[derive(Debug, Clone, Deserialize, Default)] +pub struct DebugConfig { + /// Enable debug server + #[serde(default)] + pub insecure_enable_debug_rpc: bool, + #[serde(default)] + pub insecure_skip_attestation: bool, +} + #[derive(Debug, Clone, Deserialize)] pub struct AuthConfig { pub enabled: bool, @@ -158,11 +172,41 @@ pub struct AuthConfig { } impl Config { - pub fn id(&self) -> Vec { - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(self.wg.public_key.as_bytes()); - hasher.finalize()[..20].to_vec() + /// Get or generate a unique node UUID. + /// The UUID is stored in `{data_dir}/node_uuid` and persisted across restarts. + pub fn uuid(&self) -> Vec { + use std::fs; + use std::path::Path; + + let uuid_path = Path::new(&self.sync.data_dir).join("node_uuid"); + + // Try to read existing UUID + if let Ok(content) = fs::read_to_string(&uuid_path) { + if let Ok(uuid) = uuid::Uuid::parse_str(content.trim()) { + return uuid.as_bytes().to_vec(); + } + } + + // Generate new UUID + let uuid = uuid::Uuid::new_v4(); + + // Ensure directory exists + if let Some(parent) = uuid_path.parent() { + let _ = fs::create_dir_all(parent); + } + + // Save UUID to file + if let Err(err) = fs::write(&uuid_path, uuid.to_string()) { + tracing::warn!( + "failed to save node UUID to {}: {}", + uuid_path.display(), + err + ); + } else { + tracing::info!("generated new node UUID: {}", uuid); + } + + uuid.as_bytes().to_vec() } } @@ -183,68 +227,6 @@ pub struct MutualConfig { pub ca_certs: String, } -#[derive(Debug, Clone, Deserialize)] -pub struct CertbotConfig { - /// Enable certbot - pub enabled: bool, - /// Path to the working directory - pub workdir: String, - /// ACME server URL - pub acme_url: String, - /// Cloudflare API token - pub cf_api_token: String, - /// Auto set CAA record - pub auto_set_caa: bool, - /// Domain to issue certificates for - pub domain: String, - /// Renew interval - #[serde(with = "serde_duration")] - pub renew_interval: Duration, - /// Time gap before expiration to trigger renewal - #[serde(with = "serde_duration")] - pub renew_before_expiration: Duration, - /// Renew timeout - #[serde(with = "serde_duration")] - pub renew_timeout: Duration, - /// Maximum time to wait for DNS propagation - #[serde(with = "serde_duration")] - pub max_dns_wait: Duration, - /// TTL for DNS TXT records used in ACME challenges (in seconds). - /// Minimum is 60 for Cloudflare. Lower TTL means faster DNS propagation. - #[serde(default = "default_dns_txt_ttl")] - pub dns_txt_ttl: u32, -} - -fn default_dns_txt_ttl() -> u32 { - 60 -} - -impl CertbotConfig { - fn to_bot_config(&self) -> certbot::CertBotConfig { - let workdir = certbot::WorkDir::new(&self.workdir); - certbot::CertBotConfig::builder() - .auto_create_account(true) - .cert_dir(workdir.backup_dir()) - .cert_file(workdir.cert_path()) - .key_file(workdir.key_path()) - .credentials_file(workdir.account_credentials_path()) - .acme_url(self.acme_url.clone()) - .cert_subject_alt_names(vec![self.domain.clone()]) - .cf_api_token(self.cf_api_token.clone()) - .renew_interval(self.renew_interval) - .renew_timeout(self.renew_timeout) - .renew_expires_in(self.renew_before_expiration) - .auto_set_caa(self.auto_set_caa) - .max_dns_wait(self.max_dns_wait) - .dns_txt_ttl(self.dns_txt_ttl) - .build() - } - - pub async fn build_bot(&self) -> Result { - self.to_bot_config().build_bot().await - } -} - pub const DEFAULT_CONFIG: &str = include_str!("../gateway.toml"); pub fn load_config_figment(config_file: Option<&str>) -> Figment { load_config("gateway", DEFAULT_CONFIG, config_file, false) diff --git a/gateway/src/debug_service.rs b/gateway/src/debug_service.rs new file mode 100644 index 00000000..e53b4ed2 --- /dev/null +++ b/gateway/src/debug_service.rs @@ -0,0 +1,153 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Debug service for testing - runs on a separate port when debug.enabled=true + +use anyhow::Result; +use dstack_gateway_rpc::{ + debug_server::{DebugRpc, DebugServer}, + DebugProxyStateResponse, DebugRegisterCvmRequest, DebugSyncDataResponse, InfoResponse, + InstanceEntry, NodeInfoEntry, PeerAddrEntry, ProxyStateInstance, RegisterCvmResponse, +}; +use ra_rpc::{CallContext, RpcCall}; +use tracing::warn; + +use crate::main_service::Proxy; + +pub struct DebugRpcHandler { + state: Proxy, +} + +impl DebugRpcHandler { + pub fn new(state: Proxy) -> Self { + Self { state } + } +} + +impl DebugRpc for DebugRpcHandler { + async fn register_cvm(self, request: DebugRegisterCvmRequest) -> Result { + warn!( + "Debug register CVM: app_id={}, instance_id={}", + request.app_id, request.instance_id + ); + self.state.do_register_cvm( + &request.app_id, + &request.instance_id, + &request.client_public_key, + ) + } + + async fn info(self) -> Result { + let config = &self.state.config; + let (base_domain, port) = self + .state + .kv_store() + .get_best_zt_domain() + .unwrap_or_default(); + Ok(InfoResponse { + base_domain, + external_port: port.into(), + app_address_ns_prefix: config.proxy.app_address_ns_prefix.clone(), + }) + } + + async fn get_sync_data(self) -> Result { + let kv_store = self.state.kv_store(); + let my_node_id = kv_store.my_node_id(); + + // Get all peer addresses + let peer_addrs: Vec = kv_store + .get_all_peer_addrs() + .into_iter() + .map(|(node_id, url)| PeerAddrEntry { + node_id: node_id as u64, + url, + }) + .collect(); + + // Get all node info + let nodes: Vec = kv_store + .load_all_nodes() + .into_iter() + .map(|(node_id, data)| NodeInfoEntry { + node_id: node_id as u64, + url: data.url, + wg_public_key: data.wg_public_key, + wg_endpoint: data.wg_endpoint, + wg_ip: data.wg_ip, + }) + .collect(); + + // Get all instances + let instances: Vec = kv_store + .load_all_instances() + .into_iter() + .map(|(instance_id, data)| InstanceEntry { + instance_id, + app_id: data.app_id, + ip: data.ip.to_string(), + public_key: data.public_key, + }) + .collect(); + + // Get key counts + let persistent_keys = kv_store.persistent().read().status().n_kvs as u64; + let ephemeral_keys = kv_store.ephemeral().read().status().n_kvs as u64; + + Ok(DebugSyncDataResponse { + my_node_id: my_node_id as u64, + peer_addrs, + nodes, + instances, + persistent_keys, + ephemeral_keys, + }) + } + + async fn get_proxy_state(self) -> Result { + let state = self.state.lock(); + + // Get all instances from ProxyState + let instances: Vec = state + .state + .instances + .values() + .map(|inst| { + let reg_time = inst + .reg_time + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + ProxyStateInstance { + instance_id: inst.id.clone(), + app_id: inst.app_id.clone(), + ip: inst.ip.to_string(), + public_key: inst.public_key.clone(), + reg_time, + } + }) + .collect(); + + // Get all allocated addresses + let allocated_addresses: Vec = state + .state + .allocated_addresses + .iter() + .map(|ip| ip.to_string()) + .collect(); + + Ok(DebugProxyStateResponse { + instances, + allocated_addresses, + }) + } +} + +impl RpcCall for DebugRpcHandler { + type PrpcService = DebugServer; + + fn construct(context: CallContext<'_, Proxy>) -> Result { + Ok(DebugRpcHandler::new(context.state.clone())) + } +} diff --git a/gateway/src/distributed_certbot.rs b/gateway/src/distributed_certbot.rs new file mode 100644 index 00000000..2d091fc5 --- /dev/null +++ b/gateway/src/distributed_certbot.rs @@ -0,0 +1,520 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Multi-domain certificate management using WaveKV for synchronization. +//! +//! This module provides distributed certificate management for multiple domains +//! with dynamic DNS credential configuration and attestation storage. + +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use anyhow::{bail, Context, Result}; +use certbot::{AcmeClient, Dns01Client}; +use dstack_guest_agent_rpc::RawQuoteArgs; +use ra_tls::attestation::QuoteContentType; +use ra_tls::rcgen::KeyPair; +use tracing::{error, info, warn}; + +use crate::cert_store::CertResolver; +use crate::kv::{ + AcmeAttestation, CertAttestation, CertCredentials, CertData, DnsProvider, KvStore, + ZtDomainConfig, +}; + +/// Lock timeout for certificate renewal (10 minutes) +const RENEW_LOCK_TIMEOUT_SECS: u64 = 600; + +/// Default ACME URL (Let's Encrypt production) +const DEFAULT_ACME_URL: &str = "https://acme-v02.api.letsencrypt.org/directory"; + +/// Multi-domain certificate manager +pub struct DistributedCertBot { + kv_store: Arc, + cert_resolver: Arc, +} + +impl DistributedCertBot { + pub fn new(kv_store: Arc, cert_resolver: Arc) -> Self { + Self { + kv_store, + cert_resolver, + } + } + + /// Get the current certbot configuration from KV store + fn config(&self) -> crate::kv::GlobalCertbotConfig { + self.kv_store.get_certbot_config() + } + + /// Initialize all ZT-Domain certificates + pub async fn init_all(&self) -> Result<()> { + let configs = self.kv_store.list_zt_domain_configs(); + for config in configs { + if let Err(e) = self.init_domain(&config.domain).await { + error!("cert[{}]: failed to initialize: {}", config.domain, e); + } + } + Ok(()) + } + + /// Initialize certificate for a specific domain + pub async fn init_domain(&self, domain: &str) -> Result<()> { + // First, try to load from KvStore (synced from other nodes) + if let Some(cert_data) = self.kv_store.get_cert_data(domain) { + let now = now_secs(); + if cert_data.not_after > now { + info!( + domain, + "loaded from KvStore (issued by node {}, expires in {} days)", + cert_data.issued_by, + (cert_data.not_after - now) / 86400 + ); + self.cert_resolver.update_cert(domain, &cert_data)?; + return Ok(()); + } + info!(domain, "KvStore certificate expired, will request new one"); + } + + // No valid cert, need to request new one + info!(domain, "no valid certificate found, requesting from ACME"); + self.request_new_cert(domain).await + } + + /// Try to renew all ZT-Domain certificates + pub async fn try_renew_all(&self) -> Result<()> { + let configs = self.kv_store.list_zt_domain_configs(); + for config in configs { + if let Err(e) = self.try_renew(&config.domain, false).await { + error!("cert[{}]: failed to renew: {}", config.domain, e); + } + } + Ok(()) + } + + /// Try to renew certificate for a specific domain if needed + #[tracing::instrument(skip(self))] + pub async fn try_renew(&self, domain: &str, force: bool) -> Result { + // Check if config exists + let config = self + .kv_store + .get_zt_domain_config(domain) + .context("ZT-Domain config not found")?; + + // Check if renewal is needed + let cert_data = self.kv_store.get_cert_data(domain); + let needs_renew = if force { + true + } else if let Some(ref data) = cert_data { + let now = now_secs(); + let expires_in = data.not_after.saturating_sub(now); + expires_in < self.config().renew_before_expiration.as_secs() + } else { + true + }; + + if !needs_renew { + info!("does not need renewal"); + return Ok(false); + } + + // Try to acquire lock + if !self + .kv_store + .try_acquire_cert_lock(domain, RENEW_LOCK_TIMEOUT_SECS) + { + info!("another node is renewing, skipping"); + return Ok(false); + } + + info!("acquired renew lock, starting renewal"); + + // Perform renewal or initial issuance + let result = if cert_data.is_some() { + self.do_renew(domain, &config).await + } else { + // No existing certificate, request new one + info!("no existing certificate, requesting new one"); + self.do_request_new(domain, &config).await.map(|_| true) + }; + + // Release lock regardless of result + if let Err(e) = self.kv_store.release_cert_lock(domain) { + error!("failed to release lock: {e}"); + } + + result + } + + /// Request new certificate for a domain + #[tracing::instrument(skip(self))] + async fn request_new_cert(&self, domain: &str) -> Result<()> { + let config = self + .kv_store + .get_zt_domain_config(domain) + .context("ZT-Domain config not found")?; + + // Try to acquire lock first + if !self + .kv_store + .try_acquire_cert_lock(domain, RENEW_LOCK_TIMEOUT_SECS) + { + // Another node is requesting, wait for it + info!("another node is requesting, waiting..."); + tokio::time::sleep(Duration::from_secs(30)).await; + if let Some(cert_data) = self.kv_store.get_cert_data(domain) { + self.cert_resolver.update_cert(domain, &cert_data)?; + return Ok(()); + } + bail!("failed to get certificate from KvStore after waiting"); + } + + let result = self.do_request_new(domain, &config).await; + + if let Err(e) = self.kv_store.release_cert_lock(domain) { + error!("failed to release lock: {e}"); + } + + result + } + + async fn do_request_new(&self, domain: &str, config: &ZtDomainConfig) -> Result<()> { + let acme_client = self.get_or_create_acme_client(domain, config).await?; + + // Generate new key pair (always use new key for security) + let key = KeyPair::generate().context("failed to generate key")?; + let key_pem = key.serialize_pem(); + let public_key_der = key.public_key_der(); + + // Request wildcard certificate (domain in config is base domain, cert is *.domain) + let wildcard_domain = format!("*.{}", domain); + info!( + "requesting new certificate from ACME for {}...", + wildcard_domain + ); + let cert_pem = tokio::time::timeout( + self.config().renew_timeout, + acme_client.request_new_certificate(&key_pem, &[wildcard_domain]), + ) + .await + .context("certificate request timed out")? + .context("failed to request new certificate")?; + + let not_after = get_cert_expiry(&cert_pem).context("failed to parse certificate expiry")?; + + // Save certificate to KvStore + self.save_cert_to_kvstore(domain, &cert_pem, &key_pem, not_after)?; + info!("new certificate obtained from ACME, saved to KvStore"); + + // Generate and save attestation + self.generate_and_save_attestation(domain, &public_key_der) + .await?; + + // Load into memory cert store + let cert_data = CertData { + cert_pem, + key_pem, + not_after, + issued_by: self.kv_store.my_node_id(), + issued_at: now_secs(), + }; + self.cert_resolver.update_cert(domain, &cert_data)?; + + info!( + "new certificate loaded (expires in {} days)", + (not_after - now_secs()) / 86400 + ); + Ok(()) + } + + async fn do_renew(&self, domain: &str, config: &ZtDomainConfig) -> Result { + let acme_client = self.get_or_create_acme_client(domain, config).await?; + + // Generate new key pair (always use new key for each renewal) + let key = KeyPair::generate().context("failed to generate key")?; + let key_pem = key.serialize_pem(); + let public_key_der = key.public_key_der(); + + // Verify there's a current cert (for audit trail, even though we don't use its key) + if self.kv_store.get_cert_data(domain).is_none() { + bail!("no current certificate to renew"); + } + + // Renew with new key (request wildcard certificate) + let wildcard_domain = format!("*.{}", domain); + info!( + "renewing certificate with new key from ACME for {}...", + wildcard_domain + ); + let new_cert_pem = tokio::time::timeout( + self.config().renew_timeout, + // Note: we request a new cert rather than renew, since we have a new key + acme_client.request_new_certificate(&key_pem, &[wildcard_domain]), + ) + .await + .context("certificate renewal timed out")? + .context("failed to renew certificate")?; + + let not_after = + get_cert_expiry(&new_cert_pem).context("failed to parse certificate expiry")?; + + // Save to KvStore + self.save_cert_to_kvstore(domain, &new_cert_pem, &key_pem, not_after)?; + info!("renewed certificate saved to KvStore"); + + // Generate and save attestation + self.generate_and_save_attestation(domain, &public_key_der) + .await?; + + // Load into memory cert store + let cert_data = CertData { + cert_pem: new_cert_pem, + key_pem, + not_after, + issued_by: self.kv_store.my_node_id(), + issued_at: now_secs(), + }; + self.cert_resolver.update_cert(domain, &cert_data)?; + + info!( + "renewed certificate loaded (expires in {} days)", + (not_after - now_secs()) / 86400 + ); + Ok(true) + } + + async fn get_or_create_acme_client( + &self, + domain: &str, + config: &ZtDomainConfig, + ) -> Result { + // Get DNS credential (from config or default) + let dns_cred = if let Some(ref cred_id) = config.dns_cred_id { + self.kv_store + .get_dns_credential(cred_id) + .context("specified DNS credential not found")? + } else { + self.kv_store + .get_default_dns_credential() + .context("no default DNS credential configured")? + }; + + // Create DNS client based on provider + let dns01_client = match &dns_cred.provider { + DnsProvider::Cloudflare { api_token, api_url } => { + Dns01Client::new_cloudflare(domain.to_string(), api_token.clone(), api_url.clone()) + .await? + } + }; + + // Use ACME URL from certbot config, fall back to default if not set + let config = self.config(); + let acme_url = if config.acme_url.is_empty() { + DEFAULT_ACME_URL + } else { + &config.acme_url + }; + + // Try to load global ACME credentials from KvStore + if let Some(creds) = self.kv_store.get_acme_credentials() { + if acme_url_matches(&creds.acme_credentials, acme_url) { + info!("loaded global ACME account credentials from KvStore"); + return AcmeClient::load( + dns01_client, + &creds.acme_credentials, + dns_cred.max_dns_wait, + dns_cred.dns_txt_ttl, + ) + .await + .context("failed to load ACME client from KvStore credentials"); + } + warn!("ACME URL mismatch in KvStore credentials, will create new account"); + } + + // Create new global ACME account + info!("creating new global ACME account at {acme_url}"); + let client = AcmeClient::new_account( + acme_url, + dns01_client, + dns_cred.max_dns_wait, + dns_cred.dns_txt_ttl, + ) + .await + .context("failed to create new ACME account")?; + + let creds_json = client + .dump_credentials() + .context("failed to dump ACME credentials")?; + + // Save global ACME credentials to KvStore + self.kv_store.save_acme_credentials(&CertCredentials { + acme_credentials: creds_json.clone(), + })?; + + // Generate and save ACME account attestation + if let Some(account_uri) = extract_account_uri(&creds_json) { + self.generate_and_save_acme_attestation(&account_uri) + .await?; + } + + Ok(client) + } + + async fn generate_and_save_acme_attestation(&self, account_uri: &str) -> Result<()> { + let agent = match crate::dstack_agent() { + Ok(a) => a, + Err(e) => { + warn!("failed to create dstack agent: {e}"); + return Ok(()); + } + }; + + let report_data = QuoteContentType::Custom("acme-account") + .to_report_data(account_uri.as_bytes()) + .to_vec(); + + // Get quote + let quote = match agent + .get_quote(RawQuoteArgs { + report_data: report_data.clone(), + }) + .await + { + Ok(resp) => serde_json::to_string(&resp).unwrap_or_default(), + Err(e) => { + warn!("failed to get TDX quote for ACME account: {e}"); + return Ok(()); + } + }; + + // Get attestation + let attestation_str = match agent.attest(RawQuoteArgs { report_data }).await { + Ok(resp) => serde_json::to_string(&resp).unwrap_or_default(), + Err(e) => { + warn!("failed to get attestation for ACME account: {e}"); + String::new() + } + }; + + let attestation = AcmeAttestation { + account_uri: account_uri.to_string(), + quote, + attestation: attestation_str, + generated_by: self.kv_store.my_node_id(), + generated_at: now_secs(), + }; + + self.kv_store.save_acme_attestation(&attestation)?; + info!("ACME account attestation saved to KvStore"); + Ok(()) + } + + fn save_cert_to_kvstore( + &self, + domain: &str, + cert_pem: &str, + key_pem: &str, + not_after: u64, + ) -> Result<()> { + let cert_data = CertData { + cert_pem: cert_pem.to_string(), + key_pem: key_pem.to_string(), + not_after, + issued_by: self.kv_store.my_node_id(), + issued_at: now_secs(), + }; + self.kv_store.save_cert_data(domain, &cert_data) + } + + async fn generate_and_save_attestation( + &self, + domain: &str, + public_key_der: &[u8], + ) -> Result<()> { + let agent = match crate::dstack_agent() { + Ok(a) => a, + Err(e) => { + warn!(domain, "failed to create dstack agent: {e}"); + return Ok(()); + } + }; + + let report_data = QuoteContentType::Custom("zt-cert") + .to_report_data(public_key_der) + .to_vec(); + + // Get quote + let quote = match agent + .get_quote(RawQuoteArgs { + report_data: report_data.clone(), + }) + .await + { + Ok(resp) => serde_json::to_string(&resp).unwrap_or_default(), + Err(e) => { + warn!(domain, "failed to generate TDX quote: {e}"); + return Ok(()); + } + }; + + // Get attestation + let attestation = match agent.attest(RawQuoteArgs { report_data }).await { + Ok(resp) => serde_json::to_string(&resp).unwrap_or_default(), + Err(e) => { + warn!(domain, "failed to get attestation: {e}"); + String::new() + } + }; + + let attestation = CertAttestation { + public_key: public_key_der.to_vec(), + quote, + attestation, + generated_by: self.kv_store.my_node_id(), + generated_at: now_secs(), + }; + + self.kv_store.save_cert_attestation(domain, &attestation)?; + info!(domain, "attestation saved to KvStore"); + Ok(()) + } +} + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +fn get_cert_expiry(cert_pem: &str) -> Option { + use x509_parser::prelude::*; + let pem = Pem::iter_from_buffer(cert_pem.as_bytes()).next()?.ok()?; + let cert = pem.parse_x509().ok()?; + Some(cert.validity().not_after.timestamp() as u64) +} + +fn acme_url_matches(credentials_json: &str, expected_url: &str) -> bool { + #[derive(serde::Deserialize)] + struct Creds { + #[serde(default)] + acme_url: String, + } + serde_json::from_str::(credentials_json) + .map(|c| c.acme_url == expected_url) + .unwrap_or(false) +} + +/// Extract account_id (URI) from ACME credentials JSON +fn extract_account_uri(credentials_json: &str) -> Option { + #[derive(serde::Deserialize)] + struct Creds { + #[serde(default)] + account_id: String, + } + serde_json::from_str::(credentials_json) + .ok() + .filter(|c| !c.account_id.is_empty()) + .map(|c| c.account_id) +} diff --git a/gateway/src/kv/https_client.rs b/gateway/src/kv/https_client.rs new file mode 100644 index 00000000..d0d034a9 --- /dev/null +++ b/gateway/src/kv/https_client.rs @@ -0,0 +1,322 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! HTTPS client with mTLS and custom certificate verification during TLS handshake. + +use std::fmt::Debug; +use std::io::{Read, Write}; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; +use http_body_util::{BodyExt, Full}; +use hyper::body::Bytes; +use hyper_rustls::HttpsConnectorBuilder; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, +}; +use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; +use rustls::pki_types::pem::PemObject; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; +use rustls::{DigitallySignedStruct, SignatureScheme}; +use serde::{de::DeserializeOwned, Serialize}; + +use super::{decode, encode}; + +/// Custom certificate validator trait for TLS handshake verification. +/// +/// Implementations can perform additional validation on the peer certificate +/// during the TLS handshake, before any application data is sent. +pub trait CertValidator: Debug + Send + Sync + 'static { + /// Validate the peer certificate. + /// + /// Called after standard X.509 chain verification succeeds. + /// Return `Ok(())` to accept the certificate, or `Err` to reject. + fn validate(&self, cert_der: &[u8]) -> Result<(), String>; +} + +/// TLS configuration for mTLS with optional custom certificate validation +#[derive(Clone)] +pub struct HttpsClientConfig { + pub cert_path: String, + pub key_path: String, + pub ca_cert_path: String, + /// Optional custom certificate validator (checked during TLS handshake) + pub cert_validator: Option>, +} + +/// Wrapper that adapts a CertValidator to rustls ServerCertVerifier +#[derive(Debug)] +struct CustomCertVerifier { + validator: Arc, + root_store: Arc, +} + +impl CustomCertVerifier { + fn new( + validator: Arc, + ca_cert_der: CertificateDer<'static>, + ) -> Result { + let mut root_store = rustls::RootCertStore::empty(); + root_store + .add(ca_cert_der) + .context("failed to add CA cert to root store")?; + Ok(Self { + validator, + root_store: Arc::new(root_store), + }) + } +} + +impl ServerCertVerifier for CustomCertVerifier { + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &ServerName<'_>, + _ocsp_response: &[u8], + now: UnixTime, + ) -> Result { + // First, do standard certificate verification + let verifier = rustls::client::WebPkiServerVerifier::builder(self.root_store.clone()) + .build() + .map_err(|e| rustls::Error::General(format!("failed to build verifier: {e}")))?; + + verifier.verify_server_cert(end_entity, intermediates, server_name, &[], now)?; + + // Then run custom validation + self.validator + .validate(end_entity.as_ref()) + .map_err(rustls::Error::General)?; + + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &rustls::crypto::ring::default_provider().signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &rustls::crypto::ring::default_provider().signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + rustls::crypto::ring::default_provider() + .signature_verification_algorithms + .supported_schemes() + } +} + +type HyperClient = Client, Full>; + +/// HTTPS client with mTLS and optional custom certificate validation. +/// +/// When a `cert_validator` is set in `TlsConfig`, the client runs the validator +/// during the TLS handshake, before any application data is sent. +#[derive(Clone)] +pub struct HttpsClient { + client: HyperClient, +} + +impl HttpsClient { + /// Create a new HTTPS client with mTLS configuration + pub fn new(tls: &HttpsClientConfig) -> Result { + // Load client certificate and key + let cert_pem = std::fs::read(&tls.cert_path) + .with_context(|| format!("failed to read TLS cert from {}", tls.cert_path))?; + let key_pem = std::fs::read(&tls.key_path) + .with_context(|| format!("failed to read TLS key from {}", tls.key_path))?; + + let certs: Vec> = CertificateDer::pem_slice_iter(&cert_pem) + .collect::>() + .context("failed to parse client certs")?; + + let key = PrivateKeyDer::from_pem_slice(&key_pem).context("failed to parse private key")?; + + // Load CA certificate + let ca_cert_pem = std::fs::read(&tls.ca_cert_path) + .with_context(|| format!("failed to read CA cert from {}", tls.ca_cert_path))?; + let ca_certs: Vec> = CertificateDer::pem_slice_iter(&ca_cert_pem) + .collect::>() + .context("failed to parse CA certs")?; + let ca_cert = ca_certs + .into_iter() + .next() + .context("no CA certificate found")?; + + // Build rustls config with custom verifier if validator is provided + let tls_config_builder = rustls::ClientConfig::builder(); + + let tls_config = if let Some(ref validator) = tls.cert_validator { + let verifier = CustomCertVerifier::new(validator.clone(), ca_cert)?; + tls_config_builder + .dangerous() + .with_custom_certificate_verifier(Arc::new(verifier)) + } else { + // Standard verification without custom validator + let mut root_store = rustls::RootCertStore::empty(); + root_store.add(ca_cert).context("failed to add CA cert")?; + tls_config_builder.with_root_certificates(root_store) + } + .with_client_auth_cert(certs, key) + .context("failed to set client auth cert")?; + + let https = HttpsConnectorBuilder::new() + .with_tls_config(tls_config) + .https_only() + .enable_http1() + .build(); + + let client = Client::builder(TokioExecutor::new()).build(https); + Ok(Self { client }) + } + + /// Send a POST request with JSON body and receive JSON response + pub async fn post_json( + &self, + url: &str, + body: &T, + ) -> Result { + let body = serde_json::to_vec(body).context("failed to serialize request body")?; + + let request = hyper::Request::builder() + .method(hyper::Method::POST) + .uri(url) + .header("content-type", "application/json") + .body(Full::new(Bytes::from(body))) + .context("failed to build request")?; + + let response = self + .client + .request(request) + .await + .with_context(|| format!("failed to send request to {url}"))?; + + if !response.status().is_success() { + anyhow::bail!("request failed: {}", response.status()); + } + + let body = response + .into_body() + .collect() + .await + .context("failed to read response body")? + .to_bytes(); + + serde_json::from_slice(&body).context("failed to parse response") + } + + /// Send a POST request with msgpack + gzip encoded body and receive msgpack + gzip response + pub async fn post_compressed_msg( + &self, + url: &str, + body: &T, + ) -> Result { + let encoded = encode(body).context("failed to encode request body")?; + + // Compress with gzip + let mut encoder = GzEncoder::new(Vec::new(), Compression::fast()); + encoder + .write_all(&encoded) + .context("failed to compress request")?; + let compressed = encoder.finish().context("failed to finish compression")?; + + let request = hyper::Request::builder() + .method(hyper::Method::POST) + .uri(url) + .header("content-type", "application/x-msgpack-gz") + .body(Full::new(Bytes::from(compressed))) + .context("failed to build request")?; + + let response = self + .client + .request(request) + .await + .with_context(|| format!("failed to send request to {url}"))?; + + if !response.status().is_success() { + anyhow::bail!("request failed: {}", response.status()); + } + + let body = response + .into_body() + .collect() + .await + .context("failed to read response body")? + .to_bytes(); + + // Decompress + let mut decoder = GzDecoder::new(body.as_ref()); + let mut decompressed = Vec::new(); + decoder + .read_to_end(&mut decompressed) + .context("failed to decompress response")?; + + decode(&decompressed).context("failed to decode response") + } +} + +// ============================================================================ +// Built-in validators +// ============================================================================ + +/// Validator that checks the peer certificate contains a specific app_id. +#[derive(Debug)] +pub struct AppIdValidator { + expected_app_id: Vec, +} + +impl AppIdValidator { + pub fn new(expected_app_id: Vec) -> Self { + Self { expected_app_id } + } +} + +impl CertValidator for AppIdValidator { + fn validate(&self, cert_der: &[u8]) -> Result<(), String> { + use ra_tls::traits::CertExt; + + let (_, cert) = x509_parser::parse_x509_certificate(cert_der) + .map_err(|e| format!("failed to parse certificate: {e}"))?; + + let peer_app_id = cert + .get_app_id() + .map_err(|e| format!("failed to get app_id: {e}"))?; + + let Some(peer_app_id) = peer_app_id else { + return Err("peer certificate does not contain app_id".into()); + }; + + if peer_app_id != self.expected_app_id { + return Err(format!( + "app_id mismatch: expected {}, got {}", + hex::encode(&self.expected_app_id), + hex::encode(&peer_app_id) + )); + } + + Ok(()) + } +} diff --git a/gateway/src/kv/mod.rs b/gateway/src/kv/mod.rs new file mode 100644 index 00000000..80bf168d --- /dev/null +++ b/gateway/src/kv/mod.rs @@ -0,0 +1,987 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! WaveKV-based sync layer for dstack-gateway. +//! +//! This module provides synchronization between gateway nodes. The local ProxyState +//! remains the primary data store for fast reads, while WaveKV handles cross-node sync. +//! +//! Key schema: +//! +//! # Persistent WaveKV (needs persistence + sync) +//! - `inst/{instance_id}` → InstanceData +//! - `node/{node_id}` → NodeData +//! - `dns_cred/{cred_id}` → DnsCredential +//! - `dns_cred_default` → cred_id (default credential ID) +//! - `global/certbot_config` → GlobalCertbotConfig +//! - `cert/{domain}/config` → ZtDomainConfig +//! - `cert/{domain}/data` → CertData +//! - `global/acme_credentials` → CertCredentials (shared ACME account) +//! - `global/acme_attestation` → AcmeAttestation (TDX quote of ACME account URI) +//! - `cert/{domain}/lock` → CertRenewLock +//! - `cert/{domain}/attestation/latest` → CertAttestation +//! - `cert/{domain}/attestation/{timestamp}` → CertAttestation (history) +//! +//! # Ephemeral WaveKV (no persistence, sync only) +//! - `conn/{instance_id}/{node_id}` → u64 (connection count) +//! - `last_seen/inst/{instance_id}` → u64 (timestamp) +//! - `last_seen/node/{node_id}/{seen_by_node_id}` → u64 (timestamp) + +mod https_client; +mod sync_service; + +pub use https_client::{AppIdValidator, HttpsClientConfig}; +pub use sync_service::{fetch_peers_from_bootnode, WaveKvSyncService}; +use tracing::warn; + +use std::{collections::BTreeMap, net::Ipv4Addr, path::Path, time::Duration}; + +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use tokio::sync::watch; +use wavekv::{node::NodeState, types::NodeId, Node}; + +/// Instance core data (persistent) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct InstanceData { + pub app_id: String, + pub ip: Ipv4Addr, + pub public_key: String, + pub reg_time: u64, +} + +/// Gateway node status (stored separately for independent updates) +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "snake_case")] +pub enum NodeStatus { + #[default] + Up, + Down, +} + +/// Gateway node data (persistent, rarely changes) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct NodeData { + pub uuid: Vec, + pub url: String, + pub wg_public_key: String, + pub wg_endpoint: String, + pub wg_ip: String, +} + +/// Certificate credentials (ACME account) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CertCredentials { + pub acme_credentials: String, +} + +/// ACME account attestation (TDX Quote of account URI) +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct AcmeAttestation { + /// ACME account URI + pub account_uri: String, + /// TDX Quote (JSON serialized) + #[serde(default)] + pub quote: String, + /// Full attestation (JSON serialized) + #[serde(default)] + pub attestation: String, + /// Node that generated this attestation + #[serde(default)] + pub generated_by: NodeId, + /// Timestamp when this attestation was generated + #[serde(default)] + pub generated_at: u64, +} + +/// Certificate data (cert + key) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CertData { + pub cert_pem: String, + pub key_pem: String, + pub not_after: u64, + pub issued_by: NodeId, + pub issued_at: u64, +} + +/// Certificate renew lock +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CertRenewLock { + pub started_at: u64, + pub started_by: NodeId, +} + +/// Certificate attestation (TDX Quote of certificate public key) +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct CertAttestation { + /// Certificate public key (DER encoded) + pub public_key: Vec, + /// TDX Quote (JSON serialized) + #[serde(default)] + pub quote: String, + /// Full attestation (JSON serialized) + #[serde(default)] + pub attestation: String, + /// Node that generated this attestation + #[serde(default)] + pub generated_by: NodeId, + /// Timestamp when this attestation was generated + #[serde(default)] + pub generated_at: u64, +} + +/// DNS credential configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DnsCredential { + /// Unique identifier + pub id: String, + /// Display name + pub name: String, + /// DNS provider configuration + pub provider: DnsProvider, + /// Maximum DNS wait time + #[serde(with = "serde_duration")] + pub max_dns_wait: Duration, + /// DNS TXT record TTL + pub dns_txt_ttl: u32, + /// Creation timestamp + pub created_at: u64, + /// Last update timestamp + pub updated_at: u64, +} + +/// DNS provider configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum DnsProvider { + Cloudflare { + api_token: String, + /// Cloudflare API URL (defaults to https://api.cloudflare.com/client/v4 if not set) + #[serde(default, skip_serializing_if = "Option::is_none")] + api_url: Option, + }, + // Future providers can be added here +} + +/// ZT-Domain configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ZtDomainConfig { + /// Base domain name (e.g., "app.example.com") + /// Certificate will be issued for "*.{domain}" automatically + pub domain: String, + /// DNS credential ID to use (None = use default) + pub dns_cred_id: Option, + /// Port this domain serves on (e.g., 443) + #[serde(default)] + pub port: u16, + /// Node binding (None = any node can serve this domain) + /// If set, only this node will serve this domain + #[serde(default)] + pub node: Option, + /// Priority for default base_domain selection (higher = preferred) + /// The domain with highest priority is returned as the default base_domain in APIs + #[serde(default)] + pub priority: i32, +} + +/// Global certbot configuration (stored in KV, synced across nodes) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GlobalCertbotConfig { + /// Interval between renewal checks + #[serde(with = "serde_duration")] + pub renew_interval: Duration, + /// Time before expiration to trigger renewal (e.g., 30 days) + #[serde(with = "serde_duration")] + pub renew_before_expiration: Duration, + /// Timeout for certificate renewal operations + #[serde(with = "serde_duration")] + pub renew_timeout: Duration, + /// ACME server URL (None means use default Let's Encrypt production) + pub acme_url: String, +} + +impl Default for GlobalCertbotConfig { + fn default() -> Self { + Self { + renew_interval: Duration::from_secs(12 * 3600), // 12 hours + renew_before_expiration: Duration::from_secs(30 * 86400), // 30 days + renew_timeout: Duration::from_secs(300), // 5 minutes + acme_url: Default::default(), // default Let's Encrypt + } + } +} + +// Key prefixes and builders +pub mod keys { + use super::NodeId; + + pub const INST_PREFIX: &str = "inst/"; + pub const NODE_PREFIX: &str = "node/"; + pub const NODE_INFO_PREFIX: &str = "node/info/"; + pub const NODE_STATUS_PREFIX: &str = "node/status/"; + pub const CONN_PREFIX: &str = "conn/"; + pub const HANDSHAKE_PREFIX: &str = "handshake/"; + pub const LAST_SEEN_NODE_PREFIX: &str = "last_seen/node/"; + pub const PEER_ADDR_PREFIX: &str = "__peer_addr/"; + pub const CERT_PREFIX: &str = "cert/"; + pub const DNS_CRED_PREFIX: &str = "dns_cred/"; + pub const DNS_CRED_DEFAULT: &str = "dns_cred_default"; + pub const GLOBAL_CERTBOT_CONFIG: &str = "global/certbot_config"; + pub const GLOBAL_ACME_CREDENTIALS: &str = "global/acme_credentials"; + pub const GLOBAL_ACME_ATTESTATION: &str = "global/acme_attestation"; + + pub fn inst(instance_id: &str) -> String { + format!("{INST_PREFIX}{instance_id}") + } + + pub fn node_info(node_id: NodeId) -> String { + format!("{NODE_INFO_PREFIX}{node_id}") + } + + pub fn node_status(node_id: NodeId) -> String { + format!("{NODE_STATUS_PREFIX}{node_id}") + } + + pub fn conn(instance_id: &str, node_id: NodeId) -> String { + format!("{CONN_PREFIX}{instance_id}/{node_id}") + } + + /// Key for instance handshake timestamp observed by a specific node + /// Format: handshake/{instance_id}/{observer_node_id} + pub fn handshake(instance_id: &str, observer_node_id: NodeId) -> String { + format!("{HANDSHAKE_PREFIX}{instance_id}/{observer_node_id}") + } + + /// Prefix to iterate all handshake observations for an instance + pub fn handshake_prefix(instance_id: &str) -> String { + format!("{HANDSHAKE_PREFIX}{instance_id}/") + } + + pub fn last_seen_node(node_id: NodeId, seen_by: NodeId) -> String { + format!("{LAST_SEEN_NODE_PREFIX}{node_id}/{seen_by}") + } + + pub fn last_seen_node_prefix(node_id: NodeId) -> String { + format!("{LAST_SEEN_NODE_PREFIX}{node_id}/") + } + + pub fn peer_addr(node_id: NodeId) -> String { + format!("{PEER_ADDR_PREFIX}{node_id}") + } + + // ==================== DNS Credential keys ==================== + + /// Key for a DNS credential + pub fn dns_cred(cred_id: &str) -> String { + format!("{DNS_CRED_PREFIX}{cred_id}") + } + + // ==================== Certificate keys (per domain) ==================== + + /// Key for ZT-Domain configuration + pub fn zt_domain_config(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/config") + } + + /// Key for domain certificate data (cert + key) + pub fn cert_data(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/data") + } + + /// Key for domain certificate renew lock + pub fn cert_lock(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/lock") + } + + /// Key for latest attestation of a domain + pub fn cert_attestation_latest(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/attestation/latest") + } + + /// Key for historical attestation of a domain + pub fn cert_attestation_history(domain: &str, timestamp: u64) -> String { + format!("{CERT_PREFIX}{domain}/attestation/{timestamp}") + } + + /// Prefix for all attestations of a domain (for iteration) + pub fn cert_attestation_prefix(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/attestation/") + } + + /// Parse domain from cert/{domain}/... key + pub fn parse_cert_domain(key: &str) -> Option<&str> { + let rest = key.strip_prefix(CERT_PREFIX)?; + rest.split('/').next() + } + + // ==================== Parse helpers ==================== + + /// Parse instance_id from key + pub fn parse_inst_key(key: &str) -> Option<&str> { + key.strip_prefix(INST_PREFIX) + } + + /// Parse node_id from node/info/{node_id} key + pub fn parse_node_info_key(key: &str) -> Option { + key.strip_prefix(NODE_INFO_PREFIX)?.parse().ok() + } +} + +pub fn encode(value: &T) -> Result> { + rmp_serde::encode::to_vec(value).context("failed to encode value") +} + +pub fn decode Deserialize<'de>>(bytes: &[u8]) -> Result { + rmp_serde::decode::from_slice(bytes).context("failed to decode value") +} + +trait GetPutCodec { + fn decode serde::Deserialize<'de>>(&self, key: &str) -> Option; + fn put_encoded(&mut self, key: String, value: &T) -> Result<()>; + fn iter_decoded serde::Deserialize<'de>>( + &self, + prefix: &str, + ) -> impl Iterator; + fn iter_decoded_values serde::Deserialize<'de>>( + &self, + prefix: &str, + ) -> impl Iterator; +} + +impl GetPutCodec for NodeState { + fn decode serde::Deserialize<'de>>(&self, key: &str) -> Option { + self.get(key) + .and_then(|entry| match decode(entry.value.as_ref()?) { + Ok(value) => Some(value), + Err(e) => { + warn!("failed to decode value for key {key}: {e:?}"); + None + } + }) + } + + fn put_encoded(&mut self, key: String, value: &T) -> Result<()> { + self.put(key.clone(), encode(value)?) + .with_context(|| format!("failed to put key {key}"))?; + Ok(()) + } + + fn iter_decoded serde::Deserialize<'de>>( + &self, + prefix: &str, + ) -> impl Iterator { + self.iter_by_prefix(prefix).filter_map(|(key, entry)| { + let value = match decode(entry.value.as_ref()?) { + Ok(value) => value, + Err(e) => { + warn!("failed to decode value for key {key}: {e:?}"); + return None; + } + }; + Some((key.to_string(), value)) + }) + } + + fn iter_decoded_values serde::Deserialize<'de>>( + &self, + prefix: &str, + ) -> impl Iterator { + self.iter_by_prefix(prefix).filter_map(|(key, entry)| { + let value = match decode(entry.value.as_ref()?) { + Ok(value) => value, + Err(e) => { + warn!("failed to decode value for key {key}: {e:?}"); + return None; + } + }; + Some(value) + }) + } +} + +/// Sync store wrapping two WaveKV Nodes (persistent and ephemeral). +/// +/// This is the sync layer - not the primary data store. +/// ProxyState remains in memory for fast reads. +#[derive(Clone)] +pub struct KvStore { + /// Persistent WaveKV Node (with WAL) + persistent: Node, + /// Ephemeral WaveKV Node (in-memory only) + ephemeral: Node, + /// This gateway's node ID + my_node_id: NodeId, +} + +impl KvStore { + /// Create a new sync store + pub fn new( + my_node_id: NodeId, + peer_ids: Vec, + data_dir: impl AsRef, + ) -> Result { + let persistent = + Node::new_with_persistence(my_node_id, peer_ids.clone(), data_dir.as_ref()) + .context("failed to create persistent wavekv node")?; + + let ephemeral = Node::new(my_node_id, peer_ids); + + Ok(Self { + persistent, + ephemeral, + my_node_id, + }) + } + + pub fn my_node_id(&self) -> NodeId { + self.my_node_id + } + + pub fn persistent(&self) -> &Node { + &self.persistent + } + + pub fn ephemeral(&self) -> &Node { + &self.ephemeral + } + + // ==================== Instance Sync ==================== + + /// Sync instance data to other nodes + pub fn sync_instance(&self, instance_id: &str, data: &InstanceData) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::inst(instance_id), data) + } + + /// Sync instance deletion to other nodes + pub fn sync_delete_instance(&self, instance_id: &str) -> Result<()> { + self.persistent.write().delete(keys::inst(instance_id))?; + self.ephemeral + .write() + .delete(keys::conn(instance_id, self.my_node_id))?; + // Delete this node's handshake record + self.ephemeral + .write() + .delete(keys::handshake(instance_id, self.my_node_id))?; + Ok(()) + } + + /// Load all instances from sync store (for initial sync on startup) + pub fn load_all_instances(&self) -> BTreeMap { + self.persistent + .read() + .iter_decoded(keys::INST_PREFIX) + .filter_map(|(key, data)| { + let instance_id = keys::parse_inst_key(&key)?; + Some((instance_id.into(), data)) + }) + .collect() + } + + // ==================== Node Sync ==================== + + /// Sync node data to other nodes + pub fn sync_node(&self, node_id: NodeId, data: &NodeData) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::node_info(node_id), data) + } + + /// Load all nodes from sync store + pub fn load_all_nodes(&self) -> BTreeMap { + self.persistent + .read() + .iter_decoded(keys::NODE_INFO_PREFIX) + .filter_map(|(key, data)| { + let node_id = keys::parse_node_info_key(&key)?; + Some((node_id, data)) + }) + .collect() + } + + // ==================== Node Status Sync ==================== + + /// Set node status (stored separately from NodeData) + pub fn set_node_status(&self, node_id: NodeId, status: NodeStatus) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::node_status(node_id), &status)?; + Ok(()) + } + + /// Get node status + pub fn get_node_status(&self, node_id: NodeId) -> NodeStatus { + self.persistent + .read() + .decode(&keys::node_status(node_id)) + .unwrap_or_default() + } + + /// Load all node statuses + pub fn load_all_node_statuses(&self) -> BTreeMap { + self.persistent + .read() + .iter_decoded(keys::NODE_STATUS_PREFIX) + .filter_map(|(key, status)| { + let node_id: NodeId = key.strip_prefix(keys::NODE_STATUS_PREFIX)?.parse().ok()?; + Some((node_id, status)) + }) + .collect() + } + + // ==================== Connection Count Sync ==================== + + /// Sync connection count for an instance (from this node) + pub fn sync_connections(&self, instance_id: &str, count: u64) -> Result<()> { + self.ephemeral + .write() + .put_encoded(keys::conn(instance_id, self.my_node_id), &count)?; + Ok(()) + } + + // ==================== Handshake Sync ==================== + + /// Sync handshake timestamp for an instance (as observed by this node) + pub fn sync_instance_handshake(&self, instance_id: &str, timestamp: u64) -> Result<()> { + self.ephemeral + .write() + .put_encoded(keys::handshake(instance_id, self.my_node_id), ×tamp)?; + Ok(()) + } + + /// Get all handshake observations for an instance (from all nodes) + pub fn get_instance_handshakes(&self, instance_id: &str) -> BTreeMap { + self.ephemeral + .read() + .iter_decoded(&keys::handshake_prefix(instance_id)) + .filter_map(|(key, ts)| { + let suffix = key.strip_prefix(&keys::handshake_prefix(instance_id))?; + let observer: NodeId = suffix.parse().ok()?; + Some((observer, ts)) + }) + .collect() + } + + /// Get the latest handshake timestamp for an instance (max across all nodes) + pub fn get_instance_latest_handshake(&self, instance_id: &str) -> Option { + self.ephemeral + .read() + .iter_decoded_values(&keys::handshake_prefix(instance_id)) + .max() + } + + /// Sync node last_seen (as observed by this node) + pub fn sync_node_last_seen(&self, node_id: NodeId, timestamp: u64) -> Result<()> { + self.ephemeral + .write() + .put_encoded(keys::last_seen_node(node_id, self.my_node_id), ×tamp)?; + Ok(()) + } + + /// Get all observations of a node's last_seen + pub fn get_node_last_seen_by_all(&self, node_id: NodeId) -> BTreeMap { + self.ephemeral + .read() + .iter_decoded(&keys::last_seen_node_prefix(node_id)) + .filter_map(|(key, ts)| { + let suffix = key.strip_prefix(&keys::last_seen_node_prefix(node_id))?; + let seen_by: NodeId = suffix.parse().ok()?; + Some((seen_by, ts)) + }) + .collect() + } + + /// Get the latest last_seen timestamp for a node (max across all observers) + pub fn get_node_latest_last_seen(&self, node_id: NodeId) -> Option { + self.ephemeral + .read() + .iter_decoded_values(&keys::last_seen_node_prefix(node_id)) + .max() + } + + // ==================== Watch for Remote Changes ==================== + + /// Watch for remote instance changes (for updating local ProxyState) + pub fn watch_instances(&self) -> watch::Receiver<()> { + self.persistent.watch_prefix(keys::INST_PREFIX) + } + + /// Watch for remote node changes + pub fn watch_nodes(&self) -> watch::Receiver<()> { + self.persistent.watch_prefix(keys::NODE_PREFIX) + } + + // ==================== Persistence ==================== + + pub fn persist_if_dirty(&self) -> Result { + self.persistent.persist_if_dirty() + } + + // ==================== Peer Management ==================== + + pub fn add_peer(&self, peer_id: NodeId) -> Result<()> { + self.persistent.write().add_peer(peer_id)?; + self.ephemeral.write().add_peer(peer_id)?; + Ok(()) + } + + // ==================== Peer Address (in DB) ==================== + + /// Register a node's sync URL in DB and add to peer list for sync + /// + /// This stores the URL in KvStore (for address lookup) and also adds the node + /// to the wavekv peer list (so SyncManager knows to sync with it). + pub fn register_peer_url(&self, node_id: NodeId, url: &str) -> Result<()> { + // Store URL in persistent KvStore + self.persistent + .write() + .put_encoded(keys::peer_addr(node_id), &url)?; + + let _ = self.add_peer(node_id); + Ok(()) + } + + /// Get a peer's sync URL from DB + pub fn get_peer_url(&self, node_id: NodeId) -> Option { + self.persistent.read().decode(&keys::peer_addr(node_id)) + } + + /// Query the UUID for a given node ID from KvStore + pub fn get_peer_uuid(&self, peer_id: NodeId) -> Option> { + let node_data: NodeData = self.persistent.read().decode(&keys::node_info(peer_id))?; + Some(node_data.uuid) + } + + pub fn update_peer_last_seen(&self, peer_id: NodeId) { + let ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + let key = keys::last_seen_node(peer_id, self.my_node_id); + if let Err(e) = self.ephemeral.write().put_encoded(key, &ts) { + warn!("failed to update peer {peer_id} last_seen: {e}"); + } + } + + /// Get all peer addresses from DB (for debugging/testing) + pub fn get_all_peer_addrs(&self) -> BTreeMap { + self.persistent + .read() + .iter_decoded(keys::PEER_ADDR_PREFIX) + .filter_map(|(key, url)| { + let node_id: NodeId = key.strip_prefix(keys::PEER_ADDR_PREFIX)?.parse().ok()?; + Some((node_id, url)) + }) + .collect() + } + + // ==================== DNS Credential Management ==================== + + /// Get a DNS credential by ID + pub fn get_dns_credential(&self, cred_id: &str) -> Option { + self.persistent.read().decode(&keys::dns_cred(cred_id)) + } + + /// Save a DNS credential + pub fn save_dns_credential(&self, cred: &DnsCredential) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::dns_cred(&cred.id), cred)?; + Ok(()) + } + + /// Delete a DNS credential + pub fn delete_dns_credential(&self, cred_id: &str) -> Result<()> { + self.persistent.write().delete(keys::dns_cred(cred_id))?; + Ok(()) + } + + /// List all DNS credentials + pub fn list_dns_credentials(&self) -> Vec { + self.persistent + .read() + .iter_decoded_values(keys::DNS_CRED_PREFIX) + .collect() + } + + /// Get the default DNS credential ID + pub fn get_default_dns_credential_id(&self) -> Option { + self.persistent.read().decode(keys::DNS_CRED_DEFAULT) + } + + /// Set the default DNS credential ID + pub fn set_default_dns_credential_id(&self, cred_id: &str) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::DNS_CRED_DEFAULT.to_string(), &cred_id)?; + Ok(()) + } + + /// Get the default DNS credential (resolves the ID to the actual credential) + pub fn get_default_dns_credential(&self) -> Option { + let cred_id = self.get_default_dns_credential_id()?; + self.get_dns_credential(&cred_id) + } + + // ==================== Global Certbot Config ==================== + + /// Get global certbot configuration (returns default if not set) + pub fn get_certbot_config(&self) -> GlobalCertbotConfig { + self.persistent + .read() + .decode(keys::GLOBAL_CERTBOT_CONFIG) + .unwrap_or_default() + } + + /// Set global certbot configuration + pub fn set_certbot_config(&self, config: &GlobalCertbotConfig) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::GLOBAL_CERTBOT_CONFIG.to_string(), config)?; + Ok(()) + } + + // ==================== ZT-Domain Config ==================== + + /// Get ZT-Domain configuration + pub fn get_zt_domain_config(&self, domain: &str) -> Option { + self.persistent + .read() + .decode(&keys::zt_domain_config(domain)) + } + + /// Save ZT-Domain configuration + pub fn save_zt_domain_config(&self, config: &ZtDomainConfig) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::zt_domain_config(&config.domain), config)?; + Ok(()) + } + + /// Delete ZT-Domain configuration + pub fn delete_zt_domain_config(&self, domain: &str) -> Result<()> { + self.persistent + .write() + .delete(keys::zt_domain_config(domain))?; + Ok(()) + } + + /// List all ZT-Domain configurations + pub fn list_zt_domain_configs(&self) -> Vec { + let state = self.persistent.read(); + state + .iter_by_prefix(keys::CERT_PREFIX) + .filter_map(|(key, entry)| { + // Only decode config entries (not data/acme/lock/attestation) + if !key.ends_with("/config") { + return None; + } + let value = entry.value.as_ref()?; + match decode(value) { + Ok(config) => Some(config), + Err(e) => { + warn!("failed to decode cert config for key {key}: {e:?}"); + None + } + } + }) + .collect() + } + + /// Watch for ZT-Domain config changes + pub fn watch_zt_domain_configs(&self) -> watch::Receiver<()> { + self.persistent.watch_prefix(keys::CERT_PREFIX) + } + + /// Get the best ZT-Domain config for this node. + /// + /// Selection rules: + /// 1. Only considers domains where node == None or node == my_node_id + /// 2. Higher priority wins + /// 3. If priority is equal, node == None wins (global domains preferred over node-specific) + /// + /// Returns (domain, port) of the best match, or None if no domains configured. + pub fn get_best_zt_domain(&self) -> Option<(String, u16)> { + let my_node_id = self.my_node_id; + let configs = self.list_zt_domain_configs(); + + configs + .into_iter() + .filter(|c| c.node.is_none() || c.node == Some(my_node_id)) + .max_by(|a, b| { + // Compare by priority first (higher wins) + match a.priority.cmp(&b.priority) { + std::cmp::Ordering::Equal => { + // If priority equal, None (global) wins over Some (node-specific) + // None < Some in Option ordering, so we reverse + b.node.cmp(&a.node) + } + other => other, + } + }) + .map(|c| (c.domain, c.port)) + } + + // ==================== Certificate Data ==================== + + /// Get certificate data for a domain + pub fn get_cert_data(&self, domain: &str) -> Option { + self.persistent.read().decode(&keys::cert_data(domain)) + } + + /// Save certificate data for a domain + pub fn save_cert_data(&self, domain: &str, data: &CertData) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::cert_data(domain), data)?; + Ok(()) + } + + /// Load all certificate data (for startup) + pub fn load_all_cert_data(&self) -> BTreeMap { + let state = self.persistent.read(); + state + .iter_by_prefix(keys::CERT_PREFIX) + .filter_map(|(key, entry)| { + // Only decode data entries (not config/acme/lock/attestation) + if !key.ends_with("/data") { + return None; + } + let domain = keys::parse_cert_domain(key)?; + let value = entry.value.as_ref()?; + match decode(value) { + Ok(data) => Some((domain.to_string(), data)), + Err(e) => { + warn!("failed to decode cert data for key {key}: {e:?}"); + None + } + } + }) + .collect() + } + + // ==================== Global ACME Credentials ==================== + + /// Get global ACME credentials (shared across all domains) + pub fn get_acme_credentials(&self) -> Option { + self.persistent.read().decode(keys::GLOBAL_ACME_CREDENTIALS) + } + + /// Save global ACME credentials + pub fn save_acme_credentials(&self, creds: &CertCredentials) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::GLOBAL_ACME_CREDENTIALS.to_string(), creds)?; + Ok(()) + } + + /// Get global ACME attestation (TDX quote of account URI) + pub fn get_acme_attestation(&self) -> Option { + self.persistent.read().decode(keys::GLOBAL_ACME_ATTESTATION) + } + + /// Save global ACME attestation + pub fn save_acme_attestation(&self, attestation: &AcmeAttestation) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::GLOBAL_ACME_ATTESTATION.to_string(), attestation)?; + Ok(()) + } + + // ==================== Certificate Renew Lock ==================== + + /// Get certificate renew lock for a domain + pub fn get_cert_lock(&self, domain: &str) -> Option { + self.persistent.read().decode(&keys::cert_lock(domain)) + } + + /// Try to acquire certificate renew lock + /// Returns true if lock acquired, false if already locked by another node + pub fn try_acquire_cert_lock(&self, domain: &str, lock_timeout_secs: u64) -> bool { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if let Some(existing) = self.get_cert_lock(domain) { + // Check if lock is still valid (not expired) + if now < existing.started_at + lock_timeout_secs { + return false; + } + } + + // Acquire the lock + let lock = CertRenewLock { + started_at: now, + started_by: self.my_node_id, + }; + self.persistent + .write() + .put_encoded(keys::cert_lock(domain), &lock) + .is_ok() + } + + /// Release certificate renew lock + pub fn release_cert_lock(&self, domain: &str) -> Result<()> { + self.persistent.write().delete(keys::cert_lock(domain))?; + Ok(()) + } + + // ==================== Certificate Attestation ==================== + + /// Get the latest attestation for a domain + pub fn get_cert_attestation_latest(&self, domain: &str) -> Option { + self.persistent + .read() + .decode(&keys::cert_attestation_latest(domain)) + } + + /// Save attestation for a domain (saves both latest and history) + pub fn save_cert_attestation(&self, domain: &str, attestation: &CertAttestation) -> Result<()> { + let mut state = self.persistent.write(); + // Save to history + state.put_encoded( + keys::cert_attestation_history(domain, attestation.generated_at), + attestation, + )?; + // Update latest + state.put_encoded(keys::cert_attestation_latest(domain), attestation)?; + Ok(()) + } + + /// List all attestation history for a domain (sorted by timestamp descending) + pub fn list_cert_attestations(&self, domain: &str) -> Vec { + let prefix = keys::cert_attestation_prefix(domain); + let latest_key = keys::cert_attestation_latest(domain); + let state = self.persistent.read(); + let mut attestations: Vec = state + .iter_by_prefix(&prefix) + .filter_map(|(key, entry)| { + // Skip the "latest" entry + if key == &latest_key { + return None; + } + let value = entry.value.as_ref()?; + match decode(value) { + Ok(att) => Some(att), + Err(e) => { + warn!("failed to decode attestation for key {key}: {e:?}"); + None + } + } + }) + .collect(); + // Sort by generated_at descending (newest first) + attestations.sort_by(|a, b| b.generated_at.cmp(&a.generated_at)); + attestations + } + + // ==================== Watch helpers ==================== + + /// Watch for certificate data changes (any domain) + pub fn watch_all_certs(&self) -> watch::Receiver<()> { + self.persistent.watch_prefix(keys::CERT_PREFIX) + } +} diff --git a/gateway/src/kv/sync_service.rs b/gateway/src/kv/sync_service.rs new file mode 100644 index 00000000..f691595a --- /dev/null +++ b/gateway/src/kv/sync_service.rs @@ -0,0 +1,238 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! WaveKV sync service - implements network transport for wavekv synchronization. +//! +//! Peer URLs are stored in the persistent KV store under `__peer_addr/{node_id}` keys. +//! This allows peer addresses to be automatically synced across nodes. + +use std::sync::Arc; + +use anyhow::{Context, Result}; +use dstack_gateway_rpc::GetPeersResponse; +use tracing::{info, warn}; +use wavekv::{ + sync::{ExchangeInterface, SyncConfig as KvSyncConfig, SyncManager, SyncMessage, SyncResponse}, + types::NodeId, + Node, +}; + +use crate::config::SyncConfig as GwSyncConfig; + +use super::https_client::{HttpsClient, HttpsClientConfig}; +use super::KvStore; + +/// HTTP-based network transport for WaveKV sync. +/// Holds a reference to the persistent node for reading peer URLs. +#[derive(Clone)] +pub struct HttpSyncNetwork { + client: HttpsClient, + /// Reference to persistent node for reading peer URLs + kv_store: KvStore, + /// This node's UUID (for node ID reuse detection) + my_uuid: Vec, + /// URL path suffix for this store (e.g., "persistent" or "ephemeral") + store_path: &'static str, +} + +impl HttpSyncNetwork { + pub fn new( + kv_store: KvStore, + store_path: &'static str, + tls_config: &HttpsClientConfig, + ) -> Result { + let client = HttpsClient::new(tls_config)?; + let my_uuid = kv_store + .get_peer_uuid(kv_store.my_node_id) + .context("failed to get my UUID")?; + Ok(Self { + client, + kv_store, + my_uuid, + store_path, + }) + } + + /// Get peer URL from persistent node + fn get_peer_url(&self, peer_id: NodeId) -> Option { + self.kv_store.get_peer_url(peer_id) + } +} + +impl ExchangeInterface for HttpSyncNetwork { + fn uuid(&self) -> Vec { + self.my_uuid.clone() + } + + fn query_uuid(&self, node_id: NodeId) -> Option> { + self.kv_store.get_peer_uuid(node_id) + } + + async fn sync_to(&self, _node: &Node, peer: NodeId, msg: SyncMessage) -> Result { + let url = self + .get_peer_url(peer) + .ok_or_else(|| anyhow::anyhow!("peer {} address not found in DB", peer))?; + + let sync_url = format!( + "{}/wavekv/sync/{}", + url.trim_end_matches('/'), + self.store_path + ); + + // Send request with msgpack + gzip encoding + // app_id verification happens during TLS handshake via AppIdVerifier + let sync_response: SyncResponse = self + .client + .post_compressed_msg(&sync_url, &msg) + .await + .with_context(|| format!("failed to sync to peer {peer} at {sync_url}"))?; + + // Update peer last_seen on successful sync + self.kv_store.update_peer_last_seen(peer); + + Ok(sync_response) + } +} + +/// WaveKV sync service that manages synchronization for both persistent and ephemeral stores +pub struct WaveKvSyncService { + pub persistent_manager: Arc>, + pub ephemeral_manager: Arc>, +} + +impl WaveKvSyncService { + /// Create a new WaveKV sync service + /// + /// # Arguments + /// * `kv_store` - The sync store containing persistent and ephemeral nodes + /// * `sync_config` - Sync configuration + /// * `tls_config` - TLS configuration for mTLS peer authentication + pub fn new( + kv_store: &KvStore, + sync_config: &GwSyncConfig, + tls_config: HttpsClientConfig, + ) -> Result { + let sync_config = KvSyncConfig { + interval: sync_config.interval, + timeout: sync_config.timeout, + }; + + // Both networks use the same persistent node for URL lookup, but different paths + let persistent_network = HttpSyncNetwork::new(kv_store.clone(), "persistent", &tls_config)?; + let ephemeral_network = HttpSyncNetwork::new(kv_store.clone(), "ephemeral", &tls_config)?; + + let persistent_manager = Arc::new(SyncManager::with_config( + kv_store.persistent().clone(), + persistent_network, + sync_config.clone(), + )); + let ephemeral_manager = Arc::new(SyncManager::with_config( + kv_store.ephemeral().clone(), + ephemeral_network, + sync_config, + )); + + Ok(Self { + persistent_manager, + ephemeral_manager, + }) + } + + /// Bootstrap from peers + pub async fn bootstrap(&self) -> Result<()> { + info!("bootstrapping persistent store..."); + if let Err(e) = self.persistent_manager.bootstrap().await { + warn!("failed to bootstrap persistent store: {e}"); + } + + info!("bootstrapping ephemeral store..."); + if let Err(e) = self.ephemeral_manager.bootstrap().await { + warn!("failed to bootstrap ephemeral store: {e}"); + } + + Ok(()) + } + + /// Start background sync tasks + pub async fn start_sync_tasks(&self) { + let persistent = self.persistent_manager.clone(); + let ephemeral = self.ephemeral_manager.clone(); + + tokio::join!(persistent.start_sync_tasks(), ephemeral.start_sync_tasks(),); + + info!("WaveKV sync tasks started"); + } + + /// Handle incoming sync request for persistent store + pub fn handle_persistent_sync(&self, msg: SyncMessage) -> Result { + self.persistent_manager.handle_sync(msg) + } + + /// Handle incoming sync request for ephemeral store + pub fn handle_ephemeral_sync(&self, msg: SyncMessage) -> Result { + self.ephemeral_manager.handle_sync(msg) + } +} + +/// Fetch peer list from bootnode and register them in KvStore. +/// +/// This is called during startup to bootstrap the peer list from a known bootnode. +/// Uses Gateway.GetPeers RPC which requires mTLS gateway authentication. +pub async fn fetch_peers_from_bootnode( + bootnode_url: &str, + kv_store: &KvStore, + my_node_id: NodeId, + tls_config: &HttpsClientConfig, +) -> Result<()> { + if bootnode_url.is_empty() { + info!("no bootnode configured, skipping peer fetch"); + return Ok(()); + } + + info!("fetching peers from bootnode: {}", bootnode_url); + + // Create HTTPS client for bootnode communication (with mTLS) + let client = HttpsClient::new(tls_config).context("failed to create HTTPS client")?; + + // Call Gateway.GetPeers RPC on bootnode (requires mTLS gateway auth) + let peers_url = format!("{}/prpc/GetPeers", bootnode_url.trim_end_matches('/')); + + let response: GetPeersResponse = client + .post_json(&peers_url, &()) + .await + .with_context(|| format!("failed to fetch peers from bootnode {bootnode_url}"))?; + + info!( + "bootnode returned {} peers (bootnode_id={})", + response.peers.len(), + response.my_id + ); + + // Register each peer + for peer in &response.peers { + if peer.id == my_node_id { + continue; // Skip self + } + + // Add peer to WaveKV + if let Err(e) = kv_store.add_peer(peer.id) { + warn!("failed to add peer {}: {}", peer.id, e); + continue; + } + + // Register peer URL + if !peer.url.is_empty() { + if let Err(e) = kv_store.register_peer_url(peer.id, &peer.url) { + warn!("failed to register peer URL for node {}: {}", peer.id, e); + } else { + info!( + "registered peer from bootnode: node {} -> {}", + peer.id, peer.url + ); + } + } + } + + Ok(()) +} diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 5d86e84f..1c6bebba 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -7,7 +7,7 @@ use clap::Parser; use config::{Config, TlsConfig}; use dstack_guest_agent_rpc::{dstack_guest_client::DstackGuestClient, GetTlsKeyArgs}; use http_client::prpc::PrpcClient; -use ra_rpc::{client::RaClient, rocket_helper::QuoteVerifier}; +use ra_rpc::{client::RaClient, prpc_routes as prpc, rocket_helper::QuoteVerifier}; use rocket::{ fairing::AdHoc, figment::{providers::Serialized, Figment}, @@ -15,10 +15,16 @@ use rocket::{ use tracing::info; use admin_service::AdminRpcHandler; -use main_service::{Proxy, RpcHandler}; +use main_service::{Proxy, ProxyOptions, RpcHandler}; + +use crate::debug_service::DebugRpcHandler; mod admin_service; +mod cert_store; mod config; +mod debug_service; +mod distributed_certbot; +mod kv; mod main_service; mod models; mod proxy; @@ -67,7 +73,7 @@ async fn maybe_gen_certs(config: &Config, tls_config: &TlsConfig) -> Result<()> return Ok(()); } - if config.run_in_dstack { + if !config.debug.insecure_skip_attestation { info!("Using dstack guest agent for certificate generation"); let agent_client = dstack_agent().context("Failed to create dstack client")?; let response = agent_client @@ -129,7 +135,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let _ = rustls::crypto::ring::default_provider().install_default(); @@ -138,9 +144,18 @@ async fn main() -> Result<()> { let figment = config::load_config_figment(args.config.as_deref()); let config = figment.focus("core").extract::()?; + + // Validate node_id + if config.sync.enabled && config.sync.node_id == 0 { + anyhow::bail!("node_id must be greater than 0"); + } + config::setup_wireguard(&config.wg)?; - let tls_config = figment.focus("tls").extract::()?; + let tls_config = figment + .focus("tls") + .extract::() + .context("Failed to extract tls config")?; maybe_gen_certs(&config, &tls_config) .await .context("Failed to generate certs")?; @@ -150,40 +165,51 @@ async fn main() -> Result<()> { set_max_ulimit()?; } - let my_app_id = if config.run_in_dstack { + let my_app_id = if config.debug.insecure_skip_attestation { + None + } else { let dstack_client = dstack_agent().context("Failed to create dstack client")?; let info = dstack_client .info() .await .context("Failed to get app info")?; Some(info.app_id) - } else { - None }; let proxy_config = config.proxy.clone(); let pccs_url = config.pccs_url.clone(); let admin_enabled = config.admin.enabled; - let state = main_service::Proxy::new(config, my_app_id).await?; + let debug_config = config.debug.clone(); + let state = Proxy::new(ProxyOptions { + config, + my_app_id, + tls_config, + }) + .await?; info!("Starting background tasks"); state.start_bg_tasks().await?; state.lock().reconfigure()?; proxy::start(proxy_config, state.clone()).context("failed to start the proxy")?; - let admin_figment = - Figment::new() - .merge(rocket::Config::default()) - .merge(Serialized::defaults( - figment - .find_value("core.admin") - .context("admin section not found")?, - )); + let admin_value = figment + .find_value("core.admin") + .context("admin section not found")?; + let debug_value = figment + .find_value("core.debug") + .context("debug section not found")?; + + let admin_figment = Figment::new() + .merge(rocket::Config::default()) + .merge(Serialized::defaults(admin_value)); + + let debug_figment = Figment::new() + .merge(rocket::Config::default()) + .merge(Serialized::defaults(debug_value)); let mut rocket = rocket::custom(figment) - .mount( - "/prpc", - ra_rpc::prpc_routes!(Proxy, RpcHandler, trim: "Tproxy."), - ) + .mount("/prpc", prpc!(Proxy, RpcHandler, trim: "Tproxy.")) + // Mount WaveKV sync endpoint (requires mTLS gateway auth) + .mount("/", web_routes::wavekv_sync_routes()) .attach(AdHoc::on_response("Add app version header", |_req, res| { Box::pin(async move { res.set_raw_header("X-App-Version", app_version()); @@ -193,12 +219,27 @@ async fn main() -> Result<()> { let verifier = QuoteVerifier::new(pccs_url); rocket = rocket.manage(verifier); let main_srv = rocket.launch(); + let admin_state = state.clone(); + let debug_state = state; let admin_srv = async move { if admin_enabled { rocket::custom(admin_figment) .mount("/", web_routes::routes()) - .mount("/", ra_rpc::prpc_routes!(Proxy, AdminRpcHandler)) - .manage(state) + .mount("/", prpc!(Proxy, AdminRpcHandler, trim: "Admin.")) + .mount("/prpc", prpc!(Proxy, AdminRpcHandler, trim: "Admin.")) + .manage(admin_state) + .launch() + .await + } else { + std::future::pending().await + } + }; + let debug_srv = async move { + if debug_config.insecure_enable_debug_rpc { + rocket::custom(debug_figment) + .mount("/prpc", prpc!(Proxy, DebugRpcHandler, trim: "Debug.")) + .mount("/", web_routes::health_routes()) + .manage(debug_state) .launch() .await } else { @@ -212,6 +253,9 @@ async fn main() -> Result<()> { result = admin_srv => { result.map_err(|err| anyhow!("Failed to start admin server: {err:?}"))?; } + result = debug_srv => { + result.map_err(|err| anyhow!("Failed to start debug server: {err:?}"))?; + } } Ok(()) } diff --git a/gateway/src/main_service.rs b/gateway/src/main_service.rs index e6bc6775..5b440203 100644 --- a/gateway/src/main_service.rs +++ b/gateway/src/main_service.rs @@ -3,29 +3,25 @@ // SPDX-License-Identifier: Apache-2.0 use std::{ - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashSet}, net::Ipv4Addr, ops::Deref, - path::Path, - sync::{Arc, Mutex, MutexGuard, RwLock}, + sync::{Arc, Mutex, MutexGuard}, time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; use anyhow::{bail, Context, Result}; use auth_client::AuthClient; -use certbot::{CertBot, WorkDir}; + +use crate::distributed_certbot::DistributedCertBot; use cmd_lib::run_cmd as cmd; use dstack_gateway_rpc::{ gateway_server::{GatewayRpc, GatewayServer}, - AcmeInfoResponse, GatewayState, GuestAgentConfig, InfoResponse, QuotedPublicKey, - RegisterCvmRequest, RegisterCvmResponse, WireGuardConfig, WireGuardPeer, + AcmeInfoResponse, GatewayNodeInfo, GetPeersResponse, GuestAgentConfig, InfoResponse, PeerInfo, + QuotedPublicKey, RegisterCvmRequest, RegisterCvmResponse, WireGuardConfig, WireGuardPeer, }; -use dstack_guest_agent_rpc::{dstack_guest_client::DstackGuestClient, RawQuoteArgs}; -use fs_err as fs; -use http_client::prpc::PrpcClient; use or_panic::ResultOrPanic; use ra_rpc::{CallContext, RpcCall, VerifiedAttestation}; -use ra_tls::attestation::QuoteContentType; use rand::seq::IteratorRandom; use rinja::Template as _; use safe_write::safe_write; @@ -36,13 +32,16 @@ use tokio_rustls::TlsAcceptor; use tracing::{debug, error, info, warn}; use crate::{ - config::Config, + cert_store::{CertResolver, CertStoreBuilder}, + config::{Config, TlsConfig}, + kv::{ + fetch_peers_from_bootnode, AppIdValidator, HttpsClientConfig, InstanceData, KvStore, + NodeData, NodeStatus, WaveKvSyncService, + }, models::{InstanceInfo, WgConf}, - proxy::{create_acceptor, AddressGroup, AddressInfo}, + proxy::{create_acceptor_with_cert_resolver, AddressGroup, AddressInfo}, }; -mod sync_client; - mod auth_client; #[derive(Clone)] @@ -59,26 +58,24 @@ impl Deref for Proxy { pub struct ProxyInner { pub(crate) config: Arc, - pub(crate) certbot: Option>, + /// Multi-domain certbot (from KvStore DNS credentials and domain configs) + pub(crate) certbot: Arc, my_app_id: Option>, state: Mutex, - notify_state_updated: Notify, + pub(crate) notify_state_updated: Notify, auth_client: AuthClient, - pub(crate) acceptor: RwLock, - pub(crate) h2_acceptor: RwLock, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct GatewayNodeInfo { - pub id: Vec, - pub url: String, - pub wg_peer: WireGuardPeer, - pub last_seen: SystemTime, + pub(crate) acceptor: TlsAcceptor, + pub(crate) h2_acceptor: TlsAcceptor, + /// Certificate resolver for SNI-based resolution (supports atomic updates) + pub(crate) cert_resolver: Arc, + /// WaveKV-based store for persistence (and cross-node sync when enabled) + kv_store: Arc, + /// WaveKV sync service for network synchronization + pub(crate) wavekv_sync: Option>, } #[derive(Debug, Serialize, Deserialize, Default)] pub(crate) struct ProxyStateMut { - pub(crate) nodes: BTreeMap, pub(crate) apps: BTreeMap>, pub(crate) instances: BTreeMap, pub(crate) allocated_addresses: BTreeSet, @@ -89,12 +86,22 @@ pub(crate) struct ProxyStateMut { pub(crate) struct ProxyState { pub(crate) config: Arc, pub(crate) state: ProxyStateMut, + /// Reference to KvStore for syncing changes + kv_store: Arc, +} + +/// Options for creating a Proxy instance +pub struct ProxyOptions { + pub config: Config, + pub my_app_id: Option>, + /// TLS configuration (from Rocket's tls config) + pub tls_config: TlsConfig, } impl Proxy { - pub async fn new(config: Config, my_app_id: Option>) -> Result { + pub async fn new(options: ProxyOptions) -> Result { Ok(Self { - _inner: Arc::new(ProxyInner::new(config, my_app_id).await?), + _inner: Arc::new(ProxyInner::new(options).await?), }) } } @@ -104,57 +111,149 @@ impl ProxyInner { self.state.lock().or_panic("Failed to lock AppState") } - pub async fn new(config: Config, my_app_id: Option>) -> Result { + pub async fn new(options: ProxyOptions) -> Result { + let ProxyOptions { + config, + my_app_id, + tls_config, + } = options; let config = Arc::new(config); - let mut state = fs::metadata(&config.state_path) - .is_ok() - .then(|| load_state(&config.state_path)) - .transpose() - .unwrap_or_else(|err| { - error!("Failed to load state: {err}"); - None - }) - .unwrap_or_default(); - state - .nodes - .retain(|_, info| info.wg_peer.ip != config.wg.ip.to_string()); - state.nodes.insert( - config.wg.public_key.clone(), - GatewayNodeInfo { - id: config.id(), - url: config.sync.my_url.clone(), - wg_peer: WireGuardPeer { - pk: config.wg.public_key.clone(), - ip: config.wg.ip.to_string(), - endpoint: config.wg.endpoint.clone(), - }, - last_seen: SystemTime::now(), - }, + + // Initialize WaveKV store without peers (peers will be added dynamically from bootnode) + let kv_store = Arc::new( + KvStore::new(config.sync.node_id, vec![], &config.sync.data_dir) + .context("failed to initialize WaveKV store")?, + ); + info!( + "WaveKV store initialized: node_id={}, sync_enabled={}", + config.sync.node_id, config.sync.enabled + ); + + // Load state from WaveKV + let instances = kv_store.load_all_instances(); + let nodes = kv_store.load_all_nodes(); + info!( + "Loaded state from WaveKV: {} instances, {} nodes", + instances.len(), + nodes.len() ); + let state = build_state_from_kv_store(instances); + + // Sync this node to KvStore + let node_data = NodeData { + uuid: config.uuid(), + url: config.sync.my_url.clone(), + wg_public_key: config.wg.public_key.clone(), + wg_endpoint: config.wg.endpoint.clone(), + wg_ip: config.wg.ip.to_string(), + }; + if let Err(err) = kv_store.sync_node(config.sync.node_id, &node_data) { + error!("Failed to sync this node to KvStore: {err}"); + } + // Set this node's status to Online + if let Err(err) = kv_store.set_node_status(config.sync.node_id, NodeStatus::Up) { + error!("Failed to set node status: {err}"); + } + // Register this node's sync URL in DB (for peer discovery) + if let Err(err) = kv_store.register_peer_url(config.sync.node_id, &config.sync.my_url) { + error!("Failed to register peer URL: {err}"); + } + + // Build HttpsClientConfig for mTLS communication + let https_config = { + let tls = &tls_config; + let cert_validator = my_app_id + .clone() + .map(|app_id| Arc::new(AppIdValidator::new(app_id)) as _); + HttpsClientConfig { + cert_path: tls.certs.clone(), + key_path: tls.key.clone(), + ca_cert_path: tls.mutual.ca_certs.clone(), + cert_validator, + } + }; + + // Fetch peers from bootnode if configured (only when sync is enabled) + if config.sync.enabled && !config.sync.bootnode.is_empty() { + if let Err(err) = fetch_peers_from_bootnode( + &config.sync.bootnode, + &kv_store, + config.sync.node_id, + &https_config, + ) + .await + { + warn!("Failed to fetch peers from bootnode: {err}"); + } + } + + // Create WaveKV sync service (only if sync is enabled) + let wavekv_sync = if config.sync.enabled { + match WaveKvSyncService::new(&kv_store, &config.sync, https_config) { + Ok(sync_service) => Some(Arc::new(sync_service)), + Err(err) => { + error!("Failed to create WaveKV sync service: {err}"); + None + } + } + } else { + None + }; + let state = Mutex::new(ProxyState { config: config.clone(), state, + kv_store: kv_store.clone(), }); let auth_client = AuthClient::new(config.auth.clone()); - let certbot = match config.certbot.enabled { - true => { - let certbot = config - .certbot - .build_bot() - .await - .context("Failed to build certbot")?; - info!("Certbot built, renewing..."); - // Try first renewal for the acceptor creation - certbot.renew(false).await.context("Failed to renew cert")?; - Some(Arc::new(certbot)) + // Bootstrap WaveKV first if sync is enabled, so certbot can load certs from peers + if let Some(ref wavekv_sync) = wavekv_sync { + info!("WaveKV: bootstrapping from peers..."); + if let Err(err) = wavekv_sync.bootstrap().await { + warn!("WaveKV bootstrap failed: {err}"); } - false => None, - }; - let acceptor = RwLock::new( - create_acceptor(&config.proxy, false).context("Failed to create acceptor")?, + } + + // Create CertResolver and load certificates from KvStore + let cert_resolver = Arc::new(CertResolver::new()); + let all_cert_data = kv_store.load_all_cert_data(); + if !all_cert_data.is_empty() { + let mut builder = CertStoreBuilder::new(); + for (domain, data) in &all_cert_data { + if let Err(err) = builder.add_cert(domain, data) { + warn!("failed to load certificate for {}: {}", domain, err); + } + } + cert_resolver.set(Arc::new(builder.build())); + info!( + "CertStore: loaded {} certificates from KvStore", + all_cert_data.len() + ); + } + + // Create multi-domain certbot (uses KvStore configs for DNS credentials and domains) + let certbot = Arc::new(DistributedCertBot::new( + kv_store.clone(), + cert_resolver.clone(), + )); + // Initialize any configured domains + if let Err(err) = certbot.init_all().await { + warn!("Failed to initialize multi-domain certbot: {}", err); + } + + // Create TLS acceptors with CertResolver for SNI-based resolution + // CertResolver allows atomic certificate updates without recreating acceptors + info!( + "CertResolver initialized with {} domains", + cert_resolver.list_domains().len() ); + let acceptor = + create_acceptor_with_cert_resolver(&config.proxy, cert_resolver.clone(), false) + .context("failed to create acceptor with cert resolver")?; let h2_acceptor = - RwLock::new(create_acceptor(&config.proxy, true).context("Failed to create acceptor")?); + create_acceptor_with_cert_resolver(&config.proxy, cert_resolver.clone(), true) + .context("failed to create h2 acceptor with cert resolver")?; + Ok(Self { config, state, @@ -163,98 +262,200 @@ impl ProxyInner { auth_client, acceptor, h2_acceptor, + cert_resolver, certbot, + kv_store, + wavekv_sync, }) } + + pub(crate) fn kv_store(&self) -> &Arc { + &self.kv_store + } + + pub(crate) fn my_app_id(&self) -> Option<&[u8]> { + self.my_app_id.as_deref() + } } impl Proxy { pub(crate) async fn start_bg_tasks(&self) -> Result<()> { start_recycle_thread(self.clone()); - start_sync_task(self.clone()); - start_certbot_task(self.clone()).await?; + // Start WaveKV periodic sync (bootstrap already done in new()) + if let Some(ref wavekv_sync) = self.wavekv_sync { + start_wavekv_sync_task(self.clone(), wavekv_sync.clone()).await; + } + start_wavekv_watch_task(self.clone()).context("Failed to start WaveKV watch task")?; + start_certbot_task(self.clone()).await; + start_cert_store_watch_task(self.clone()); + start_zt_domain_watch_task(self.clone()); Ok(()) } - pub(crate) async fn renew_cert(&self, force: bool) -> Result { - let Some(certbot) = &self.certbot else { - return Ok(false); - }; - let renewed = certbot.renew(force).await.context("Failed to renew cert")?; - if renewed { - self.reload_certificates() - .context("Failed to reload certificates")?; - } - Ok(renewed) - } - - pub(crate) async fn acme_info(&self) -> Result { - let config = self.lock().config.clone(); - let workdir = WorkDir::new(&config.certbot.workdir); - let account_uri = workdir.acme_account_uri().unwrap_or_default(); - let keys = workdir.list_cert_public_keys().unwrap_or_default(); - let agent = crate::dstack_agent().context("Failed to get dstack agent")?; - let account_quote = get_or_generate_quote( - &agent, - QuoteContentType::Custom("acme-account"), - account_uri.as_bytes(), - workdir.acme_account_quote_path(), - ) - .await - .unwrap_or_default(); - let account_attestation = get_or_generate_attestation( - &agent, - QuoteContentType::Custom("acme-account"), - account_uri.as_bytes(), - workdir.acme_account_quote_path(), - ) - .await - .unwrap_or_default(); + /// Reload all certificates from KvStore into CertStore (atomic replacement) + pub(crate) fn reload_all_certs_from_kvstore(&self) -> Result<()> { + let all_cert_data = self.kv_store.load_all_cert_data(); + + // Build new CertStore from scratch + let mut builder = CertStoreBuilder::new(); + let mut loaded = 0; + for (domain, data) in &all_cert_data { + if let Err(err) = builder.add_cert(domain, data) { + warn!("failed to reload certificate for {}: {}", domain, err); + } else { + loaded += 1; + } + } + + // Atomically replace the CertStore (no need to recreate acceptors) + self.cert_resolver.set(Arc::new(builder.build())); + info!("CertStore: reloaded {} certificates from KvStore", loaded); + Ok(()) + } + + /// Renew a specific domain certificate or all domains + pub(crate) async fn renew_cert(&self, domain: Option<&str>, force: bool) -> Result { + match domain { + Some(domain) => self + .certbot + .try_renew(domain, force) + .await + .context("failed to renew cert"), + None => { + // Renew all domains + self.certbot + .try_renew_all() + .await + .context("failed to renew all certs")?; + Ok(true) + } + } + } + + /// Get ACME info for all managed domains (or a specific domain) + pub(crate) fn acme_info(&self, domain: Option<&str>) -> Result { + let kv_store = self.kv_store.clone(); let mut quoted_hist_keys = vec![]; - for cert_path in workdir.list_certs().unwrap_or_default() { - let cert_pem = fs::read_to_string(&cert_path).context("Failed to read key")?; - let pubkey = certbot::read_pubkey(&cert_pem).context("Failed to read pubkey")?; - let quote = get_or_generate_quote( - &agent, - QuoteContentType::Custom("zt-cert"), - &pubkey, - cert_path.display().to_string() + ".quote", - ) - .await - .unwrap_or_default(); - let attestation = get_or_generate_attestation( - &agent, - QuoteContentType::Custom("zt-cert"), - &pubkey, - cert_path.display().to_string() + ".quote", - ) - .await + + // Get domains to query + let domains: Vec = match domain { + Some(d) => vec![d.to_string()], + None => kv_store + .list_zt_domain_configs() + .into_iter() + .map(|c| c.domain) + .collect(), + }; + + // Get account_uri, account_quote and account_attestation from global ACME attestation + let (account_uri, account_quote, account_attestation) = kv_store + .get_acme_attestation() + .map(|att| (att.account_uri, att.quote, att.attestation)) .unwrap_or_default(); - quoted_hist_keys.push(QuotedPublicKey { - public_key: pubkey, - quote, - attestation, - }); - } - let active_cert = - fs::read_to_string(workdir.cert_path()).context("Failed to read active cert")?; + for domain in &domains { + // Get all attestations for this domain + let attestations = kv_store.list_cert_attestations(domain); + for att in attestations { + quoted_hist_keys.push(QuotedPublicKey { + public_key: att.public_key, + quote: att.quote, + attestation: att.attestation, + }); + } + } Ok(AcmeInfoResponse { account_uri, - hist_keys: keys.into_iter().collect(), account_quote, account_attestation, quoted_hist_keys, - active_cert, - base_domain: config.proxy.base_domain.clone(), }) } + + /// Register a CVM with the given app_id, instance_id and client_public_key + pub fn do_register_cvm( + &self, + app_id: &str, + instance_id: &str, + client_public_key: &str, + ) -> Result { + let mut state = self.lock(); + + // Check if this node is marked as down + let my_status = state.kv_store.get_node_status(state.config.sync.node_id); + if matches!(my_status, NodeStatus::Down) { + bail!("this gateway node is marked as down and cannot accept new registrations"); + } + + if app_id.is_empty() { + bail!("[{instance_id}] app id is empty"); + } + if instance_id.is_empty() { + bail!("[{instance_id}] instance id is empty"); + } + if client_public_key.is_empty() { + bail!("[{instance_id}] client public key is empty"); + } + let client_info = state + .new_client_by_id(instance_id, app_id, client_public_key) + .context("failed to allocate IP address for client")?; + if let Err(err) = state.reconfigure() { + error!("failed to reconfigure: {}", err); + } + let gateways = state.get_active_nodes(); + let servers = gateways + .iter() + .map(|n| WireGuardPeer { + pk: n.wg_public_key.clone(), + ip: n.wg_ip.clone(), + endpoint: n.wg_endpoint.clone(), + }) + .collect::>(); + let (base_domain, port) = state.kv_store.get_best_zt_domain().unwrap_or_default(); + let response = RegisterCvmResponse { + wg: Some(WireGuardConfig { + client_ip: client_info.ip.to_string(), + servers, + }), + agent: Some(GuestAgentConfig { + external_port: port.into(), + internal_port: 8090, + domain: base_domain, + app_address_ns_prefix: state.config.proxy.app_address_ns_prefix.clone(), + }), + gateways, + }; + self.notify_state_updated.notify_one(); + Ok(response) + } } -fn load_state(state_path: &str) -> Result { - let state_str = fs::read_to_string(state_path).context("Failed to read state")?; - serde_json::from_str(&state_str).context("Failed to load state") +fn build_state_from_kv_store(instances: BTreeMap) -> ProxyStateMut { + let mut state = ProxyStateMut::default(); + + // Build instances + for (instance_id, data) in instances { + let info = InstanceInfo { + id: instance_id.clone(), + app_id: data.app_id.clone(), + ip: data.ip, + public_key: data.public_key, + reg_time: UNIX_EPOCH + .checked_add(Duration::from_secs(data.reg_time)) + .unwrap_or(UNIX_EPOCH), + connections: Default::default(), + }; + state.allocated_addresses.insert(data.ip); + state + .apps + .entry(data.app_id) + .or_default() + .insert(instance_id.clone()); + state.instances.insert(instance_id, info); + } + + state } fn start_recycle_thread(proxy: Proxy) { @@ -270,33 +471,260 @@ fn start_recycle_thread(proxy: Proxy) { }); } -async fn start_certbot_task(proxy: Proxy) -> Result<()> { - let Some(certbot) = proxy.certbot.clone() else { - info!("Certbot is not enabled"); - return Ok(()); - }; +/// Start periodic certificate renewal task for multi-domain certbot +async fn start_certbot_task(proxy: Proxy) { + info!("starting certificate renewal task"); + + // Periodic renewal task for all domains tokio::spawn(async move { + // Run once at startup to check for any pending renewals + info!("running initial certificate renewal check"); + if let Err(err) = proxy.renew_cert(None, false).await { + error!("failed initial certificate renewal: {err}"); + } + loop { - tokio::time::sleep(certbot.renew_interval()).await; - if let Err(err) = proxy.renew_cert(false).await { - error!("Failed to renew cert: {err}"); + // Get current config from KV store (allows dynamic updates) + let renew_interval = proxy.kv_store.get_certbot_config().renew_interval; + if renew_interval.is_zero() { + // Check again later if disabled + tokio::time::sleep(Duration::from_secs(60)).await; + continue; + } + + // Wait for the interval + tokio::time::sleep(renew_interval).await; + + // Renew certificates + if let Err(err) = proxy.renew_cert(None, false).await { + error!("failed to renew certificates: {err}"); } } }); - Ok(()) } -fn start_sync_task(proxy: Proxy) { +/// Watch for certificate changes from KvStore and update CertStore +fn start_cert_store_watch_task(proxy: Proxy) { + let kv_store = proxy.kv_store.clone(); + + // Watch for any certificate changes (all domains) + let mut rx = kv_store.watch_all_certs(); + tokio::spawn(async move { + loop { + if rx.changed().await.is_err() { + break; + } + info!("WaveKV: detected certificate changes, reloading CertStore..."); + if let Err(err) = proxy.reload_all_certs_from_kvstore() { + error!("Failed to reload certificates from KvStore: {err}"); + } + } + }); + info!("CertStore watch task started"); +} + +/// Watch for ZT-Domain config changes and auto-renew certificates +fn start_zt_domain_watch_task(proxy: Proxy) { + let kv_store = proxy.kv_store.clone(); + let certbot = proxy.certbot.clone(); + + let mut rx = kv_store.watch_zt_domain_configs(); + tokio::spawn(async move { + // Track known domains to detect additions + let mut known_domains = kv_store + .list_zt_domain_configs() + .into_iter() + .map(|c| c.domain) + .collect::>(); + + loop { + if rx.changed().await.is_err() { + break; + } + + // Get current domains + let current_domains: HashSet = kv_store + .list_zt_domain_configs() + .into_iter() + .map(|c| c.domain) + .collect(); + + // Find newly added domains + let new_domains: Vec = current_domains + .iter() + .filter(|d| !known_domains.contains(*d)) + .cloned() + .collect(); + + // Update known domains + known_domains = current_domains; + + // Trigger renewal for new domains + for domain in new_domains { + info!("ZT-Domain added: {domain}, attempting certificate request..."); + let certbot = certbot.clone(); + tokio::spawn(async move { + match certbot.try_renew(&domain, false).await { + Ok(renewed) => { + if renewed { + info!("cert[{domain}]: successfully issued/renewed"); + } else { + info!("cert[{domain}]: renewal not needed or another node is handling it"); + } + } + Err(e) => { + warn!("cert[{domain}]: auto-renewal failed: {e}"); + } + } + }); + } + } + }); + info!("ZT-Domain watch task started"); +} + +async fn start_wavekv_sync_task(proxy: Proxy, wavekv_sync: Arc) { if !proxy.config.sync.enabled { - info!("sync is disabled"); + info!("WaveKV sync is disabled"); return; } + + // Bootstrap already done in ProxyInner::new() before certbot init + // Peers are discovered from bootnode or via Admin.SetNodeInfo RPC + + // Start periodic sync tasks (runs forever in background) + tokio::spawn(async move { + wavekv_sync.start_sync_tasks().await; + }); + info!("WaveKV sync tasks started"); +} + +fn start_wavekv_watch_task(proxy: Proxy) -> Result<()> { + let kv_store = proxy.kv_store.clone(); + + // Watch for instance changes + let proxy_clone = proxy.clone(); + let store_clone = kv_store.clone(); + // Register watcher first, then do initial load to avoid race condition + let mut rx = store_clone.watch_instances(); + reload_instances_from_kv_store(&proxy_clone, &store_clone) + .context("Failed to initial load instances from KvStore")?; + tokio::spawn(async move { + loop { + if rx.changed().await.is_err() { + break; + } + info!("WaveKV: detected remote instance changes, reloading..."); + if let Err(err) = reload_instances_from_kv_store(&proxy_clone, &store_clone) { + error!("Failed to reload instances from KvStore: {err}"); + } + } + }); + + // Initial WireGuard configuration + proxy.lock().reconfigure()?; + + // Watch for node changes and reconfigure WireGuard + let mut rx = kv_store.watch_nodes(); + let proxy_for_nodes = proxy.clone(); tokio::spawn(async move { - match sync_client::sync_task(proxy).await { - Ok(_) => info!("Sync task exited"), - Err(err) => error!("Failed to run sync task: {err}"), + loop { + if rx.changed().await.is_err() { + break; + } + info!("WaveKV: detected remote node changes, reconfiguring WireGuard..."); + if let Err(err) = proxy_for_nodes.lock().reconfigure() { + error!("Failed to reconfigure WireGuard: {err}"); + } } }); + + // Start periodic persistence task + let persist_interval = proxy.config.sync.persist_interval; + if !persist_interval.is_zero() { + let kv_store_for_persist = kv_store.clone(); + tokio::spawn(async move { + let mut ticker = tokio::time::interval(persist_interval); + loop { + ticker.tick().await; + match kv_store_for_persist.persist_if_dirty() { + Ok(true) => info!("WaveKV: periodic persist completed"), + Ok(false) => {} // No changes to persist + Err(err) => error!("WaveKV: periodic persist failed: {err}"), + } + } + }); + info!("WaveKV: periodic persistence enabled (interval: {persist_interval:?})"); + } + + // Start periodic connection sync task + if proxy.config.sync.sync_connections_enabled { + let sync_interval = proxy.config.sync.sync_connections_interval; + let proxy_for_sync = proxy.clone(); + tokio::spawn(async move { + let mut ticker = tokio::time::interval(sync_interval); + loop { + ticker.tick().await; + let state = proxy_for_sync.lock(); + for (instance_id, instance) in &state.state.instances { + let count = instance.num_connections(); + state.sync_connections(instance_id, count); + } + } + }); + info!( + "WaveKV: periodic connection sync enabled (interval: {:?})", + proxy.config.sync.sync_connections_interval + ); + } + + Ok(()) +} + +fn reload_instances_from_kv_store(proxy: &Proxy, store: &KvStore) -> Result<()> { + let instances = store.load_all_instances(); + let mut state = proxy.lock(); + let mut wg_changed = false; + + for (instance_id, data) in instances { + let new_info = InstanceInfo { + id: instance_id.clone(), + app_id: data.app_id.clone(), + ip: data.ip, + public_key: data.public_key.clone(), + reg_time: UNIX_EPOCH + .checked_add(Duration::from_secs(data.reg_time)) + .unwrap_or(UNIX_EPOCH), + connections: Default::default(), + }; + + if let Some(existing) = state.state.instances.get(&instance_id) { + // Check if wg config needs update + if existing.public_key != data.public_key || existing.ip != data.ip { + wg_changed = true; + } + // Only update if remote is newer (based on reg_time) + if data.reg_time <= encode_ts(existing.reg_time) { + continue; + } + } else { + wg_changed = true; + } + + state.state.allocated_addresses.insert(data.ip); + state + .state + .apps + .entry(data.app_id) + .or_default() + .insert(instance_id.clone()); + state.state.instances.insert(instance_id, new_info); + } + + if wg_changed { + state.reconfigure()?; + } + Ok(()) } impl ProxyState { @@ -348,6 +776,16 @@ impl ProxyState { } let existing = existing.clone(); if self.valid_ip(existing.ip) { + // Sync existing instance to KvStore (might be from legacy state) + let data = InstanceData { + app_id: existing.app_id.clone(), + ip: existing.ip, + public_key: existing.public_key.clone(), + reg_time: encode_ts(existing.reg_time), + }; + if let Err(err) = self.kv_store.sync_instance(&existing.id, &data) { + error!("failed to sync existing instance to KvStore: {err}"); + } return Some(existing); } info!("ip {} is invalid, removing", existing.ip); @@ -360,7 +798,6 @@ impl ProxyState { ip, public_key: public_key.to_string(), reg_time: SystemTime::now(), - last_seen: SystemTime::now(), connections: Default::default(), }; self.add_instance(host_info.clone()); @@ -368,6 +805,17 @@ impl ProxyState { } fn add_instance(&mut self, info: InstanceInfo) { + // Sync to KvStore + let data = InstanceData { + app_id: info.app_id.clone(), + ip: info.ip, + public_key: info.public_key.clone(), + reg_time: encode_ts(info.reg_time), + }; + if let Err(err) = self.kv_store.sync_instance(&info.id, &data) { + error!("failed to sync instance to KvStore: {err}"); + } + self.state .apps .entry(info.app_id.clone()) @@ -396,13 +844,6 @@ impl ProxyState { Ok(_) => info!("wg config updated"), Err(e) => error!("failed to set wg config: {e}"), } - self.save_state()?; - Ok(()) - } - - fn save_state(&self) -> Result<()> { - let state_str = serde_json::to_string(&self.state).context("Failed to serialize state")?; - safe_write(&self.config.state_path, state_str).context("Failed to write state")?; Ok(()) } @@ -549,6 +990,12 @@ impl ProxyState { .instances .remove(id) .context("instance not found")?; + + // Sync deletion to KvStore + if let Err(err) = self.kv_store.sync_delete_instance(id) { + error!("Failed to sync instance deletion to KvStore: {err}"); + } + self.state.allocated_addresses.remove(&info.ip); if let Some(app_instances) = self.state.apps.get_mut(&info.app_id) { app_instances.remove(id); @@ -560,48 +1007,50 @@ impl ProxyState { } fn recycle(&mut self) -> Result<()> { - // Recycle stale Gateway nodes - let mut staled_nodes = vec![]; - for node in self.state.nodes.values() { - if node.wg_peer.pk == self.config.wg.public_key { - continue; - } - if node.last_seen.elapsed().unwrap_or_default() > self.config.recycle.node_timeout { - staled_nodes.push(node.wg_peer.pk.clone()); - } - } - for id in staled_nodes { - self.state.nodes.remove(&id); + // Refresh state: sync local handshakes to KvStore, update local last_seen from global + if let Err(err) = self.refresh_state() { + warn!("failed to refresh state: {err}"); } - // Recycle stale CVM instances + // Note: Gateway nodes are not removed from KvStore, only marked offline/retired + + // Recycle stale CVM instances based on global last_seen (max across all nodes) let stale_timeout = self.config.recycle.timeout; - let stale_handshakes = self.latest_handshakes(Some(stale_timeout))?; - if tracing::enabled!(tracing::Level::DEBUG) { - for (pubkey, (ts, elapsed)) in &stale_handshakes { - debug!("stale instance: {pubkey} recent={ts} ({elapsed:?} ago)"); - } - } - // Find and remove instances with matching public keys + let now = SystemTime::now(); + let stale_instances: Vec<_> = self .state .instances .iter() - .filter(|(_, info)| { - stale_handshakes.contains_key(&info.public_key) && { - info.reg_time.elapsed().unwrap_or_default() > stale_timeout + .filter(|(id, info)| { + // Skip if instance was registered recently + if info.reg_time.elapsed().unwrap_or_default() <= stale_timeout { + return false; + } + // Check global last_seen from KvStore (max across all nodes) + let global_ts = self.kv_store.get_instance_latest_handshake(id); + let last_seen = global_ts.map(decode_ts).unwrap_or(info.reg_time); + let elapsed = now.duration_since(last_seen).unwrap_or_default(); + if elapsed > stale_timeout { + debug!( + "stale instance: {} last_seen={:?} ({:?} ago)", + id, last_seen, elapsed + ); + true + } else { + false } }) - .map(|(id, _info)| id.clone()) + .map(|(id, _)| id.clone()) .collect(); - debug!("stale instances: {:#?}", stale_instances); + let num_recycled = stale_instances.len(); for id in stale_instances { self.remove_instance(&id)?; } - info!("recycled {num_recycled} stale instances"); - // Reconfigure WireGuard with updated peers + if num_recycled > 0 { + info!("recycled {num_recycled} stale instances"); self.reconfigure()?; } Ok(()) @@ -611,89 +1060,94 @@ impl ProxyState { std::process::exit(0); } - fn dedup_nodes(&mut self) { - // Dedup nodes by URL, keeping the latest one - let mut node_map = BTreeMap::::new(); + pub(crate) fn refresh_state(&mut self) -> Result<()> { + // Get local WG handshakes and sync to KvStore + let handshakes = self.latest_handshakes(None)?; + + // Build a map from public_key to instance_id for lookup + let pk_to_id: BTreeMap<&str, &str> = self + .state + .instances + .iter() + .map(|(id, info)| (info.public_key.as_str(), id.as_str())) + .collect(); - for node in std::mem::take(&mut self.state.nodes).into_values() { - match node_map.get(&node.wg_peer.endpoint) { - Some(existing) if existing.last_seen >= node.last_seen => {} - _ => { - node_map.insert(node.wg_peer.endpoint.clone(), node); + // Sync local handshake observations to KvStore + for (pk, (ts, _)) in &handshakes { + if let Some(&instance_id) = pk_to_id.get(pk.as_str()) { + if let Err(err) = self.kv_store.sync_instance_handshake(instance_id, *ts) { + debug!("failed to sync instance handshake: {err}"); } } } - for node in node_map.into_values() { - self.state.nodes.insert(node.wg_peer.pk.clone(), node); + + // Update this node's last_seen in KvStore + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + if let Err(err) = self + .kv_store + .sync_node_last_seen(self.config.sync.node_id, now) + { + debug!("failed to sync node last_seen: {err}"); } + Ok(()) } - fn update_state( - &mut self, - proxy_nodes: Vec, - apps: Vec, - ) -> Result<()> { - for node in proxy_nodes { - if node.wg_peer.pk == self.config.wg.public_key { - continue; - } - if node.url == self.config.sync.my_url { - continue; - } - if let Some(existing) = self.state.nodes.get(&node.wg_peer.pk) { - if node.last_seen <= existing.last_seen { - continue; - } - } - self.state.nodes.insert(node.wg_peer.pk.clone(), node); + /// Sync connection count for an instance to KvStore + pub(crate) fn sync_connections(&self, instance_id: &str, count: u64) { + if let Err(err) = self.kv_store.sync_connections(instance_id, count) { + debug!("Failed to sync connections: {err}"); } - self.dedup_nodes(); + } - let mut wg_changed = false; - for app in apps { - if let Some(existing) = self.state.instances.get(&app.id) { - let existing_ts = (existing.reg_time, existing.last_seen); - let update_ts = (app.reg_time, app.last_seen); - if update_ts <= existing_ts { - continue; - } - if !wg_changed { - wg_changed = existing.public_key != app.public_key || existing.ip != app.ip; - } - } else { - wg_changed = true; - } - self.add_instance(app); - } - info!("updated, wg_changed: {wg_changed}"); - if wg_changed { - self.reconfigure()?; - } else { - self.save_state()?; - } - Ok(()) + /// Get latest handshake for an instance from KvStore (max across all nodes) + pub(crate) fn get_instance_latest_handshake(&self, instance_id: &str) -> Option { + self.kv_store.get_instance_latest_handshake(instance_id) } - fn dump_state(&mut self) -> (Vec, Vec) { - self.refresh_state().ok(); - ( - self.state.nodes.values().cloned().collect(), - self.state.instances.values().cloned().collect(), - ) + /// Get all nodes from KvStore (for admin API - includes all nodes) + pub(crate) fn get_all_nodes(&self) -> Vec { + self.get_all_nodes_filtered(false) } - pub(crate) fn refresh_state(&mut self) -> Result<()> { - let handshakes = self.latest_handshakes(None)?; - for instance in self.state.instances.values_mut() { - let Some((ts, _)) = handshakes.get(&instance.public_key).copied() else { - continue; - }; - instance.last_seen = decode_ts(ts); - } - if let Some(node) = self.state.nodes.get_mut(&self.config.wg.public_key) { - node.last_seen = SystemTime::now(); - } - Ok(()) + /// Get nodes for CVM registration (excludes nodes with status "down") + pub(crate) fn get_active_nodes(&self) -> Vec { + self.get_all_nodes_filtered(true) + } + + /// Get all nodes from KvStore with optional filtering + fn get_all_nodes_filtered(&self, exclude_down: bool) -> Vec { + let node_statuses = if exclude_down { + self.kv_store.load_all_node_statuses() + } else { + Default::default() + }; + + self.kv_store + .load_all_nodes() + .into_iter() + .filter(|(id, _)| { + if !exclude_down { + return true; + } + // Exclude nodes with status "down" + match node_statuses.get(id) { + Some(NodeStatus::Down) => false, + _ => true, // Include Up or nodes without explicit status + } + }) + .map(|(id, node)| GatewayNodeInfo { + id, + uuid: node.uuid, + wg_public_key: node.wg_public_key, + wg_ip: node.wg_ip, + wg_endpoint: node.wg_endpoint, + url: node.url, + last_seen: self.kv_store.get_node_latest_last_seen(id).unwrap_or(0), + }) + .collect() } } @@ -715,7 +1169,7 @@ pub struct RpcHandler { impl RpcHandler { fn ensure_from_gateway(&self) -> Result<()> { - if !self.state.config.run_in_dstack { + if self.state.config.debug.insecure_skip_attestation { return Ok(()); } if self.remote_app_id.is_none() { @@ -743,124 +1197,44 @@ impl GatewayRpc for RpcHandler { .context("App authorization failed")?; let app_id = hex::encode(&app_info.app_id); let instance_id = hex::encode(&app_info.instance_id); - - let mut state = self.state.lock(); - if request.client_public_key.is_empty() { - bail!("[{instance_id}] client public key is empty"); - } - let client_info = state - .new_client_by_id(&instance_id, &app_id, &request.client_public_key) - .context("failed to allocate IP address for client")?; - if let Err(err) = state.reconfigure() { - error!("failed to reconfigure: {}", err); - } - let servers = state - .state - .nodes - .values() - .map(|n| n.wg_peer.clone()) - .collect::>(); - let response = RegisterCvmResponse { - wg: Some(WireGuardConfig { - client_ip: client_info.ip.to_string(), - servers, - }), - agent: Some(GuestAgentConfig { - external_port: state.config.proxy.external_port as u32, - internal_port: state.config.proxy.agent_port as u32, - domain: state.config.proxy.base_domain.clone(), - app_address_ns_prefix: state.config.proxy.app_address_ns_prefix.clone(), - }), - }; - self.state.notify_state_updated.notify_one(); - Ok(response) + self.state + .do_register_cvm(&app_id, &instance_id, &request.client_public_key) } async fn acme_info(self) -> Result { - self.state.acme_info().await - } - - async fn update_state(self, request: GatewayState) -> Result<()> { - self.ensure_from_gateway()?; - let mut nodes = vec![]; - let mut apps = vec![]; - - for node in request.nodes { - nodes.push(GatewayNodeInfo { - id: node.id, - wg_peer: node.wg_peer.context("wg_peer is missing")?, - last_seen: decode_ts(node.last_seen), - url: node.url, - }); - } - - for app in request.apps { - apps.push(InstanceInfo { - id: app.instance_id, - app_id: app.app_id, - ip: app.ip.parse().context("Invalid IP address")?, - public_key: app.public_key, - reg_time: decode_ts(app.reg_time), - last_seen: decode_ts(app.last_seen), - connections: Default::default(), - }); - } - - self.state - .lock() - .update_state(nodes, apps) - .context("failed to update state")?; - Ok(()) + self.state.acme_info(None) } async fn info(self) -> Result { let state = self.state.lock(); + let (base_domain, port) = state.kv_store.get_best_zt_domain().unwrap_or_default(); Ok(InfoResponse { - base_domain: state.config.proxy.base_domain.clone(), - external_port: state.config.proxy.external_port as u32, + base_domain, + external_port: port.into(), app_address_ns_prefix: state.config.proxy.app_address_ns_prefix.clone(), }) } -} -async fn get_or_generate_quote( - agent: &DstackGuestClient, - content_type: QuoteContentType<'_>, - payload: &[u8], - quote_path: impl AsRef, -) -> Result { - let quote_path = quote_path.as_ref(); - if fs::metadata(quote_path).is_ok() { - return fs::read_to_string(quote_path).context("Failed to read quote"); - } - let report_data = content_type.to_report_data(payload).to_vec(); - let response = agent - .get_quote(RawQuoteArgs { report_data }) - .await - .context("Failed to get quote")?; - let quote = serde_json::to_string(&response).context("Failed to serialize quote")?; - safe_write(quote_path, "e).context("Failed to write quote")?; - Ok(quote) -} + async fn get_peers(self) -> Result { + self.ensure_from_gateway()?; + + let kv_store = self.state.kv_store(); + let config = &self.state.config; -async fn get_or_generate_attestation( - agent: &DstackGuestClient, - content_type: QuoteContentType<'_>, - payload: &[u8], - quote_path: impl AsRef, -) -> Result { - let quote_path = quote_path.as_ref(); - if fs::metadata(quote_path).is_ok() { - return fs::read_to_string(quote_path).context("Failed to read quote"); - } - let report_data = content_type.to_report_data(payload).to_vec(); - let response = agent - .attest(RawQuoteArgs { report_data }) - .await - .context("Failed to get quote")?; - let attestation = serde_json::to_string(&response).context("Failed to serialize quote")?; - safe_write(quote_path, &attestation).context("Failed to write quote")?; - Ok(attestation) + // Get all peer addresses from KvStore + let peer_addrs = kv_store.get_all_peer_addrs(); + + let peers: Vec = peer_addrs + .into_iter() + .map(|(id, url)| PeerInfo { id, url }) + .collect(); + + Ok(GetPeersResponse { + my_id: config.sync.node_id, + my_url: config.sync.my_url.clone(), + peers, + }) + } } impl RpcCall for RpcHandler { @@ -875,30 +1249,5 @@ impl RpcCall for RpcHandler { } } -impl From for dstack_gateway_rpc::GatewayNodeInfo { - fn from(node: GatewayNodeInfo) -> Self { - Self { - id: node.id, - wg_peer: Some(node.wg_peer), - last_seen: encode_ts(node.last_seen), - url: node.url, - } - } -} - -impl From for dstack_gateway_rpc::AppInstanceInfo { - fn from(app: InstanceInfo) -> Self { - Self { - num_connections: app.num_connections(), - instance_id: app.id, - app_id: app.app_id, - ip: app.ip.to_string(), - public_key: app.public_key, - reg_time: encode_ts(app.reg_time), - last_seen: encode_ts(app.last_seen), - } - } -} - #[cfg(test)] mod tests; diff --git a/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config-2.snap b/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config-2.snap index 67b4180d..f211b458 100644 --- a/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config-2.snap +++ b/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config-2.snap @@ -1,6 +1,6 @@ --- source: gateway/src/main_service/tests.rs -assertion_line: 36 +assertion_line: 71 expression: info1 --- InstanceInfo { @@ -12,9 +12,5 @@ InstanceInfo { tv_sec: 0, tv_nsec: 0, }, - last_seen: SystemTime { - tv_sec: 0, - tv_nsec: 0, - }, connections: 0, } diff --git a/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config.snap b/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config.snap index ef5978f3..5b07304c 100644 --- a/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config.snap +++ b/gateway/src/main_service/snapshots/dstack_gateway__main_service__tests__config.snap @@ -1,6 +1,6 @@ --- source: gateway/src/main_service/tests.rs -assertion_line: 29 +assertion_line: 65 expression: info --- InstanceInfo { @@ -12,9 +12,5 @@ InstanceInfo { tv_sec: 0, tv_nsec: 0, }, - last_seen: SystemTime { - tv_sec: 0, - tv_nsec: 0, - }, connections: 0, } diff --git a/gateway/src/main_service/sync_client.rs b/gateway/src/main_service/sync_client.rs deleted file mode 100644 index 7feba2a0..00000000 --- a/gateway/src/main_service/sync_client.rs +++ /dev/null @@ -1,183 +0,0 @@ -// SPDX-FileCopyrightText: © 2025 Phala Network -// -// SPDX-License-Identifier: Apache-2.0 - -use std::time::{Duration, Instant}; - -use anyhow::{Context, Result}; -use dstack_gateway_rpc::{gateway_client::GatewayClient, GatewayState}; -use dstack_guest_agent_rpc::GetTlsKeyArgs; -use ra_rpc::client::{RaClient, RaClientConfig}; -use tracing::{error, info}; - -use crate::{dstack_agent, main_service::Proxy}; - -struct SyncClient { - in_dstack: bool, - cert_pem: String, - key_pem: String, - ca_cert_pem: String, - app_id: Vec, - timeout: Duration, - pccs_url: Option, -} - -impl SyncClient { - fn create_rpc_client(&self, url: &str) -> Result> { - let app_id = self.app_id.clone(); - let url = format!("{}/prpc", url.trim_end_matches('/')); - let client = if self.in_dstack { - RaClientConfig::builder() - .remote_uri(url) - // Don't verify server RA because we use the CA cert from KMS to verify - // the server cert. - .verify_server_attestation(false) - .tls_no_check(true) - .tls_no_check_hostname(false) - .tls_client_cert(self.cert_pem.clone()) - .tls_client_key(self.key_pem.clone()) - .tls_ca_cert(self.ca_cert_pem.clone()) - .tls_built_in_root_certs(false) - .maybe_pccs_url(self.pccs_url.clone()) - .cert_validator(Box::new(move |cert| { - let cert = cert.context("TLS cert not found")?; - let remote_app_id = cert.app_id.context("App id not found")?; - if remote_app_id != app_id { - return Err(anyhow::anyhow!("Remote app id mismatch")); - } - Ok(()) - })) - .build() - .into_client() - .context("failed to create client")? - } else { - RaClient::new(url, true)? - }; - Ok(GatewayClient::new(client)) - } - - async fn sync_state(&self, url: &str, state: &GatewayState) -> Result<()> { - info!("Trying to sync state to {url}"); - let rpc = self.create_rpc_client(url)?; - tokio::time::timeout(self.timeout, rpc.update_state(state.clone())) - .await - .ok() - .context("Timeout while syncing state")? - .context("Failed to sync state")?; - info!("Synced state to {url}"); - Ok(()) - } - - async fn sync_state_ignore_error(&self, url: &str, state: &GatewayState) -> bool { - match self.sync_state(url, state).await { - Ok(_) => true, - Err(e) => { - error!("Failed to sync state to {url}: {e:?}"); - false - } - } - } -} - -pub(crate) async fn sync_task(proxy: Proxy) -> Result<()> { - let config = proxy.config.clone(); - let sync_client = if config.run_in_dstack { - let agent = dstack_agent().context("Failed to create dstack agent client")?; - let keys = agent - .get_tls_key(GetTlsKeyArgs { - subject: "dstack-gateway-sync-client".into(), - alt_names: vec![], - usage_ra_tls: false, - usage_server_auth: false, - usage_client_auth: true, - }) - .await - .context("Failed to get sync-client keys")?; - let my_app_id = agent - .info() - .await - .context("Failed to get guest info")? - .app_id; - SyncClient { - in_dstack: true, - cert_pem: keys.certificate_chain.join("\n"), - key_pem: keys.key, - ca_cert_pem: keys.certificate_chain.last().cloned().unwrap_or_default(), - app_id: my_app_id, - timeout: config.sync.timeout, - pccs_url: config.pccs_url.clone(), - } - } else { - SyncClient { - in_dstack: false, - cert_pem: "".into(), - key_pem: "".into(), - ca_cert_pem: "".into(), - app_id: vec![], - timeout: config.sync.timeout, - pccs_url: config.pccs_url.clone(), - } - }; - - let mut last_broadcast_time = Instant::now(); - let mut broadcast = false; - loop { - if broadcast { - last_broadcast_time = Instant::now(); - } - - let (mut nodes, apps) = proxy.lock().dump_state(); - // Sort nodes by pubkey - nodes.sort_by(|a, b| a.id.cmp(&b.id)); - - let self_idx = nodes - .iter() - .position(|n| n.wg_peer.pk == config.wg.public_key) - .unwrap_or(0); - - let state = GatewayState { - nodes: nodes.into_iter().map(|n| n.into()).collect(), - apps: apps.into_iter().map(|a| a.into()).collect(), - }; - - if state.nodes.is_empty() { - // If no nodes exist yet, sync with bootnode - sync_client - .sync_state_ignore_error(&config.sync.bootnode, &state) - .await; - } else { - let nodes = &state.nodes; - // Try nodes after self, wrapping around to beginning - let mut success = false; - for i in 1..nodes.len() { - let idx = (self_idx + i) % nodes.len(); - if sync_client - .sync_state_ignore_error(&nodes[idx].url, &state) - .await - { - success = true; - if !broadcast { - break; - } - } - } - - // If no node succeeded, try bootnode as fallback - if !success { - info!("Fallback to sync with bootnode"); - sync_client - .sync_state_ignore_error(&config.sync.bootnode, &state) - .await; - } - } - - tokio::select! { - _ = proxy.notify_state_updated.notified() => { - broadcast = true; - } - _ = tokio::time::sleep(config.sync.interval) => { - broadcast = last_broadcast_time.elapsed() >= config.sync.broadcast_interval; - } - } - } -} diff --git a/gateway/src/main_service/tests.rs b/gateway/src/main_service/tests.rs index d98c0131..1a43b154 100644 --- a/gateway/src/main_service/tests.rs +++ b/gateway/src/main_service/tests.rs @@ -3,17 +3,44 @@ // SPDX-License-Identifier: Apache-2.0 use super::*; -use crate::config::{load_config_figment, Config}; +use crate::config::{load_config_figment, Config, MutualConfig}; +use tempfile::TempDir; -async fn create_test_state() -> Proxy { +struct TestState { + proxy: Proxy, + _temp_dir: TempDir, +} + +impl std::ops::Deref for TestState { + type Target = Proxy; + fn deref(&self) -> &Self::Target { + &self.proxy + } +} + +async fn create_test_state() -> TestState { let figment = load_config_figment(None); let mut config = figment.focus("core").extract::().unwrap(); - let cargo_dir = env!("CARGO_MANIFEST_DIR"); - config.proxy.cert_chain = format!("{cargo_dir}/assets/cert.pem"); - config.proxy.cert_key = format!("{cargo_dir}/assets/cert.key"); - Proxy::new(config, None) + let temp_dir = TempDir::new().expect("failed to create temp dir"); + config.sync.data_dir = temp_dir.path().to_string_lossy().to_string(); + let options = ProxyOptions { + config, + my_app_id: None, + tls_config: TlsConfig { + certs: "".to_string(), + key: "".to_string(), + mutual: MutualConfig { + ca_certs: "".to_string(), + }, + }, + }; + let proxy = Proxy::new(options) .await - .expect("failed to create app state") + .expect("failed to create app state"); + TestState { + proxy, + _temp_dir: temp_dir, + } } #[tokio::test] @@ -32,14 +59,12 @@ async fn test_config() { .unwrap(); info.reg_time = SystemTime::UNIX_EPOCH; - info.last_seen = SystemTime::UNIX_EPOCH; insta::assert_debug_snapshot!(info); let mut info1 = state .lock() .new_client_by_id("test-id-1", "app-id-1", "test-pubkey-1") .unwrap(); info1.reg_time = SystemTime::UNIX_EPOCH; - info1.last_seen = SystemTime::UNIX_EPOCH; insta::assert_debug_snapshot!(info1); let wg_config = state.lock().generate_wg_config().unwrap(); insta::assert_snapshot!(wg_config); diff --git a/gateway/src/models.rs b/gateway/src/models.rs index ec476cff..37caa274 100644 --- a/gateway/src/models.rs +++ b/gateway/src/models.rs @@ -60,7 +60,6 @@ pub struct InstanceInfo { pub ip: Ipv4Addr, pub public_key: String, pub reg_time: SystemTime, - pub last_seen: SystemTime, #[serde(skip)] pub connections: Arc, } diff --git a/gateway/src/proxy.rs b/gateway/src/proxy.rs index 73b947cc..26bc1f1b 100644 --- a/gateway/src/proxy.rs +++ b/gateway/src/proxy.rs @@ -11,11 +11,13 @@ use std::{ }; use anyhow::{bail, Context, Result}; +use or_panic::ResultOrPanic; use sni::extract_sni; -pub(crate) use tls_terminate::create_acceptor; +pub(crate) use tls_terminate::create_acceptor_with_cert_resolver; use tokio::{ io::AsyncReadExt, net::{TcpListener, TcpStream}, + runtime::Runtime, time::timeout, }; use tracing::{debug, error, info, info_span, Instrument}; @@ -60,10 +62,6 @@ async fn take_sni(stream: &mut TcpStream) -> Result<(Option, Vec)> { Ok((None, buffer)) } -fn is_subdomain(sni: &str, base_domain: &str) -> bool { - sni.ends_with(base_domain) -} - #[derive(Debug)] struct DstInfo { app_id: String, @@ -72,14 +70,7 @@ struct DstInfo { is_h2: bool, } -fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result { - // format: [-][s]. - let subdomain = sni - .strip_suffix(dotted_base_domain) - .context("invalid sni format")?; - if subdomain.contains('.') { - bail!("only one level of subdomain is supported, got sni={sni}, subdomain={subdomain}"); - } +fn parse_dst_info(subdomain: &str) -> Result { let mut parts = subdomain.split('-'); let app_id = parts.next().context("no app id found")?.to_owned(); if app_id.is_empty() { @@ -131,11 +122,7 @@ fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result { pub static NUM_CONNECTIONS: AtomicU64 = AtomicU64::new(0); -async fn handle_connection( - mut inbound: TcpStream, - state: Proxy, - dotted_base_domain: &str, -) -> Result<()> { +async fn handle_connection(mut inbound: TcpStream, state: Proxy) -> Result<()> { let timeouts = &state.config.proxy.timeouts; let (sni, buffer) = timeout(timeouts.handshake, take_sni(&mut inbound)) .await @@ -144,8 +131,10 @@ async fn handle_connection( let Some(sni) = sni else { bail!("no sni found"); }; - if is_subdomain(&sni, dotted_base_domain) { - let dst = parse_destination(&sni, dotted_base_domain)?; + + let (subdomain, base_domain) = sni.split_once('.').context("invalid sni")?; + if state.cert_resolver.get().contains_wildcard(base_domain) { + let dst = parse_dst_info(subdomain)?; debug!("dst: {dst:?}"); if dst.is_tls { tls_passthough::proxy_to_app(state, inbound, buffer, &dst.app_id, dst.port).await @@ -160,19 +149,7 @@ async fn handle_connection( } #[inline(never)] -pub async fn proxy_main(config: &ProxyConfig, proxy: Proxy) -> Result<()> { - let workers_rt = tokio::runtime::Builder::new_multi_thread() - .thread_name("proxy-worker") - .enable_all() - .worker_threads(config.workers) - .build() - .context("Failed to build Tokio runtime")?; - - let dotted_base_domain = { - let base_domain = config.base_domain.as_str(); - let base_domain = base_domain.strip_prefix(".").unwrap_or(base_domain); - Arc::new(format!(".{base_domain}")) - }; +pub async fn proxy_main(rt: &Runtime, config: &ProxyConfig, proxy: Proxy) -> Result<()> { let listener = TcpListener::bind((config.listen_addr, config.listen_port)) .await .with_context(|| { @@ -195,16 +172,12 @@ pub async fn proxy_main(config: &ProxyConfig, proxy: Proxy) -> Result<()> { info!(%from, "new connection"); let proxy = proxy.clone(); - let dotted_base_domain = dotted_base_domain.clone(); - workers_rt.spawn( + rt.spawn( async move { let _conn_entered = conn_entered; let timeouts = &proxy.config.proxy.timeouts; - let result = timeout( - timeouts.total, - handle_connection(inbound, proxy, &dotted_base_domain), - ) - .await; + let result = + timeout(timeouts.total, handle_connection(inbound, proxy)).await; match result { Ok(Ok(_)) => { info!("connection closed"); @@ -233,17 +206,24 @@ fn next_connection_id() -> usize { } pub fn start(config: ProxyConfig, app_state: Proxy) -> Result<()> { - // Create a new single-threaded runtime - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .context("Failed to build Tokio runtime")?; - std::thread::Builder::new() .name("proxy-main".to_string()) .spawn(move || { + // Create a new single-threaded runtime + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .or_panic("Failed to build Tokio runtime"); + + let worker_rt = tokio::runtime::Builder::new_multi_thread() + .thread_name("proxy-worker") + .enable_all() + .worker_threads(config.workers) + .build() + .or_panic("Failed to build Tokio runtime"); + // Run the proxy_main function in this runtime - if let Err(err) = rt.block_on(proxy_main(&config, app_state)) { + if let Err(err) = rt.block_on(proxy_main(&worker_rt, &config, app_state)) { error!( "error on {}:{}: {err:?}", config.listen_addr, config.listen_port @@ -260,64 +240,40 @@ mod tests { #[test] fn test_parse_destination() { - let base_domain = ".example.com"; - // Test basic app_id only - let result = parse_destination("myapp.example.com", base_domain).unwrap(); + let result = parse_dst_info("myapp").unwrap(); assert_eq!(result.app_id, "myapp"); assert_eq!(result.port, 80); assert!(!result.is_tls); // Test app_id with custom port - let result = parse_destination("myapp-8080.example.com", base_domain).unwrap(); + let result = parse_dst_info("myapp-8080").unwrap(); assert_eq!(result.app_id, "myapp"); assert_eq!(result.port, 8080); assert!(!result.is_tls); // Test app_id with TLS - let result = parse_destination("myapp-443s.example.com", base_domain).unwrap(); + let result = parse_dst_info("myapp-443s").unwrap(); assert_eq!(result.app_id, "myapp"); assert_eq!(result.port, 443); assert!(result.is_tls); // Test app_id with custom port and TLS - let result = parse_destination("myapp-8443s.example.com", base_domain).unwrap(); + let result = parse_dst_info("myapp-8443s").unwrap(); assert_eq!(result.app_id, "myapp"); assert_eq!(result.port, 8443); assert!(result.is_tls); // Test default port but ends with s - let result = parse_destination("myapps.example.com", base_domain).unwrap(); + let result = parse_dst_info("myapps").unwrap(); assert_eq!(result.app_id, "myapps"); assert_eq!(result.port, 80); assert!(!result.is_tls); // Test default port but ends with s in port part - let result = parse_destination("myapp-s.example.com", base_domain).unwrap(); + let result = parse_dst_info("myapp-s").unwrap(); assert_eq!(result.app_id, "myapp"); assert_eq!(result.port, 443); assert!(result.is_tls); } - - #[test] - fn test_parse_destination_errors() { - let base_domain = ".example.com"; - - // Test invalid domain suffix - assert!(parse_destination("myapp.wrong.com", base_domain).is_err()); - - // Test multiple subdomains - assert!(parse_destination("invalid.myapp.example.com", base_domain).is_err()); - - // Test invalid port format - assert!(parse_destination("myapp-65536.example.com", base_domain).is_err()); - assert!(parse_destination("myapp-abc.example.com", base_domain).is_err()); - - // Test too many parts - assert!(parse_destination("myapp-8080-extra.example.com", base_domain).is_err()); - - // Test empty app_id - assert!(parse_destination("-8080.example.com", base_domain).is_err()); - assert!(parse_destination("myapp-8080ss.example.com", base_domain).is_err()); - } } diff --git a/gateway/src/proxy/tls_terminate.rs b/gateway/src/proxy/tls_terminate.rs index ad19ebf4..d1db93f8 100644 --- a/gateway/src/proxy/tls_terminate.rs +++ b/gateway/src/proxy/tls_terminate.rs @@ -8,23 +8,20 @@ use std::sync::Arc; use std::task::{Context, Poll}; use anyhow::{anyhow, bail, Context as _, Result}; -use fs_err as fs; use hyper::body::Incoming; use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::{Request, Response, StatusCode}; use hyper_util::rt::tokio::TokioIo; -use rustls::pki_types::pem::PemObject; -use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::version::{TLS12, TLS13}; use serde::Serialize; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::TcpStream; use tokio::time::timeout; use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor}; -use tracing::{debug, info}; +use tracing::debug; -use or_panic::ResultOrPanic; +use crate::cert_store::CertResolver; use crate::config::{CryptoProvider, ProxyConfig, TlsVersion}; use crate::main_service::Proxy; @@ -99,20 +96,19 @@ where } } -pub(crate) fn create_acceptor(config: &ProxyConfig, h2: bool) -> Result { - let cert_pem = fs::read(&config.cert_chain).context("failed to read certificate")?; - let key_pem = fs::read(&config.cert_key).context("failed to read private key")?; - let certs = CertificateDer::pem_slice_iter(cert_pem.as_slice()) - .collect::, _>>() - .context("failed to parse certificate")?; - let key = - PrivateKeyDer::from_pem_slice(key_pem.as_slice()).context("failed to parse private key")?; - - let provider = match config.tls_crypto_provider { +/// Create a TLS acceptor using CertResolver for SNI-based certificate resolution +/// +/// The CertResolver allows atomic certificate updates without recreating the acceptor. +pub(crate) fn create_acceptor_with_cert_resolver( + proxy_config: &ProxyConfig, + cert_resolver: Arc, + h2: bool, +) -> Result { + let provider = match proxy_config.tls_crypto_provider { CryptoProvider::AwsLcRs => rustls::crypto::aws_lc_rs::default_provider(), CryptoProvider::Ring => rustls::crypto::ring::default_provider(), }; - let supported_versions = config + let supported_versions = proxy_config .tls_versions .iter() .map(|v| match v { @@ -120,11 +116,12 @@ pub(crate) fn create_acceptor(config: &ProxyConfig, h2: bool) -> Result &TLS13, }) .collect::>(); + let mut config = rustls::ServerConfig::builder_with_provider(Arc::new(provider)) .with_protocol_versions(&supported_versions) - .context("Failed to build TLS config")? + .context("failed to build TLS config")? .with_no_client_auth() - .with_single_cert(certs, key)?; + .with_cert_resolver(cert_resolver); if h2 { config.alpn_protocols = vec![b"h2".to_vec()]; @@ -152,27 +149,6 @@ fn empty_response(status: StatusCode) -> Result> { } impl Proxy { - /// Reload the TLS acceptor with fresh certificates - pub fn reload_certificates(&self) -> Result<()> { - info!("Reloading TLS certificates"); - // Replace the acceptor with the new one - if let Ok(mut acceptor) = self.acceptor.write() { - *acceptor = create_acceptor(&self.config.proxy, false)?; - info!("TLS certificates successfully reloaded"); - } else { - bail!("Failed to acquire write lock for TLS acceptor"); - } - - if let Ok(mut acceptor) = self.h2_acceptor.write() { - *acceptor = create_acceptor(&self.config.proxy, true)?; - info!("TLS certificates successfully reloaded"); - } else { - bail!("Failed to acquire write lock for TLS acceptor"); - } - - Ok(()) - } - pub(crate) async fn handle_this_node( &self, inbound: TcpStream, @@ -213,7 +189,7 @@ impl Proxy { json_response(&app_info) } "/acme-info" => { - let acme_info = self.acme_info().await.context("Failed to get acme info")?; + let acme_info = self.acme_info(None).context("Failed to get acme info")?; json_response(&acme_info) } _ => empty_response(StatusCode::NOT_FOUND), @@ -278,15 +254,9 @@ impl Proxy { inbound, }; let acceptor = if h2 { - self.h2_acceptor - .read() - .or_panic("lock should never fail") - .clone() + &self.h2_acceptor } else { - self.acceptor - .read() - .or_panic("lock should never fail") - .clone() + &self.acceptor }; let tls_stream = timeout( self.config.proxy.timeouts.handshake, @@ -315,7 +285,7 @@ impl Proxy { let addresses = self .lock() .select_top_n_hosts(app_id) - .with_context(|| format!("app {app_id} not found"))?; + .with_context(|| format!("app <{app_id}> not found"))?; debug!("selected top n hosts: {addresses:?}"); let tls_stream = self.tls_accept(inbound, buffer, h2).await?; let (outbound, _counter) = timeout( diff --git a/gateway/src/web_routes.rs b/gateway/src/web_routes.rs index 1bd57f2b..5f72735d 100644 --- a/gateway/src/web_routes.rs +++ b/gateway/src/web_routes.rs @@ -7,12 +7,28 @@ use anyhow::Result; use rocket::{get, response::content::RawHtml, routes, Route, State}; mod route_index; +mod wavekv_sync; #[get("/")] async fn index(state: &State) -> Result, String> { route_index::index(state).await.map_err(|e| format!("{e}")) } +#[get("/health")] +fn health() -> &'static str { + "OK" +} + pub fn routes() -> Vec { routes![index] } + +/// Health endpoint for simple liveness checks +pub fn health_routes() -> Vec { + routes![health] +} + +/// WaveKV sync endpoint (for main server, requires mTLS gateway auth) +pub fn wavekv_sync_routes() -> Vec { + routes![wavekv_sync::sync_store] +} diff --git a/gateway/src/web_routes/wavekv_sync.rs b/gateway/src/web_routes/wavekv_sync.rs new file mode 100644 index 00000000..dead1141 --- /dev/null +++ b/gateway/src/web_routes/wavekv_sync.rs @@ -0,0 +1,150 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! WaveKV sync HTTP endpoints +//! +//! Sync data is encoded using msgpack + gzip compression for efficiency. + +use crate::{ + kv::{decode, encode}, + main_service::Proxy, +}; +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; +use ra_tls::traits::CertExt; +use rocket::{ + data::{Data, ToByteUnit}, + http::{ContentType, Status}, + mtls::{oid::Oid, Certificate}, + post, State, +}; +use std::io::{Read, Write}; +use tracing::warn; +use wavekv::sync::{SyncMessage, SyncResponse}; + +/// Wrapper to implement CertExt for Rocket's Certificate +struct RocketCert<'a>(&'a Certificate<'a>); + +impl CertExt for RocketCert<'_> { + fn get_extension_der(&self, oid: &[u64]) -> anyhow::Result>> { + let oid = Oid::from(oid).map_err(|_| anyhow::anyhow!("failed to create OID from slice"))?; + let Some(ext) = self.0.extensions().iter().find(|ext| ext.oid == oid) else { + return Ok(None); + }; + Ok(Some(ext.value.to_vec())) + } +} + +/// Decode compressed msgpack data +fn decode_sync_message(data: &[u8]) -> Result { + // Decompress + let mut decoder = GzDecoder::new(data); + let mut decompressed = Vec::new(); + decoder.read_to_end(&mut decompressed).map_err(|e| { + warn!("failed to decompress sync message: {e}"); + Status::BadRequest + })?; + + decode(&decompressed).map_err(|e| { + warn!("failed to decode sync message: {e}"); + Status::BadRequest + }) +} + +/// Encode and compress sync response +fn encode_sync_response(response: &SyncResponse) -> Result, Status> { + let encoded = encode(response).map_err(|e| { + warn!("failed to encode sync response: {e}"); + Status::InternalServerError + })?; + + // Compress + let mut encoder = GzEncoder::new(Vec::new(), Compression::fast()); + encoder.write_all(&encoded).map_err(|e| { + warn!("failed to compress sync response: {e}"); + Status::InternalServerError + })?; + encoder.finish().map_err(|e| { + warn!("failed to finish compression: {e}"); + Status::InternalServerError + }) +} + +/// Verify that the request is from a gateway with the same app_id (mTLS verification) +fn verify_gateway_peer(state: &Proxy, cert: Option>) -> Result<(), Status> { + // Skip verification if not running in dstack (test mode) + if state.config.debug.insecure_skip_attestation { + return Ok(()); + } + + let Some(cert) = cert else { + warn!("WaveKV sync: client certificate required but not provided"); + return Err(Status::Unauthorized); + }; + + let remote_app_id = RocketCert(&cert).get_app_id().map_err(|e| { + warn!("WaveKV sync: failed to extract app_id from certificate: {e}"); + Status::Unauthorized + })?; + + let Some(remote_app_id) = remote_app_id else { + warn!("WaveKV sync: certificate does not contain app_id"); + return Err(Status::Unauthorized); + }; + + if state.my_app_id() != Some(remote_app_id.as_slice()) { + warn!( + "WaveKV sync: app_id mismatch, expected {:?}, got {:?}", + state.my_app_id(), + remote_app_id + ); + return Err(Status::Forbidden); + } + + Ok(()) +} + +/// Handle sync request (msgpack + gzip encoded) +#[post("/wavekv/sync/", data = "")] +pub async fn sync_store( + state: &State, + cert: Option>, + store: &str, + data: Data<'_>, +) -> Result<(ContentType, Vec), Status> { + verify_gateway_peer(state, cert)?; + + let Some(ref wavekv_sync) = state.wavekv_sync else { + return Err(Status::ServiceUnavailable); + }; + + // Read and decode request + let bytes = data + .open(16.mebibytes()) + .into_bytes() + .await + .map_err(|_| Status::BadRequest)?; + let msg = decode_sync_message(&bytes)?; + + // Reject sync from node_id == 0 + if msg.sender_id == 0 { + warn!("rejected sync from invalid node_id 0"); + return Err(Status::BadRequest); + } + + // Handle sync based on store type + let response = match store { + "persistent" => wavekv_sync.handle_persistent_sync(msg), + "ephemeral" => wavekv_sync.handle_ephemeral_sync(msg), + _ => return Err(Status::NotFound), + } + .map_err(|e| { + tracing::error!("{store} sync failed: {e}"); + Status::InternalServerError + })?; + + // Encode response + let encoded = encode_sync_response(&response)?; + + Ok((ContentType::new("application", "x-msgpack-gz"), encoded)) +} diff --git a/gateway/templates/dashboard.html b/gateway/templates/dashboard.html index 56750204..6fe3b7eb 100644 --- a/gateway/templates/dashboard.html +++ b/gateway/templates/dashboard.html @@ -34,7 +34,7 @@ border-collapse: collapse; background-color: white; border-radius: 8px; - box-shadow: 0 1px 3px rgba(0,0,0,0.1); + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); margin: 20px 0; } @@ -93,14 +93,14 @@ font-size: 12px; white-space: nowrap; z-index: 1; - box-shadow: 0 2px 4px rgba(0,0,0,0.2); + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2); } .info-section { background: white; padding: 20px; border-radius: 8px; - box-shadow: 0 1px 3px rgba(0,0,0,0.1); + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); } .info-group { @@ -152,6 +152,242 @@ text-overflow: ellipsis; white-space: nowrap; } + + .last-seen-cell { + white-space: nowrap; + } + + .last-seen-row { + margin-bottom: 4px; + } + + .last-seen-row:last-child { + margin-bottom: 0; + } + + .observer-label { + color: #666; + font-size: 0.9em; + } + + .node-status { + font-weight: bold; + } + + .node-status.up { + color: #4CAF50; + } + + .node-status.down { + color: #f44336; + } + + .status-controls { + display: flex; + gap: 5px; + } + + .status-btn { + padding: 4px 8px; + border: none; + border-radius: 4px; + cursor: pointer; + font-size: 12px; + font-weight: bold; + transition: opacity 0.2s; + } + + .status-btn:hover { + opacity: 0.8; + } + + .status-btn.up { + background-color: #4CAF50; + color: white; + } + + .status-btn.down { + background-color: #f44336; + color: white; + } + + .status-btn:disabled { + opacity: 0.5; + cursor: not-allowed; + } + + /* Certificate config styles */ + .action-btn { + padding: 6px 12px; + border: none; + border-radius: 4px; + cursor: pointer; + font-size: 13px; + font-weight: bold; + transition: opacity 0.2s; + margin-right: 5px; + } + + .action-btn:hover { + opacity: 0.8; + } + + .action-btn.primary { + background-color: #4CAF50; + color: white; + } + + .action-btn.danger { + background-color: #f44336; + color: white; + } + + .action-btn.secondary { + background-color: #2196F3; + color: white; + } + + .action-btn.warning { + background-color: #ff9800; + color: white; + } + + .action-btn:disabled { + opacity: 0.5; + cursor: not-allowed; + } + + .default-badge { + background-color: #4CAF50; + color: white; + padding: 2px 8px; + border-radius: 12px; + font-size: 11px; + font-weight: bold; + } + + .cert-status { + display: flex; + flex-direction: column; + gap: 4px; + } + + .cert-status-item { + font-size: 12px; + } + + .cert-status-item.has-cert { + color: #4CAF50; + } + + .cert-status-item.no-cert { + color: #f44336; + } + + .modal-overlay { + display: none; + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background: rgba(0, 0, 0, 0.5); + z-index: 1000; + justify-content: center; + align-items: center; + } + + .modal-overlay.active { + display: flex; + } + + .modal { + background: white; + padding: 30px; + border-radius: 8px; + box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3); + min-width: 400px; + max-width: 500px; + } + + .modal h3 { + margin-top: 0; + color: #333; + border-bottom: 2px solid #4CAF50; + padding-bottom: 10px; + } + + .modal-field { + margin-bottom: 15px; + } + + .modal-field label { + display: block; + margin-bottom: 5px; + font-weight: bold; + color: #555; + } + + .modal-field input, + .modal-field select { + width: 100%; + padding: 10px; + border: 1px solid #ddd; + border-radius: 4px; + box-sizing: border-box; + } + + .modal-field input[type="checkbox"] { + width: auto; + } + + .modal-actions { + display: flex; + justify-content: flex-end; + gap: 10px; + margin-top: 20px; + } + + .toast { + position: fixed; + bottom: 20px; + right: 20px; + padding: 15px 25px; + border-radius: 4px; + color: white; + font-weight: bold; + z-index: 2000; + animation: slideIn 0.3s ease; + } + + .toast.success { + background-color: #4CAF50; + } + + .toast.error { + background-color: #f44336; + } + + @keyframes slideIn { + from { + transform: translateX(100%); + opacity: 0; + } + to { + transform: translateX(0); + opacity: 1; + } + } + + .section-header { + display: flex; + justify-content: space-between; + align-items: center; + } + + .section-header h2 { + margin: 0; + } Dashboard \ No newline at end of file diff --git a/gateway/test-run/.env.example b/gateway/test-run/.env.example new file mode 100644 index 00000000..ff657175 --- /dev/null +++ b/gateway/test-run/.env.example @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# Cloudflare API token with DNS edit permissions +# Required scopes: Zone.DNS (Edit), Zone.Zone (Read) +CF_API_TOKEN=your_cloudflare_api_token_here + +# Cloudflare Zone ID for your domain +CF_ZONE_ID=your_zone_id_here + +# Test domain (must be a wildcard domain managed by Cloudflare) +# Example: *.test.example.com +TEST_DOMAIN=*.test.example.com diff --git a/gateway/test-run/.gitignore b/gateway/test-run/.gitignore new file mode 100644 index 00000000..17972360 --- /dev/null +++ b/gateway/test-run/.gitignore @@ -0,0 +1,2 @@ +/run/ +.env diff --git a/gateway/test-run/cluster.sh b/gateway/test-run/cluster.sh new file mode 100755 index 00000000..23521bd1 --- /dev/null +++ b/gateway/test-run/cluster.sh @@ -0,0 +1,442 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# Gateway cluster management script for manual testing + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +GATEWAY_BIN="${SCRIPT_DIR}/../../target/release/dstack-gateway" +RUN_DIR="run" +CERTS_DIR="$RUN_DIR/certs" +CA_CERT="$CERTS_DIR/gateway-ca.cert" +LOG_DIR="$RUN_DIR/logs" +TMUX_SESSION="gateway-cluster" + +# Colors +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +show_help() { + echo "Gateway Cluster Management Script" + echo "" + echo "Usage: $0 " + echo "" + echo "Commands:" + echo " start Start a 3-node gateway cluster in tmux" + echo " stop Stop the cluster (keep tmux session)" + echo " reg Register a random instance" + echo " status Show cluster status" + echo " clean Destroy cluster and clean all data" + echo " attach Attach to tmux session" + echo " help Show this help" + echo "" +} + +# Generate certificates +generate_certs() { + mkdir -p "$CERTS_DIR" + mkdir -p "$RUN_DIR/certbot/live" + + # Generate CA certificate + if [[ ! -f "$CERTS_DIR/gateway-ca.key" ]]; then + log_info "Creating CA certificate..." + openssl genrsa -out "$CERTS_DIR/gateway-ca.key" 2048 2>/dev/null + openssl req -x509 -new -nodes \ + -key "$CERTS_DIR/gateway-ca.key" \ + -sha256 -days 365 \ + -out "$CERTS_DIR/gateway-ca.cert" \ + -subj "/CN=Test CA/O=Gateway Test" \ + 2>/dev/null + fi + + # Generate RPC certificate signed by CA + if [[ ! -f "$CERTS_DIR/gateway-rpc.key" ]]; then + log_info "Creating RPC certificate..." + openssl genrsa -out "$CERTS_DIR/gateway-rpc.key" 2048 2>/dev/null + openssl req -new \ + -key "$CERTS_DIR/gateway-rpc.key" \ + -out "$CERTS_DIR/gateway-rpc.csr" \ + -subj "/CN=localhost" \ + 2>/dev/null + cat > "$CERTS_DIR/ext.cnf" << EXTEOF +authorityKeyIdentifier=keyid,issuer +basicConstraints=CA:FALSE +keyUsage = digitalSignature, nonRepudiation, keyEncipherment, dataEncipherment +subjectAltName = @alt_names + +[alt_names] +DNS.1 = localhost +IP.1 = 127.0.0.1 +EXTEOF + openssl x509 -req \ + -in "$CERTS_DIR/gateway-rpc.csr" \ + -CA "$CERTS_DIR/gateway-ca.cert" \ + -CAkey "$CERTS_DIR/gateway-ca.key" \ + -CAcreateserial \ + -out "$CERTS_DIR/gateway-rpc.cert" \ + -days 365 \ + -sha256 \ + -extfile "$CERTS_DIR/ext.cnf" \ + 2>/dev/null + rm -f "$CERTS_DIR/gateway-rpc.csr" "$CERTS_DIR/ext.cnf" + fi + + # Generate proxy certificates + local proxy_cert_dir="$RUN_DIR/certbot/live" + if [[ ! -f "$proxy_cert_dir/cert.pem" ]]; then + log_info "Creating proxy certificates..." + openssl req -x509 -newkey rsa:2048 -nodes \ + -keyout "$proxy_cert_dir/key.pem" \ + -out "$proxy_cert_dir/cert.pem" \ + -days 365 \ + -subj "/CN=localhost" \ + 2>/dev/null + fi + + # Generate unique WireGuard key pair for each node + for i in 1 2 3; do + if [[ ! -f "$CERTS_DIR/wg-node${i}.key" ]]; then + log_info "Generating WireGuard keys for node ${i}..." + wg genkey > "$CERTS_DIR/wg-node${i}.key" + wg pubkey < "$CERTS_DIR/wg-node${i}.key" > "$CERTS_DIR/wg-node${i}.pub" + fi + done +} + +# Generate node config +generate_config() { + local node_id=$1 + local rpc_port=$((13000 + node_id * 10 + 2)) + local wg_port=$((13000 + node_id * 10 + 3)) + local proxy_port=$((13000 + node_id * 10 + 4)) + local debug_port=$((13000 + node_id * 10 + 5)) + local admin_port=$((13000 + node_id * 10 + 6)) + local wg_ip="10.0.3${node_id}.1/24" + local other_nodes="" + local peer_urls="" + + # Read WireGuard keys for this node + local wg_private_key=$(cat "$CERTS_DIR/wg-node${node_id}.key") + local wg_public_key=$(cat "$CERTS_DIR/wg-node${node_id}.pub") + + for i in 1 2 3; do + if [[ $i -ne $node_id ]]; then + local peer_rpc_port=$((13000 + i * 10 + 2)) + if [[ -n "$other_nodes" ]]; then + other_nodes="$other_nodes, $i" + peer_urls="$peer_urls, \"$i:https://localhost:$peer_rpc_port\"" + else + other_nodes="$i" + peer_urls="\"$i:https://localhost:$peer_rpc_port\"" + fi + fi + done + + local abs_run_dir="$SCRIPT_DIR/$RUN_DIR" + cat > "$RUN_DIR/node${node_id}.toml" << EOF +log_level = "info" +address = "0.0.0.0" +port = ${rpc_port} + +[tls] +key = "${abs_run_dir}/certs/gateway-rpc.key" +certs = "${abs_run_dir}/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "${abs_run_dir}/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "" +rpc_domain = "gateway.test.local" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = true +port = ${debug_port} +address = "127.0.0.1" + +[core.admin] +enabled = true +port = ${admin_port} +address = "127.0.0.1" + +[core.sync] +enabled = true +interval = "5s" +timeout = "10s" +my_url = "https://localhost:${rpc_port}" +bootnode = "" +node_id = ${node_id} +data_dir = "${RUN_DIR}/wavekv_node${node_id}" + +[core.certbot] +enabled = false + +[core.wg] +private_key = "${wg_private_key}" +public_key = "${wg_public_key}" +listen_port = ${wg_port} +ip = "${wg_ip}" +reserved_net = ["10.0.3${node_id}.1/31"] +client_ip_range = "10.0.3${node_id}.1/24" +config_path = "${RUN_DIR}/wg_node${node_id}.conf" +interface = "gw-test${node_id}" +endpoint = "127.0.0.1:${wg_port}" + +[core.proxy] +cert_chain = "${RUN_DIR}/certbot/live/cert.pem" +cert_key = "${RUN_DIR}/certbot/live/key.pem" +base_domain = "test.local" +listen_addr = "0.0.0.0" +listen_port = ${proxy_port} +tappd_port = 8090 +external_port = ${proxy_port} + +[core.recycle] +enabled = true +interval = "30s" +timeout = "120s" +node_timeout = "300s" +EOF +} + +# Build gateway binary +build_gateway() { + if [[ ! -f "$GATEWAY_BIN" ]]; then + log_info "Building gateway..." + (cd "$SCRIPT_DIR/.." && cargo build --release) + fi +} + +# Start cluster +cmd_start() { + build_gateway + generate_certs + + # Check if tmux session exists + if tmux has-session -t "$TMUX_SESSION" 2>/dev/null; then + log_warn "Cluster already running. Use 'clean' to restart." + cmd_status + return 0 + fi + + log_info "Generating configs..." + mkdir -p "$RUN_DIR" "$LOG_DIR" + for i in 1 2 3; do + generate_config $i + mkdir -p "$RUN_DIR/wavekv_node${i}" + done + + log_info "Starting cluster in tmux session '$TMUX_SESSION'..." + + # Create wrapper scripts that keep running even if gateway exits + for i in 1 2 3; do + cat > "$RUN_DIR/run_node${i}.sh" << RUNEOF +#!/bin/bash +cd "$SCRIPT_DIR" +while true; do + echo "Starting node ${i}..." + sudo RUST_LOG=info $GATEWAY_BIN -c $RUN_DIR/node${i}.toml 2>&1 | tee -a $LOG_DIR/node${i}.log + echo "Node ${i} exited. Press Ctrl+C to stop, or wait 3s to restart..." + sleep 3 +done +RUNEOF + chmod +x "$RUN_DIR/run_node${i}.sh" + done + + # Create tmux session + tmux new-session -d -s "$TMUX_SESSION" -n "node1" + tmux send-keys -t "$TMUX_SESSION:node1" "$RUN_DIR/run_node1.sh" Enter + + sleep 1 + + # Add windows for other nodes + tmux new-window -t "$TMUX_SESSION" -n "node2" + tmux send-keys -t "$TMUX_SESSION:node2" "$RUN_DIR/run_node2.sh" Enter + + tmux new-window -t "$TMUX_SESSION" -n "node3" + tmux send-keys -t "$TMUX_SESSION:node3" "$RUN_DIR/run_node3.sh" Enter + + # Add a shell window + tmux new-window -t "$TMUX_SESSION" -n "shell" + + sleep 3 + + log_info "Cluster started!" + echo "" + cmd_status + echo "" + log_info "Use '$0 attach' to view logs" +} + +# Stop cluster +cmd_stop() { + log_info "Stopping cluster..." + sudo pkill -9 -f "dstack-gateway.*node[123].toml" 2>/dev/null || true + sudo ip link delete gw-test1 2>/dev/null || true + sudo ip link delete gw-test2 2>/dev/null || true + sudo ip link delete gw-test3 2>/dev/null || true + log_info "Cluster stopped" +} + +# Clean everything +cmd_clean() { + cmd_stop + + # Kill tmux session + tmux kill-session -t "$TMUX_SESSION" 2>/dev/null || true + + log_info "Cleaning data..." + sudo rm -rf "$RUN_DIR/wavekv_node"* + sudo rm -f "$RUN_DIR/gateway-state-node"*.json + rm -f "$RUN_DIR/wg_node"*.conf + rm -f "$RUN_DIR/node"*.toml + rm -f "$RUN_DIR/run_node"*.sh + rm -rf "$LOG_DIR" + + log_info "Cleaned" +} + +# Show status +cmd_status() { + echo -e "${BLUE}=== Gateway Cluster Status ===${NC}" + echo "" + + for i in 1 2 3; do + local rpc_port=$((13000 + i * 10 + 2)) + local proxy_port=$((13000 + i * 10 + 4)) + local debug_port=$((13000 + i * 10 + 5)) + local admin_port=$((13000 + i * 10 + 6)) + + if pgrep -f "dstack-gateway.*node${i}.toml" > /dev/null 2>&1; then + echo -e "Node $i: ${GREEN}RUNNING${NC}" + else + echo -e "Node $i: ${RED}STOPPED${NC}" + fi + echo " RPC: https://localhost:${rpc_port}" + echo " Proxy: https://localhost:${proxy_port}" + echo " Debug: http://localhost:${debug_port}" + echo " Admin: http://localhost:${admin_port}" + echo "" + done + + # Show instance count from first running node + for i in 1 2 3; do + local debug_port=$((13000 + i * 10 + 5)) + if pgrep -f "dstack-gateway.*node${i}.toml" > /dev/null 2>&1; then + local response=$(curl -s -X POST "http://localhost:${debug_port}/prpc/GetSyncData" \ + -H "Content-Type: application/json" -d '{}' 2>/dev/null) + if [[ -n "$response" ]]; then + local n_instances=$(echo "$response" | python3 -c "import sys,json; print(len(json.load(sys.stdin).get('instances', [])))" 2>/dev/null || echo "?") + local n_nodes=$(echo "$response" | python3 -c "import sys,json; print(len(json.load(sys.stdin).get('nodes', [])))" 2>/dev/null || echo "?") + echo -e "${BLUE}Cluster State:${NC}" + echo " Nodes: $n_nodes" + echo " Instances: $n_instances" + fi + break + fi + done +} + +# Register a random instance +cmd_reg() { + # Find a running node + local debug_port="" + for i in 1 2 3; do + local port=$((13000 + i * 10 + 5)) + if pgrep -f "dstack-gateway.*node${i}.toml" > /dev/null 2>&1; then + debug_port=$port + break + fi + done + + if [[ -z "$debug_port" ]]; then + log_error "No running nodes found. Start cluster first." + exit 1 + fi + + # Generate random WireGuard key pair + local private_key=$(wg genkey) + local public_key=$(echo "$private_key" | wg pubkey) + + # Generate random IDs + local app_id="app-$(openssl rand -hex 4)" + local instance_id="inst-$(openssl rand -hex 4)" + + log_info "Registering instance..." + log_info " App ID: $app_id" + log_info " Instance ID: $instance_id" + log_info " Public Key: $public_key" + + local response=$(curl -s \ + -X POST "http://localhost:${debug_port}/prpc/RegisterCvm" \ + -H "Content-Type: application/json" \ + -d "{\"client_public_key\": \"$public_key\", \"app_id\": \"$app_id\", \"instance_id\": \"$instance_id\"}" 2>/dev/null) + + if echo "$response" | python3 -c "import sys,json; d=json.load(sys.stdin); assert 'wg' in d" 2>/dev/null; then + local client_ip=$(echo "$response" | python3 -c "import sys,json; print(json.load(sys.stdin)['wg']['client_ip'])" 2>/dev/null) + log_info "Registered successfully!" + echo -e " Client IP: ${GREEN}$client_ip${NC}" + echo "" + echo "Instance details:" + echo "$response" | python3 -m json.tool 2>/dev/null || echo "$response" + else + log_error "Registration failed:" + echo "$response" | python3 -m json.tool 2>/dev/null || echo "$response" + exit 1 + fi +} + +# Attach to tmux +cmd_attach() { + if tmux has-session -t "$TMUX_SESSION" 2>/dev/null; then + tmux attach -t "$TMUX_SESSION" + else + log_error "No cluster running" + exit 1 + fi +} + +# Main +case "${1:-help}" in + start) + cmd_start + ;; + stop) + cmd_stop + ;; + clean) + cmd_clean + ;; + status) + cmd_status + ;; + reg) + cmd_reg + ;; + attach) + cmd_attach + ;; + help|--help|-h) + show_help + ;; + *) + log_error "Unknown command: $1" + show_help + exit 1 + ;; +esac diff --git a/gateway/test-run/e2e/certs/gateway-ca.cert b/gateway/test-run/e2e/certs/gateway-ca.cert new file mode 100644 index 00000000..a33ca300 --- /dev/null +++ b/gateway/test-run/e2e/certs/gateway-ca.cert @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFTTCCAzWgAwIBAgIUQ35rCxX0ip4rTKbamxO5GvQ8h5UwDQYJKoZIhvcNAQEL +BQAwNjEYMBYGA1UEAwwPR2F0ZXdheSBUZXN0IENBMQ0wCwYDVQQKDARUZXN0MQsw +CQYDVQQGEwJVUzAeFw0yNjAxMTgwMTQ4MTJaFw0zNjAxMTYwMTQ4MTJaMDYxGDAW +BgNVBAMMD0dhdGV3YXkgVGVzdCBDQTENMAsGA1UECgwEVGVzdDELMAkGA1UEBhMC +VVMwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQD6A2nInhqWiOdaVSOQ +5KWMs6MI+CiGKPoQZh8r/+fiwkvk22PYXsTPJpdQEIYr2QfnvOrHgXZFz3yJ0ahM +16PxSNXD52kSxyfbk/q4mSLVOLY8NFRwUF1Eu9wwwrgSXYRks0uuDuYiI4bW6xvR +kwFLPvT+OUz2pYXvoJXyuYfxDRUa2dR4z4bV+rHrPInB0pBjw9bykszi6UsHmhir +r1SBZH2V0panXmvuXS5gGL4Gq953LKuDyAZfQboK4QldATJg2kKvj9+tty9yVn3i +iOjJRbOCf/IC9U4PJIIQcwxCDcT6VBCXE/X55W3NNwuWH/hHQ0pHdPL+oKA7G1L3 +WG1OCInVmaJboLku0OjPCxJ/n2TAACgo1Q2XzZ5rW6tvssFueIrR2/9HfJ9M1b8Q +HhnO305Z4EQkMIFuOI64wH3LEanjoLmK2D9cY3Bb4E9ddrstErUT9PO8RA5UiFUN +qycfGB/68eTVADtMjsJDdMyLFix4LLPaddjD8KdA218V/rZ+nIBtAMy2bbIsRPbm +yXz5LpNeIsGOokvpCe3iLT85o8UEvuJt0OB9H8HikX9WiHgub6MIYledqWz6H/cf +LsJjeKap/RZnOTs49rwyvrTkGhS+dIUVlP67P6i1SEJS/HRnYO91cXoRJk0AE8aH +uUPpAVTWEKF6w6PtqFbbr9oeKQIDAQABo1MwUTAdBgNVHQ4EFgQUqQfwkQvNdkWA +mmlLl8VeH1xqXP8wHwYDVR0jBBgwFoAUqQfwkQvNdkWAmmlLl8VeH1xqXP8wDwYD +VR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAgEAnOLEjvNdtpKp2sbKTpoK +GwDx4imvjAbfSrrjEbRBTYqfeI1FLCL7lws6QWq0orWoZzP9IILuGfA0/MMhqbSl +zPMSp6qFEWchuk4/Z4vthko3VM63wF0G1gMdtGQ+d4ZHELYJIOenVDwA64KVZ/qd +NYjGFompEIn8ciYHJuurvx3awSRbnDtCurtIFOQBZJvTdX03SuQOlYm6hV8D5wOt +9g8Fb3pauIFvnN6bngL68xj5Gc5mPxjFdu8XgKDVffxrgDe3eWpszcMDkQRqmJDm +bq8veJ58dVJpHNSgCOVojEgqhvDBDRO5nalEoFKxp56yyNveh1fBS087r1pfoEfE +2qX5InW6jVBe+N824IwH45S2unpB8k2utVXKOtPMF0LpYvKq++adv38vuw/iyhd6 +sjJxEhzHytAnH6nBTFM9drwlNhsK7VkRyunqJbQATjrayfgV+Qo4NKZIRsmroZpx +c/aznoLB+fgwXeST81wYiTBu1jc7tWL28IhLUhNXwWRuBa/j4UtrrYUd15iqZHPN +sJORAZB7kljogjNso2+pn1yWgAiS33bIikhXcXoSHmRI3yySolw+pTem5hbb/C6a +7IjKZSmw6QO3uGRgP4AHnQ7NPtAwigrLLsMPhtZYXRV2eS1HLagqN0Ia9c45iNVb +b5rqOJkqXy5xOjVLA+bk9Kg= +-----END CERTIFICATE----- diff --git a/gateway/test-run/e2e/certs/gateway-ca.key b/gateway/test-run/e2e/certs/gateway-ca.key new file mode 100644 index 00000000..c81dbf3e --- /dev/null +++ b/gateway/test-run/e2e/certs/gateway-ca.key @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQD6A2nInhqWiOda +VSOQ5KWMs6MI+CiGKPoQZh8r/+fiwkvk22PYXsTPJpdQEIYr2QfnvOrHgXZFz3yJ +0ahM16PxSNXD52kSxyfbk/q4mSLVOLY8NFRwUF1Eu9wwwrgSXYRks0uuDuYiI4bW +6xvRkwFLPvT+OUz2pYXvoJXyuYfxDRUa2dR4z4bV+rHrPInB0pBjw9bykszi6UsH +mhirr1SBZH2V0panXmvuXS5gGL4Gq953LKuDyAZfQboK4QldATJg2kKvj9+tty9y +Vn3iiOjJRbOCf/IC9U4PJIIQcwxCDcT6VBCXE/X55W3NNwuWH/hHQ0pHdPL+oKA7 +G1L3WG1OCInVmaJboLku0OjPCxJ/n2TAACgo1Q2XzZ5rW6tvssFueIrR2/9HfJ9M +1b8QHhnO305Z4EQkMIFuOI64wH3LEanjoLmK2D9cY3Bb4E9ddrstErUT9PO8RA5U +iFUNqycfGB/68eTVADtMjsJDdMyLFix4LLPaddjD8KdA218V/rZ+nIBtAMy2bbIs +RPbmyXz5LpNeIsGOokvpCe3iLT85o8UEvuJt0OB9H8HikX9WiHgub6MIYledqWz6 +H/cfLsJjeKap/RZnOTs49rwyvrTkGhS+dIUVlP67P6i1SEJS/HRnYO91cXoRJk0A +E8aHuUPpAVTWEKF6w6PtqFbbr9oeKQIDAQABAoIB/xz6RP543cqf/HBnUn3ZXxaB ++/fdWBhjkYnlkypvZzOnVpvevzrqYOijddePk1ewR3oiT5x2RYdfZ9wPGBi1hjZj +MoGuXS/kktDEsxA+CvIgopOrouq+zS60DKfQ2kC1QRhysWfCmiIXOacFoTkP3QiO +48ydu6lqcnL7klJzsx86x6f4ESkjzvKrAQyIvBtBvuqnPVK/jVkQrePZ+/GEHzeL +YSLUR7HWDTUwNgaFr/e98nGmG8vlJ5cENyYEvkdRnT5yy6mN56qkbYS4hFyhRfFi +VDa9/0xnM2YEmQDXMf/xIyQHpqfqi12bKoXF0XYuIBRrvhF3RAGU64Z05aeE9xQQ +g0NufzIJjONT2S/Dpr23j/vx0QLtFz9TRR5trxpmgD2BQa5xDvp//AsO0SwhM53x +a8BAnKiacbZSzG/r/OztBqXsjWpdUW70Hf0myESxkhXWZLIU/DHiqc80azcPtUmT +MRWXRPy/nUgZkiS7fgjPiLVNgGtU+XaoLpLz+TgPb4YCVq59xirpnSU5fwf/WyGd +fb5jMt6Jl8nxoHfJPWSyyB9YHxY6Xzcq7lhfWySuS3jRun33yyB4IiIGM1c/rM3y +bO7XvjKzyS5XA6lyZUyADrtvl5k+ZYmqjA/ffBxdzzlvL1TFKP0OB06F1mj46nNF +8HAGvlI7alqKtbOcj2ECggEBAP3F+/e45dkZM6lnqb1nq0WScwtS6e0gV/B5iSJ9 +eNeQZ9EnjM611IfutR0dTFzaiKjb3Yd5Ek43oTZhIuIHihOiuUYyLHQTeVfhFOcI +kitlTqHfWwRGePiWfUgjk+Zfa/KmJ5ev0FgHLzRdZRF41km6eEN7Homz0bNXn84b +a96qqv7J8Gmm17IxMUs8bb02YeAJM7of9jA9k9jwIrBW+XIKNgCnHGxStl7Nn2YN +3qr0Aj5e0hWHaGcuAxsuxm/m3uBGasUWZyChouPqX83XqJR0dHdUDll4rfOQ2ica +FY3b2LcCbQY3AwtVQcQDBnEnLgaidtW5Ivi2qqTKiwxtT0kCggEBAPw0+7oWEvh0 +gotm+pnF7A2asT+GI1mnsXJFuXqkJSpILoGkz2zYay+UW5bcYmsmrPzZyrakb51n +Gn2KcujnmWy+Sk2o1rc87Rc+qVE74YgoCQuZK2ydCrD9CdyCh7oSB/FG9BFRXu44 +mpxedxrr1jGetjJlniTOkOpCrKGtLQadzgbe0eeGBBF0DiLXARRRJTvP9CztQH7g +wj/E5SpBfZoyYwFfMrubSpOExBKNlv/wNY64WBaS3W3ahVE5wHzHI+i+kSAXtWkT +iuR42SJSF7yydfav1O167Lr/y+FBfkNzo08zsCUzwBiAdusIR8yZbQ0eQJnyOukh +5d7QZUII9+ECggEBAOzSIRSJWP3jVeHGWpHlt+CCDZhItQLUBxzj3kTwgJ/yI9/8 +n5ub9g0wh5X27HdOfO/P1okBREL4CRr9RRdX39P5LBtE4VUlgzyuUNpVlkqnDN1k +2cRAm82oapuyj+gRrmRQCGy25p/vfG7KpXHLqXY+bNLUh6gLxisuH3SxBFZUQKTr +AM8nouyomY7TgrlrkaUIEVylTRKxtFJjrouPbtOskb7ENHMmMQiBrToIwX4Znipk +RHtQ1O5M8xsf6JEvC1iSfjsUcAL0tFUrOGKY2bpIfxOIvqdiRjshN2P8JJcwzanj +uqhtGAswceIgzJc17+7DGFDUp70gglisp1xeefkCggEBANYyeYh1ru7spOKYN0Xa +Xry/IMJ+vg8q6P3QUdLjDd13KGha/P/IXmAudAsQaVXvpwOoRQ4RYeog4tK0fxtn +d1pv0tNaDeHaENKpGUwwuz7UIbqD/+ljBu2COpnZEkTpg21bgXYj0ago0sbzQ9zN +Z0EFNmBfBYzlExaiQdOeLJtt8sjK/SLRIytfkZHtYLFMqX6/AvYVGa2oXdGi+66D +qJUJLiTAIWpMXW4kWBIZxqDf9dycm9OwL/dYm9l8XwqaZtkI4GCNQjlXq6KXMHKB +nj9Yoe89Lm3y9JNtJE7PPNk0oQJnN7ag2Qj2MgkzIyeVNpTpmJwmqfnOHFi8TQNk +coECggEAV0Khwzyev9BcKVblFeeGwl+x2q/cULjQJoHE3fVv2XJtYgKXj2s1dPhJ +I9uLIp/ezC9iKK3visqxm259oD/EbifpWGAaO5/BrAlt/belxgfXwL5CccN6Gk93 +SiuBTXovUSIMCrfHPxV9DamQxWPchq4y5MaNPYgg1xD/XiIT7i/AQHP4nfnLEPUk +7dEuEllyJ3ULS7lnScT3UHl7clOOPR8vOw9zQL7syRHu/ga0vDhDElPQnpoFMHn3 +UKX5/iR4Kwuqy8dzQkzAXkcHrnXuyQCfAJx3dOhOaNzS6WZXrINcXv/Oupei9gzH +IRkSoTZbdQxjMxNMiM/WTV5KpX82iA== +-----END PRIVATE KEY----- diff --git a/gateway/test-run/e2e/certs/gateway-rpc.cert b/gateway/test-run/e2e/certs/gateway-rpc.cert new file mode 100644 index 00000000..bebfa090 --- /dev/null +++ b/gateway/test-run/e2e/certs/gateway-rpc.cert @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIEoDCCAoigAwIBAgIUQ7gdJo6XEeikpmphfBX0h/w0FAowDQYJKoZIhvcNAQEL +BQAwNjEYMBYGA1UEAwwPR2F0ZXdheSBUZXN0IENBMQ0wCwYDVQQKDARUZXN0MQsw +CQYDVQQGEwJVUzAeFw0yNjAxMTgwMTQ4MTJaFw0yNzAxMTgwMTQ4MTJaMB0xGzAZ +BgNVBAMMEmdhdGV3YXkudGVzdC5sb2NhbDCCASIwDQYJKoZIhvcNAQEBBQADggEP +ADCCAQoCggEBAKX00TJvL88TCjmlWx3Z81jP5uKXhgm1SRhyZXmYK27pLCHUF4wz +FfOAC0N7J6ljbzv5kToaicOeb67Ghy/0Y+XrnRMYyAVSlO5dYSFrg3yUcd+8UyIt +AQC+Ni5GLj+BNh08MNjWbHshpi7D21N/XDI3+PNpxL1bG94QqBhM7esJq1bVGxQd +a2geYDj7Zw0yPYQOkRShHoB811Csd0f28MtVdIlEgAx/7CrNX1K4ywNm4vosiLxY +8us8j1PMJ4Qjk/KQDHAq3l43pw+xBUE459NHIg21I493vDqEf71BPLnd/szsy5Vg +XaliNpJf+7gjS20CdBhzgBDIXeEwuX1iSVECAwEAAaOBvjCBuzAJBgNVHRMEAjAA +MAsGA1UdDwQEAwIF4DBhBgNVHREEWjBYghJnYXRld2F5LnRlc3QubG9jYWyCCWdh +dGV3YXktMYIJZ2F0ZXdheS0ygglnYXRld2F5LTOCCWxvY2FsaG9zdIcEfwAAAYcE +rB4AFYcErB4AFocErB4AFzAdBgNVHQ4EFgQUxyKOvTDhXPMVw9nomGvphXRoIHUw +HwYDVR0jBBgwFoAUqQfwkQvNdkWAmmlLl8VeH1xqXP8wDQYJKoZIhvcNAQELBQAD +ggIBAFcVamp+SfBa3hjdnSZdKGoq34PDPtfI+XhtJnxCDeUX63y8Ikb+Q/hmkYRq +cPrjWZVlnpcCVGdnNQMpsY5h37o0+IUDLIkL1Gwkimk9q74Ee31UT3vUwMHb44Ed +pR5je1gqp3QszLsLu3wQs1AYISJUCNJLsK8JxvNTlfSQNvljWqfFo1qF7Agzl66X +v87OhcCUtICoc1wGE1w78diHyefQAanCzLUBEfr7lKtSiXCcis7iCOfZ2NJ/faQu +V26/mM6RXUB/KPUhliX5RSBAzdjVW4oo/dTzZ1BCEBhz5+RNXw4RjBsqohW7JQQv +j4UbytVvctHwVicuePDTekMH0LNE0+OtpCglgn5VAlsRvNdHPRmVPBjffEE3ECUx +HiZ9h7Re4GavqnNaS4dZbgMdR4hprT/cN/XHpPtd42I6v3yzUvdgXwGPyNnbH0Ff +iBtGGunigRcS69uv7sAh6WyWp97jsd6f8QH5ktJyrzqvlXV1jhxx43eed34C46dQ +J8RS+Cf4Sd7r7XzgAb9BTOfM6ETSF/7rHK4sDbC8xHlpO42R/U7biC1cUe5rZhz6 +lDDcx5B6xA0dayaG49jzYm7FSRFb5CfRh2ZV/2o7LGGaO65JyZgkGx1e+COtkinX +ijtYWQ5II6sk5KbnnFr0mJg7LYehYxV4PSmPSTX4iPAJnqgM +-----END CERTIFICATE----- diff --git a/gateway/test-run/e2e/certs/gateway-rpc.key b/gateway/test-run/e2e/certs/gateway-rpc.key new file mode 100644 index 00000000..b5660a61 --- /dev/null +++ b/gateway/test-run/e2e/certs/gateway-rpc.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCl9NEyby/PEwo5 +pVsd2fNYz+bil4YJtUkYcmV5mCtu6Swh1BeMMxXzgAtDeyepY287+ZE6GonDnm+u +xocv9GPl650TGMgFUpTuXWEha4N8lHHfvFMiLQEAvjYuRi4/gTYdPDDY1mx7IaYu +w9tTf1wyN/jzacS9WxveEKgYTO3rCatW1RsUHWtoHmA4+2cNMj2EDpEUoR6AfNdQ +rHdH9vDLVXSJRIAMf+wqzV9SuMsDZuL6LIi8WPLrPI9TzCeEI5PykAxwKt5eN6cP +sQVBOOfTRyINtSOPd7w6hH+9QTy53f7M7MuVYF2pYjaSX/u4I0ttAnQYc4AQyF3h +MLl9YklRAgMBAAECggEAPQwt2EumXpo2bLYzKmv+ZHE2EayDlhal6ORMB8q+T3Je +1aLbdqtkK8qyWgR3tovpYzqO/by9aMRjePt2x2EzTmS5x0iaa7rRJk4baNvP5ogE +y7TPMAc2EzvlWmheouW5Lk/x+BIIndLm+tT5XWHAXIjSf1gtEyrsuWePLkE+U/MG ++xZ6pfd5pkUfO7XKqPnUVdKPIpwvl/Wzn1vsz5UOgoVg95VVAhA/n0aMJWHBNbTP +3sHvjK1iSMPWNfjhXmC5+3mqd7lJK41dRHQ0U0vLnCCzw8RXR6BOdY5HBAZDlVg/ +eJ4nn+TVUZxesEm+x7KBKiyfGLxVwgJtUEU3GzqkMQKBgQDOkzMie/+QnphLs0R7 +sIHaxqMwPaPcIddqnOn2j+QRhxqk9idAIfz1uX7yXn0dNT8JdnpJYTGoWvaDsJD8 +RMTautOZN6DZ0ZuEWWfCaeJo9QzJAo2QDeV22aLGnkP5DmU5JwtmZyBADaYy3Ybr +BJF4bTYK6oj0OEjoHf3B+xf83wKBgQDNqbUGLn/6KHg3mShNynuo3ubjDq6SON8n +PlHbHLiMBNuAThwN2AV42PdQfwXgnJuUKA6xNIiHXZ7Jp2dD+eZ4sfwD65N69k34 +z9c/nkq90U/iUcpPYDCj3eEfVc/G9z/5Px1pG/BKkGAjAQB+D4P30XsvJLR9rSu/ +eKuPvIxPzwKBgDScUrKep/j6G0l0T6W8z2Wbn2Yi3L+sssNJUWDlRq2cHhITSu3P +ejBO3OD3ZZ/xtqs/TGex5Ea/W/cwGczV6tjWKhvkigfPlW8Aoidmdi5K8sWi69Db +aSx6wzUYi7E7lFYY9pNPAmytzT05JCpo0G++SLxA/T5Ns2vCb6VewL47AoGBAI8Q +T7nWHPZSspXSd8PtZ6ooLKqkKvHSmAD/jAeE6ieUtXCCZWeH7v6Kxzd6tQbzShJ8 +7wN8DMFFcdDLH72cmCM7hJjhhf0SW1kKk6xQm6OBeDVyOe6PdiZ3kUOv+NJqalki ++32Djtr/pbCT4NjQSDfaw/seaGPIU9dkxMs/GMfnAoGAV7VVmTNIihNt8IUwRXVZ +VsS3sqrMH3nL8aL5O10LlE1861czcUMl1xBLjm6ErWmSG30SzhrFW6fpeptlECmn +CSUIbYS1ONXHW8UBR+mHYBOXE75fzz/9zJ3BbN+/JPOy0XNtRBJSuPC9Wt0CApu2 +mchnIUsxbtRi5ZmdhObFrz4= +-----END PRIVATE KEY----- diff --git a/gateway/test-run/e2e/configs/gateway-1.toml b/gateway/test-run/e2e/configs/gateway-1.toml new file mode 100644 index 00000000..30f088e4 --- /dev/null +++ b/gateway/test-run/e2e/configs/gateway-1.toml @@ -0,0 +1,53 @@ +# Gateway Node 1 configuration for E2E testing +log_level = "debug" +address = "0.0.0.0" +port = 9012 + +[tls] +key = "/etc/gateway/certs/gateway-rpc.key" +certs = "/etc/gateway/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "/etc/gateway/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "" +rpc_domain = "gateway.test.local" + +[core.admin] +enabled = true +port = 9016 +address = "0.0.0.0" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = true +port = 9015 +address = "0.0.0.0" + +[core.sync] +enabled = true +interval = "5s" +timeout = "10s" +my_url = "https://gateway-1:9012" +bootnode = "https://gateway-2:9012" +node_id = 1 +data_dir = "/var/lib/gateway/wavekv" + +[core.wg] +private_key = "SEcoI37oGWynhukxXo5Mi8/8zZBU6abg6T1TOJRMj1Y=" +public_key = "xc+7qkdeNFfl4g4xirGGGXHMc0cABuE5IHaLeCASVWM=" +listen_port = 9013 +ip = "10.0.41.1/24" +reserved_net = ["10.0.41.1/31"] +client_ip_range = "10.0.41.1/24" +config_path = "/var/lib/gateway/wg.conf" +interface = "wg-test1" +endpoint = "gateway-1:9013" + +[core.proxy] +listen_addr = "0.0.0.0" +listen_port = 9014 +tappd_port = 8090 +external_port = 9014 diff --git a/gateway/test-run/e2e/configs/gateway-2.toml b/gateway/test-run/e2e/configs/gateway-2.toml new file mode 100644 index 00000000..bd361b04 --- /dev/null +++ b/gateway/test-run/e2e/configs/gateway-2.toml @@ -0,0 +1,53 @@ +# Gateway Node 2 configuration for E2E testing +log_level = "debug" +address = "0.0.0.0" +port = 9012 + +[tls] +key = "/etc/gateway/certs/gateway-rpc.key" +certs = "/etc/gateway/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "/etc/gateway/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "" +rpc_domain = "gateway.test.local" + +[core.admin] +enabled = true +port = 9016 +address = "0.0.0.0" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = true +port = 9015 +address = "0.0.0.0" + +[core.sync] +enabled = true +interval = "5s" +timeout = "10s" +my_url = "https://gateway-2:9012" +bootnode = "https://gateway-1:9012" +node_id = 2 +data_dir = "/var/lib/gateway/wavekv" + +[core.wg] +private_key = "SEcoI37oGWynhukxXo5Mi8/8zZBU6abg6T1TOJRMj1Y=" +public_key = "xc+7qkdeNFfl4g4xirGGGXHMc0cABuE5IHaLeCASVWM=" +listen_port = 9013 +ip = "10.0.42.1/24" +reserved_net = ["10.0.42.1/31"] +client_ip_range = "10.0.42.1/24" +config_path = "/var/lib/gateway/wg.conf" +interface = "wg-test2" +endpoint = "gateway-2:9013" + +[core.proxy] +listen_addr = "0.0.0.0" +listen_port = 9014 +tappd_port = 8090 +external_port = 9014 diff --git a/gateway/test-run/e2e/configs/gateway-3.toml b/gateway/test-run/e2e/configs/gateway-3.toml new file mode 100644 index 00000000..e2e19c9f --- /dev/null +++ b/gateway/test-run/e2e/configs/gateway-3.toml @@ -0,0 +1,53 @@ +# Gateway Node 3 configuration for E2E testing +log_level = "debug" +address = "0.0.0.0" +port = 9012 + +[tls] +key = "/etc/gateway/certs/gateway-rpc.key" +certs = "/etc/gateway/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "/etc/gateway/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "" +rpc_domain = "gateway.test.local" + +[core.admin] +enabled = true +port = 9016 +address = "0.0.0.0" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = true +port = 9015 +address = "0.0.0.0" + +[core.sync] +enabled = true +interval = "5s" +timeout = "10s" +my_url = "https://gateway-3:9012" +bootnode = "https://gateway-1:9012" +node_id = 3 +data_dir = "/var/lib/gateway/wavekv" + +[core.wg] +private_key = "SEcoI37oGWynhukxXo5Mi8/8zZBU6abg6T1TOJRMj1Y=" +public_key = "xc+7qkdeNFfl4g4xirGGGXHMc0cABuE5IHaLeCASVWM=" +listen_port = 9013 +ip = "10.0.43.1/24" +reserved_net = ["10.0.43.1/31"] +client_ip_range = "10.0.43.1/24" +config_path = "/var/lib/gateway/wg.conf" +interface = "wg-test3" +endpoint = "gateway-3:9013" + +[core.proxy] +listen_addr = "0.0.0.0" +listen_port = 9014 +tappd_port = 8090 +external_port = 9014 diff --git a/gateway/test-run/e2e/docker-compose.yml b/gateway/test-run/e2e/docker-compose.yml new file mode 100644 index 00000000..48421456 --- /dev/null +++ b/gateway/test-run/e2e/docker-compose.yml @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# E2E test environment for dstack-gateway certbot functionality +# Uses mock services: Pebble (ACME) + mock-cf-dns-api (Cloudflare DNS) + dstack-simulator + +networks: + certbot-test: + driver: bridge + ipam: + config: + - subnet: 172.30.0.0/24 + +volumes: + pebble-certs: + +services: + # ==================== Mock Services ==================== + + # Mock Cloudflare DNS API + mock-cf-dns-api: + image: kvin/mock-cf-dns-api:latest + container_name: mock-cf-dns-api + networks: + certbot-test: + ipv4_address: 172.30.0.10 + ports: + - "18080:8080" + environment: + - PORT=8080 + - DEBUG=true + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8080/health')"] + interval: 5s + timeout: 3s + retries: 5 + + # Pebble - Let's Encrypt test server (custom build with HTTP support) + pebble: + image: kvin/pebble:latest + container_name: pebble + command: ["-http", "-dnsserver", "172.30.0.10:53"] + networks: + certbot-test: + ipv4_address: 172.30.0.11 + ports: + - "14000:14000" # ACME directory + - "15000:15000" # Management interface + environment: + - PEBBLE_VA_NOSLEEP=1 + - PEBBLE_VA_ALWAYS_VALID=1 # Skip actual DNS validation for testing + healthcheck: + test: ["CMD", "wget", "-q", "--spider", "http://localhost:14000/dir"] + interval: 5s + timeout: 3s + retries: 10 + + # dstack-simulator - provides mock dstack agent API + dstack-simulator: + image: ${SIMULATOR_IMAGE:-dstack-simulator:test} + container_name: dstack-simulator + networks: + certbot-test: + ipv4_address: 172.30.0.5 + ports: + - "18090:8090" # HTTP API + volumes: + - ../../../sdk/simulator:/sdk:ro + - ./simulator/dstack.toml:/app/dstack.toml:ro + working_dir: /app + command: ["/sdk/dstack-simulator", "-c", "/app/dstack.toml"] + environment: + - RUST_LOG=info,dstack_guest_agent=debug + healthcheck: + test: ["CMD", "curl", "-sf", "http://localhost:8090/Info"] + interval: 5s + timeout: 3s + retries: 10 + + # ==================== Gateway Cluster ==================== + + # Gateway Node 1 - Will be the first to request certificate + gateway-1: + image: ${GATEWAY_IMAGE:-dstack-gateway:test} + container_name: gateway-1 + networks: + certbot-test: + ipv4_address: 172.30.0.21 + ports: + - "19012:9012" # RPC + - "19014:9014" # Proxy + - "19015:9015" # Debug + - "19016:9016" # Admin + volumes: + - ./configs/gateway-1.toml:/etc/gateway/gateway.toml:ro + - ./certs:/etc/gateway/certs:ro + tmpfs: + - /var/lib/gateway + environment: + - RUST_LOG=info,dstack_gateway=debug,certbot=debug + - DSTACK_AGENT_ADDRESS=http://172.30.0.5:8090 + depends_on: + mock-cf-dns-api: + condition: service_healthy + pebble: + condition: service_healthy + dstack-simulator: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9015/health"] + interval: 5s + timeout: 3s + retries: 10 + start_period: 30s + cap_add: + - NET_ADMIN + extra_hosts: + # Pebble returns localhost in directory URLs, so we need to resolve localhost to pebble's IP + - "localhost:172.30.0.11" + + # Gateway Node 2 - Will sync certificate from Node 1 + gateway-2: + image: ${GATEWAY_IMAGE:-dstack-gateway:test} + container_name: gateway-2 + networks: + certbot-test: + ipv4_address: 172.30.0.22 + ports: + - "19022:9012" # RPC + - "19024:9014" # Proxy + - "19025:9015" # Debug + - "19026:9016" # Admin + volumes: + - ./configs/gateway-2.toml:/etc/gateway/gateway.toml:ro + - ./certs:/etc/gateway/certs:ro + tmpfs: + - /var/lib/gateway + environment: + - RUST_LOG=info,dstack_gateway=debug,certbot=debug + - DSTACK_AGENT_ADDRESS=http://172.30.0.5:8090 + depends_on: + gateway-1: + condition: service_healthy + dstack-simulator: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9015/health"] + interval: 5s + timeout: 3s + retries: 10 + start_period: 30s + cap_add: + - NET_ADMIN + + # Gateway Node 3 - Will sync certificate from cluster + gateway-3: + image: ${GATEWAY_IMAGE:-dstack-gateway:test} + container_name: gateway-3 + networks: + certbot-test: + ipv4_address: 172.30.0.23 + ports: + - "19032:9012" # RPC + - "19034:9014" # Proxy + - "19035:9015" # Debug + - "19036:9016" # Admin + volumes: + - ./configs/gateway-3.toml:/etc/gateway/gateway.toml:ro + - ./certs:/etc/gateway/certs:ro + tmpfs: + - /var/lib/gateway + environment: + - RUST_LOG=info,dstack_gateway=debug,certbot=debug + - DSTACK_AGENT_ADDRESS=http://172.30.0.5:8090 + depends_on: + gateway-2: + condition: service_healthy + dstack-simulator: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9015/health"] + interval: 5s + timeout: 3s + retries: 10 + start_period: 30s + cap_add: + - NET_ADMIN + + # ==================== Test Runner ==================== + + test-runner: + image: alpine:latest + container_name: test-runner + networks: + certbot-test: + ipv4_address: 172.30.0.100 + volumes: + - ./test.sh:/test.sh:ro + entrypoint: ["/bin/sh", "-c", "apk add --no-cache curl openssl jq && /bin/sh /test.sh"] + depends_on: + gateway-1: + condition: service_healthy + gateway-2: + condition: service_healthy + gateway-3: + condition: service_healthy diff --git a/gateway/test-run/e2e/pebble-config.json b/gateway/test-run/e2e/pebble-config.json new file mode 100644 index 00000000..41411088 --- /dev/null +++ b/gateway/test-run/e2e/pebble-config.json @@ -0,0 +1,18 @@ +{ + "pebble": { + "listenAddress": "0.0.0.0:14000", + "managementListenAddress": "0.0.0.0:15000", + "certificate": "/etc/pebble/certs/localhost/cert.pem", + "privateKey": "/etc/pebble/certs/localhost/key.pem", + "httpPort": 5002, + "tlsPort": 5001, + "ocspResponderURL": "", + "externalAccountBindingRequired": false, + "domainBlocklist": [], + "retryAfter": { + "authz": 3, + "order": 5 + }, + "certificateValidityPeriod": 157680000 + } +} diff --git a/gateway/test-run/e2e/run-e2e.sh b/gateway/test-run/e2e/run-e2e.sh new file mode 100755 index 00000000..206ff895 --- /dev/null +++ b/gateway/test-run/e2e/run-e2e.sh @@ -0,0 +1,176 @@ +#!/bin/bash +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# E2E test runner for dstack-gateway +# Builds gateway and simulator images, then runs the test suite + +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +log_info() { echo -e "${BLUE}[INFO]${NC} $1"; } +log_success() { echo -e "${GREEN}[OK]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +# Parse arguments +SKIP_BUILD=false +SKIP_SIMULATOR_BUILD=false +KEEP_RUNNING=false +CLEAN=false + +while [[ $# -gt 0 ]]; do + case $1 in + --skip-build) + SKIP_BUILD=true + shift + ;; + --skip-simulator-build) + SKIP_SIMULATOR_BUILD=true + shift + ;; + --keep-running) + KEEP_RUNNING=true + shift + ;; + --clean) + CLEAN=true + shift + ;; + -h|--help) + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " --skip-build Skip building gateway image" + echo " --skip-simulator-build Skip building simulator" + echo " --keep-running Keep containers running after test" + echo " --clean Clean up containers and images" + echo " -h, --help Show this help" + exit 0 + ;; + *) + log_error "Unknown option: $1" + exit 1 + ;; + esac +done + +cd "$SCRIPT_DIR" + +# Clean up if requested +if $CLEAN; then + log_info "Cleaning up..." + docker compose down -v --remove-orphans 2>/dev/null || true + docker rmi dstack-gateway:test dstack-simulator:test 2>/dev/null || true + log_success "Cleanup complete" + exit 0 +fi + +# Step 1: Build simulator if needed (musl static build) +if ! $SKIP_SIMULATOR_BUILD; then + log_info "Building dstack-simulator (musl static)..." + cd "$REPO_ROOT" + cargo build --release -p dstack-guest-agent --target x86_64-unknown-linux-musl + + # Copy binary to simulator directory + cp target/x86_64-unknown-linux-musl/release/dstack-guest-agent sdk/simulator/ + ln -sf dstack-guest-agent sdk/simulator/dstack-simulator + log_success "Simulator built: sdk/simulator/dstack-simulator" +fi + +# Create minimal simulator docker image (alpine for musl) +log_info "Creating simulator docker image..." +cat > /tmp/Dockerfile.simulator << 'EOF' +FROM alpine:latest +RUN apk add --no-cache curl ca-certificates +WORKDIR /app +EOF +docker build -t dstack-simulator:test -f /tmp/Dockerfile.simulator . +rm /tmp/Dockerfile.simulator +log_success "Simulator image created: dstack-simulator:test" + +# Step 2: Build gateway if needed (musl static build) +if ! $SKIP_BUILD; then + log_info "Building dstack-gateway (musl static)..." + cd "$REPO_ROOT" + cargo build --release -p dstack-gateway --target x86_64-unknown-linux-musl + + # Copy binary to e2e directory + cp target/x86_64-unknown-linux-musl/release/dstack-gateway "$SCRIPT_DIR/" + log_success "Gateway built: $SCRIPT_DIR/dstack-gateway" +fi + +# Step 3: Create gateway docker image (alpine for musl) +log_info "Creating gateway docker image..." +cd "$SCRIPT_DIR" + +cat > Dockerfile.gateway << 'EOF' +FROM alpine:latest + +RUN apk add --no-cache \ + wireguard-tools \ + iproute2 \ + curl \ + ca-certificates + +COPY dstack-gateway /usr/local/bin/dstack-gateway + +RUN chmod +x /usr/local/bin/dstack-gateway + +ENTRYPOINT ["/usr/local/bin/dstack-gateway", "-c", "/etc/gateway/gateway.toml"] +EOF + +docker build -t dstack-gateway:test -f Dockerfile.gateway . +rm Dockerfile.gateway +log_success "Gateway image created: dstack-gateway:test" + +# Step 4: Generate certificates if not exist +if [ ! -f "certs/gateway.crt" ]; then + log_info "Generating test certificates..." + ./setup.sh + log_success "Certificates generated" +fi + +# Step 5: Run docker compose +log_info "Starting e2e test environment..." +docker compose down -v --remove-orphans 2>/dev/null || true + +export GATEWAY_IMAGE=dstack-gateway:test +export SIMULATOR_IMAGE=dstack-simulator:test + +docker compose up -d mock-cf-dns-api pebble dstack-simulator +log_info "Waiting for mock services to be healthy..." +sleep 5 + +docker compose up -d gateway-1 gateway-2 gateway-3 +log_info "Waiting for gateway cluster to be healthy..." +sleep 10 + +# Step 6: Run tests +log_info "Running tests..." +docker compose run --rm test-runner +TEST_EXIT_CODE=$? + +# Step 7: Cleanup +if ! $KEEP_RUNNING; then + log_info "Stopping containers..." + docker compose down -v --remove-orphans +fi + +if [ $TEST_EXIT_CODE -eq 0 ]; then + log_success "All tests passed!" +else + log_error "Tests failed with exit code: $TEST_EXIT_CODE" +fi + +exit $TEST_EXIT_CODE diff --git a/gateway/test-run/e2e/setup.sh b/gateway/test-run/e2e/setup.sh new file mode 100755 index 00000000..fee0739e --- /dev/null +++ b/gateway/test-run/e2e/setup.sh @@ -0,0 +1,188 @@ +#!/bin/bash +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# Setup script for E2E test environment + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# Colors +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } + +# Generate self-signed certificates for RPC TLS +generate_certs() { + local certs_dir="$SCRIPT_DIR/certs" + + if [[ -f "$certs_dir/gateway-rpc.cert" ]]; then + log_info "Certificates already exist, skipping generation" + return 0 + fi + + log_info "Generating self-signed certificates for RPC TLS..." + mkdir -p "$certs_dir" + + # Generate CA key and certificate + openssl genrsa -out "$certs_dir/gateway-ca.key" 4096 + openssl req -x509 -new -nodes \ + -key "$certs_dir/gateway-ca.key" \ + -sha256 -days 3650 \ + -out "$certs_dir/gateway-ca.cert" \ + -subj "/CN=Gateway Test CA/O=Test/C=US" + + # Generate server key + openssl genrsa -out "$certs_dir/gateway-rpc.key" 2048 + + # Generate server CSR with SANs + cat > "$certs_dir/server.cnf" << EOF +[req] +distinguished_name = req_distinguished_name +req_extensions = v3_req +prompt = no + +[req_distinguished_name] +CN = gateway.test.local + +[v3_req] +basicConstraints = CA:FALSE +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +subjectAltName = @alt_names + +[alt_names] +DNS.1 = gateway.test.local +DNS.2 = gateway-1 +DNS.3 = gateway-2 +DNS.4 = gateway-3 +DNS.5 = localhost +IP.1 = 127.0.0.1 +IP.2 = 172.30.0.21 +IP.3 = 172.30.0.22 +IP.4 = 172.30.0.23 +EOF + + openssl req -new \ + -key "$certs_dir/gateway-rpc.key" \ + -out "$certs_dir/gateway-rpc.csr" \ + -config "$certs_dir/server.cnf" + + # Sign the certificate + openssl x509 -req \ + -in "$certs_dir/gateway-rpc.csr" \ + -CA "$certs_dir/gateway-ca.cert" \ + -CAkey "$certs_dir/gateway-ca.key" \ + -CAcreateserial \ + -out "$certs_dir/gateway-rpc.cert" \ + -days 365 \ + -sha256 \ + -extensions v3_req \ + -extfile "$certs_dir/server.cnf" + + # Clean up + rm -f "$certs_dir/gateway-rpc.csr" "$certs_dir/server.cnf" "$certs_dir/gateway-ca.srl" + + log_info "Certificates generated in $certs_dir" +} + +# Build simulator (musl static) +build_simulator() { + local repo_dir="$SCRIPT_DIR/../../.." + + if [[ -n "$SKIP_SIMULATOR_BUILD" ]]; then + log_info "Skipping simulator build (SKIP_SIMULATOR_BUILD is set)" + return 0 + fi + + log_info "Building dstack-simulator (musl static)..." + cd "$repo_dir" + cargo build --release -p dstack-guest-agent --target x86_64-unknown-linux-musl + + # Copy binary to simulator directory + cp target/x86_64-unknown-linux-musl/release/dstack-guest-agent sdk/simulator/ + ln -sf dstack-guest-agent sdk/simulator/dstack-simulator + + log_info "Simulator built: sdk/simulator/dstack-simulator" + + # Create simulator docker image (alpine for musl) + log_info "Building simulator Docker image..." + cat > /tmp/Dockerfile.simulator << 'EOF' +FROM alpine:latest +RUN apk add --no-cache curl ca-certificates +WORKDIR /app +EOF + docker build -t dstack-simulator:test -f /tmp/Dockerfile.simulator . + rm /tmp/Dockerfile.simulator + + log_info "Simulator Docker image built: dstack-simulator:test" +} + +# Build gateway Docker image (musl static) +build_gateway_image() { + local repo_dir="$SCRIPT_DIR/../../.." + + if [[ -n "$SKIP_BUILD" ]]; then + log_info "Skipping gateway build (SKIP_BUILD is set)" + return 0 + fi + + log_info "Building dstack-gateway (musl static)..." + cd "$repo_dir" + cargo build --release -p dstack-gateway --target x86_64-unknown-linux-musl + + log_info "Building Docker image..." + + # Create a minimal Dockerfile for testing (alpine for musl) + cat > "$SCRIPT_DIR/Dockerfile.gateway" << 'EOF' +FROM alpine:latest + +RUN apk add --no-cache \ + ca-certificates \ + curl \ + iproute2 \ + wireguard-tools + +COPY dstack-gateway /usr/local/bin/dstack-gateway + +RUN chmod +x /usr/local/bin/dstack-gateway + +WORKDIR /etc/gateway + +ENTRYPOINT ["/usr/local/bin/dstack-gateway", "-c", "/etc/gateway/gateway.toml"] +EOF + + cp "$repo_dir/target/x86_64-unknown-linux-musl/release/dstack-gateway" "$SCRIPT_DIR/" + docker build -t dstack-gateway:test -f "$SCRIPT_DIR/Dockerfile.gateway" "$SCRIPT_DIR" + rm -f "$SCRIPT_DIR/dstack-gateway" "$SCRIPT_DIR/Dockerfile.gateway" + + log_info "Gateway Docker image built: dstack-gateway:test" +} + +# Main +main() { + log_info "Setting up E2E test environment..." + + generate_certs + build_simulator + build_gateway_image + + log_info "" + log_info "Setup complete! Run the tests with:" + log_info " ./run-e2e.sh" + log_info "" + log_info "Or run individual services:" + log_info " docker compose up -d mock-cf-dns-api pebble dstack-simulator" + log_info " docker compose up gateway-1 gateway-2 gateway-3" + log_info "" + log_info "View mock CF DNS API: http://localhost:18080" + log_info "View Pebble mgmt: https://localhost:15000" + log_info "View Simulator: http://localhost:18090" +} + +main "$@" diff --git a/gateway/test-run/e2e/simulator/dstack.toml b/gateway/test-run/e2e/simulator/dstack.toml new file mode 100644 index 00000000..eabb6084 --- /dev/null +++ b/gateway/test-run/e2e/simulator/dstack.toml @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# dstack-simulator configuration for e2e testing (HTTP mode for containers) + +[default] +workers = 4 +max_blocking = 32 +ident = "dstack Simulator (E2E)" +temp_dir = "/tmp" +keep_alive = 10 +log_level = "debug" + +[default.core] +# Use files from sdk/simulator directory (mounted as /sdk) +keys_file = "/sdk/appkeys.json" +compose_file = "/sdk/app-compose.json" +sys_config_file = "/sdk/sys-config.json" + +[default.core.simulator] +enabled = true +attestation_file = "/sdk/attestation.bin" + +# HTTP endpoints for container access +# All interfaces must be defined for the simulator to start + +[internal-v0] +address = "0.0.0.0" +port = 8091 +reuse = true + +# Gateway uses /prpc endpoint on internal interface +[internal] +address = "0.0.0.0" +port = 8090 +reuse = true + +[external] +address = "0.0.0.0" +port = 8092 +reuse = true + +[guest-api] +address = "0.0.0.0" +port = 8093 +reuse = true diff --git a/gateway/test-run/e2e/test.sh b/gateway/test-run/e2e/test.sh new file mode 100755 index 00000000..9c1db1a2 --- /dev/null +++ b/gateway/test-run/e2e/test.sh @@ -0,0 +1,332 @@ +#!/bin/sh +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# E2E test script for dstack-gateway certbot functionality +# This script runs inside the test-runner container + +set -e + +# ==================== Configuration ==================== + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' + +# Gateway endpoints +GATEWAY_PROXIES="gateway-1:9014 gateway-2:9014 gateway-3:9014" +GATEWAY_DEBUG_URLS="http://gateway-1:9015 http://gateway-2:9015 http://gateway-3:9015" +GATEWAY_ADMIN="http://gateway-1:9016" + +# External services +MOCK_CF_API="http://mock-cf-dns-api:8080" +PEBBLE_DIR="http://pebble:14000/dir" + +# Certificate domains to test (base domains, certs will be issued for *.domain) +CERT_DOMAINS="test0.local test1.local test2.local" + +# Cloudflare mock settings +CF_API_TOKEN="test-token" +CF_API_URL="http://mock-cf-dns-api:8080/client/v4" +ACME_URL="http://pebble:14000/dir" + +# Test counters +TESTS_PASSED=0 +TESTS_FAILED=0 + +# ==================== Logging ==================== + +log_info() { printf "${BLUE}[INFO]${NC} %s\n" "$1"; } +log_warn() { printf "${YELLOW}[WARN]${NC} %s\n" "$1"; } +log_error() { printf "${RED}[ERROR]${NC} %s\n" "$1"; } +log_success() { printf "${GREEN}[PASS]${NC} %s\n" "$1"; } +log_fail() { printf "${RED}[FAIL]${NC} %s\n" "$1"; } + +log_section() { + printf "\n" + log_info "==========================================" + log_info "$1" + log_info "==========================================" +} + +log_phase() { + printf "\n" + log_info "Phase $1: $2" + log_info "------------------------------------------" +} + +# ==================== Test Utilities ==================== + +# Run a test and record result +run_test() { + local name="$1" + local result="$2" + + if [ "$result" = "0" ]; then + log_success "$name" + TESTS_PASSED=$((TESTS_PASSED + 1)) + else + log_fail "$name" + TESTS_FAILED=$((TESTS_FAILED + 1)) + fi +} + +# Wait for HTTP service to respond +wait_for_service() { + local url="$1" + local name="$2" + local max_wait="${3:-60}" + local waited=0 + + log_info "Waiting for $name..." + while [ $waited -lt $max_wait ]; do + if curl -sf "$url" > /dev/null 2>&1; then + log_info "$name is ready" + return 0 + fi + sleep 2 + waited=$((waited + 2)) + done + + log_error "$name failed to become ready within ${max_wait}s" + return 1 +} + +# ==================== Domain Helpers ==================== + +# Convert base domain to test SNI: test0.local -> gateway.test0.local +# Uses "gateway" as it's a special app_id that proxies to gateway's own endpoints +get_test_sni() { + echo "gateway.${1}" +} + +# Convert base domain to wildcard format for certificate SAN check +get_wildcard_domain() { + echo "*.${1}" +} + +# ==================== Certificate Helpers ==================== + +# Get certificate via openssl s_client +get_cert_pem() { + local host="$1" + local sni="$2" + echo | timeout 5 openssl s_client -connect "$host" -servername "$sni" 2>/dev/null +} + +get_cert_serial() { + get_cert_pem "$1" "$2" | openssl x509 -noout -serial 2>/dev/null | cut -d= -f2 +} + +get_cert_issuer() { + get_cert_pem "$1" "$2" | openssl x509 -noout -issuer 2>/dev/null +} + +get_cert_san() { + get_cert_pem "$1" "$2" | openssl x509 -noout -ext subjectAltName 2>/dev/null +} + +# ==================== Test Functions ==================== + +test_http_health() { + curl -sf "$1" > /dev/null +} + +test_certificate_issued() { + local host="$1" + local sni="$2" + [ -n "$(get_cert_serial "$host" "$sni")" ] +} + +test_certificates_match() { + local sni="$1" + local serial1="" serial2="" serial3="" + local i=1 + + for proxy in $GATEWAY_PROXIES; do + eval "serial${i}=\"\$(get_cert_serial \"\$proxy\" \"\$sni\")\"" + log_info "Gateway $i cert serial ($sni): $(eval echo \$serial$i)" >&2 + i=$((i + 1)) + done + + [ "$serial1" = "$serial2" ] && [ "$serial2" = "$serial3" ] && [ -n "$serial1" ] +} + +test_certificate_from_pebble() { + local sni="$1" + local proxy=$(echo "$GATEWAY_PROXIES" | cut -d' ' -f1) + get_cert_issuer "$proxy" "$sni" | grep -qi "pebble" +} + +test_sni_cert_selection() { + local host="$1" + local sni="$2" + local expected_wildcard="$3" + get_cert_san "$host" "$sni" | grep -q "$expected_wildcard" +} + +test_proxy_tls_health() { + local host="$1" + local gateway_sni="$2" + curl -sf --connect-to "${gateway_sni}:9014:${host}" -k "https://${gateway_sni}:9014/health" > /dev/null 2>&1 +} + +# ==================== Setup ==================== + +setup_certbot_config() { + log_info "Configuring certbot via Admin API..." + + # Set ACME URL + log_info "Setting ACME URL: ${ACME_URL}" + if ! curl -sf -X POST "${GATEWAY_ADMIN}/prpc/Admin.SetCertbotConfig" \ + -H "Content-Type: application/json" \ + -d '{"acme_url": "'"${ACME_URL}"'"}' > /dev/null; then + log_error "Failed to set certbot config" + return 1 + fi + + # Create DNS credential + log_info "Creating DNS credential..." + if ! curl -sf -X POST "${GATEWAY_ADMIN}/prpc/Admin.CreateDnsCredential" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "test-cloudflare", + "provider_type": "cloudflare", + "cf_api_token": "'"${CF_API_TOKEN}"'", + "cf_api_url": "'"${CF_API_URL}"'", + "set_as_default": true, + "dns_txt_ttl": 1, + "max_dns_wait": 0 + }' > /dev/null; then + log_error "Failed to create DNS credential" + return 1 + fi + + # Add domains and trigger renewal + for domain in $CERT_DOMAINS; do + log_info "Adding domain: $domain" + curl -sf -X POST "${GATEWAY_ADMIN}/prpc/Admin.AddZtDomain" \ + -H "Content-Type: application/json" \ + -d '{"domain": "'"${domain}"'"}' > /dev/null || true + + log_info "Triggering renewal for: $domain" + curl -sf -X POST "${GATEWAY_ADMIN}/prpc/Admin.RenewZtDomainCert" \ + -H "Content-Type: application/json" \ + -d '{"domain": "'"${domain}"'", "force": true}' > /dev/null || \ + log_warn "Renewal request failed for $domain (may retry)" + done + + return 0 +} + +# ==================== Main ==================== + +main() { + log_section "dstack-gateway Certbot E2E Test" + + # Phase 1: Mock services + log_phase 1 "Verify mock services" + run_test "Mock CF DNS API health" "$(test_http_health "${MOCK_CF_API}/health"; echo $?)" + run_test "Pebble ACME directory" "$(test_http_health "${PEBBLE_DIR}"; echo $?)" + + # Phase 2: Gateway cluster + log_phase 2 "Verify gateway cluster" + local i=1 + for url in $GATEWAY_DEBUG_URLS; do + run_test "Gateway $i health" "$(test_http_health "${url}/health"; echo $?)" + i=$((i + 1)) + done + + # Phase 3: Configure certbot + log_phase 3 "Configure certbot" + if ! setup_certbot_config; then + log_error "Failed to setup certbot configuration" + fi + + # Phase 4: Certificate issuance + log_phase 4 "Certificate issuance" + local first_domain=$(echo "$CERT_DOMAINS" | cut -d' ' -f1) + local first_sni=$(get_test_sni "$first_domain") + local first_proxy=$(echo "$GATEWAY_PROXIES" | cut -d' ' -f1) + + log_info "Waiting for certificates (up to 120s)..." + local waited=0 + while [ $waited -lt 120 ]; do + if test_certificate_issued "$first_proxy" "$first_sni"; then + log_info "Certificate detected for $first_sni" + break + fi + sleep 5 + waited=$((waited + 5)) + log_info "Waiting... (${waited}s)" + done + + for domain in $CERT_DOMAINS; do + local sni=$(get_test_sni "$domain") + run_test "Certificate issued for $domain" \ + "$(test_certificate_issued "$first_proxy" "$sni"; echo $?)" + done + + log_info "Waiting 20s for cluster sync..." + sleep 20 + + # Phase 5: Certificate consistency + log_phase 5 "Certificate consistency" + for domain in $CERT_DOMAINS; do + local sni=$(get_test_sni "$domain") + run_test "All gateways have same cert for $domain" \ + "$(test_certificates_match "$sni"; echo $?)" + run_test "Cert for $domain issued by Pebble" \ + "$(test_certificate_from_pebble "$sni"; echo $?)" + done + + # Phase 6: SNI-based selection + log_phase 6 "SNI-based certificate selection" + for domain in $CERT_DOMAINS; do + local sni=$(get_test_sni "$domain") + local wildcard=$(get_wildcard_domain "$domain") + run_test "SNI $sni returns $wildcard cert" \ + "$(test_sni_cert_selection "$first_proxy" "$sni" "$wildcard"; echo $?)" + done + + # Phase 7: Proxy TLS health + log_phase 7 "Proxy TLS health endpoint" + for domain in $CERT_DOMAINS; do + local sni=$(get_test_sni "$domain") + local i=1 + for proxy in $GATEWAY_PROXIES; do + run_test "Gateway $i TLS health ($sni)" \ + "$(test_proxy_tls_health "$proxy" "$sni"; echo $?)" + i=$((i + 1)) + done + done + + # Phase 8: DNS records (informational) + log_phase 8 "DNS-01 challenge records" + local records=$(curl -sf "${MOCK_CF_API}/api/records" 2>/dev/null || echo "") + if echo "$records" | grep -q "TXT"; then + log_success "DNS TXT records found" + else + log_info "No DNS TXT records (expected if certs cached)" + fi + + # Summary + log_section "Test Summary" + log_info "Passed: $TESTS_PASSED" + log_info "Failed: $TESTS_FAILED" + log_info "Domains: $(echo "$CERT_DOMAINS" | wc -w)" + + if [ $TESTS_FAILED -eq 0 ]; then + log_success "All tests passed!" + exit 0 + else + log_fail "Some tests failed!" + exit 1 + fi +} + +main diff --git a/gateway/test-run/test_certbot.sh b/gateway/test-run/test_certbot.sh new file mode 100755 index 00000000..26a6ba97 --- /dev/null +++ b/gateway/test-run/test_certbot.sh @@ -0,0 +1,563 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# Distributed Certbot E2E test script +# Tests certificate issuance and synchronization across gateway nodes + +set -m + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# Show help +show_help() { + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Distributed Certbot E2E Test" + echo "" + echo "Options:" + echo " --fresh Clean everything and request new certificate from ACME" + echo " --sync-only Keep existing cert, only test sync between nodes" + echo " --clean Clean all test data and exit" + echo " -h, --help Show this help message" + echo "" + echo "Default (no options): Keep ACME account, request new certificate" + echo "" + echo "Examples:" + echo " $0 # Keep account, new cert" + echo " $0 --fresh # Fresh start, new account and cert" + echo " $0 --sync-only # Test sync with existing cert" + echo " $0 --clean # Clean up all test data" +} + +# Parse arguments +MODE="default" +while [[ $# -gt 0 ]]; do + case $1 in + --fresh) + MODE="fresh" + shift + ;; + --sync-only) + MODE="sync-only" + shift + ;; + --clean) + MODE="clean" + shift + ;; + -h|--help) + show_help + exit 0 + ;; + *) + echo "Unknown option: $1" + show_help + exit 1 + ;; + esac +done + +# Load environment variables from .env +if [[ -f ".env" ]]; then + source ".env" +else + echo "ERROR: .env file not found!" + echo "" + echo "Please create a .env file with the following variables:" + echo " CF_API_TOKEN=" + echo " CF_ZONE_ID=" + echo " TEST_DOMAIN=" + echo "" + echo "The domain must be managed by Cloudflare and the API token must have" + echo "permissions to manage DNS records and CAA records." + exit 1 +fi + +# Validate required environment variables +if [[ -z "$CF_API_TOKEN" ]]; then + echo "ERROR: CF_API_TOKEN is not set in .env" + exit 1 +fi + +if [[ -z "$CF_ZONE_ID" ]]; then + echo "ERROR: CF_ZONE_ID is not set in .env" + exit 1 +fi + +if [[ -z "$TEST_DOMAIN" ]]; then + echo "ERROR: TEST_DOMAIN is not set in .env" + exit 1 +fi + +GATEWAY_BIN="$SCRIPT_DIR/../../target/release/dstack-gateway" +RUN_DIR="run" +CERTS_DIR="$RUN_DIR/certs" +CA_CERT="$CERTS_DIR/gateway-ca.cert" +LOG_DIR="$RUN_DIR/logs" +CURRENT_TEST="test_certbot" + +# Let's Encrypt staging URL (for testing without rate limits) +ACME_STAGING_URL="https://acme-staging-v02.api.letsencrypt.org/directory" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +cleanup() { + log_info "Cleaning up..." + sudo pkill -9 -f "dstack-gateway.*certbot_node[12].toml" >/dev/null 2>&1 || true + sudo ip link delete certbot-test1 2>/dev/null || true + sudo ip link delete certbot-test2 2>/dev/null || true + sleep 1 + stty sane 2>/dev/null || true +} + +trap cleanup EXIT + +# Generate node config with certbot enabled +generate_certbot_config() { + local node_id=$1 + local rpc_port=$((14000 + node_id * 10 + 2)) + local wg_port=$((14000 + node_id * 10 + 3)) + local proxy_port=$((14000 + node_id * 10 + 4)) + local debug_port=$((14000 + node_id * 10 + 5)) + local wg_ip="10.0.4${node_id}.1/24" + + # Build peer config + local other_node=$((3 - node_id)) # If node_id=1, other=2; if node_id=2, other=1 + local other_rpc_port=$((14000 + other_node * 10 + 2)) + + local abs_run_dir="$SCRIPT_DIR/$RUN_DIR" + local certbot_dir="$abs_run_dir/certbot_node${node_id}" + + mkdir -p "$certbot_dir" + + cat > "$RUN_DIR/certbot_node${node_id}.toml" << EOF +log_level = "info" +address = "0.0.0.0" +port = ${rpc_port} + +[tls] +key = "${abs_run_dir}/certs/gateway-rpc.key" +certs = "${abs_run_dir}/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "${abs_run_dir}/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "" +rpc_domain = "gateway.tdxlab.dstack.org" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = true +port = ${debug_port} +address = "127.0.0.1" + +[core.sync] +enabled = true +interval = "5s" +timeout = "10s" +my_url = "https://localhost:${rpc_port}" +bootnode = "https://localhost:${other_rpc_port}" +node_id = ${node_id} +data_dir = "${RUN_DIR}/wavekv_certbot_node${node_id}" + +[core.certbot] +enabled = true +workdir = "${certbot_dir}" +acme_url = "${ACME_STAGING_URL}" +cf_api_token = "${CF_API_TOKEN}" +cf_zone_id = "${CF_ZONE_ID}" +auto_set_caa = true +domain = "${TEST_DOMAIN}" +renew_interval = "1h" +renew_before_expiration = "720h" +renew_timeout = "5m" + +[core.wg] +private_key = "SEcoI37oGWynhukxXo5Mi8/8zZBU6abg6T1TOJRMj1Y=" +public_key = "xc+7qkdeNFfl4g4xirGGGXHMc0cABuE5IHaLeCASVWM=" +listen_port = ${wg_port} +ip = "${wg_ip}" +reserved_net = ["10.0.4${node_id}.1/31"] +client_ip_range = "10.0.4${node_id}.1/24" +config_path = "${RUN_DIR}/wg_certbot_node${node_id}.conf" +interface = "certbot-test${node_id}" +endpoint = "127.0.0.1:${wg_port}" + +[core.proxy] +cert_chain = "${certbot_dir}/live/cert.pem" +cert_key = "${certbot_dir}/live/key.pem" +base_domain = "tdxlab.dstack.org" +listen_addr = "0.0.0.0" +listen_port = ${proxy_port} +tappd_port = 8090 +external_port = ${proxy_port} +EOF + log_info "Generated certbot_node${node_id}.toml (rpc=${rpc_port}, debug=${debug_port}, proxy=${proxy_port})" +} + +start_certbot_node() { + local node_id=$1 + local config="$RUN_DIR/certbot_node${node_id}.toml" + local log_file="${LOG_DIR}/${CURRENT_TEST}_node${node_id}.log" + + log_info "Starting certbot node ${node_id}..." + mkdir -p "$RUN_DIR/wavekv_certbot_node${node_id}" + mkdir -p "$LOG_DIR" + ( sudo RUST_LOG=info "$GATEWAY_BIN" -c "$config" > "$log_file" 2>&1 & ) + + # Wait for process to either stabilize or fail + local max_wait=30 + local waited=0 + while [[ $waited -lt $max_wait ]]; do + sleep 2 + waited=$((waited + 2)) + + if ! pgrep -f "dstack-gateway.*${config}" > /dev/null; then + # Process exited, check why + log_error "Certbot node ${node_id} exited after ${waited}s" + echo "--- Log output ---" + cat "$log_file" + echo "--- End log ---" + + # Check for rate limit error + if grep -q "rateLimited" "$log_file"; then + log_error "Let's Encrypt rate limit hit. Wait a few minutes and retry." + fi + return 1 + fi + + # Check if cert files exist (indicates successful init) + local certbot_dir="$RUN_DIR/certbot_node${node_id}" + if [[ -f "$certbot_dir/live/cert.pem" ]] && [[ -f "$certbot_dir/live/key.pem" ]]; then + log_info "Certbot node ${node_id} started and certificate obtained" + return 0 + fi + + log_info "Waiting for node ${node_id} to initialize... (${waited}s)" + done + + # Process still running but no cert yet - might still be requesting + if pgrep -f "dstack-gateway.*${config}" > /dev/null; then + log_info "Certbot node ${node_id} still running, certificate request in progress" + return 0 + fi + + log_error "Certbot node ${node_id} failed to start within ${max_wait}s" + cat "$log_file" + return 1 +} + +stop_certbot_node() { + local node_id=$1 + log_info "Stopping certbot node ${node_id}..." + sudo pkill -9 -f "dstack-gateway.*certbot_node${node_id}.toml" >/dev/null 2>&1 || true + sleep 1 +} + +# Get debug sync data from a node +debug_get_sync_data() { + local debug_port=$1 + curl -s "http://localhost:${debug_port}/prpc/GetSyncData" \ + -H "Content-Type: application/json" \ + -d '{}' 2>/dev/null +} + +# Check if KvStore has cert data for the domain +check_kvstore_cert() { + local debug_port=$1 + local response=$(debug_get_sync_data "$debug_port") + + # The cert data would be in the persistent store + # For now, check if we can get any data + if [[ -z "$response" ]]; then + return 1 + fi + + # Check for cert-related keys in the response + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + # Check if there are any keys that start with 'cert/' + # This is a simplified check + print('ok') + sys.exit(0) +except Exception as e: + print(f'error: {e}', file=sys.stderr) + sys.exit(1) +" 2>/dev/null +} + +# Check if proxy is using a valid certificate by connecting via TLS +check_proxy_cert() { + local proxy_port=$1 + + # Use gateway.{base_domain} as the SNI for health endpoint + local gateway_host="gateway.tdxlab.dstack.org" + + # Use openssl to check the certificate + local cert_info=$(echo | timeout 5 openssl s_client -connect "localhost:${proxy_port}" -servername "$gateway_host" 2>/dev/null) + + if [[ -z "$cert_info" ]]; then + log_error "Failed to connect to proxy on port ${proxy_port}" + return 1 + fi + + # Check if the certificate is valid (not self-signed test cert) + # For staging certs, the issuer should contain "Staging" or "(STAGING)" + local issuer=$(echo "$cert_info" | openssl x509 -noout -issuer 2>/dev/null) + + if echo "$issuer" | grep -qi "staging\|fake\|test"; then + log_info "Proxy on port ${proxy_port} is using Let's Encrypt staging certificate" + log_info "Issuer: $issuer" + return 0 + elif echo "$issuer" | grep -qi "let's encrypt\|letsencrypt"; then + log_info "Proxy on port ${proxy_port} is using Let's Encrypt certificate" + log_info "Issuer: $issuer" + return 0 + else + log_warn "Proxy on port ${proxy_port} certificate issuer: $issuer" + # Still return success if we got a certificate + return 0 + fi +} + +# Get certificate expiry from proxy health endpoint +get_proxy_cert_expiry() { + local proxy_port=$1 + # Use gateway.{base_domain} as the SNI for health endpoint + local gateway_host="gateway.tdxlab.dstack.org" + echo | timeout 5 openssl s_client -connect "localhost:${proxy_port}" -servername "$gateway_host" 2>/dev/null | \ + openssl x509 -noout -enddate 2>/dev/null | \ + cut -d= -f2 +} + +# Get certificate serial from proxy health endpoint +get_proxy_cert_serial() { + local proxy_port=$1 + local gateway_host="gateway.tdxlab.dstack.org" + echo | timeout 5 openssl s_client -connect "localhost:${proxy_port}" -servername "$gateway_host" 2>/dev/null | \ + openssl x509 -noout -serial 2>/dev/null | \ + cut -d= -f2 +} + +# Get certificate issuer from proxy +get_proxy_cert_issuer() { + local proxy_port=$1 + local gateway_host="gateway.tdxlab.dstack.org" + echo | timeout 5 openssl s_client -connect "localhost:${proxy_port}" -servername "$gateway_host" 2>/dev/null | \ + openssl x509 -noout -issuer 2>/dev/null +} + +# Wait for certificate to be issued (with timeout) +wait_for_cert() { + local proxy_port=$1 + local timeout_secs=${2:-300} # Default 5 minutes + local start_time=$(date +%s) + + log_info "Waiting for certificate to be issued (timeout: ${timeout_secs}s)..." + + while true; do + local current_time=$(date +%s) + local elapsed=$((current_time - start_time)) + + if [[ $elapsed -ge $timeout_secs ]]; then + log_error "Timeout waiting for certificate" + return 1 + fi + + # Try to get certificate info + local expiry=$(get_proxy_cert_expiry "$proxy_port") + if [[ -n "$expiry" ]]; then + log_info "Certificate detected! Expiry: $expiry" + return 0 + fi + + log_info "Waiting... (${elapsed}s elapsed)" + sleep 10 + done +} + +# ============================================================ +# Main Test +# ============================================================ + +do_clean() { + log_info "Cleaning all certbot test data..." + cleanup + sudo rm -rf "$RUN_DIR/certbot_node1" "$RUN_DIR/certbot_node2" + sudo rm -rf "$RUN_DIR/wavekv_certbot_node1" "$RUN_DIR/wavekv_certbot_node2" + sudo rm -f "$RUN_DIR/gateway-state-certbot-node1.json" "$RUN_DIR/gateway-state-certbot-node2.json" + log_info "Done." +} + +main() { + log_info "==========================================" + log_info "Distributed Certbot E2E Test" + log_info "==========================================" + log_info "Test domain: $TEST_DOMAIN" + log_info "ACME URL: $ACME_STAGING_URL" + log_info "Mode: $MODE" + log_info "" + + # Handle --clean mode + if [[ "$MODE" == "clean" ]]; then + do_clean + return 0 + fi + + # Handle --sync-only mode: check if cert exists + if [[ "$MODE" == "sync-only" ]]; then + if [[ ! -f "$RUN_DIR/certbot_node1/live/cert.pem" ]]; then + log_error "No existing certificate found. Run without --sync-only first." + return 1 + fi + log_info "Using existing certificate for sync test" + fi + + # Clean up processes and state + cleanup + + # Decide what to clean based on mode + case "$MODE" in + fresh) + # Clean everything including ACME account + log_info "Fresh mode: cleaning all data including ACME account" + sudo rm -rf "$RUN_DIR/certbot_node1" "$RUN_DIR/certbot_node2" + ;; + sync-only) + # Keep node1 cert, only clean node2 and wavekv + log_info "Sync-only mode: keeping node1 certificate" + sudo rm -rf "$RUN_DIR/certbot_node2" + ;; + *) + # Default: keep ACME account (credentials.json), clean certs + log_info "Default mode: keeping ACME account, requesting new certificate" + # Backup credentials if exists + if [[ -f "$RUN_DIR/certbot_node1/credentials.json" ]]; then + sudo cp "$RUN_DIR/certbot_node1/credentials.json" /tmp/certbot_credentials_backup.json + fi + sudo rm -rf "$RUN_DIR/certbot_node1" "$RUN_DIR/certbot_node2" + # Restore credentials + if [[ -f /tmp/certbot_credentials_backup.json ]]; then + mkdir -p "$RUN_DIR/certbot_node1" + sudo mv /tmp/certbot_credentials_backup.json "$RUN_DIR/certbot_node1/credentials.json" + fi + ;; + esac + + # Always clean wavekv and gateway state + sudo rm -rf "$RUN_DIR/wavekv_certbot_node1" "$RUN_DIR/wavekv_certbot_node2" + sudo rm -f "$RUN_DIR/gateway-state-certbot-node1.json" "$RUN_DIR/gateway-state-certbot-node2.json" + + # Generate configs + log_info "Generating node configurations..." + generate_certbot_config 1 + generate_certbot_config 2 + + # Start Node 1 first - it will request the certificate + log_info "" + log_info "==========================================" + log_info "Phase 1: Start Node 1 and request certificate" + log_info "==========================================" + + if ! start_certbot_node 1; then + log_error "Failed to start node 1" + return 1 + fi + + # Wait for certificate to be issued + local proxy_port_1=14014 + if ! wait_for_cert "$proxy_port_1" 300; then + log_error "Node 1 failed to obtain certificate" + cat "$LOG_DIR/${CURRENT_TEST}_node1.log" | tail -50 + return 1 + fi + + # Get Node 1's certificate info + local node1_serial=$(get_proxy_cert_serial "$proxy_port_1") + local node1_expiry=$(get_proxy_cert_expiry "$proxy_port_1") + log_info "Node 1 certificate serial: $node1_serial" + log_info "Node 1 certificate expiry: $node1_expiry" + + # Show certificate source logs for Node 1 + log_info "" + log_info "Node 1 certificate source:" + grep -E "cert\[|acme\[" "$LOG_DIR/${CURRENT_TEST}_node1.log" 2>/dev/null | sed 's/^/ /' + + # Start Node 2 - it should sync the certificate from Node 1 + log_info "" + log_info "==========================================" + log_info "Phase 2: Start Node 2 and verify sync" + log_info "==========================================" + + if ! start_certbot_node 2; then + log_error "Failed to start node 2" + return 1 + fi + + # Wait for Node 2 to sync and load the certificate + local proxy_port_2=14024 + sleep 10 # Give time for sync + + if ! wait_for_cert "$proxy_port_2" 60; then + log_error "Node 2 failed to obtain certificate via sync" + cat "$LOG_DIR/${CURRENT_TEST}_node2.log" | tail -50 + return 1 + fi + + # Get Node 2's certificate info + local node2_serial=$(get_proxy_cert_serial "$proxy_port_2") + local node2_expiry=$(get_proxy_cert_expiry "$proxy_port_2") + log_info "Node 2 certificate serial: $node2_serial" + log_info "Node 2 certificate expiry: $node2_expiry" + + # Show certificate source logs for Node 2 + log_info "" + log_info "Node 2 certificate source:" + grep -E "cert\[|acme\[" "$LOG_DIR/${CURRENT_TEST}_node2.log" 2>/dev/null | sed 's/^/ /' + + # Verify both nodes have the same certificate + log_info "" + log_info "==========================================" + log_info "Verification" + log_info "==========================================" + + if [[ "$node1_serial" == "$node2_serial" ]]; then + log_info "SUCCESS: Both nodes have the same certificate (serial: $node1_serial)" + else + log_error "FAILURE: Certificate mismatch!" + log_error " Node 1 serial: $node1_serial" + log_error " Node 2 serial: $node2_serial" + return 1 + fi + + # Check that proxy is actually using the certificate + check_proxy_cert "$proxy_port_1" + check_proxy_cert "$proxy_port_2" + + log_info "" + log_info "==========================================" + log_info "All tests passed!" + log_info "==========================================" + + return 0 +} + +# Run main +main +exit $? diff --git a/gateway/test-run/test_suite.sh b/gateway/test-run/test_suite.sh new file mode 100755 index 00000000..ddb00814 --- /dev/null +++ b/gateway/test-run/test_suite.sh @@ -0,0 +1,2130 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# WaveKV integration test script + +# Don't use set -e as it causes issues with cleanup and test flow +# set -e + +# Disable job control messages (prevents "Killed" messages from messing up output) +set +m + +# Fix terminal output - ensure proper line endings +stty -echoctl 2>/dev/null || true + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +GATEWAY_BIN="/home/kvin/sdc/home/wavekv/dstack/target/release/dstack-gateway" +RUN_DIR="run" +CERTS_DIR="$RUN_DIR/certs" +CA_CERT="$CERTS_DIR/gateway-ca.cert" +LOG_DIR="$RUN_DIR/logs" +CURRENT_TEST="" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +cleanup() { + log_info "Cleaning up..." + # Kill only dstack-gateway processes started by this test (matching our specific config path) + # Use absolute path to avoid killing system dstack-gateway processes + pkill -9 -f "dstack-gateway -c ${SCRIPT_DIR}/${RUN_DIR}/node" >/dev/null 2>&1 || true + pkill -9 -f "dstack-gateway.*${SCRIPT_DIR}/${RUN_DIR}/node" >/dev/null 2>&1 || true + sleep 1 + # Only delete WireGuard interfaces with sudo (these are our test interfaces) + sudo ip link delete wavekv-test1 2>/dev/null || true + sudo ip link delete wavekv-test2 2>/dev/null || true + sudo ip link delete wavekv-test3 2>/dev/null || true + # Clean up all wavekv data directories to prevent peer list contamination + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" "$RUN_DIR/wavekv_node3" 2>/dev/null || true + rm -f "$RUN_DIR/gateway-state-node"*.json 2>/dev/null || true + sleep 1 + stty sane 2>/dev/null || true +} + +trap cleanup EXIT + +# Generate node configs +# Usage: generate_config [bootnode_url] +generate_config() { + local node_id=$1 + local bootnode_url=${2:-""} + local rpc_port=$((13000 + node_id * 10 + 2)) + local wg_port=$((13000 + node_id * 10 + 3)) + local proxy_port=$((13000 + node_id * 10 + 4)) + local debug_port=$((13000 + node_id * 10 + 5)) + local admin_port=$((13000 + node_id * 10 + 6)) + local wg_ip="10.0.3${node_id}.1/24" + + # Use absolute paths to avoid Rocket's relative path resolution issues + local abs_run_dir="$SCRIPT_DIR/$RUN_DIR" + cat >"$RUN_DIR/node${node_id}.toml" </dev/null | grep -q ":${port} "; then + return 0 + fi + sleep 1 + ((waited++)) + done + return 1 +} + +ensure_wg_interface() { + local node_id=$1 + local iface="wavekv-test${node_id}" + + # Check if interface exists, create if not + if ! ip link show "$iface" >/dev/null 2>&1; then + log_info "Creating WireGuard interface ${iface}..." + sudo ip link add "$iface" type wireguard || { + log_error "Failed to create WireGuard interface ${iface}" + return 1 + } + fi + return 0 +} + +start_node() { + local node_id=$1 + local config="${SCRIPT_DIR}/${RUN_DIR}/node${node_id}.toml" + local log_file="${LOG_DIR}/${CURRENT_TEST}_node${node_id}.log" + + # Calculate ports for this node + local admin_port=$((13000 + node_id * 10 + 6)) + local rpc_port=$((13000 + node_id * 10 + 2)) + + log_info "Starting node ${node_id}..." + + # Kill any existing test process for this node first (use absolute path to be precise) + pkill -9 -f "dstack-gateway -c ${config}" >/dev/null 2>&1 || true + pkill -9 -f "dstack-gateway.*${config}" >/dev/null 2>&1 || true + sleep 1 + + # Wait for ports to be free + if ! wait_for_port_free $admin_port; then + log_error "Port $admin_port still in use after waiting" + netstat -tlnp 2>/dev/null | grep ":${admin_port} " || true + return 1 + fi + if ! wait_for_port_free $rpc_port; then + log_error "Port $rpc_port still in use after waiting" + netstat -tlnp 2>/dev/null | grep ":${rpc_port} " || true + return 1 + fi + + # Ensure WireGuard interface exists before starting + if ! ensure_wg_interface "$node_id"; then + return 1 + fi + + mkdir -p "$RUN_DIR/wavekv_node${node_id}" + mkdir -p "$LOG_DIR" + (RUST_LOG=info "$GATEWAY_BIN" -c "$config" >"$log_file" 2>&1 &) + sleep 2 + + if pgrep -f "dstack-gateway.*${config}" >/dev/null; then + log_info "Node ${node_id} started successfully" + return 0 + else + log_error "Node ${node_id} failed to start" + cat "$log_file" + return 1 + fi +} + +stop_node() { + local node_id=$1 + local config="${SCRIPT_DIR}/${RUN_DIR}/node${node_id}.toml" + local admin_port=$((13000 + node_id * 10 + 6)) + + log_info "Stopping node ${node_id}..." + # Kill only the specific test process using absolute config path + pkill -9 -f "dstack-gateway -c ${config}" >/dev/null 2>&1 || true + pkill -9 -f "dstack-gateway.*${config}" >/dev/null 2>&1 || true + sleep 1 + + # Verify the port is free, otherwise force kill by PID + if ! wait_for_port_free $admin_port; then + log_warn "Node ${node_id} port still in use, forcing cleanup..." + # Find and kill the process holding the port + local pid=$(netstat -tlnp 2>/dev/null | grep ":${admin_port} " | awk '{print $7}' | cut -d'/' -f1) + if [[ -n "$pid" ]]; then + kill -9 "$pid" 2>/dev/null || true + sleep 1 + fi + fi + + # Reset terminal to fix any broken line endings + stty sane 2>/dev/null || true +} + +# Get WaveKV status via Admin.WaveKvStatus RPC +# Usage: get_status +get_status() { + local admin_port=$1 + curl -s -X POST "http://localhost:${admin_port}/prpc/Admin.WaveKvStatus" \ + -H "Content-Type: application/json" \ + -d '{}' 2>/dev/null +} + +get_n_keys() { + local admin_port=$1 + get_status "$admin_port" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d['persistent']['n_keys'])" 2>/dev/null || echo "0" +} + +# Register CVM via debug port (no attestation required) +# Usage: debug_register_cvm +# Returns: JSON response +debug_register_cvm() { + local debug_port=$1 + local public_key=$2 + local app_id=${3:-"testapp"} + local instance_id=${4:-"testinstance"} + curl -s \ + -X POST "http://localhost:${debug_port}/prpc/RegisterCvm" \ + -H "Content-Type: application/json" \ + -d "{\"client_public_key\": \"$public_key\", \"app_id\": \"$app_id\", \"instance_id\": \"$instance_id\"}" 2>/dev/null +} + +# Check if debug service is available +# Usage: check_debug_service +check_debug_service() { + local debug_port=$1 + local response=$(curl -s -X POST "http://localhost:${debug_port}/prpc/Debug.Info" \ + -H "Content-Type: application/json" -d '{}' 2>/dev/null) + if echo "$response" | python3 -c "import sys,json; d=json.load(sys.stdin); assert 'base_domain' in d" 2>/dev/null; then + return 0 + else + return 1 + fi +} + +# Verify register response is successful (has wg config, no error) +# Usage: verify_register_response +verify_register_response() { + local response="$1" + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + if 'error' in d: + print(f'ERROR: {d[\"error\"]}', file=sys.stderr) + sys.exit(1) + assert 'wg' in d, 'missing wg config' + assert 'client_ip' in d['wg'], 'missing client_ip' + print(d['wg']['client_ip']) +except Exception as e: + print(f'ERROR: {e}', file=sys.stderr) + sys.exit(1) +" 2>/dev/null +} + +# Get sync data from debug port (peer_addrs, nodes, instances) +# Usage: debug_get_sync_data +# Returns: JSON response with my_node_id, peer_addrs, nodes, instances +debug_get_sync_data() { + local debug_port=$1 + curl -s -X POST "http://localhost:${debug_port}/prpc/Debug.GetSyncData" \ + -H "Content-Type: application/json" -d '{}' 2>/dev/null +} + +# Check if node has synced peer address from another node +# Usage: has_peer_addr +# Returns: 0 if peer address exists, 1 otherwise +has_peer_addr() { + local debug_port=$1 + local peer_node_id=$2 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + peer_addrs = d.get('peer_addrs', []) + for pa in peer_addrs: + if pa.get('node_id') == $peer_node_id: + sys.exit(0) + sys.exit(1) +except Exception as e: + sys.exit(1) +" +} + +# Check if node has synced node info from another node +# Usage: has_node_info +# Returns: 0 if node info exists, 1 otherwise +has_node_info() { + local debug_port=$1 + local peer_node_id=$2 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + nodes = d.get('nodes', []) + for n in nodes: + if n.get('node_id') == $peer_node_id: + sys.exit(0) + sys.exit(1) +except Exception as e: + sys.exit(1) +" +} + +# Get number of peer addresses from sync data +# Usage: get_n_peer_addrs +get_n_peer_addrs() { + local debug_port=$1 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + print(len(d.get('peer_addrs', []))) +except: + print(0) +" 2>/dev/null +} + +# Get number of node infos from sync data +# Usage: get_n_nodes +get_n_nodes() { + local debug_port=$1 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + print(len(d.get('nodes', []))) +except: + print(0) +" 2>/dev/null +} + +# Get number of instances from KvStore sync data +# Usage: get_n_instances +get_n_instances() { + local debug_port=$1 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + print(len(d.get('instances', []))) +except: + print(0) +" 2>/dev/null +} + +# Get Proxy State from debug port (in-memory state) +# Usage: debug_get_proxy_state +# Returns: JSON response with instances and allocated_addresses +debug_get_proxy_state() { + local debug_port=$1 + curl -s -X POST "http://localhost:${debug_port}/prpc/GetProxyState" \ + -H "Content-Type: application/json" -d '{}' 2>/dev/null +} + +# Get number of instances from ProxyState (in-memory) +# Usage: get_n_proxy_state_instances +get_n_proxy_state_instances() { + local debug_port=$1 + local response=$(debug_get_proxy_state "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + print(len(d.get('instances', []))) +except: + print(0) +" 2>/dev/null +} + +# Check KvStore and ProxyState instance consistency +# Usage: check_instance_consistency +# Returns: 0 if consistent, 1 otherwise +check_instance_consistency() { + local debug_port=$1 + local kvstore_instances=$(get_n_instances "$debug_port") + local proxystate_instances=$(get_n_proxy_state_instances "$debug_port") + + if [[ "$kvstore_instances" -eq "$proxystate_instances" ]]; then + return 0 + else + log_error "Instance count mismatch: KvStore=$kvstore_instances, ProxyState=$proxystate_instances" + return 1 + fi +} + +# ============================================================================= +# Test 1: Single node persistence +# ============================================================================= +test_persistence() { + log_info "========== Test 1: Persistence ==========" + cleanup + + generate_config 1 + + # Start node and let it write some data + start_node 1 + + local admin_port=13016 + local initial_keys=$(get_n_keys $admin_port) + log_info "Initial keys: $initial_keys" + + # The gateway auto-writes some data (peer_addr, etc) + sleep 2 + local keys_after_write=$(get_n_keys $admin_port) + log_info "Keys after startup: $keys_after_write" + + # Stop and restart + stop_node 1 + log_info "Restarting node 1..." + start_node 1 + + local keys_after_restart=$(get_n_keys $admin_port) + log_info "Keys after restart: $keys_after_restart" + + if [[ "$keys_after_restart" -ge "$keys_after_write" ]]; then + log_info "Persistence test PASSED" + return 0 + else + log_error "Persistence test FAILED: expected >= $keys_after_write keys, got $keys_after_restart" + return 1 + fi +} + +# ============================================================================= +# Test 2: Multi-node sync +# ============================================================================= +test_multi_node_sync() { + log_info "========== Test 2: Multi-node Sync ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" "$RUN_DIR/wavekv_node3" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" "$RUN_DIR/gateway-state-node3.json" + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + + # Wait for sync + log_info "Waiting for nodes to sync..." + sleep 10 + + # Use debug RPC to check actual synced data + local peer_addrs1=$(get_n_peer_addrs $debug_port1) + local peer_addrs2=$(get_n_peer_addrs $debug_port2) + local nodes1=$(get_n_nodes $debug_port1) + local nodes2=$(get_n_nodes $debug_port2) + + log_info "Node 1: peer_addrs=$peer_addrs1, nodes=$nodes1" + log_info "Node 2: peer_addrs=$peer_addrs2, nodes=$nodes2" + + # For true sync, each node should have: + # - At least 2 peer addresses (both nodes' addresses) + # - At least 2 node infos (both nodes' info) + local sync_ok=true + + if ! has_peer_addr $debug_port1 2; then + log_error "Node 1 missing peer_addr for node 2" + sync_ok=false + fi + if ! has_peer_addr $debug_port2 1; then + log_error "Node 2 missing peer_addr for node 1" + sync_ok=false + fi + if ! has_node_info $debug_port1 2; then + log_error "Node 1 missing node_info for node 2" + sync_ok=false + fi + if ! has_node_info $debug_port2 1; then + log_error "Node 2 missing node_info for node 1" + sync_ok=false + fi + + if [[ "$sync_ok" == "true" ]]; then + log_info "Multi-node sync test PASSED" + return 0 + else + log_error "Multi-node sync test FAILED: nodes did not sync peer data" + log_info "Sync data from node 1: $(debug_get_sync_data $debug_port1)" + log_info "Sync data from node 2: $(debug_get_sync_data $debug_port2)" + return 1 + fi +} + +# ============================================================================= +# Test 3: Node recovery after disconnect +# ============================================================================= +test_node_recovery() { + log_info "========== Test 3: Node Recovery ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + + # Wait for initial sync + sleep 5 + + # Stop node 2 + log_info "Stopping node 2 to simulate disconnect..." + stop_node 2 + + # Wait and let node 1 continue + sleep 3 + + # Check node 1 has its own data + local peer_addrs1_before=$(get_n_peer_addrs $debug_port1) + log_info "Node 1 peer_addrs before node 2 restart: $peer_addrs1_before" + + # Restart node 2 + log_info "Restarting node 2..." + start_node 2 + + # Re-register peers after restart + setup_peers 1 2 + + # Wait for sync + sleep 10 + + # After recovery, node 2 should have synced node 1's data + local sync_ok=true + + if ! has_peer_addr $debug_port2 1; then + log_error "Node 2 missing peer_addr for node 1 after recovery" + sync_ok=false + fi + if ! has_node_info $debug_port2 1; then + log_error "Node 2 missing node_info for node 1 after recovery" + sync_ok=false + fi + + if [[ "$sync_ok" == "true" ]]; then + log_info "Node recovery test PASSED" + return 0 + else + log_error "Node recovery test FAILED: node 2 did not sync data from node 1" + log_info "Sync data from node 2: $(debug_get_sync_data $debug_port2)" + return 1 + fi +} + +# ============================================================================= +# Test 4: Status endpoint structure (Admin.WaveKvStatus RPC) +# ============================================================================= +test_status_endpoint() { + log_info "========== Test 4: Status Endpoint ==========" + cleanup + + generate_config 1 + start_node 1 + + local admin_port=13016 + local status=$(get_status $admin_port) + + # Verify all expected fields exist + local checks_passed=0 + local total_checks=6 + + echo "$status" | python3 -c " +import sys, json +d = json.load(sys.stdin) +assert d['enabled'] == True, 'enabled should be True' +assert 'persistent' in d, 'missing persistent' +assert 'ephemeral' in d, 'missing ephemeral' +assert d['persistent']['wal_enabled'] == True, 'persistent wal should be enabled' +assert d['ephemeral']['wal_enabled'] == False, 'ephemeral wal should be disabled' +assert 'peers' in d['persistent'], 'missing peers in persistent' +print('All status checks passed') +" && checks_passed=1 + + if [[ $checks_passed -eq 1 ]]; then + log_info "Status endpoint test PASSED" + return 0 + else + log_error "Status endpoint test FAILED" + log_info "Status response: $status" + return 1 + fi +} + +# ============================================================================= +# Test 5: Cross-node data sync verification (KvStore + ProxyState) +# ============================================================================= +test_cross_node_data_sync() { + log_info "========== Test 5: Cross-node Data Sync ==========" + cleanup + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + + # Wait for initial connection + sleep 5 + + # Verify debug service is available + if ! check_debug_service $debug_port1; then + log_error "Debug service not available on node 1" + return 1 + fi + + # Register a client on node 1 via debug port + log_info "Registering client on node 1 via debug port..." + local register_response=$(debug_register_cvm $debug_port1 "testkey12345678901234567890123456789012345=" "app1" "inst1") + log_info "Register response: $register_response" + + # Verify registration succeeded + local client_ip=$(verify_register_response "$register_response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + log_info "Registered client with IP: $client_ip" + + # Wait for sync (need at least 3 sync intervals of 5s for data to propagate) + log_info "Waiting for sync..." + sleep 20 + + # Check KvStore instance count on both nodes + local kv_instances1=$(get_n_instances $debug_port1) + local kv_instances2=$(get_n_instances $debug_port2) + + # Check ProxyState instance count on both nodes + local ps_instances1=$(get_n_proxy_state_instances $debug_port1) + local ps_instances2=$(get_n_proxy_state_instances $debug_port2) + + log_info "Node 1: KvStore=$kv_instances1, ProxyState=$ps_instances1" + log_info "Node 2: KvStore=$kv_instances2, ProxyState=$ps_instances2" + + local test_passed=true + + # Verify KvStore sync + if [[ "$kv_instances1" -lt 1 ]] || [[ "$kv_instances2" -lt 1 ]]; then + log_error "KvStore sync failed: kv_instances1=$kv_instances1, kv_instances2=$kv_instances2" + test_passed=false + fi + + # Verify ProxyState sync (node 2 should have loaded instance from KvStore) + if [[ "$ps_instances1" -lt 1 ]] || [[ "$ps_instances2" -lt 1 ]]; then + log_error "ProxyState sync failed: ps_instances1=$ps_instances1, ps_instances2=$ps_instances2" + test_passed=false + fi + + # Verify consistency on each node + if [[ "$kv_instances1" -ne "$ps_instances1" ]]; then + log_error "Node 1 inconsistent: KvStore=$kv_instances1, ProxyState=$ps_instances1" + test_passed=false + fi + if [[ "$kv_instances2" -ne "$ps_instances2" ]]; then + log_error "Node 2 inconsistent: KvStore=$kv_instances2, ProxyState=$ps_instances2" + test_passed=false + fi + + if [[ "$test_passed" == "true" ]]; then + log_info "Cross-node data sync test PASSED (KvStore and ProxyState consistent)" + return 0 + else + log_info "KvStore from node 1: $(debug_get_sync_data $debug_port1)" + log_info "KvStore from node 2: $(debug_get_sync_data $debug_port2)" + log_info "ProxyState from node 1: $(debug_get_proxy_state $debug_port1)" + log_info "ProxyState from node 2: $(debug_get_proxy_state $debug_port2)" + return 1 + fi +} + +# ============================================================================= +# Test 6: prpc DebugRegisterCvm endpoint (on separate debug port) +# ============================================================================= +test_prpc_register() { + log_info "========== Test 6: prpc DebugRegisterCvm ==========" + cleanup + + generate_config 1 + start_node 1 + + local debug_port=13015 + + # Verify debug service is available first + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + log_info "Debug service is available" + + # Register via debug port + local register_response=$(debug_register_cvm $debug_port "prpctest12345678901234567890123456789012=" "deadbeef" "cafebabe") + log_info "Register response: $register_response" + + # Verify registration succeeded + local client_ip=$(verify_register_response "$register_response") + if [[ -z "$client_ip" ]]; then + log_error "prpc DebugRegisterCvm test FAILED" + return 1 + fi + + log_info "DebugRegisterCvm success: client_ip=$client_ip" + log_info "prpc DebugRegisterCvm test PASSED" + return 0 +} + +# ============================================================================= +# Test 7: prpc Info endpoint +# ============================================================================= +test_prpc_info() { + log_info "========== Test 7: prpc Info ==========" + cleanup + + generate_config 1 + start_node 1 + + local port=13012 + + # Call Info via prpc + # Note: trim: "Tproxy." removes "Tproxy.Gateway." prefix, so endpoint is just /prpc/Info + local info_response=$(curl -sk --cacert "$CA_CERT" \ + -X POST "https://localhost:${port}/prpc/Info" \ + -H "Content-Type: application/json" \ + -d '{}' 2>/dev/null) + + log_info "Info response: $info_response" + + # Verify response has expected fields and no error + echo "$info_response" | python3 -c " +import sys, json +d = json.load(sys.stdin) +if 'error' in d: + print(f'ERROR: {d[\"error\"]}', file=sys.stderr) + sys.exit(1) +assert 'base_domain' in d, 'missing base_domain' +assert 'external_port' in d, 'missing external_port' +print('prpc Info check passed') +" && { + log_info "prpc Info test PASSED" + return 0 + } || { + log_error "prpc Info test FAILED" + return 1 + } +} + +# ============================================================================= +# Test 8: Client registration and data persistence +# ============================================================================= +test_client_registration_persistence() { + log_info "========== Test 8: Client Registration Persistence ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local debug_port=13015 + local admin_port=13016 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Register a client via debug port + log_info "Registering client..." + local register_response=$(debug_register_cvm $debug_port "persisttest1234567890123456789012345678901=" "persist_app" "persist_inst") + log_info "Register response: $register_response" + + # Verify registration succeeded + local client_ip=$(verify_register_response "$register_response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + + # Get initial key count + local keys_before=$(get_n_keys $admin_port) + log_info "Keys before restart: $keys_before" + + # Restart node + stop_node 1 + start_node 1 + + # Check keys after restart + local keys_after=$(get_n_keys $admin_port) + log_info "Keys after restart: $keys_after" + + if [[ "$keys_after" -ge "$keys_before" ]] && [[ "$keys_before" -gt 2 ]]; then + log_info "Client registration persistence test PASSED" + return 0 + else + log_error "Client registration persistence test FAILED: keys_before=$keys_before, keys_after=$keys_after" + return 1 + fi +} + +# ============================================================================= +# Test 9: Stress test - multiple writes +# ============================================================================= +test_stress_writes() { + log_info "========== Test 9: Stress Test ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local debug_port=13015 + local admin_port=13016 + local num_clients=10 + local success_count=0 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + log_info "Registering $num_clients clients via debug port..." + for i in $(seq 1 $num_clients); do + local key=$(printf "stresstest%02d12345678901234567890123456=" "$i") + local app_id=$(printf "stressapp%02d" "$i") + local inst_id=$(printf "stressinst%02d" "$i") + local response=$(debug_register_cvm $debug_port "$key" "$app_id" "$inst_id") + if verify_register_response "$response" >/dev/null 2>&1; then + ((success_count++)) + fi + done + + log_info "Successfully registered $success_count/$num_clients clients" + + sleep 2 + + local keys_after=$(get_n_keys $admin_port) + log_info "Keys after stress test: $keys_after" + + # We expect successful registrations to create keys + if [[ "$success_count" -eq "$num_clients" ]] && [[ "$keys_after" -gt 2 ]]; then + log_info "Stress test PASSED" + return 0 + else + log_error "Stress test FAILED: success_count=$success_count, keys_after=$keys_after" + return 1 + fi +} + +# ============================================================================= +# Test 10: Network partition simulation (KvStore + ProxyState consistency) +# ============================================================================= +test_network_partition() { + log_info "========== Test 10: Network Partition Recovery ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + + # Let them sync initially + sleep 5 + + # Verify debug service is available + if ! check_debug_service $debug_port1; then + log_error "Debug service not available on node 1" + return 1 + fi + + # Stop node 2 (simulate partition) + log_info "Simulating network partition - stopping node 2..." + stop_node 2 + + # Register clients on node 1 while node 2 is down + log_info "Registering clients on node 1 during partition..." + local success_count=0 + for i in $(seq 1 3); do + local key=$(printf "partition%02d123456789012345678901234567=" "$i") + local response=$(debug_register_cvm $debug_port1 "$key" "partition_app$i" "partition_inst$i") + if verify_register_response "$response" >/dev/null 2>&1; then + ((success_count++)) + fi + done + log_info "Registered $success_count/3 clients during partition" + + local kv1_during=$(get_n_instances $debug_port1) + local ps1_during=$(get_n_proxy_state_instances $debug_port1) + log_info "Node 1 during partition: KvStore=$kv1_during, ProxyState=$ps1_during" + + # Restore node 2 + log_info "Healing partition - restarting node 2..." + start_node 2 + + # Re-register peers after restart + setup_peers 1 2 + + # Wait for sync + sleep 15 + + # Check KvStore and ProxyState on both nodes after recovery + local kv1_after=$(get_n_instances $debug_port1) + local kv2_after=$(get_n_instances $debug_port2) + local ps1_after=$(get_n_proxy_state_instances $debug_port1) + local ps2_after=$(get_n_proxy_state_instances $debug_port2) + + log_info "Node 1 after recovery: KvStore=$kv1_after, ProxyState=$ps1_after" + log_info "Node 2 after recovery: KvStore=$kv2_after, ProxyState=$ps2_after" + + local test_passed=true + + # Verify basic sync + if [[ "$success_count" -ne 3 ]] || [[ "$kv1_during" -lt 3 ]]; then + log_error "Registration or KvStore write failed during partition" + test_passed=false + fi + + # Verify node 2 synced KvStore + if [[ "$kv2_after" -lt "$kv1_during" ]]; then + log_error "Node 2 KvStore sync failed: kv2_after=$kv2_after, expected >= $kv1_during" + test_passed=false + fi + + # Verify node 2 ProxyState sync + if [[ "$ps2_after" -lt "$kv1_during" ]]; then + log_error "Node 2 ProxyState sync failed: ps2_after=$ps2_after, expected >= $kv1_during" + test_passed=false + fi + + # Verify consistency on each node + if [[ "$kv1_after" -ne "$ps1_after" ]]; then + log_error "Node 1 inconsistent: KvStore=$kv1_after, ProxyState=$ps1_after" + test_passed=false + fi + if [[ "$kv2_after" -ne "$ps2_after" ]]; then + log_error "Node 2 inconsistent: KvStore=$kv2_after, ProxyState=$ps2_after" + test_passed=false + fi + + if [[ "$test_passed" == "true" ]]; then + log_info "Network partition recovery test PASSED (KvStore and ProxyState consistent)" + return 0 + else + log_info "KvStore from node 2: $(debug_get_sync_data $debug_port2)" + log_info "ProxyState from node 2: $(debug_get_proxy_state $debug_port2)" + return 1 + fi +} + +# ============================================================================= +# Test 11: Three-node cluster (KvStore + ProxyState consistency) +# ============================================================================= +test_three_node_cluster() { + log_info "========== Test 11: Three-node Cluster ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" "$RUN_DIR/wavekv_node3" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" "$RUN_DIR/gateway-state-node3.json" + + generate_config 1 + generate_config 2 + generate_config 3 + + start_node 1 + start_node 2 + start_node 3 + + # Register peers so all nodes can discover each other + setup_peers 1 2 3 + + local debug_port1=13015 + local debug_port2=13025 + local debug_port3=13035 + + # Wait for cluster to form + sleep 10 + + # Verify debug service is available + if ! check_debug_service $debug_port1; then + log_error "Debug service not available on node 1" + return 1 + fi + + # Register client on node 1 + log_info "Registering client on node 1..." + local response=$(debug_register_cvm $debug_port1 "threenode12345678901234567890123456789=" "threenode_app" "threenode_inst") + local client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + log_info "Registered client with IP: $client_ip" + + # Wait for sync across all nodes (need at least 2 sync intervals of 5s) + sleep 20 + + # Check KvStore instances on all three nodes + local kv1=$(get_n_instances $debug_port1) + local kv2=$(get_n_instances $debug_port2) + local kv3=$(get_n_instances $debug_port3) + + # Check ProxyState instances on all three nodes + local ps1=$(get_n_proxy_state_instances $debug_port1) + local ps2=$(get_n_proxy_state_instances $debug_port2) + local ps3=$(get_n_proxy_state_instances $debug_port3) + + log_info "Node 1: KvStore=$kv1, ProxyState=$ps1" + log_info "Node 2: KvStore=$kv2, ProxyState=$ps2" + log_info "Node 3: KvStore=$kv3, ProxyState=$ps3" + + local test_passed=true + + # Verify KvStore sync on all nodes + if [[ "$kv1" -lt 1 ]] || [[ "$kv2" -lt 1 ]] || [[ "$kv3" -lt 1 ]]; then + log_error "KvStore sync failed: kv1=$kv1, kv2=$kv2, kv3=$kv3" + test_passed=false + fi + + # Verify ProxyState sync on all nodes + if [[ "$ps1" -lt 1 ]] || [[ "$ps2" -lt 1 ]] || [[ "$ps3" -lt 1 ]]; then + log_error "ProxyState sync failed: ps1=$ps1, ps2=$ps2, ps3=$ps3" + test_passed=false + fi + + # Verify consistency on each node + if [[ "$kv1" -ne "$ps1" ]] || [[ "$kv2" -ne "$ps2" ]] || [[ "$kv3" -ne "$ps3" ]]; then + log_error "Inconsistency detected between KvStore and ProxyState" + test_passed=false + fi + + if [[ "$test_passed" == "true" ]]; then + log_info "Three-node cluster test PASSED (KvStore and ProxyState consistent)" + return 0 + else + log_info "KvStore from node 1: $(debug_get_sync_data $debug_port1)" + log_info "KvStore from node 2: $(debug_get_sync_data $debug_port2)" + log_info "KvStore from node 3: $(debug_get_sync_data $debug_port3)" + log_info "ProxyState from node 1: $(debug_get_proxy_state $debug_port1)" + log_info "ProxyState from node 2: $(debug_get_proxy_state $debug_port2)" + log_info "ProxyState from node 3: $(debug_get_proxy_state $debug_port3)" + return 1 + fi +} + +# ============================================================================= +# Test 12: WAL file integrity +# ============================================================================= +test_wal_integrity() { + log_info "========== Test 12: WAL File Integrity ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local debug_port=13015 + local success_count=0 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Register some clients via debug port + for i in $(seq 1 5); do + local key=$(printf "waltest%02d1234567890123456789012345678901=" "$i") + local response=$(debug_register_cvm $debug_port "$key" "wal_app$i" "wal_inst$i") + if verify_register_response "$response" >/dev/null 2>&1; then + ((success_count++)) + fi + done + log_info "Registered $success_count/5 clients" + + if [[ "$success_count" -ne 5 ]]; then + log_error "Failed to register all clients" + return 1 + fi + + sleep 2 + stop_node 1 + + # Check WAL file exists and has content + local wal_file="$RUN_DIR/wavekv_node1/node_1.wal" + if [[ -f "$wal_file" ]]; then + local wal_size=$(stat -c%s "$wal_file" 2>/dev/null || stat -f%z "$wal_file" 2>/dev/null) + log_info "WAL file size: $wal_size bytes" + + if [[ "$wal_size" -gt 100 ]]; then + log_info "WAL file integrity test PASSED" + return 0 + else + log_error "WAL file integrity test FAILED: WAL file too small ($wal_size bytes)" + return 1 + fi + else + log_error "WAL file not found: $wal_file" + return 1 + fi +} + +# ============================================================================= +# Test 13: Three-node cluster with bootnode (no dynamic peer setup) +# ============================================================================= +test_three_node_bootnode() { + log_info "========== Test 13: Three-node Cluster with Bootnode ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" "$RUN_DIR/wavekv_node3" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" "$RUN_DIR/gateway-state-node3.json" + + # Node 1 is the bootnode (no bootnode config) + # Node 2 and 3 use node 1 as bootnode + local bootnode_url="https://localhost:13012" + + generate_config 1 "" + generate_config 2 "$bootnode_url" + generate_config 3 "$bootnode_url" + + # Start node 1 first (bootnode) + start_node 1 + sleep 2 + + # Start node 2 and 3, they will discover each other via bootnode + start_node 2 + start_node 3 + + local debug_port1=13015 + local debug_port2=13025 + local debug_port3=13035 + + # Wait for cluster to form via bootnode discovery + log_info "Waiting for nodes to discover each other via bootnode..." + sleep 15 + + # Verify debug service is available on all nodes + for port in $debug_port1 $debug_port2 $debug_port3; do + if ! check_debug_service $port; then + log_error "Debug service not available on port $port" + return 1 + fi + done + + # Check peer discovery - each node should know about the others + local peer_addrs1=$(get_n_peer_addrs $debug_port1) + local peer_addrs2=$(get_n_peer_addrs $debug_port2) + local peer_addrs3=$(get_n_peer_addrs $debug_port3) + + log_info "Peer addresses: node1=$peer_addrs1, node2=$peer_addrs2, node3=$peer_addrs3" + + # Register client on node 2 (not the bootnode) + log_info "Registering client on node 2..." + local response=$(debug_register_cvm $debug_port2 "bootnode12345678901234567890123456789=" "bootnode_app" "bootnode_inst") + local client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + log_info "Registered client with IP: $client_ip" + + # Wait for sync across all nodes + sleep 20 + + # Check KvStore instances on all three nodes + local kv1=$(get_n_instances $debug_port1) + local kv2=$(get_n_instances $debug_port2) + local kv3=$(get_n_instances $debug_port3) + + # Check ProxyState instances on all three nodes + local ps1=$(get_n_proxy_state_instances $debug_port1) + local ps2=$(get_n_proxy_state_instances $debug_port2) + local ps3=$(get_n_proxy_state_instances $debug_port3) + + log_info "Node 1 (bootnode): KvStore=$kv1, ProxyState=$ps1" + log_info "Node 2: KvStore=$kv2, ProxyState=$ps2" + log_info "Node 3: KvStore=$kv3, ProxyState=$ps3" + + local test_passed=true + + # Verify peer discovery worked (each node should have at least 2 peer addresses) + if [[ "$peer_addrs1" -lt 2 ]] || [[ "$peer_addrs2" -lt 2 ]] || [[ "$peer_addrs3" -lt 2 ]]; then + log_error "Peer discovery via bootnode failed: peer_addrs1=$peer_addrs1, peer_addrs2=$peer_addrs2, peer_addrs3=$peer_addrs3" + test_passed=false + fi + + # Verify KvStore sync on all nodes + if [[ "$kv1" -lt 1 ]] || [[ "$kv2" -lt 1 ]] || [[ "$kv3" -lt 1 ]]; then + log_error "KvStore sync failed: kv1=$kv1, kv2=$kv2, kv3=$kv3" + test_passed=false + fi + + # Verify ProxyState sync on all nodes + if [[ "$ps1" -lt 1 ]] || [[ "$ps2" -lt 1 ]] || [[ "$ps3" -lt 1 ]]; then + log_error "ProxyState sync failed: ps1=$ps1, ps2=$ps2, ps3=$ps3" + test_passed=false + fi + + # Verify consistency on each node + if [[ "$kv1" -ne "$ps1" ]] || [[ "$kv2" -ne "$ps2" ]] || [[ "$kv3" -ne "$ps3" ]]; then + log_error "Inconsistency detected between KvStore and ProxyState" + test_passed=false + fi + + if [[ "$test_passed" == "true" ]]; then + log_info "Three-node bootnode cluster test PASSED" + return 0 + else + log_info "Sync data from node 1: $(debug_get_sync_data $debug_port1)" + log_info "Sync data from node 2: $(debug_get_sync_data $debug_port2)" + log_info "Sync data from node 3: $(debug_get_sync_data $debug_port3)" + return 1 + fi +} + +# ============================================================================= +# Test 14: Node ID reuse rejection +# ============================================================================= +test_node_id_reuse_rejected() { + log_info "========== Test 14: Node ID Reuse Rejected ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" + + # Start node 1 and node 2, let them sync + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + local admin_port1=13016 + + # Wait for initial sync + log_info "Waiting for initial sync between node 1 and node 2..." + sleep 10 + + # Verify both nodes have synced + if ! has_peer_addr $debug_port1 2; then + log_error "Node 1 missing peer_addr for node 2" + return 1 + fi + if ! has_peer_addr $debug_port2 1; then + log_error "Node 2 missing peer_addr for node 1" + return 1 + fi + log_info "Initial sync completed successfully" + + # Get initial key count on node 1 + local keys_before=$(get_n_keys $admin_port1) + log_info "Keys on node 1 before node 2 restart: $keys_before" + + # Stop node 2 and delete its data (simulating a fresh node trying to reuse the ID) + log_info "Stopping node 2 and deleting its data..." + stop_node 2 + rm -rf "$RUN_DIR/wavekv_node2" + rm -f "$RUN_DIR/gateway-state-node2.json" + + # Restart node 2 - it will have a new UUID but same node_id + log_info "Restarting node 2 with fresh data (new UUID, same node_id)..." + start_node 2 + + # Re-register peers + setup_peers 1 2 + + # Wait for sync attempt + sleep 15 + + # Check node 2's log for UUID mismatch error + local log_file="${LOG_DIR}/${CURRENT_TEST}_node2.log" + if grep -q "UUID mismatch" "$log_file" 2>/dev/null; then + log_info "Found UUID mismatch error in node 2 log (expected)" + else + log_warn "UUID mismatch error not found in log (may still be rejected)" + fi + + # Node 1 should have rejected sync from new node 2 + # Check if node 1's data is still intact (keys should not decrease) + local keys_after=$(get_n_keys $admin_port1) + log_info "Keys on node 1 after node 2 restart: $keys_after" + + # The new node 2 should NOT have received data from node 1 + # because node 1 should reject sync due to UUID mismatch + local kv2=$(get_n_instances $debug_port2) + log_info "Node 2 instances after restart: $kv2" + + # Verify node 1's data is intact + if [[ "$keys_after" -lt "$keys_before" ]]; then + log_error "Node 1 lost data after node 2 restart with reused ID" + return 1 + fi + + # The test passes if: + # 1. Node 1's data is intact + # 2. Either UUID mismatch was logged OR node 2 didn't get full sync + log_info "Node ID reuse rejection test PASSED" + return 0 +} + +# ============================================================================= +# Test 15: Periodic persistence +# ============================================================================= +test_periodic_persistence() { + log_info "========== Test 15: Periodic Persistence ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local debug_port=13015 + local admin_port=13016 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Register some clients to create data + log_info "Registering clients to create data..." + local success_count=0 + for i in $(seq 1 3); do + local key=$(printf "persist%02d123456789012345678901234567890=" "$i") + local response=$(debug_register_cvm $debug_port "$key" "persist_app$i" "persist_inst$i") + if verify_register_response "$response" >/dev/null 2>&1; then + ((success_count++)) + fi + done + log_info "Registered $success_count/3 clients" + + if [[ "$success_count" -ne 3 ]]; then + log_error "Failed to register all clients" + return 1 + fi + + # Get initial key count + local keys_before=$(get_n_keys $admin_port) + log_info "Keys before waiting for persist: $keys_before" + + # Wait for periodic persistence (persist_interval is 5s in test config) + log_info "Waiting for periodic persistence (8s)..." + sleep 8 + + # Check log for periodic persist message + local log_file="${LOG_DIR}/${CURRENT_TEST}_node1.log" + if grep -q "periodic persist completed" "$log_file" 2>/dev/null; then + log_info "Found periodic persist message in log" + else + log_error "Periodic persist message not found in log - test FAILED" + return 1 + fi + + # Stop node + stop_node 1 + + # Check WAL file exists and has content + local wal_file="$RUN_DIR/wavekv_node1/node_1.wal" + if [[ ! -f "$wal_file" ]]; then + log_error "WAL file not found: $wal_file" + return 1 + fi + + local wal_size=$(stat -c%s "$wal_file" 2>/dev/null || stat -f%z "$wal_file" 2>/dev/null) + log_info "WAL file size after periodic persist: $wal_size bytes" + + # Restart node and verify data is recovered + log_info "Restarting node to verify persistence..." + start_node 1 + + local keys_after=$(get_n_keys $admin_port) + log_info "Keys after restart: $keys_after" + + if [[ "$keys_after" -ge "$keys_before" ]]; then + log_info "Periodic persistence test PASSED" + return 0 + else + log_error "Periodic persistence test FAILED: keys_before=$keys_before, keys_after=$keys_after" + return 1 + fi +} + +# ============================================================================= +# Admin RPC helper functions +# ============================================================================= + +# Call Admin.SetNodeUrl RPC +# Usage: admin_set_node_url +admin_set_node_url() { + local admin_port=$1 + local node_id=$2 + local url=$3 + curl -s -X POST "http://localhost:${admin_port}/prpc/Admin.SetNodeUrl" \ + -H "Content-Type: application/json" \ + -d "{\"id\": $node_id, \"url\": \"$url\"}" 2>/dev/null +} + +# Register peers between nodes via Admin RPC +# This is needed since we removed peer_node_ids/peer_urls from config +# Usage: setup_peers +# Example: setup_peers 1 2 3 # Sets up peers between nodes 1, 2, and 3 +setup_peers() { + local node_ids=("$@") + + for src_node in "${node_ids[@]}"; do + local src_admin_port=$((13000 + src_node * 10 + 6)) + + for dst_node in "${node_ids[@]}"; do + if [[ "$src_node" != "$dst_node" ]]; then + local dst_rpc_port=$((13000 + dst_node * 10 + 2)) + local dst_url="https://localhost:${dst_rpc_port}" + admin_set_node_url "$src_admin_port" "$dst_node" "$dst_url" + fi + done + done + + # Wait for peers to be registered + sleep 1 +} + +# Call Admin.SetNodeStatus RPC +# Usage: admin_set_node_status +# status: "up" or "down" +admin_set_node_status() { + local admin_port=$1 + local node_id=$2 + local status=$3 + curl -s -X POST "http://localhost:${admin_port}/prpc/Admin.SetNodeStatus" \ + -H "Content-Type: application/json" \ + -d "{\"id\": $node_id, \"status\": \"$status\"}" 2>/dev/null +} + +# Call Admin.Status RPC to get all nodes +# Usage: admin_get_status +admin_get_status() { + local admin_port=$1 + curl -s -X POST "http://localhost:${admin_port}/prpc/Admin.Status" \ + -H "Content-Type: application/json" \ + -d '{}' 2>/dev/null +} + +# Get peer URL from sync data +# Usage: get_peer_url +get_peer_url_from_sync() { + local debug_port=$1 + local node_id=$2 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + for pa in d.get('peer_addrs', []): + if pa.get('node_id') == $node_id: + print(pa.get('url', '')) + sys.exit(0) + print('') +except: + print('') +" 2>/dev/null +} + +# ============================================================================= +# Test 16: Admin.SetNodeUrl RPC +# ============================================================================= +test_admin_set_node_url() { + log_info "========== Test 16: Admin.SetNodeUrl RPC ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local admin_port=13016 + local debug_port=13015 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Set URL for a new node (node 2) via Admin RPC + local new_url="https://new-node2.example.com:8011" + log_info "Setting node 2 URL via Admin.SetNodeUrl..." + local response=$(admin_set_node_url $admin_port 2 "$new_url") + log_info "SetNodeUrl response: $response" + + # Check if the response contains an error + if echo "$response" | grep -q '"error"'; then + log_error "SetNodeUrl returned error: $response" + return 1 + fi + + # Wait for data to be written + sleep 2 + + # Verify the URL was stored in KvStore + local stored_url=$(get_peer_url_from_sync $debug_port 2) + log_info "Stored URL for node 2: $stored_url" + + if [[ "$stored_url" == "$new_url" ]]; then + log_info "Admin.SetNodeUrl test PASSED" + return 0 + else + log_error "Admin.SetNodeUrl test FAILED: expected '$new_url', got '$stored_url'" + log_info "Sync data: $(debug_get_sync_data $debug_port)" + return 1 + fi +} + +# ============================================================================= +# Test 17: Admin.SetNodeStatus RPC +# ============================================================================= +test_admin_set_node_status() { + log_info "========== Test 17: Admin.SetNodeStatus RPC ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local admin_port=13016 + local debug_port=13015 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # First set a URL for node 2 so we have a peer + admin_set_node_url $admin_port 2 "https://node2.example.com:8011" + sleep 1 + + # Set node 2 status to "down" + log_info "Setting node 2 status to 'down' via Admin.SetNodeStatus..." + local response=$(admin_set_node_status $admin_port 2 "down") + log_info "SetNodeStatus response: $response" + + # Check if the response contains an error + if echo "$response" | grep -q '"error"'; then + log_error "SetNodeStatus returned error: $response" + return 1 + fi + + sleep 1 + + # Set node 2 status back to "up" + log_info "Setting node 2 status to 'up' via Admin.SetNodeStatus..." + response=$(admin_set_node_status $admin_port 2 "up") + log_info "SetNodeStatus response: $response" + + if echo "$response" | grep -q '"error"'; then + log_error "SetNodeStatus returned error: $response" + return 1 + fi + + # Test invalid status + log_info "Testing invalid status..." + response=$(admin_set_node_status $admin_port 2 "invalid") + if echo "$response" | grep -q '"error"'; then + log_info "Invalid status correctly rejected" + else + log_warn "Invalid status was not rejected (may be acceptable)" + fi + + log_info "Admin.SetNodeStatus test PASSED" + return 0 +} + +# ============================================================================= +# Test 18: Node down excluded from RegisterCvm response +# ============================================================================= +test_node_status_register_exclude() { + log_info "========== Test 18: Node Down Excluded from Registration ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local admin_port1=13016 + local admin_port2=13026 + local debug_port1=13015 + + # Wait for sync + sleep 5 + + # Verify debug service is available + if ! check_debug_service $debug_port1; then + log_error "Debug service not available on node 1" + return 1 + fi + + # Set node 2 status to "down" via node 1's admin API + log_info "Setting node 2 status to 'down'..." + admin_set_node_status $admin_port1 2 "down" + sleep 2 + + # Register a client on node 1 + log_info "Registering client on node 1 (node 2 is down)..." + local response=$(debug_register_cvm $debug_port1 "downtest12345678901234567890123456789012=" "downtest_app" "downtest_inst") + log_info "Register response: $response" + + # Verify registration succeeded + local client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + log_info "Registered client with IP: $client_ip" + + # Check gateways list in response - should NOT include node 2 + local has_node2=$(echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + gateways = d.get('gateways', []) + for gw in gateways: + if gw.get('id') == 2: + sys.exit(0) + sys.exit(1) +except: + sys.exit(1) +" && echo "yes" || echo "no") + + if [[ "$has_node2" == "yes" ]]; then + log_error "Node 2 (down) was included in registration response" + log_info "Response: $response" + return 1 + else + log_info "Node 2 (down) correctly excluded from registration response" + fi + + # Set node 2 status back to "up" + log_info "Setting node 2 status to 'up'..." + admin_set_node_status $admin_port1 2 "up" + sleep 2 + + # Register another client + log_info "Registering client on node 1 (node 2 is now up)..." + response=$(debug_register_cvm $debug_port1 "uptest123456789012345678901234567890123=" "uptest_app" "uptest_inst2") + + # Check gateways list - should now include node 2 + has_node2=$(echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + gateways = d.get('gateways', []) + for gw in gateways: + if gw.get('id') == 2: + sys.exit(0) + sys.exit(1) +except: + sys.exit(1) +" && echo "yes" || echo "no") + + if [[ "$has_node2" == "no" ]]; then + log_error "Node 2 (up) was NOT included in registration response" + log_info "Response: $response" + return 1 + else + log_info "Node 2 (up) correctly included in registration response" + fi + + log_info "Node down excluded from registration test PASSED" + return 0 +} + +# ============================================================================= +# Test 19: Node down rejects RegisterCvm requests +# ============================================================================= +test_node_status_register_reject() { + log_info "========== Test 19: Node Down Rejects Registration ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local admin_port=13016 + local debug_port=13015 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Register a client when node is up (should succeed) + log_info "Registering client when node 1 is up..." + local response=$(debug_register_cvm $debug_port "upnode123456789012345678901234567890123=" "upnode_app" "upnode_inst") + local client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed when node was up" + return 1 + fi + log_info "Registration succeeded when node was up (IP: $client_ip)" + + # Set node 1 status to "down" (marking itself as down) + log_info "Setting node 1 status to 'down'..." + admin_set_node_status $admin_port 1 "down" + sleep 2 + + # Try to register a client when node is down (should fail) + log_info "Attempting to register client when node 1 is down..." + response=$(debug_register_cvm $debug_port "downnode12345678901234567890123456789012=" "downnode_app" "downnode_inst") + log_info "Register response: $response" + + # Check if response contains error about node being down + if echo "$response" | grep -qi "error"; then + log_info "Registration correctly rejected when node is down" + if echo "$response" | grep -qi "marked as down"; then + log_info "Error message mentions 'marked as down' (correct)" + fi + else + log_error "Registration was NOT rejected when node is down" + log_info "Response: $response" + return 1 + fi + + # Set node 1 status back to "up" + log_info "Setting node 1 status to 'up'..." + admin_set_node_status $admin_port 1 "up" + sleep 2 + + # Register a client again (should succeed) + log_info "Registering client when node 1 is back up..." + response=$(debug_register_cvm $debug_port "backup123456789012345678901234567890123=" "backup_app" "backup_inst") + client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed when node was back up" + return 1 + fi + log_info "Registration succeeded when node was back up (IP: $client_ip)" + + log_info "Node down rejects registration test PASSED" + return 0 +} + +# ============================================================================= +# Clean command - remove all generated files +# ============================================================================= +clean() { + log_info "Cleaning up generated files..." + + # Kill only test gateway processes (matching our specific config path) + pkill -9 -f "dstack-gateway -c ${SCRIPT_DIR}/${RUN_DIR}/node" >/dev/null 2>&1 || true + pkill -9 -f "dstack-gateway.*${SCRIPT_DIR}/${RUN_DIR}/node" >/dev/null 2>&1 || true + sleep 1 + + # Remove WireGuard interfaces (only our test interfaces need sudo) + sudo ip link delete wavekv-test1 2>/dev/null || true + sudo ip link delete wavekv-test2 2>/dev/null || true + sudo ip link delete wavekv-test3 2>/dev/null || true + + # Remove run directory (contains all generated files including certs) + rm -rf "$RUN_DIR" + + log_info "Cleanup complete" +} + +# ============================================================================= +# Ensure all certificates exist (CA + RPC + proxy) +# ============================================================================= +ensure_certs() { + # Create directories + mkdir -p "$CERTS_DIR" + mkdir -p "$RUN_DIR/certbot/live" + + # Generate CA certificate if not exists + if [[ ! -f "$CERTS_DIR/gateway-ca.key" ]] || [[ ! -f "$CERTS_DIR/gateway-ca.cert" ]]; then + log_info "Creating CA certificate..." + openssl genrsa -out "$CERTS_DIR/gateway-ca.key" 2048 2>/dev/null + openssl req -x509 -new -nodes \ + -key "$CERTS_DIR/gateway-ca.key" \ + -sha256 -days 365 \ + -out "$CERTS_DIR/gateway-ca.cert" \ + -subj "/CN=Test CA/O=WaveKV Test" \ + 2>/dev/null + fi + + # Generate RPC certificate signed by CA if not exists + if [[ ! -f "$CERTS_DIR/gateway-rpc.key" ]] || [[ ! -f "$CERTS_DIR/gateway-rpc.cert" ]]; then + log_info "Creating RPC certificate signed by CA..." + openssl genrsa -out "$CERTS_DIR/gateway-rpc.key" 2048 2>/dev/null + openssl req -new \ + -key "$CERTS_DIR/gateway-rpc.key" \ + -out "$CERTS_DIR/gateway-rpc.csr" \ + -subj "/CN=localhost" \ + 2>/dev/null + # Create certificate with SAN for localhost + cat >"$CERTS_DIR/ext.cnf" </dev/null + rm -f "$CERTS_DIR/gateway-rpc.csr" "$CERTS_DIR/ext.cnf" + fi + + # Generate proxy certificates (for TLS termination) + local proxy_cert_dir="$RUN_DIR/certbot/live" + if [[ ! -f "$proxy_cert_dir/cert.pem" ]] || [[ ! -f "$proxy_cert_dir/key.pem" ]]; then + log_info "Creating proxy certificates..." + openssl req -x509 -newkey rsa:2048 -nodes \ + -keyout "$proxy_cert_dir/key.pem" \ + -out "$proxy_cert_dir/cert.pem" \ + -days 365 \ + -subj "/CN=localhost" \ + 2>/dev/null + fi +} + +# ============================================================================= +# Main +# ============================================================================= +main() { + # Handle clean command + if [[ "${1:-}" == "clean" ]]; then + clean + exit 0 + fi + + # Handle cfg command - generate node configuration + if [[ "${1:-}" == "cfg" ]]; then + local node_id="${2:-}" + if [[ -z "$node_id" ]]; then + log_error "Usage: $0 cfg " + log_info "Example: $0 cfg 1" + exit 1 + fi + + # Ensure certificates exist + ensure_certs + + # Generate config for the specified node + generate_config "$node_id" + log_info "Configuration generated: $RUN_DIR/node${node_id}.toml" + exit 0 + fi + + # Handle ls command - list all test cases + if [[ "${1:-}" == "ls" ]]; then + echo "Available test cases:" + echo "" + echo "Quick tests:" + echo " test_persistence - Single node persistence" + echo " test_status_endpoint - Status endpoint structure" + echo " test_prpc_register - prpc DebugRegisterCvm endpoint" + echo " test_prpc_info - prpc Info endpoint" + echo " test_wal_integrity - WAL file integrity" + echo "" + echo "Sync tests:" + echo " test_multi_node_sync - Multi-node sync" + echo " test_node_recovery - Node recovery after disconnect" + echo " test_cross_node_data_sync - Cross-node data sync verification" + echo "" + echo "Advanced tests:" + echo " test_client_registration_persistence - Client registration and persistence" + echo " test_stress_writes - Stress test - multiple writes" + echo " test_network_partition - Network partition simulation" + echo " test_three_node_cluster - Three-node cluster" + echo " test_three_node_bootnode - Three-node cluster with bootnode" + echo " test_node_id_reuse_rejected - Node ID reuse rejection" + echo " test_periodic_persistence - Periodic persistence" + echo "" + echo "Admin RPC tests:" + echo " test_admin_set_node_url - Admin.SetNodeUrl RPC" + echo " test_admin_set_node_status - Admin.SetNodeStatus RPC" + echo " test_node_status_register_exclude - Node down excluded from registration" + echo " test_node_status_register_reject - Node down rejects registration" + echo "" + echo "Usage:" + echo " $0 - Run all tests" + echo " $0 quick - Run quick tests only" + echo " $0 sync - Run sync tests only" + echo " $0 advanced - Run advanced tests only" + echo " $0 admin - Run admin RPC tests only" + echo " $0 case - Run specific test case" + echo " $0 ls - List all test cases" + echo " $0 clean - Clean up generated files" + exit 0 + fi + + # Handle case command - run specific test case + if [[ "${1:-}" == "case" ]]; then + local test_case="${2:-}" + if [[ -z "$test_case" ]]; then + log_error "Usage: $0 case " + log_info "Run '$0 ls' to see all available test cases" + exit 1 + fi + + # Check if gateway binary exists + if [[ ! -f "$GATEWAY_BIN" ]]; then + log_error "Gateway binary not found: $GATEWAY_BIN" + log_info "Please run: cargo build --release" + exit 1 + fi + + # Ensure certificates exist + ensure_certs + + # Check if test function exists + if ! declare -f "$test_case" >/dev/null; then + log_error "Test case not found: $test_case" + log_info "Use '$0 case' to see available test cases" + exit 1 + fi + + # Run the specific test + log_info "Running test case: $test_case" + CURRENT_TEST="$test_case" + if $test_case; then + log_info "Test PASSED: $test_case" + cleanup + exit 0 + else + log_error "Test FAILED: $test_case" + cleanup + exit 1 + fi + fi + + log_info "Starting WaveKV integration tests..." + + if [[ ! -f "$GATEWAY_BIN" ]]; then + log_error "Gateway binary not found: $GATEWAY_BIN" + log_info "Please run: cargo build --release" + exit 1 + fi + + # Ensure all certificates exist (RPC + proxy) + ensure_certs + + local failed=0 + local passed=0 + local failed_tests=() + + run_test() { + local test_name=$1 + CURRENT_TEST="$test_name" + if $test_name; then + ((passed++)) + else + ((failed++)) + failed_tests+=("$test_name") + fi + cleanup + } + + # Run selected test or all tests + local test_filter="${1:-all}" + + if [[ "$test_filter" == "all" ]] || [[ "$test_filter" == "quick" ]]; then + run_test test_persistence + run_test test_status_endpoint + run_test test_prpc_register + run_test test_prpc_info + run_test test_wal_integrity + fi + + if [[ "$test_filter" == "all" ]] || [[ "$test_filter" == "sync" ]]; then + run_test test_multi_node_sync + run_test test_node_recovery + run_test test_cross_node_data_sync + fi + + if [[ "$test_filter" == "all" ]] || [[ "$test_filter" == "advanced" ]]; then + run_test test_client_registration_persistence + run_test test_stress_writes + run_test test_network_partition + run_test test_three_node_cluster + run_test test_three_node_bootnode + run_test test_node_id_reuse_rejected + run_test test_periodic_persistence + fi + + if [[ "$test_filter" == "all" ]] || [[ "$test_filter" == "admin" ]]; then + run_test test_admin_set_node_url + run_test test_admin_set_node_status + run_test test_node_status_register_exclude + run_test test_node_status_register_reject + fi + + echo "" + log_info "==========================================" + log_info "Tests passed: $passed" + if [[ $failed -gt 0 ]]; then + log_error "Tests failed: $failed" + echo "" + log_error "Failed test cases:" + for test_name in "${failed_tests[@]}"; do + log_error " - $test_name" + done + echo "" + log_info "To rerun a failed test:" + log_info " $0 case " + log_info "Example:" + if [[ ${#failed_tests[@]} -gt 0 ]]; then + log_info " $0 case ${failed_tests[0]}" + fi + fi + log_info "==========================================" + + return $failed +} + +# Run if executed directly +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + main "$@" +fi diff --git a/guest-agent/src/main.rs b/guest-agent/src/main.rs index 3183e61f..de4fb3c5 100644 --- a/guest-agent/src/main.rs +++ b/guest-agent/src/main.rs @@ -205,7 +205,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let args = Args::parse(); let figment = config::load_config_figment(args.config.as_deref()); diff --git a/kms/src/main.rs b/kms/src/main.rs index 8584eec9..eddfbdc9 100644 --- a/kms/src/main.rs +++ b/kms/src/main.rs @@ -82,7 +82,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let args = Args::parse(); diff --git a/sdk/rust/tests/test_tappd_client.rs b/sdk/rust/tests/test_tappd_client.rs index 0d363b94..b5afb571 100644 --- a/sdk/rust/tests/test_tappd_client.rs +++ b/sdk/rust/tests/test_tappd_client.rs @@ -11,27 +11,18 @@ use std::env; async fn test_tappd_client_creation() { // Test client creation with default endpoint let _client = TappdClient::new(None); - - // This should succeed without panicking - assert!(true); } #[tokio::test] async fn test_tappd_client_with_custom_endpoint() { // Test client creation with custom endpoint let _client = TappdClient::new(Some("/custom/path/tappd.sock")); - - // This should succeed without panicking - assert!(true); } #[tokio::test] async fn test_tappd_client_with_http_endpoint() { // Test client creation with HTTP endpoint let _client = TappdClient::new(Some("http://localhost:8080")); - - // This should succeed without panicking - assert!(true); } // Integration tests that require a running tappd service diff --git a/supervisor/client/src/main.rs b/supervisor/client/src/main.rs index b993984c..c3b13abd 100644 --- a/supervisor/client/src/main.rs +++ b/supervisor/client/src/main.rs @@ -50,7 +50,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let cli = Cli::parse(); diff --git a/supervisor/src/main.rs b/supervisor/src/main.rs index 752b511c..291ef320 100644 --- a/supervisor/src/main.rs +++ b/supervisor/src/main.rs @@ -90,10 +90,12 @@ fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .with_writer(file) + .with_ansi(false) .init(); } else { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) + .with_ansi(false) .init(); } #[cfg(unix)] diff --git a/tdx-attest/src/dummy.rs b/tdx-attest/src/dummy.rs index 65314a4f..77991e37 100644 --- a/tdx-attest/src/dummy.rs +++ b/tdx-attest/src/dummy.rs @@ -2,7 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 -use cc_eventlog::TdxEventLog; use num_enum::FromPrimitive; use thiserror::Error; @@ -48,10 +47,7 @@ pub fn extend_rtmr(_index: u32, _event_type: u32, _digest: [u8; 48]) -> Result<( pub fn get_report(_report_data: &TdxReportData) -> Result { Err(TdxAttestError::NotSupported) } -pub fn get_quote( - _report_data: &TdxReportData, - _att_key_id_list: Option<&[TdxUuid]>, -) -> Result<(TdxUuid, Vec)> { +pub fn get_quote(_report_data: &TdxReportData) -> Result> { let _ = _report_data; Err(TdxAttestError::NotSupported) } diff --git a/tools/mock-cf-dns-api/Dockerfile b/tools/mock-cf-dns-api/Dockerfile new file mode 100644 index 00000000..081b9625 --- /dev/null +++ b/tools/mock-cf-dns-api/Dockerfile @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +FROM python:3.12-slim + +WORKDIR /app + +# Install dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application +COPY app.py . + +# Environment variables +ENV PORT=8080 +ENV DEBUG=false + +# Expose port +EXPOSE 8080 + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8080/health || exit 1 + +# Run with gunicorn for production +CMD ["gunicorn", "--bind", "0.0.0.0:8080", "--workers", "2", "--threads", "4", "app:app"] diff --git a/tools/mock-cf-dns-api/app.py b/tools/mock-cf-dns-api/app.py new file mode 100644 index 00000000..5703db79 --- /dev/null +++ b/tools/mock-cf-dns-api/app.py @@ -0,0 +1,848 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Mock Cloudflare DNS API Server + +A mock server that simulates Cloudflare's DNS API for testing purposes. +Supports the following endpoints used by certbot: +- POST /client/v4/zones/{zone_id}/dns_records - Create DNS record +- GET /client/v4/zones/{zone_id}/dns_records - List DNS records +- DELETE /client/v4/zones/{zone_id}/dns_records/{record_id} - Delete DNS record +""" + +import os +import uuid +import time +import json +from datetime import datetime +from flask import Flask, request, jsonify, render_template_string +from functools import wraps + +app = Flask(__name__) + +# In-memory storage for DNS records +# Structure: {zone_id: {record_id: record_data}} +dns_records = {} + +# Request/Response logs for debugging +request_logs = [] +MAX_LOGS = 100 + +# Valid API tokens (for testing, accept any non-empty token or use env var) +VALID_TOKENS = os.environ.get("CF_API_TOKENS", "").split(",") if os.environ.get("CF_API_TOKENS") else None + + +def log_request(zone_id, method, path, req_data, resp_data, status_code): + """Log API requests for the management UI.""" + log_entry = { + "timestamp": datetime.now().isoformat(), + "zone_id": zone_id, + "method": method, + "path": path, + "request": req_data, + "response": resp_data, + "status_code": status_code, + } + request_logs.insert(0, log_entry) + if len(request_logs) > MAX_LOGS: + request_logs.pop() + + +def generate_record_id(): + """Generate a Cloudflare-style record ID.""" + return uuid.uuid4().hex[:32] + + +def get_current_time(): + """Get current time in Cloudflare format.""" + return datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.000000Z") + + +def verify_auth(f): + """Decorator to verify Bearer token authentication.""" + @wraps(f) + def decorated(*args, **kwargs): + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return jsonify({ + "success": False, + "errors": [{"code": 10000, "message": "Authentication error"}], + "messages": [], + "result": None + }), 401 + + token = auth_header[7:] # Remove "Bearer " prefix + + # If VALID_TOKENS is set, validate against it; otherwise accept any token + if VALID_TOKENS and token not in VALID_TOKENS: + return jsonify({ + "success": False, + "errors": [{"code": 10000, "message": "Invalid API token"}], + "messages": [], + "result": None + }), 403 + + return f(*args, **kwargs) + return decorated + + +def cf_response(result, success=True, errors=None, messages=None): + """Create a Cloudflare-style API response.""" + return { + "success": success, + "errors": errors or [], + "messages": messages or [], + "result": result + } + + +def cf_error(message, code=1000): + """Create a Cloudflare-style error response.""" + return cf_response(None, success=False, errors=[{"code": code, "message": message}]) + + +# ==================== DNS Record Endpoints ==================== + +@app.route("/client/v4/zones//dns_records", methods=["POST"]) +@verify_auth +def create_dns_record(zone_id): + """Create a new DNS record.""" + data = request.get_json() + + if not data: + resp = cf_error("Invalid request body") + log_request(zone_id, "POST", f"/zones/{zone_id}/dns_records", None, resp, 400) + return jsonify(resp), 400 + + record_type = data.get("type") + name = data.get("name") + + if not record_type or not name: + resp = cf_error("Missing required fields: type, name") + log_request(zone_id, "POST", f"/zones/{zone_id}/dns_records", data, resp, 400) + return jsonify(resp), 400 + + # Initialize zone if not exists + if zone_id not in dns_records: + dns_records[zone_id] = {} + + record_id = generate_record_id() + now = get_current_time() + + # Build record based on type + record = { + "id": record_id, + "zone_id": zone_id, + "zone_name": f"zone-{zone_id[:8]}.example.com", + "name": name, + "type": record_type, + "ttl": data.get("ttl", 1), + "proxied": data.get("proxied", False), + "proxiable": False, + "locked": False, + "created_on": now, + "modified_on": now, + "meta": { + "auto_added": False, + "managed_by_apps": False, + "managed_by_argo_tunnel": False + } + } + + # Handle different record types + if record_type == "TXT": + record["content"] = data.get("content", "") + elif record_type == "CAA": + caa_data = data.get("data", {}) + record["data"] = caa_data + # Format content as Cloudflare does + flags = caa_data.get("flags", 0) + tag = caa_data.get("tag", "") + value = caa_data.get("value", "") + record["content"] = f'{flags} {tag} "{value}"' + elif record_type == "A": + record["content"] = data.get("content", "") + elif record_type == "AAAA": + record["content"] = data.get("content", "") + elif record_type == "CNAME": + record["content"] = data.get("content", "") + else: + record["content"] = data.get("content", "") + + dns_records[zone_id][record_id] = record + + resp = cf_response(record) + log_request(zone_id, "POST", f"/zones/{zone_id}/dns_records", data, resp, 200) + + print(f"[CREATE] Zone: {zone_id}, Record: {record_id}, Type: {record_type}, Name: {name}") + + return jsonify(resp), 200 + + +@app.route("/client/v4/zones//dns_records", methods=["GET"]) +@verify_auth +def list_dns_records(zone_id): + """List DNS records for a zone.""" + zone_records = dns_records.get(zone_id, {}) + records_list = list(zone_records.values()) + + # Filter by type if specified + record_type = request.args.get("type") + if record_type: + records_list = [r for r in records_list if r["type"] == record_type] + + # Filter by name if specified + name = request.args.get("name") + if name: + records_list = [r for r in records_list if r["name"] == name] + + # Get pagination params + page = int(request.args.get("page", 1)) + per_page = int(request.args.get("per_page", 100)) + + # Pagination + total_count = len(records_list) + total_pages = max(1, (total_count + per_page - 1) // per_page) + start_idx = (page - 1) * per_page + end_idx = start_idx + per_page + page_records = records_list[start_idx:end_idx] + + resp = { + "success": True, + "errors": [], + "messages": [], + "result": page_records, + "result_info": { + "page": page, + "per_page": per_page, + "count": len(page_records), + "total_count": total_count, + "total_pages": total_pages + } + } + log_request(zone_id, "GET", f"/zones/{zone_id}/dns_records", dict(request.args), resp, 200) + + return jsonify(resp), 200 + + +@app.route("/client/v4/zones//dns_records/", methods=["GET"]) +@verify_auth +def get_dns_record(zone_id, record_id): + """Get a specific DNS record.""" + zone_records = dns_records.get(zone_id, {}) + record = zone_records.get(record_id) + + if not record: + resp = cf_error("Record not found", 81044) + log_request(zone_id, "GET", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404) + return jsonify(resp), 404 + + resp = cf_response(record) + log_request(zone_id, "GET", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 200) + + return jsonify(resp), 200 + + +@app.route("/client/v4/zones//dns_records/", methods=["PUT"]) +@verify_auth +def update_dns_record(zone_id, record_id): + """Update a DNS record.""" + zone_records = dns_records.get(zone_id, {}) + record = zone_records.get(record_id) + + if not record: + resp = cf_error("Record not found", 81044) + log_request(zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404) + return jsonify(resp), 404 + + data = request.get_json() + if not data: + resp = cf_error("Invalid request body") + log_request(zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 400) + return jsonify(resp), 400 + + # Update allowed fields + for field in ["name", "type", "content", "ttl", "proxied", "data"]: + if field in data: + record[field] = data[field] + + record["modified_on"] = get_current_time() + + resp = cf_response(record) + log_request(zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", data, resp, 200) + + print(f"[UPDATE] Zone: {zone_id}, Record: {record_id}") + + return jsonify(resp), 200 + + +@app.route("/client/v4/zones//dns_records/", methods=["DELETE"]) +@verify_auth +def delete_dns_record(zone_id, record_id): + """Delete a DNS record.""" + zone_records = dns_records.get(zone_id, {}) + + if record_id not in zone_records: + resp = cf_error("Record not found", 81044) + log_request(zone_id, "DELETE", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404) + return jsonify(resp), 404 + + del zone_records[record_id] + + resp = cf_response({"id": record_id}) + log_request(zone_id, "DELETE", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 200) + + print(f"[DELETE] Zone: {zone_id}, Record: {record_id}") + + return jsonify(resp), 200 + + +# ==================== Zone Endpoints ==================== + +# Pre-configured zones for testing +# Can be configured via MOCK_ZONES environment variable (JSON format) +# Example: MOCK_ZONES='[{"id":"zone123","name":"example.com"},{"id":"zone456","name":"test.local"}]' +DEFAULT_ZONES = [ + {"id": "mock-zone-test-local", "name": "test.local"}, + {"id": "mock-zone-example-com", "name": "example.com"}, + {"id": "mock-zone-test0-local", "name": "test0.local"}, + {"id": "mock-zone-test1-local", "name": "test1.local"}, + {"id": "mock-zone-test2-local", "name": "test2.local"}, +] + + +def get_configured_zones(): + """Get zones from environment or use defaults.""" + zones_json = os.environ.get("MOCK_ZONES") + if zones_json: + try: + return json.loads(zones_json) + except json.JSONDecodeError: + print(f"Warning: Invalid MOCK_ZONES JSON, using defaults") + return DEFAULT_ZONES + + +@app.route("/client/v4/zones", methods=["GET"]) +@verify_auth +def list_zones(): + """List all zones (paginated).""" + page = int(request.args.get("page", 1)) + per_page = int(request.args.get("per_page", 50)) + name_filter = request.args.get("name") + + zones = get_configured_zones() + + # Filter by name if specified + if name_filter: + zones = [z for z in zones if z["name"] == name_filter] + + # Build full zone objects + full_zones = [] + for z in zones: + full_zones.append({ + "id": z["id"], + "name": z["name"], + "status": "active", + "paused": False, + "type": "full", + "development_mode": 0, + "name_servers": [ + "ns1.mock-cloudflare.com", + "ns2.mock-cloudflare.com" + ], + "created_on": "2024-01-01T00:00:00.000000Z", + "modified_on": get_current_time(), + }) + + # Pagination + total_count = len(full_zones) + total_pages = max(1, (total_count + per_page - 1) // per_page) + start_idx = (page - 1) * per_page + end_idx = start_idx + per_page + page_zones = full_zones[start_idx:end_idx] + + result = { + "success": True, + "errors": [], + "messages": [], + "result": page_zones, + "result_info": { + "page": page, + "per_page": per_page, + "count": len(page_zones), + "total_count": total_count, + "total_pages": total_pages + } + } + + log_request("*", "GET", "/zones", dict(request.args), result, 200) + print(f"[LIST ZONES] page={page}, per_page={per_page}, count={len(page_zones)}, total={total_count}") + + return jsonify(result), 200 + + +@app.route("/client/v4/zones/", methods=["GET"]) +@verify_auth +def get_zone(zone_id): + """Get zone details (mock).""" + # Try to find zone in configured zones + zones = get_configured_zones() + zone_info = next((z for z in zones if z["id"] == zone_id), None) + + if zone_info: + zone_name = zone_info["name"] + else: + # Fallback for unknown zone IDs + zone_name = f"zone-{zone_id[:8]}.example.com" + + zone = { + "id": zone_id, + "name": zone_name, + "status": "active", + "paused": False, + "type": "full", + "development_mode": 0, + "name_servers": [ + "ns1.mock-cloudflare.com", + "ns2.mock-cloudflare.com" + ], + "created_on": "2024-01-01T00:00:00.000000Z", + "modified_on": get_current_time(), + } + + resp = cf_response(zone) + log_request(zone_id, "GET", f"/zones/{zone_id}", None, resp, 200) + + return jsonify(resp), 200 + + +# ==================== Management UI ==================== + +MANAGEMENT_HTML = """ + + + + + + Mock Cloudflare DNS API - Management + + + +
+
+

Mock Cloudflare DNS API

+

Testing server for ACME DNS-01 challenges

+
+ +
+
+

{{ zone_count }}

+

Zones

+
+
+

{{ record_count }}

+

DNS Records

+
+
+

{{ request_count }}

+

API Requests

+
+
+ +
+
+

DNS Records

+ +
+
+ {% if records %} + + + + + + + + + + + + + {% for record in records %} + + + + + + + + + {% endfor %} + +
Zone IDTypeNameContentCreatedActions
{{ record.zone_id[:12] }}...{{ record.type }}{{ record.name }}{{ record.content }}{{ record.created_on[:19] }} + +
+ {% else %} +
+

No DNS records yet. Records created via API will appear here.

+
+ {% endif %} +
+
+ +
+
+

Recent API Requests

+ +
+
+ {% if logs %} + {% for log in logs %} +
+ {{ log.timestamp }} + {{ log.method }} + {{ log.path }} + + ({{ log.status_code }}) + + {% if log.request %} +
+ Request/Response +
Request: {{ log.request | tojson(indent=2) }}
+
Response: {{ log.response | tojson(indent=2) }}
+
+ {% endif %} +
+ {% endfor %} + {% else %} +
+

No API requests yet.

+
+ {% endif %} +
+
+ +
+
+

API Usage

+
+
+

Base URL:

+

+            
+
+
+ + + + + + +""" + + +@app.route("/") +def management_ui(): + """Render the management UI.""" + all_records = [] + for zone_id, records in dns_records.items(): + all_records.extend(records.values()) + + # Sort by created time, newest first + all_records.sort(key=lambda r: r.get("created_on", ""), reverse=True) + + return render_template_string( + MANAGEMENT_HTML, + zone_count=len(dns_records), + record_count=sum(len(r) for r in dns_records.values()), + request_count=len(request_logs), + records=all_records, + logs=request_logs[:20], + port=os.environ.get("PORT", 8080) + ) + + +# ==================== Management API ==================== + +@app.route("/api/records", methods=["DELETE"]) +def clear_all_records(): + """Clear all DNS records.""" + dns_records.clear() + return jsonify({"success": True}) + + +@app.route("/api/records//", methods=["DELETE"]) +def delete_record_ui(zone_id, record_id): + """Delete a specific record from UI.""" + if zone_id in dns_records and record_id in dns_records[zone_id]: + del dns_records[zone_id][record_id] + return jsonify({"success": True}) + + +@app.route("/api/logs", methods=["DELETE"]) +def clear_logs(): + """Clear request logs.""" + request_logs.clear() + return jsonify({"success": True}) + + +@app.route("/api/records", methods=["GET"]) +def get_all_records(): + """Get all records as JSON.""" + all_records = [] + for zone_id, records in dns_records.items(): + all_records.extend(records.values()) + return jsonify(all_records) + + +@app.route("/health") +def health(): + """Health check endpoint.""" + return jsonify({"status": "healthy", "records": sum(len(r) for r in dns_records.values())}) + + +if __name__ == "__main__": + port = int(os.environ.get("PORT", 8080)) + debug = os.environ.get("DEBUG", "false").lower() == "true" + + print(f""" + ╔═══════════════════════════════════════════════════════════════╗ + ║ Mock Cloudflare DNS API Server ║ + ╠═══════════════════════════════════════════════════════════════╣ + ║ Management UI: http://localhost:{port}/ ║ + ║ API Base URL: http://localhost:{port}/client/v4 ║ + ╚═══════════════════════════════════════════════════════════════╝ + """) + + app.run(host="0.0.0.0", port=port, debug=debug) diff --git a/tools/mock-cf-dns-api/docker-compose.yml b/tools/mock-cf-dns-api/docker-compose.yml new file mode 100644 index 00000000..951bd68b --- /dev/null +++ b/tools/mock-cf-dns-api/docker-compose.yml @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: 2024-2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +version: "3.8" + +services: + mock-cf-dns-api: + image: kvin/mock-cf-dns-api:latest + ports: + - "8080:8080" + environment: + - PORT=8080 + - DEBUG=false + # Optional: comma-separated list of valid API tokens + # - CF_API_TOKENS=token1,token2 + restart: unless-stopped diff --git a/tools/mock-cf-dns-api/requirements.txt b/tools/mock-cf-dns-api/requirements.txt new file mode 100644 index 00000000..a36229c7 --- /dev/null +++ b/tools/mock-cf-dns-api/requirements.txt @@ -0,0 +1,2 @@ +flask>=3.0.0 +gunicorn>=21.0.0 diff --git a/verifier/src/main.rs b/verifier/src/main.rs index 1bd72a91..c3e5a4a6 100644 --- a/verifier/src/main.rs +++ b/verifier/src/main.rs @@ -161,7 +161,11 @@ async fn run_oneshot(file_path: &str, config: &Config) -> anyhow::Result<()> { #[rocket::main] async fn main() -> Result<()> { - tracing_subscriber::fmt::try_init().ok(); + { + use tracing_subscriber::{fmt, EnvFilter}; + let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); + fmt().with_env_filter(filter).with_ansi(false).init(); + } let cli = Cli::parse(); diff --git a/vmm/src/main.rs b/vmm/src/main.rs index bb1f7873..40033ed5 100644 --- a/vmm/src/main.rs +++ b/vmm/src/main.rs @@ -159,7 +159,7 @@ async fn main() -> Result<()> { { use tracing_subscriber::{fmt, EnvFilter}; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - fmt().with_env_filter(filter).init(); + fmt().with_env_filter(filter).with_ansi(false).init(); } let args = Args::parse();