Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,26 @@ func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, strin

func matchAcceptEncoding(accepted []string, encoding string) bool {
for _, v := range accepted {
if strings.Contains(v, encoding) {
return true
// Split off any parameters (e.g. ";q=0.5") and trim whitespace
name, params, _ := strings.Cut(strings.TrimSpace(v), ";")
name = strings.TrimSpace(name)

if !strings.EqualFold(name, encoding) {
continue
}

// Check for explicit q=0, which means the client refused this encoding
if params != "" {
params = strings.TrimSpace(params)
if strings.HasPrefix(params, "q=") {
qval := strings.TrimSpace(params[2:])
if qval == "0" || qval == "0." || qval == "0.0" || qval == "0.00" || qval == "0.000" {
return false
}
}
}

return true
}
return false
}
Expand Down
55 changes: 55 additions & 0 deletions middleware/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,61 @@ func TestCompressorWildcards(t *testing.T) {
}
}

// TestMatchAcceptEncoding verifies that Accept-Encoding negotiation uses
// proper token matching rather than substring matching. The current
// implementation uses strings.Contains, which incorrectly matches:
// - "gzip;q=0" as gzip (client explicitly refused gzip)
// - "xgzip" as gzip (not a real encoding, but contains "gzip" as a substring)
func TestMatchAcceptEncoding(t *testing.T) {
tests := []struct {
name string
accepted []string
encoding string
want bool
}{
{
name: "exact match",
accepted: []string{"gzip"},
encoding: "gzip",
want: true,
},
{
name: "q=0 means refused",
accepted: []string{"gzip;q=0"},
encoding: "gzip",
want: false,
},
{
name: "substring should not match",
accepted: []string{"xgzip"},
encoding: "gzip",
want: false,
},
{
name: "encoding with positive q-value",
accepted: []string{"gzip;q=0.5"},
encoding: "gzip",
want: true,
},
{
name: "encoding with whitespace",
accepted: []string{" gzip "},
encoding: "gzip",
want: true,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := matchAcceptEncoding(tc.accepted, tc.encoding)
if got != tc.want {
t.Errorf("matchAcceptEncoding(%v, %q) = %v, want %v",
tc.accepted, tc.encoding, got, tc.want)
}
})
}
}

func testRequestWithAcceptedEncodings(t *testing.T, ts *httptest.Server, method, path string, encodings ...string) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+path, nil)
if err != nil {
Expand Down
15 changes: 14 additions & 1 deletion middleware/recoverer.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func Recoverer(next http.Handler) http.Handler {
PrintPrettyStack(rvr)
}

if r.Header.Get("Connection") != "Upgrade" {
if !headerContainsToken(r.Header, "Connection", "Upgrade") {
w.WriteHeader(http.StatusInternalServerError)
}
}
Expand Down Expand Up @@ -201,3 +201,16 @@ func (s prettyStack) decorateSourceLine(line string, useColor bool, num int) (st

return buf.String(), nil
}

// headerContainsToken checks whether a comma-separated, case-insensitive
// HTTP header contains a specific token (RFC 7230 §3.2.6).
func headerContainsToken(h http.Header, headerName, token string) bool {
for _, v := range h[http.CanonicalHeaderKey(headerName)] {
for _, s := range strings.Split(v, ",") {
if strings.EqualFold(strings.TrimSpace(s), token) {
return true
}
}
}
return false
}
63 changes: 63 additions & 0 deletions middleware/recoverer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,69 @@ func TestRecoverer(t *testing.T) {
t.Fatal("First func call line should start with ->.")
}

// TestRecovererUpgradeConnectionDetection verifies that the Recoverer does not
// write a 500 status code when the Connection header indicates an upgrade.
// HTTP headers are case-insensitive (RFC 7230 §3.2) and the Connection header
// can contain multiple comma-separated tokens (e.g. "keep-alive, Upgrade").
// The current implementation only matches the exact string "Upgrade".
func TestRecovererUpgradeConnectionDetection(t *testing.T) {
tests := []struct {
name string
connHeader string
expect500 bool
}{
{
name: "exact Upgrade is not 500",
connHeader: "Upgrade",
expect500: false,
},
{
name: "lowercase upgrade is not 500",
connHeader: "upgrade",
expect500: false,
},
{
name: "Upgrade in token list is not 500",
connHeader: "keep-alive, Upgrade",
expect500: false,
},
{
name: "no Connection header is 500",
connHeader: "",
expect500: true,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
oldRecovererErrorWriter := recovererErrorWriter
defer func() { recovererErrorWriter = oldRecovererErrorWriter }()
recovererErrorWriter = &bytes.Buffer{}

r := chi.NewRouter()
r.Use(Recoverer)
r.Get("/", panickingHandler)

w := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
if tc.connHeader != "" {
req.Header.Set("Connection", tc.connHeader)
}

r.ServeHTTP(w, req)

got500 := w.Code == http.StatusInternalServerError
if got500 != tc.expect500 {
t.Errorf("Connection: %q — got status %d, expected 500=%v",
tc.connHeader, w.Code, tc.expect500)
}
})
}
}

func TestRecovererAbortHandler(t *testing.T) {
defer func() {
rcv := recover()
Expand Down