Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions middleware/request_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ import (
)

// Key to use when setting the request ID.
type ctxKeyRequestID int
type ctxKeyRequestID any

// RequestIDKey is the key that holds the unique request ID in a request context.
const RequestIDKey ctxKeyRequestID = 0
// requestIDKey is the key that holds the unique request ID in a request context.
var requestIDKey ctxKeyRequestID = 0

// RequestIDHeader is the name of the HTTP Header which contains the request id.
// Exported so that it can be changed by developers
Expand Down Expand Up @@ -59,6 +59,30 @@ func init() {
prefix = fmt.Sprintf("%s/%s", hostname, b64[0:10])
}

// RequestIDWithCustomKey is a middleware that injects a request ID into the context of each
// request, storing it under both the default RequestIDKey and the provided custom key.
// This allows retrieving the request ID via GetReqID(ctx) as well as via the custom key.
// Panics if requestIDKey is empty.
func RequestIDWithCustomKey(reqIDKey string) func(http.Handler) http.Handler {
if reqIDKey == "" {
panic("chi/middleware: RequestIDWithCustomKey expects a non-empty key")
}

return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Set the request ID under the custom key as well as the default RequestIDKey
// to permit to get the request ID via GetReqID(ctx) as well as via the custom key.
requestIDKey = reqIDKey
myid := reqid.Add(1)
requestID := fmt.Sprintf("%s-%06d", prefix, myid)
ctx = context.WithValue(ctx, reqIDKey, requestID)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
}

// RequestID is a middleware that injects a request ID into the context of each
// request. A request ID is a string of the form "host.example.com/random-0001",
// where "random" is a base62 random string that uniquely identifies this go
Expand All @@ -72,7 +96,7 @@ func RequestID(next http.Handler) http.Handler {
myid := reqid.Add(1)
requestID = fmt.Sprintf("%s-%06d", prefix, myid)
}
ctx = context.WithValue(ctx, RequestIDKey, requestID)
ctx = context.WithValue(ctx, requestIDKey, requestID)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
Expand All @@ -84,7 +108,7 @@ func GetReqID(ctx context.Context) string {
if ctx == nil {
return ""
}
if reqID, ok := ctx.Value(RequestIDKey).(string); ok {
if reqID, ok := ctx.Value(requestIDKey).(string); ok {
return reqID
}
return ""
Expand Down
73 changes: 73 additions & 0 deletions middleware/request_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,76 @@ func TestRequestID(t *testing.T) {
}
}
}

func TestRequestIDWithCustomKey(t *testing.T) {
tests := map[string]struct {
customKey string
request func() *http.Request
expectPanic bool
}{
"Sets request ID under custom key": {
"x-custom-id",
func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
return req
},
false,
},
"Custom key value matches GetReqID": {
"x-trace-id",
func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
return req
},
false,
},
"Panics on empty key": {
"",
func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
return req
},
true,
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
if test.expectPanic {
defer func() {
if r := recover(); r == nil {
t.Fatalf("[%s] expected panic but did not get one", name)
}
}()
RequestIDWithCustomKey(test.customKey)
return
}

var gotDefault string
var gotCustom string

w := httptest.NewRecorder()

r := chi.NewRouter()
r.Use(RequestIDWithCustomKey(test.customKey))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
gotDefault = GetReqID(r.Context())
if v, ok := r.Context().Value(test.customKey).(string); ok {
gotCustom = v
}
w.Write([]byte(gotDefault))
})
r.ServeHTTP(w, test.request())

if gotDefault == "" {
t.Fatalf("[%s] expected default RequestIDKey to be set in context", name)
}
if gotCustom == "" {
t.Fatalf("[%s] expected custom key %q to be set in context", name, test.customKey)
}
if gotDefault != gotCustom {
t.Fatalf("[%s] default id %q and custom id %q should be equal", name, gotDefault, gotCustom)
}
})
}
}