diff --git a/middleware/request_id.go b/middleware/request_id.go index e1d4ccb7..de139b9b 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -15,10 +15,10 @@ import ( ) // Key to use when setting the request ID. -type ctxKeyRequestID int +type ctxKeyRequestID any -// RequestIDKey is the key that holds the unique request ID in a request context. -const RequestIDKey ctxKeyRequestID = 0 +// requestIDKey is the key that holds the unique request ID in a request context. +var requestIDKey ctxKeyRequestID = 0 // RequestIDHeader is the name of the HTTP Header which contains the request id. // Exported so that it can be changed by developers @@ -59,6 +59,30 @@ func init() { prefix = fmt.Sprintf("%s/%s", hostname, b64[0:10]) } +// RequestIDWithCustomKey is a middleware that injects a request ID into the context of each +// request, storing it under both the default RequestIDKey and the provided custom key. +// This allows retrieving the request ID via GetReqID(ctx) as well as via the custom key. +// Panics if requestIDKey is empty. +func RequestIDWithCustomKey(reqIDKey string) func(http.Handler) http.Handler { + if reqIDKey == "" { + panic("chi/middleware: RequestIDWithCustomKey expects a non-empty key") + } + + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + // Set the request ID under the custom key as well as the default RequestIDKey + // to permit to get the request ID via GetReqID(ctx) as well as via the custom key. + requestIDKey = reqIDKey + myid := reqid.Add(1) + requestID := fmt.Sprintf("%s-%06d", prefix, myid) + ctx = context.WithValue(ctx, reqIDKey, requestID) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) + } +} + // RequestID is a middleware that injects a request ID into the context of each // request. A request ID is a string of the form "host.example.com/random-0001", // where "random" is a base62 random string that uniquely identifies this go @@ -72,7 +96,7 @@ func RequestID(next http.Handler) http.Handler { myid := reqid.Add(1) requestID = fmt.Sprintf("%s-%06d", prefix, myid) } - ctx = context.WithValue(ctx, RequestIDKey, requestID) + ctx = context.WithValue(ctx, requestIDKey, requestID) next.ServeHTTP(w, r.WithContext(ctx)) } return http.HandlerFunc(fn) @@ -84,7 +108,7 @@ func GetReqID(ctx context.Context) string { if ctx == nil { return "" } - if reqID, ok := ctx.Value(RequestIDKey).(string); ok { + if reqID, ok := ctx.Value(requestIDKey).(string); ok { return reqID } return "" diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index cf07f185..ed682659 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -69,3 +69,76 @@ func TestRequestID(t *testing.T) { } } } + +func TestRequestIDWithCustomKey(t *testing.T) { + tests := map[string]struct { + customKey string + request func() *http.Request + expectPanic bool + }{ + "Sets request ID under custom key": { + "x-custom-id", + func() *http.Request { + req, _ := http.NewRequest("GET", "/", nil) + return req + }, + false, + }, + "Custom key value matches GetReqID": { + "x-trace-id", + func() *http.Request { + req, _ := http.NewRequest("GET", "/", nil) + return req + }, + false, + }, + "Panics on empty key": { + "", + func() *http.Request { + req, _ := http.NewRequest("GET", "/", nil) + return req + }, + true, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.expectPanic { + defer func() { + if r := recover(); r == nil { + t.Fatalf("[%s] expected panic but did not get one", name) + } + }() + RequestIDWithCustomKey(test.customKey) + return + } + + var gotDefault string + var gotCustom string + + w := httptest.NewRecorder() + + r := chi.NewRouter() + r.Use(RequestIDWithCustomKey(test.customKey)) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + gotDefault = GetReqID(r.Context()) + if v, ok := r.Context().Value(test.customKey).(string); ok { + gotCustom = v + } + w.Write([]byte(gotDefault)) + }) + r.ServeHTTP(w, test.request()) + + if gotDefault == "" { + t.Fatalf("[%s] expected default RequestIDKey to be set in context", name) + } + if gotCustom == "" { + t.Fatalf("[%s] expected custom key %q to be set in context", name, test.customKey) + } + if gotDefault != gotCustom { + t.Fatalf("[%s] default id %q and custom id %q should be equal", name, gotDefault, gotCustom) + } + }) + } +}