Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions codegen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,23 @@ func (b *Build) Implementation() *Generator {
// Hooks defines a set of method interceptors for methods included in
// http.ResponseWriter as well as some others. You can think of them as
// middleware for the function calls they target. See Wrap for more details.
//
// For each method, the exact matching hook takes precedence. For example,
// WriteString calls the WriteString hook when it is configured, even if a
// Write hook is also configured. If the exact hook is not configured, most
// methods call through to the underlying ResponseWriter directly.
//
// Two compatibility fallbacks preserve the behavior users had before Wrap
// learned about newer optional interfaces:
// - If the underlying ResponseWriter implements io.StringWriter and
// WriteString is called, but only the Write hook is configured, WriteString
// is routed through the Write hook with []byte(s). If neither hook is
// configured, WriteString calls the underlying WriteString method directly.
// - If the underlying ResponseWriter implements both http.Flusher and
// FlushError, and FlushError is called, but only the Flush hook is
// configured, FlushError is routed through the Flush hook while preserving
// the error returned by the underlying FlushError method. If neither hook is
// configured, FlushError calls the underlying FlushError method directly.
type Hooks struct {
`)
for _, iface := range ifaces {
Expand Down Expand Up @@ -113,6 +130,17 @@ type Hooks struct {
for _, fn := range iface.Funcs {
g.Printf("if hooks.%s != nil {\n", fn.Name)
g.Printf("state.%s = hooks.%s(t%d.%s)\n", fieldName(fn.Name), fn.Name, i, fn.Name)
if fn.Name == "FlushError" {
// http.ResponseController.Flush prefers FlushError over Flush.
// Preserve existing Flush hooks when wrapping writers that expose both.
g.Printf("} else if state.flush != nil {\n")
g.Printf("state.flushError = func() (err error) { hooks.Flush(func() { err = t%d.FlushError() })(); return err }\n", i)
} else if fn.Name == "WriteString" {
// io.WriteString prefers WriteString over Write. Preserve existing Write
// hooks when wrapping writers that expose both methods.
g.Printf("} else if state.write != nil {\n")
g.Printf("state.writeString = func(s string) (int, error) { return state.write([]byte(s)) }\n")
}
g.Printf("}\n")
}
g.Printf("}\n")
Expand Down
21 changes: 21 additions & 0 deletions wrap_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 60 additions & 0 deletions wrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,72 @@ package httpsnoop

import (
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
)

type flushErrorResponseWriter struct {
h http.Header
err error
}

func (w *flushErrorResponseWriter) Header() http.Header {
if w.h == nil {
w.h = http.Header{}
}
return w.h
}

func (w *flushErrorResponseWriter) Write(p []byte) (int, error) { return len(p), nil }
func (w *flushErrorResponseWriter) WriteHeader(code int) {}
func (w *flushErrorResponseWriter) Flush() {}
Comment thread
felixge marked this conversation as resolved.
func (w *flushErrorResponseWriter) FlushError() error { return w.err }

func TestWrap_preservesWriteHookForWriteString(t *testing.T) {
var got string
w := Wrap(httptest.NewRecorder(), Hooks{
Write: func(next WriteFunc) WriteFunc {
return func(p []byte) (int, error) {
got = string(p)
return next(p)
}
},
})

if _, ok := w.(io.StringWriter); !ok {
t.Fatal("wrapped writer should expose io.StringWriter")
}
if _, err := io.WriteString(w, "hello"); err != nil {
t.Fatal(err)
}
if got != "hello" {
t.Fatalf("Write hook saw %q, want %q", got, "hello")
}
}

func TestWrap_preservesFlushHookForFlushError(t *testing.T) {
flushed := false
wantErr := errors.New("flush failed")
w := Wrap(&flushErrorResponseWriter{err: wantErr}, Hooks{
Flush: func(next FlushFunc) FlushFunc {
return func() {
flushed = true
next()
}
},
})

if err := http.NewResponseController(w).Flush(); !errors.Is(err, wantErr) {
t.Fatalf("got err %v, want %v", err, wantErr)
}
if !flushed {
t.Fatal("Flush hook was not called")
}
}

func TestWrap_integration(t *testing.T) {
tests := []struct {
Name string
Expand Down
Loading