diff --git a/tools/oprt2/pkg/attunehooks/authenticators/authenticator.go b/tools/oprt2/pkg/attunehooks/authenticators/authenticator.go index 7f117461..8f1e84c2 100644 --- a/tools/oprt2/pkg/attunehooks/authenticators/authenticator.go +++ b/tools/oprt2/pkg/attunehooks/authenticators/authenticator.go @@ -1,13 +1,24 @@ package authenticators import ( + "context" "errors" + "log/slog" + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/attunehooks/authenticators/mtls" + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/attunehooks/authenticators/token" "github.com/gravitational/shared-workflows/tools/oprt2/pkg/commandrunner" "github.com/gravitational/shared-workflows/tools/oprt2/pkg/config" ) // FromConfig builds an Attune authenticator hook from the provided config. -func FromConfig(config config.Authenticator) (commandrunner.Hook, error) { - return nil, errors.New("not implemented") +func FromConfig(ctx context.Context, config config.Authenticator, logger *slog.Logger) (commandrunner.Hook, error) { + switch { + case config.MTLS != nil: + return mtls.FromConfig(ctx, config.MTLS, logger) + case config.Token != nil: + return token.FromConfig(config.Token) + default: + return nil, errors.New("no or unknown Attune authenticator specified") + } } diff --git a/tools/oprt2/pkg/attunehooks/authenticators/mtls/authenticator.go b/tools/oprt2/pkg/attunehooks/authenticators/mtls/authenticator.go new file mode 100644 index 00000000..ec12dee8 --- /dev/null +++ b/tools/oprt2/pkg/attunehooks/authenticators/mtls/authenticator.go @@ -0,0 +1,163 @@ +/* + * Copyright 2025 Gravitational, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mtls + +import ( + "context" + "fmt" + "log/slog" + "net" + "net/url" + "os/exec" + + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/attunehooks/authenticators/mtls/certprovider" + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/attunehooks/authenticators/mtls/proxy" + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/commandrunner" + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/logging" +) + +// Authenticator authenticates with the Attune control plane with mTLS authentication. +// It does this by creating a local TCP proxy that wraps the connection in TLS, forwarding +// it to a reverse proxy in front of Attune. The reverse proxy handles authentication, and +// strips the TLS layer. +type Authenticator struct { + attuneEndpointHost string + attuneEndpointPort string + certprovider certprovider.Provider + logger *slog.Logger + + // State vars + stopProxy func() error + proxyAddress string +} + +var _ commandrunner.Hook = (*Authenticator)(nil) + +// Authenticator creates a new Authenticator. +func NewAuthenticator(ctx context.Context, attuneEndpoint string, certprovider certprovider.Provider, opts ...AuthenticatorOption) (a *Authenticator, err error) { + a = &Authenticator{ + certprovider: certprovider, + logger: logging.DiscardLogger, + } + + if err := a.setHostPort(attuneEndpoint); err != nil { + return nil, err + } + + for _, opt := range opts { + opt(a) + } + + if err := a.setup(ctx); err != nil { + return nil, err + } + + return a, nil +} + +// Name is the name of the authenticator. +func (a *Authenticator) Name() string { + return "mTLS" +} + +func (a *Authenticator) setHostPort(attuneEndpoint string) error { + attuneEndpointURL, err := url.Parse(attuneEndpoint) + if err == nil { + a.attuneEndpointHost = attuneEndpointURL.Hostname() + if a.attuneEndpointHost == "" { + return fmt.Errorf("the Attune endpoint does not contain a hostname: %q", attuneEndpoint) + } + + a.attuneEndpointPort = attuneEndpointURL.Port() + if a.attuneEndpointPort != "" { + return nil + } + + switch attuneEndpointURL.Scheme { + case "https": + a.attuneEndpointPort = "443" + return nil + case "http": + a.attuneEndpointPort = "80" + return nil + } + } + + if host, port, err := net.SplitHostPort(attuneEndpoint); err == nil { + a.attuneEndpointHost = host + a.attuneEndpointPort = port + return nil + } + + return fmt.Errorf("failed to parse Attune endpoint: %q", attuneEndpoint) +} + +func (a *Authenticator) setup(ctx context.Context) error { + // Start a TCP proxy + proxyServeCtx, cancelProxyServe := context.WithCancel(ctx) + t2tp := proxy.NewTCP2TLS(a.attuneEndpointHost, a.attuneEndpointPort, proxy.WithLogger(a.logger), proxy.WithClientCertificateProvider(a.certprovider)) + proxyServeErr := make(chan error) + go func() { + defer close(proxyServeErr) + proxyServeErr <- t2tp.ListenAndServe(proxyServeCtx) + }() + + a.stopProxy = func() error { + cancelProxyServe() + actualProxyServeErr := <-proxyServeErr + if actualProxyServeErr != nil { + actualProxyServeErr = fmt.Errorf("the TLS2TCP proxy failed while serving: %w", actualProxyServeErr) + } + return actualProxyServeErr + } + + proxyAddress, err := t2tp.GetAddress(ctx) + if err != nil { + return fmt.Errorf("failed to get TLS2TCP listening address: %w", err) + } + a.proxyAddress = proxyAddress.String() + + return nil +} + +// Command adds mTLS authentication to the Attune command. +func (a *Authenticator) Command(_ context.Context, cmd *exec.Cmd) error { + cmd.Env = append( + cmd.Env, + // This value is only here because the Attune CLI requires it to be set. It is + // meaningless, and the ingres gateway replaces on backend requests. + "ATTUNE_API_TOKEN=dummy-value", + // This must be HTTP to avoid dealing with trust issues, and because the proxy is + // only aware of TCP, downwards. The proxy always binds to localhost anyway, so + // there isn't a security risk here that we are concerned about. + "ATTUNE_API_ENDPOINT=http://"+a.proxyAddress, + ) + return nil +} + +// Close closes the authenticator. +func (a *Authenticator) Close(ctx context.Context) error { + if a.stopProxy == nil { + return nil + } + + if err := a.stopProxy(); err != nil { + return fmt.Errorf("failed to stop TCP2TLS proxy: %w", err) + } + + return nil +} diff --git a/tools/oprt2/pkg/attunehooks/authenticators/certprovider/README.md b/tools/oprt2/pkg/attunehooks/authenticators/mtls/certprovider/README.md similarity index 100% rename from tools/oprt2/pkg/attunehooks/authenticators/certprovider/README.md rename to tools/oprt2/pkg/attunehooks/authenticators/mtls/certprovider/README.md diff --git a/tools/oprt2/pkg/attunehooks/authenticators/certprovider/provider.go b/tools/oprt2/pkg/attunehooks/authenticators/mtls/certprovider/provider.go similarity index 77% rename from tools/oprt2/pkg/attunehooks/authenticators/certprovider/provider.go rename to tools/oprt2/pkg/attunehooks/authenticators/mtls/certprovider/provider.go index 434690c6..ac656fac 100644 --- a/tools/oprt2/pkg/attunehooks/authenticators/certprovider/provider.go +++ b/tools/oprt2/pkg/attunehooks/authenticators/mtls/certprovider/provider.go @@ -19,6 +19,9 @@ package certprovider import ( "context" "crypto/tls" + "errors" + + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/config" ) // Provider provides a certificate for client authentication. @@ -26,3 +29,8 @@ type Provider interface { // Gets a keypair for use with mTLS authentication. GetClientCertificate(context.Context) (*tls.Certificate, error) } + +// FromConfig builds a cert provider from the provided config. +func FromConfig(config config.CertificateProvider) (Provider, error) { + return nil, errors.New("not implemented") +} diff --git a/tools/oprt2/pkg/attunehooks/authenticators/mtls/config.go b/tools/oprt2/pkg/attunehooks/authenticators/mtls/config.go new file mode 100644 index 00000000..f850105f --- /dev/null +++ b/tools/oprt2/pkg/attunehooks/authenticators/mtls/config.go @@ -0,0 +1,40 @@ +/* + * Copyright 2025 Gravitational, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mtls + +import ( + "context" + "fmt" + "log/slog" + + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/attunehooks/authenticators/mtls/certprovider" + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/config" +) + +func FromConfig(ctx context.Context, config *config.MTLSAuthenticator, logger *slog.Logger) (*Authenticator, error) { + certProvider, err := certprovider.FromConfig(config.CertificateSource) + if err != nil { + return nil, fmt.Errorf("failed to get mTLS certificate source: %w", err) + } + + authenticator, err := NewAuthenticator(ctx, config.Endpoint, certProvider, WithLogger(logger)) + if err != nil { + return nil, fmt.Errorf("failed to create mTLS authenticator: %w", err) + } + + return authenticator, nil +} diff --git a/tools/oprt2/pkg/attunehooks/authenticators/mtls/options.go b/tools/oprt2/pkg/attunehooks/authenticators/mtls/options.go new file mode 100644 index 00000000..4ea3702a --- /dev/null +++ b/tools/oprt2/pkg/attunehooks/authenticators/mtls/options.go @@ -0,0 +1,34 @@ +/* + * Copyright 2025 Gravitational, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mtls + +import ( + "log/slog" + + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/logging" +) + +type AuthenticatorOption func(a *Authenticator) + +func WithLogger(logger *slog.Logger) AuthenticatorOption { + return func(a *Authenticator) { + if logger == nil { + logger = logging.DiscardLogger + } + a.logger = logger + } +} diff --git a/tools/oprt2/pkg/attunehooks/authenticators/mtls/proxy/options.go b/tools/oprt2/pkg/attunehooks/authenticators/mtls/proxy/options.go new file mode 100644 index 00000000..a8243fcf --- /dev/null +++ b/tools/oprt2/pkg/attunehooks/authenticators/mtls/proxy/options.go @@ -0,0 +1,43 @@ +/* + * Copyright 2025 Gravitational, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package proxy + +import ( + "log/slog" + + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/attunehooks/authenticators/mtls/certprovider" + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/logging" +) + +type TCP2TLSOption func(*TCP2TLS) + +// WithClientCertificateProvider configures the proxy to use mTLS authentication via certs provided by the provider. +func WithClientCertificateProvider(provider certprovider.Provider) TCP2TLSOption { + return func(t2t *TCP2TLS) { + t2t.clientCertProvider = provider + } +} + +// WithLogger configures the proxy with the provided logger. +func WithLogger(logger *slog.Logger) TCP2TLSOption { + return func(t2t *TCP2TLS) { + if logger == nil { + logger = logging.DiscardLogger + } + t2t.logger = logger + } +} diff --git a/tools/oprt2/pkg/attunehooks/authenticators/mtls/proxy/tcp2tls.go b/tools/oprt2/pkg/attunehooks/authenticators/mtls/proxy/tcp2tls.go new file mode 100644 index 00000000..f70e125d --- /dev/null +++ b/tools/oprt2/pkg/attunehooks/authenticators/mtls/proxy/tcp2tls.go @@ -0,0 +1,249 @@ +/* + * Copyright 2025 Gravitational, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package proxy + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "log/slog" + "net" + "sync" + + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/attunehooks/authenticators/mtls/certprovider" + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/ctxcopy" + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/logging" +) + +// TCP2TLS accepts TCP connections and wraps them in a (m)TLS tunnel to a specific destination. +type TCP2TLS struct { + listeningAddress *net.TCPAddr + started chan struct{} + closeStartedChan func() // Used to ensure that `close(started)` is only called once + destinationHost string + destinationPort string + clientCertProvider certprovider.Provider + logger *slog.Logger +} + +// NewTCP2TLS creates a new TCP2TLS proxy instance. +func NewTCP2TLS(destinationHost, destinationPort string, opts ...TCP2TLSOption) *TCP2TLS { + started := make(chan struct{}) + t2t := &TCP2TLS{ + started: started, + closeStartedChan: sync.OnceFunc(func() { close(started) }), + destinationHost: destinationHost, + destinationPort: destinationPort, + logger: logging.DiscardLogger, + } + + for _, opt := range opts { + opt(t2t) + } + + return t2t +} + +// Close closes the proxy. +func (t2t *TCP2TLS) Close() { + t2t.closeStartedChan() +} + +// ListenAndServe starts the proxy. +func (t2t *TCP2TLS) ListenAndServe(ctx context.Context) error { + proxyListener, err := t2t.listen() + if err != nil { + t2t.closeStartedChan() + return fmt.Errorf("proxy listener failed to start: %w", err) + } + + // The documentation for net.TCPListener states that Addr will always return a *net.TCPAddr + t2t.listeningAddress = proxyListener.Addr().(*net.TCPAddr) + t2t.logger.InfoContext(ctx, "Proxy is listening", "localAddress", t2t.listeningAddress.String()) + t2t.closeStartedChan() + + // Track incoming connectionsInProgress to ensure that they are all completed prior to returning (and cleanup) + var connectionsInProgress sync.WaitGroup + + // Close the listener when the context is cancelled + listenerCloseErr := make(chan error) + go func() { + defer close(listenerCloseErr) + <-ctx.Done() + + connectionsInProgress.Wait() // Don't close the socket until all active connections are complete + listenerCloseErr <- proxyListener.Close() + }() + + // Accept and proxy connections + for { + clientConnection, err := proxyListener.AcceptTCP() + if err != nil { + // Ignore errors caused by context cancellation, and stop accepting new connections + if t2t.isServeStopping(ctx) { + break + } + + t2t.logger.WarnContext(ctx, "Failed to accept client connection", "error", err.Error()) + continue + } + t2t.logger.DebugContext(ctx, "Accepted connection", "clientAddress", clientConnection.RemoteAddr().String()) + + connectionsInProgress.Go(func() { + if err := t2t.proxyConnection(ctx, clientConnection); err != nil { + t2t.logger.WarnContext(ctx, "An error occurred while handling proxy connection", + "clientAddress", clientConnection.RemoteAddr().String(), + "destinationHost", t2t.destinationHost, + "destinationPort", t2t.destinationPort, + "error", err.Error(), + ) + } + }) + } + + // This will block until all connections in progress are completed + if err := <-listenerCloseErr; err != nil { + return fmt.Errorf("failed to close proxy listener (socket leak): %w", err) + } + return nil +} + +func (t2t *TCP2TLS) getTLSConfig(ctx context.Context) (*tls.Config, error) { + config := &tls.Config{ + ServerName: t2t.destinationHost, + } + + if t2t.clientCertProvider != nil { + cert, err := t2t.clientCertProvider.GetClientCertificate(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get client certificate from provider: %w", err) + } + + config.Certificates = append(config.Certificates, *cert) + } + + return config, nil +} + +// listen creates a new localhost TCP socket to listen for incoming connections. +func (t2t *TCP2TLS) listen() (*net.TCPListener, error) { + addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort("127.0.0.1", "0")) + if err != nil { + return nil, fmt.Errorf("failed to bind to a free TCP port on loopback address: %w", err) + } + + proxyListener, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to listen for TCP connections on %q: %w", addr.AddrPort().String(), err) + } + + return proxyListener, nil +} + +func (t2t *TCP2TLS) proxyConnection(ctx context.Context, clientConnection *net.TCPConn) (err error) { + defer func() { + closeErr := clientConnection.Close() + if closeErr != nil { + closeErr = fmt.Errorf("failed to close client connection (connection leak): %w", closeErr) + } + err = errors.Join(err, closeErr) + }() + + // Build a new TLS dialer + config, err := t2t.getTLSConfig(ctx) + if err != nil { + return fmt.Errorf("failed to build TLS config: %w", err) + } + + destinationDialer := &tls.Dialer{ + Config: config, + } + + // Establish a new TLS connection to the reverse proxy. Don't share connections to avoid + // connection sharing bugs + destinationAddr := net.JoinHostPort(t2t.destinationHost, t2t.destinationPort) + destinationConnection, err := destinationDialer.DialContext(ctx, "tcp", destinationAddr) + if err != nil { + return fmt.Errorf("failed to connect to destination address: %w", err) + } + defer func() { + closeErr := destinationConnection.Close() + if closeErr != nil { + closeErr = fmt.Errorf("failed to close destination connection (connection leak): %w", closeErr) + } + err = errors.Join(err, closeErr) + }() + + // Docs state that this will always be a `*tls.Conn` + destinationTLSConnection := destinationConnection.(*tls.Conn) + deadline, ok := ctx.Deadline() + if ok { + if err := destinationTLSConnection.SetDeadline(deadline); err != nil { + return fmt.Errorf("failed to set TLS connection deadline: %w", err) + } + } + + // Read and write until the connection is closed + // This will not allocate a buffer because the TCP connection implements + // io.ReadFrom _and_ io.WriteTo. + readFromDestinationErr := ctxcopy.CopyConcurrently(ctx, destinationTLSConnection, clientConnection) + writeFromDestinationErr := ctxcopy.CopyConcurrently(ctx, clientConnection, destinationTLSConnection) + + // Wait for all reads and writes to complete + readErr := <-readFromDestinationErr + writeErr := <-writeFromDestinationErr + + if !t2t.isServeStopping(ctx) { + if readErr != nil { + readErr = fmt.Errorf("failed to read all data from the destination stream") + } + + if writeErr != nil { + writeErr = fmt.Errorf("failed to write all data to the destination stream") + } + + return errors.Join(readErr, writeErr) + } + + return nil +} + +func (t2t *TCP2TLS) isServeStopping(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + return false + } +} + +// Gets the address that the proxy is listening on. This will block until the proxy begins listening, +// the listener errors, or the context is cancelled. +func (t2t *TCP2TLS) GetAddress(ctx context.Context) (*net.TCPAddr, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-t2t.started: + } + + if t2t.listeningAddress == nil { + return nil, fmt.Errorf("failed to get listening address (listen failed)") + } + + return t2t.listeningAddress, nil +} diff --git a/tools/oprt2/pkg/attunehooks/authenticators/token/authenticator.go b/tools/oprt2/pkg/attunehooks/authenticators/token/authenticator.go new file mode 100644 index 00000000..6d6ba27b --- /dev/null +++ b/tools/oprt2/pkg/attunehooks/authenticators/token/authenticator.go @@ -0,0 +1,71 @@ +/* + * Copyright 2025 Gravitational, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package token + +import ( + "context" + "fmt" + "net/url" + "os/exec" + + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/commandrunner" +) + +// Authenticator authenticates with the Attune control plane with Attune-supported token authentication. +type Authenticator struct { + attuneEndpoint string + token string +} + +var _ commandrunner.Hook = (*Authenticator)(nil) + +// NewAuthenticator creates a new Authenticator +func NewAuthenticator(attuneEndpoint, token string) (*Authenticator, error) { + attuneEndpointURL, err := url.Parse(attuneEndpoint) + if err != nil { + return nil, fmt.Errorf("invalid Attune endpoint URL %q: %w", attuneEndpoint, err) + } + + if attuneEndpointURL.Hostname() == "" { + return nil, fmt.Errorf("the Attune endpoint URL is missing a hostname: %q", attuneEndpoint) + } + + return &Authenticator{ + attuneEndpoint: attuneEndpoint, + token: token, + }, nil +} + +// Name is the name of the authenticator. +func (a *Authenticator) Name() string { + return "token" +} + +// Command adds token authentication to the Attune command. +func (a *Authenticator) Command(_ context.Context, cmd *exec.Cmd) error { + cmd.Env = append( + cmd.Env, + "ATTUNE_API_TOKEN="+a.token, + "ATTUNE_API_ENDPOINT="+a.attuneEndpoint, + ) + return nil +} + +// Close closes the authenticator. +func (a *Authenticator) Close(_ context.Context) error { + return nil +} diff --git a/tools/oprt2/pkg/attunehooks/authenticators/token/config.go b/tools/oprt2/pkg/attunehooks/authenticators/token/config.go new file mode 100644 index 00000000..784a5a6e --- /dev/null +++ b/tools/oprt2/pkg/attunehooks/authenticators/token/config.go @@ -0,0 +1,33 @@ +/* + * Copyright 2025 Gravitational, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package token + +import ( + "fmt" + + "github.com/gravitational/shared-workflows/tools/oprt2/pkg/config" +) + +// Config creates a new token authenticator from the provided config. +func FromConfig(config *config.TokenAuthenticator) (*Authenticator, error) { + hook, err := NewAuthenticator(config.Endpoint, config.Token) + if err != nil { + return nil, fmt.Errorf("failed to build token authenticator: %w", err) + } + + return hook, nil +} diff --git a/tools/oprt2/pkg/config/config.go b/tools/oprt2/pkg/config/config.go index 689522f4..2022bbee 100644 --- a/tools/oprt2/pkg/config/config.go +++ b/tools/oprt2/pkg/config/config.go @@ -16,9 +16,34 @@ package config +// CertificateProvider defines the configuration for providing client certificates. +type CertificateProvider struct { + // Not implemented +} + +// MTLSAuthenticator defines the configuration for authenticating with Attune via mTLS. +type MTLSAuthenticator struct { + // Endpoint is the Attune control plane endpoint to use. + Endpoint string + // CertificateSource is the certificate provider to use for client authentication. + CertificateSource CertificateProvider +} + +// TokenAuthenticator defines the configuration for authenticating with Attune via tokens. +type TokenAuthenticator struct { + // Endpoint is the Attune control plane endpoint to use. + Endpoint string + // Token is the token to provide the control plane. + Token string +} + // Authenticator defines Attune authentication configuration. +// Only one field may be specified. type Authenticator struct { - // Not implemented + // Token authenticates with Attune via a Token. + Token *TokenAuthenticator + // MTLS authenticates with Attune via mTLS. + MTLS *MTLSAuthenticator } // S3FileManager is a file manager that uses files stored in a remote S3 bucket. diff --git a/tools/oprt2/pkg/ctxcopy/contextualcopy.go b/tools/oprt2/pkg/ctxcopy/contextualcopy.go new file mode 100644 index 00000000..dc312a87 --- /dev/null +++ b/tools/oprt2/pkg/ctxcopy/contextualcopy.go @@ -0,0 +1,60 @@ +/* + * Copyright 2025 Gravitational, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ctxcopy + +import ( + "context" + "io" +) + +// Copy is a contextual version of [io.Copy]. If the context is cancelled, calls to +// read and write functions will error, stopping the copy. +// This retains all properties of [io.Copy], including support for [io.WriterTo] and +// [io.ReaderFrom]. +func Copy(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) { + if dstWriterTo, ok := dst.(writerToWriter); ok { + dst = newContextWriterTo(ctx, dstWriterTo) + } else { + dst = newContextWriter(ctx, dst) + } + + if srcReaderFrom, ok := src.(readerFromReader); ok { + src = newContextReaderFrom(ctx, srcReaderFrom) + } else { + src = newContextReader(ctx, src) + } + + return io.Copy(dst, src) +} + +// copyConcurrently copies all context from src to dst without blocking. When the copy is complete +// (or it fails), an error will be sent and the channel will close. +// This calls [Copy], so it inherits all properties of [Copy]. +// The returned channel will be closed when the internal goroutine exits, so it should not be closed +// by the caller. +func CopyConcurrently(ctx context.Context, dst io.Writer, src io.Reader) <-chan error { + done := make(chan error, 1) + + go func() { + defer close(done) + + _, err := Copy(ctx, dst, src) + done <- err + }() + + return done +} diff --git a/tools/oprt2/pkg/ctxcopy/io.go b/tools/oprt2/pkg/ctxcopy/io.go new file mode 100644 index 00000000..03a369ea --- /dev/null +++ b/tools/oprt2/pkg/ctxcopy/io.go @@ -0,0 +1,140 @@ +/* + * Copyright 2025 Gravitational, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ctxcopy + +import ( + "context" + "io" +) + +// Implement Reader, Writer, ReaderFrom, WriterFrom interfaces via wrappers that are context cancellable. + +// Readers + +type contextReader struct { + ctx context.Context + reader io.Reader +} + +func newContextReader(ctx context.Context, reader io.Reader) *contextReader { + return &contextReader{ + ctx: ctx, + reader: reader, + } +} + +var _ io.Reader = &contextReader{} + +// Read satisfies the [io.Read] interface. +func (cr *contextReader) Read(p []byte) (int, error) { + select { + case <-cr.ctx.Done(): + return 0, cr.ctx.Err() + default: + } + + return cr.reader.Read(p) +} + +type readerFromReader interface { + io.Reader + io.ReaderFrom +} + +type contextReaderFrom struct { + *contextReader + reader io.ReaderFrom +} + +var _ io.Reader = &contextReaderFrom{} +var _ io.ReaderFrom = &contextReaderFrom{} + +func newContextReaderFrom(ctx context.Context, reader readerFromReader) *contextReaderFrom { + return &contextReaderFrom{ + contextReader: newContextReader(ctx, reader), + reader: reader, + } +} + +// ReadFrom satisfies the [io.ReaderFrom] interface +func (crf *contextReaderFrom) ReadFrom(r io.Reader) (int64, error) { + select { + case <-crf.ctx.Done(): + return 0, crf.ctx.Err() + default: + } + + return crf.reader.ReadFrom(r) +} + +// Writers + +type contextWriter struct { + ctx context.Context + writer io.Writer +} + +func newContextWriter(ctx context.Context, writer io.Writer) *contextWriter { + return &contextWriter{ + ctx: ctx, + writer: writer, + } +} + +var _ io.Writer = &contextWriter{} + +// Write satisfies the [io.Write] interface. +func (cw *contextWriter) Write(p []byte) (int, error) { + select { + case <-cw.ctx.Done(): + return 0, cw.ctx.Err() + default: + } + + return cw.writer.Write(p) +} + +type writerToWriter interface { + io.Writer + io.WriterTo +} + +type contextWriterTo struct { + *contextWriter + writer io.WriterTo +} + +var _ io.Writer = &contextWriterTo{} +var _ io.WriterTo = &contextWriterTo{} + +func newContextWriterTo(ctx context.Context, writer writerToWriter) *contextWriterTo { + return &contextWriterTo{ + contextWriter: newContextWriter(ctx, writer), + writer: writer, + } +} + +// WriteTo satisfies the [io.WriterTo] interface. +func (cwt *contextWriterTo) WriteTo(w io.Writer) (int64, error) { + select { + case <-cwt.ctx.Done(): + return 0, cwt.ctx.Err() + default: + } + + return cwt.writer.WriteTo(w) +} diff --git a/tools/oprt2/pkg/ospackages/publishers/attune/config.go b/tools/oprt2/pkg/ospackages/publishers/attune/config.go index e4ebef92..3ad3c0b1 100644 --- a/tools/oprt2/pkg/ospackages/publishers/attune/config.go +++ b/tools/oprt2/pkg/ospackages/publishers/attune/config.go @@ -41,7 +41,7 @@ var _ ospackages.APTPublisher = (*publisherFromConfig)(nil) // FromConfig creates a new Attune publisher instance from the provided config and Attune runner. func FromConfig(ctx context.Context, config config.AttuneAPTPackagePublisher, logger *slog.Logger) (ospackages.APTPublisher, error) { - authenticator, err := authenticators.FromConfig(config.Authentication) + authenticator, err := authenticators.FromConfig(ctx, config.Authentication, logger) if err != nil { return nil, fmt.Errorf("failed to create Attune authenticator: %w", err) }