diff --git a/context.go b/context.go index 82220730..fc01e106 100644 --- a/context.go +++ b/context.go @@ -76,6 +76,13 @@ type Context struct { methodsAllowed []methodTyp // allowed methods in case of a 405 methodNotAllowed bool + poisoned bool // prevents returning to pool +} + +// PreventReuse marks the Context as being unable to be returned to the pool +// and allowing it to be safely used outside of the running request. +func (x *Context) PreventReuse() { + x.poisoned = true } // Reset a routing context to its initial state. diff --git a/mux.go b/mux.go index 71652dd1..af9973b3 100644 --- a/mux.go +++ b/mux.go @@ -88,7 +88,10 @@ func (mx *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Serve the request and once its done, put the request context back in the sync pool mx.handler.ServeHTTP(w, r) - mx.pool.Put(rctx) + + if !rctx.poisoned { + mx.pool.Put(rctx) + } } // Use appends a middleware handler to the Mux middleware stack. diff --git a/mux_test.go b/mux_test.go index d69a6f8a..86b5ba70 100644 --- a/mux_test.go +++ b/mux_test.go @@ -1980,6 +1980,29 @@ func TestServerBaseContext(t *testing.T) { } } +func TestServerPoisonedContext(t *testing.T) { + r := NewRouter() + var firstContext *Context + r.Get("/first", func(w http.ResponseWriter, r *http.Request) { + firstContext, _ = r.Context().Value(RouteCtxKey).(*Context) + firstContext.PreventReuse() + }) + r.Get("/second", func(w http.ResponseWriter, r *http.Request) { + context, _ := r.Context().Value(RouteCtxKey).(*Context) + if context == firstContext { + t.Fatalf("expected to get a fresh context instance for each request") + } + }) + + ts := httptest.NewUnstartedServer(r) + ts.Start() + + defer ts.Close() + + testRequest(t, ts, "GET", "/first", nil) + testRequest(t, ts, "GET", "/second", nil) +} + func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) { req, err := http.NewRequest(method, ts.URL+path, body) if err != nil {