diff --git a/README.md b/README.md index bd744e5..7e1c094 100644 --- a/README.md +++ b/README.md @@ -1,73 +1,3 @@ # morecontext -context.Context helpers to make your life slightly easier. - --- - import "github.com/gopuff/morecontext" - - -## Usage - -#### func ForSignals - -```go -func ForSignals(sigs ...os.Signal) context.Context -``` -ForSignals returns a context.Context that will be cancelled if the given signals -(or SIGTERM and SIGINT by default, if none are passed) are received by the -process. - -#### func WithMessage - -```go -func WithMessage(ctx context.Context, format string, args ...interface{}) context.Context -``` -WithMessage is a helper for creating a MessageContext instance, with more -context about exactly which context cancellation occurred. - -#### type MessageContext - -```go -type MessageContext struct { - context.Context - Message string -} -``` - -MessageContext is a context.Context wrapper that is intended to include some -extra metadata about which context was cancelled. Useful for distinguishing e.g. -http request cancellation vs deadline vs process sigterm handling. - -#### func (MessageContext) Err - -```go -func (c MessageContext) Err() error -``` -Err returns an error with extra context messaging - -#### type MessageError - -```go -type MessageError struct { - Message string - Original error -} -``` - -MessageError implements error but separates the message from the original error -in case you want it. - -#### func (*MessageError) Error - -```go -func (c *MessageError) Error() string -``` -Error implements error and prints out the message plus the metadata about which -context was cancelled. - -#### func (*MessageError) Unwrap - -```go -func (c *MessageError) Unwrap() error -``` -Unwrap supports errors.Is/errors.As +context.Context helpers and implementations to make your life slightly easier. diff --git a/reason.go b/reason.go new file mode 100644 index 0000000..b66d13b --- /dev/null +++ b/reason.go @@ -0,0 +1,43 @@ +package morecontext + +import ( + "context" + "fmt" +) + +// A CancelReasonContext is a cancellable context whose cancellation requires a +// reason, and whose Err() message will include the reason, if this context was +// the one cancelled. +type CancelReasonContext struct { + context.Context + Reason error +} + +var _ context.Context = &CancelReasonContext{} + +// If this context was the one cancelled, we return the reason it was +// cancelled. If the underlying context was cancelled, then we return its +// message. +func (crc CancelReasonContext) Err() error { + if crc.Reason != nil { + return fmt.Errorf("context cancelled: %w", crc.Reason) + } + return crc.Context.Err() +} + +// WithCancelReason returns a context implementation that must be cancelled, +// and whose cancellation must include a Reason error that will be returned by +// any calls to `Err`. +func WithCancelReason(ctx context.Context) (*CancelReasonContext, func(error)) { + ctx, cancel := context.WithCancel(ctx) + crc := CancelReasonContext{ + Context: ctx, + } + + c := func(err error) { + crc.Reason = err + cancel() + } + + return &crc, c +} diff --git a/reason_test.go b/reason_test.go new file mode 100644 index 0000000..502d28b --- /dev/null +++ b/reason_test.go @@ -0,0 +1,33 @@ +package morecontext + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReasonContext(t *testing.T) { + asrt := assert.New(t) + ctx, cancel := WithCancelReason(context.Background()) + + cancel(fmt.Errorf("foo bar baz")) + + err := ctx.Err() + asrt.Error(err) + asrt.Contains(err.Error(), "foo bar baz") +} + +func TestReasonParentCancel(t *testing.T) { + asrt := assert.New(t) + ctx, c1 := context.WithCancel(context.Background()) + ctx, cancel := WithCancelReason(ctx) + defer cancel(nil) + + c1() + + err := ctx.Err() + asrt.Error(err) + asrt.NotContains(err.Error(), "foo bar baz") +} diff --git a/signal.go b/signal.go index cc4bda6..d470933 100644 --- a/signal.go +++ b/signal.go @@ -5,63 +5,33 @@ import ( "fmt" "os" "os/signal" - "sync" "syscall" ) -// sigCtx is a context that will be cancelled if certain signals are received. -// Its Err will include details if this is the reason it was cancelled. -type sigCtx struct { - context.Context - - exitSignal os.Signal - m sync.Mutex -} - -// Err implements context.Context.Err but includes the os.Signal that caused -// the context cancellation. -func (sc *sigCtx) Err() error { - sc.m.Lock() - defer sc.m.Unlock() - err := sc.Context.Err() - - return &MessageError{ - Message: fmt.Sprintf("context cancelled: got signal %s", sc.exitSignal.String()), - Original: err, - } -} - // ForSignals returns a context.Context that will be cancelled if the given // signals (or SIGTERM and SIGINT by default, if none are passed) are received // by the process. func ForSignals(sigs ...os.Signal) context.Context { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := WithCancelReason(context.Background()) - // If no signals are returnd we will use a sensible default set. + // If no signals are included we will use a sensible default set. if len(sigs) == 0 { sigs = []os.Signal{syscall.SIGTERM, syscall.SIGINT} } - sc := &sigCtx{Context: ctx} - ch := make(chan os.Signal, 2) signal.Notify(ch, sigs...) go func() { - i := 0 for sig := range ch { i++ if i > 1 { os.Exit(1) } - sc.m.Lock() - sc.exitSignal = sig - sc.m.Unlock() - - cancel() + cancel(fmt.Errorf("got signal %s", sig)) } }() - return sc + return ctx }