Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 1 addition & 71 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,73 +1,3 @@
# morecontext

context.Context helpers to make your life slightly easier.

--
import "github.com/gopuff/morecontext"


## Usage

#### func ForSignals

```go
func ForSignals(sigs ...os.Signal) context.Context
```
ForSignals returns a context.Context that will be cancelled if the given signals
(or SIGTERM and SIGINT by default, if none are passed) are received by the
process.

#### func WithMessage

```go
func WithMessage(ctx context.Context, format string, args ...interface{}) context.Context
```
WithMessage is a helper for creating a MessageContext instance, with more
context about exactly which context cancellation occurred.

#### type MessageContext

```go
type MessageContext struct {
context.Context
Message string
}
```

MessageContext is a context.Context wrapper that is intended to include some
extra metadata about which context was cancelled. Useful for distinguishing e.g.
http request cancellation vs deadline vs process sigterm handling.

#### func (MessageContext) Err

```go
func (c MessageContext) Err() error
```
Err returns an error with extra context messaging

#### type MessageError

```go
type MessageError struct {
Message string
Original error
}
```

MessageError implements error but separates the message from the original error
in case you want it.

#### func (*MessageError) Error

```go
func (c *MessageError) Error() string
```
Error implements error and prints out the message plus the metadata about which
context was cancelled.

#### func (*MessageError) Unwrap

```go
func (c *MessageError) Unwrap() error
```
Unwrap supports errors.Is/errors.As
context.Context helpers and implementations to make your life slightly easier.
43 changes: 43 additions & 0 deletions reason.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package morecontext

import (
"context"
"fmt"
)

// A CancelReasonContext is a cancellable context whose cancellation requires a
// reason, and whose Err() message will include the reason, if this context was
// the one cancelled.
type CancelReasonContext struct {
context.Context
Reason error
}

var _ context.Context = &CancelReasonContext{}

// If this context was the one cancelled, we return the reason it was
// cancelled. If the underlying context was cancelled, then we return its
// message.
func (crc CancelReasonContext) Err() error {
if crc.Reason != nil {
return fmt.Errorf("context cancelled: %w", crc.Reason)
}
return crc.Context.Err()
}

// WithCancelReason returns a context implementation that must be cancelled,
// and whose cancellation must include a Reason error that will be returned by
// any calls to `Err`.
func WithCancelReason(ctx context.Context) (*CancelReasonContext, func(error)) {
ctx, cancel := context.WithCancel(ctx)
crc := CancelReasonContext{
Context: ctx,
}

c := func(err error) {
crc.Reason = err
cancel()
}

return &crc, c
}
33 changes: 33 additions & 0 deletions reason_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package morecontext

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestReasonContext(t *testing.T) {
asrt := assert.New(t)
ctx, cancel := WithCancelReason(context.Background())

cancel(fmt.Errorf("foo bar baz"))

err := ctx.Err()
asrt.Error(err)
asrt.Contains(err.Error(), "foo bar baz")
}

func TestReasonParentCancel(t *testing.T) {
asrt := assert.New(t)
ctx, c1 := context.WithCancel(context.Background())
ctx, cancel := WithCancelReason(ctx)
defer cancel(nil)

c1()

err := ctx.Err()
asrt.Error(err)
asrt.NotContains(err.Error(), "foo bar baz")
}
38 changes: 4 additions & 34 deletions signal.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,63 +5,33 @@ import (
"fmt"
"os"
"os/signal"
"sync"
"syscall"
)

// sigCtx is a context that will be cancelled if certain signals are received.
// Its Err will include details if this is the reason it was cancelled.
type sigCtx struct {
context.Context

exitSignal os.Signal
m sync.Mutex
}

// Err implements context.Context.Err but includes the os.Signal that caused
// the context cancellation.
func (sc *sigCtx) Err() error {
sc.m.Lock()
defer sc.m.Unlock()
err := sc.Context.Err()

return &MessageError{
Message: fmt.Sprintf("context cancelled: got signal %s", sc.exitSignal.String()),
Original: err,
}
}

// ForSignals returns a context.Context that will be cancelled if the given
// signals (or SIGTERM and SIGINT by default, if none are passed) are received
// by the process.
func ForSignals(sigs ...os.Signal) context.Context {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := WithCancelReason(context.Background())

// If no signals are returnd we will use a sensible default set.
// If no signals are included we will use a sensible default set.
if len(sigs) == 0 {
sigs = []os.Signal{syscall.SIGTERM, syscall.SIGINT}
}

sc := &sigCtx{Context: ctx}

ch := make(chan os.Signal, 2)
signal.Notify(ch, sigs...)

go func() {

i := 0
for sig := range ch {
i++
if i > 1 {
os.Exit(1)
}
sc.m.Lock()
sc.exitSignal = sig
sc.m.Unlock()

cancel()
cancel(fmt.Errorf("got signal %s", sig))
}
}()

return sc
return ctx
}