diff --git a/middleware/logger.go b/middleware/logger.go index cff9bd20..2d2553a5 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -160,7 +160,7 @@ func (l *defaultLogEntry) Write(status, bytes int, header http.Header, elapsed t } func (l *defaultLogEntry) Panic(v interface{}, stack []byte) { - PrintPrettyStack(v) + PrintPrettyStackColor(v, l.useColor) } func init() { diff --git a/middleware/recoverer.go b/middleware/recoverer.go index 81342dfa..b7410554 100644 --- a/middleware/recoverer.go +++ b/middleware/recoverer.go @@ -51,10 +51,19 @@ func Recoverer(next http.Handler) http.Handler { // for ability to test the PrintPrettyStack function var recovererErrorWriter io.Writer = os.Stderr +// PrintPrettyStack prints a formatted, coloured stack trace to stderr. func PrintPrettyStack(rvr interface{}) { + PrintPrettyStackColor(rvr, true) +} + +// PrintPrettyStackColor prints a formatted stack trace to stderr. +// When useColor is false ANSI colour codes are suppressed, which is useful +// for terminals that do not support them (e.g. on Windows) or when output is +// being captured. +func PrintPrettyStackColor(rvr interface{}, useColor bool) { debugStack := debug.Stack() s := prettyStack{} - out, err := s.parse(debugStack, rvr) + out, err := s.parse(debugStack, rvr, useColor) if err == nil { recovererErrorWriter.Write(out) } else { @@ -66,9 +75,8 @@ func PrintPrettyStack(rvr interface{}) { type prettyStack struct { } -func (s prettyStack) parse(debugStack []byte, rvr interface{}) ([]byte, error) { +func (s prettyStack) parse(debugStack []byte, rvr interface{}, useColor bool) ([]byte, error) { var err error - useColor := true buf := &bytes.Buffer{} cW(buf, false, bRed, "\n") diff --git a/middleware/wrap_writer.go b/middleware/wrap_writer.go index 367e0fcd..b2de8752 100644 --- a/middleware/wrap_writer.go +++ b/middleware/wrap_writer.go @@ -208,8 +208,10 @@ func (f *http2FancyWriter) Push(target string, opts *http.PushOptions) error { func (f *httpFancyWriter) ReadFrom(r io.Reader) (int64, error) { if f.basicWriter.tee != nil { + // Route through basicWriter.Write so that data is also written to the + // tee writer. basicWriter.Write already increments basicWriter.bytes, + // so we must NOT add n again here (that would double-count). n, err := io.Copy(&f.basicWriter, r) - f.basicWriter.bytes += int(n) return n, err } rf := f.basicWriter.ResponseWriter.(io.ReaderFrom) diff --git a/middleware/wrap_writer_test.go b/middleware/wrap_writer_test.go index 7e8f6ab2..dfbbaa40 100644 --- a/middleware/wrap_writer_test.go +++ b/middleware/wrap_writer_test.go @@ -4,6 +4,7 @@ import ( "bytes" "net/http" "net/http/httptest" + "strings" "testing" ) @@ -84,3 +85,27 @@ func TestBasicWriterDiscardsWritesToOriginalResponseWriter(t *testing.T) { assertEqual(t, 11, wrap.BytesWritten()) }) } + +// TestHttpFancyWriterReadFromByteCountWithTee is a regression test for +// https://github.com/go-chi/chi/issues/1067. +// httpFancyWriter.ReadFrom was adding n to basicWriter.bytes even when the +// write went through basicWriter.Write (which already increments the counter), +// resulting in double-counting the bytes when a Tee writer was set. +func TestHttpFancyWriterReadFromByteCountWithTee(t *testing.T) { + original := &httptest.ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + } + f := &httpFancyWriter{basicWriter: basicWriter{ResponseWriter: original}} + + var teeBuf bytes.Buffer + f.Tee(&teeBuf) + + const input = "hello world" + n, err := f.ReadFrom(strings.NewReader(input)) + assertNoError(t, err) + assertEqual(t, int64(len(input)), n) + // Before the fix, BytesWritten() returned 22 (double-counted). + assertEqual(t, len(input), f.BytesWritten()) + assertEqual(t, []byte(input), teeBuf.Bytes()) +}