Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ type Context struct {

methodsAllowed []methodTyp // allowed methods in case of a 405
methodNotAllowed bool
poisoned bool // prevents returning to pool
}

// PreventReuse marks the Context as being unable to be returned to the pool
// and allowing it to be safely used outside of the running request.
func (x *Context) PreventReuse() {
x.poisoned = true
}

// Reset a routing context to its initial state.
Expand Down
5 changes: 4 additions & 1 deletion mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ func (mx *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// Serve the request and once its done, put the request context back in the sync pool
mx.handler.ServeHTTP(w, r)
mx.pool.Put(rctx)

if !rctx.poisoned {
mx.pool.Put(rctx)
}
}

// Use appends a middleware handler to the Mux middleware stack.
Expand Down
23 changes: 23 additions & 0 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1980,6 +1980,29 @@ func TestServerBaseContext(t *testing.T) {
}
}

func TestServerPoisonedContext(t *testing.T) {
r := NewRouter()
var firstContext *Context
r.Get("/first", func(w http.ResponseWriter, r *http.Request) {
firstContext, _ = r.Context().Value(RouteCtxKey).(*Context)
firstContext.PreventReuse()
})
r.Get("/second", func(w http.ResponseWriter, r *http.Request) {
context, _ := r.Context().Value(RouteCtxKey).(*Context)
if context == firstContext {
t.Fatalf("expected to get a fresh context instance for each request")
}
})

ts := httptest.NewUnstartedServer(r)
ts.Start()

defer ts.Close()

testRequest(t, ts, "GET", "/first", nil)
testRequest(t, ts, "GET", "/second", nil)
}

func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
Expand Down