1515package config
1616
1717import (
18+ "context"
1819 "crypto/tls"
1920 "crypto/x509"
2021 "net/http"
2122 "os"
23+ "time"
2224
2325 "github.com/spf13/pflag"
2426 "go.thethings.network/lorawan-stack-migrate/pkg/source"
2527 "go.thethings.network/lorawan-stack/v3/pkg/types"
2628 "google.golang.org/grpc"
2729 "google.golang.org/grpc/credentials"
30+ "google.golang.org/grpc/credentials/insecure"
2831)
2932
33+ const dialTimeout = 10 * time .Second
34+
3035func New () (* Config , * pflag.FlagSet ) {
3136 var (
3237 config = & Config {}
@@ -100,13 +105,6 @@ func (c *Config) Initialize() error {
100105 if err := c .JoinEUI .UnmarshalText ([]byte (c .joinEUI )); err != nil {
101106 return errInvalidJoinEUI .WithAttributes ("join_eui" , c .joinEUI )
102107 }
103-
104- if ! c .insecure || c .caPath != "" {
105- if err := setCustomCA (c .caPath ); err != nil {
106- return err
107- }
108- }
109-
110108 err := c .dialGRPC (
111109 grpc .FailOnNonTempDialError (true ),
112110 grpc .WithBlock (),
@@ -120,32 +118,45 @@ func (c *Config) Initialize() error {
120118}
121119
122120func (c * Config ) dialGRPC (opts ... grpc.DialOption ) error {
123- if c .insecure && c .caPath == "" {
124- opts = append (opts , grpc .WithInsecure ())
125- }
126- if tls := http .DefaultTransport .(* http.Transport ).TLSClientConfig ; tls != nil {
127- opts = append (opts , grpc .WithTransportCredentials (credentials .NewTLS (tls )))
121+ if c .insecure {
122+ opts = append (opts , grpc .WithTransportCredentials (insecure .NewCredentials ()))
123+ } else {
124+ tlsConfig , err := generateTLSConfig (c .caPath )
125+ if err != nil {
126+ return err
127+ }
128+ opts = append (opts , grpc .WithTransportCredentials (credentials .NewTLS (tlsConfig )))
128129 }
130+
131+ ctx , cancel := context .WithTimeout (context .Background (), dialTimeout )
132+ defer cancel ()
133+
129134 var err error
130- c .ClientConn , err = grpc .Dial ( c .url , opts ... )
135+ c .ClientConn , err = grpc .DialContext ( ctx , c .url , opts ... )
131136 if err != nil {
132137 return err
133138 }
134139 return nil
135140}
136141
137- func setCustomCA (path string ) error {
138- pemBytes , err := os .ReadFile (path )
139- if err != nil {
140- return err
142+ // GenerateTLSConfig generates a TLS configuration.
143+ func generateTLSConfig (caPath string ) (cfg * tls.Config , err error ) {
144+ cfg = http .DefaultTransport .(* http.Transport ).TLSClientConfig
145+ if cfg == nil {
146+ cfg = & tls.Config {}
141147 }
142- rootCAs := http .DefaultTransport .(* http.Transport ).TLSClientConfig .RootCAs
143- if rootCAs == nil {
144- if rootCAs , err = x509 .SystemCertPool (); err != nil {
145- rootCAs = x509 .NewCertPool ()
148+ if cfg .RootCAs == nil {
149+ if cfg .RootCAs , err = x509 .SystemCertPool (); err != nil {
150+ cfg .RootCAs = x509 .NewCertPool ()
146151 }
147152 }
148- rootCAs .AppendCertsFromPEM (pemBytes )
149- http .DefaultTransport .(* http.Transport ).TLSClientConfig = & tls.Config {RootCAs : rootCAs }
150- return nil
153+ if caPath == "" {
154+ return cfg , nil
155+ }
156+ pemBytes , err := os .ReadFile (caPath )
157+ if err != nil {
158+ return nil , err
159+ }
160+ cfg .RootCAs .AppendCertsFromPEM (pemBytes )
161+ return cfg , nil
151162}
0 commit comments