Skip to content

Commit

Permalink
Ccmtaylor compose (#29)
Browse files Browse the repository at this point in the history
Ccmtaylor compose
  • Loading branch information
Gustavo Chaín authored Nov 25, 2019
2 parents e598f34 + 50840d6 commit e01e589
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 1 deletion.
76 changes: 76 additions & 0 deletions compose.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package sqlhooks

import (
"context"
"fmt"
)

// Compose allows for composing multiple Hooks into one.
// It runs every callback on every hook in argument order,
// even if previous hooks return an error.
// If multiple hooks return errors, the error return value will be
// MultipleErrors, which allows for introspecting the errors if necessary.
func Compose(hooks ...Hooks) Hooks {
return composed(hooks)
}

type composed []Hooks

func (c composed) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
var errors []error
for _, hook := range c {
c, err := hook.Before(ctx, query, args...)
if err != nil {
errors = append(errors, err)
}
if c != nil {
ctx = c
}
}
return ctx, wrapErrors(nil, errors)
}

func (c composed) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
var errors []error
for _, hook := range c {
var err error
c, err := hook.After(ctx, query, args...)
if err != nil {
errors = append(errors, err)
}
if c != nil {
ctx = c
}
}
return ctx, wrapErrors(nil, errors)
}

func (c composed) OnError(ctx context.Context, cause error, query string, args ...interface{}) error {
var errors []error
for _, hook := range c {
if onErrorer, ok := hook.(OnErrorer); ok {
if err := onErrorer.OnError(ctx, cause, query, args...); err != nil && err != cause {
errors = append(errors, err)
}
}
}
return wrapErrors(cause, errors)
}

func wrapErrors(def error, errors []error) error {
switch len(errors) {
case 0:
return def
case 1:
return errors[0]
default:
return MultipleErrors(errors)
}
}

// MultipleErrors is an error that contains multiple errors.
type MultipleErrors []error

func (m MultipleErrors) Error() string {
return fmt.Sprint("multiple errors:", []error(m))
}
26 changes: 26 additions & 0 deletions compose_1_13.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// +build go1.13

package sqlhooks

import "errors"

// Is returns true if any of the wrapped errors is target according to errors.Is()
func (m MultipleErrors) Is(target error) bool {
for _, err := range m {
if errors.Is(err, target) {
return true
}
}
return false
}

// Is tries to convert each wrapped error to target with errors.As() and returns true that succeeds.
// If none of the errors are convertible, returns false.
func (m MultipleErrors) As(target interface{}) bool {
for _, err := range m {
if errors.As(err, &target) {
return true
}
}
return false
}
96 changes: 96 additions & 0 deletions compose_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package sqlhooks

import (
"context"
"errors"
"reflect"
"testing"
)

var (
oops = errors.New("oops")
oopsHook = &testHooks{
before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, oops
},
after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, oops
},
onError: func(ctx context.Context, err error, query string, args ...interface{}) error {
return oops
},
}
okHook = &testHooks{
before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, nil
},
after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, nil
},
onError: func(ctx context.Context, err error, query string, args ...interface{}) error {
return nil
},
}
)

func TestCompose(t *testing.T) {
for _, it := range []struct {
name string
hooks Hooks
want error
}{
{"happy case", Compose(okHook, okHook), nil},
{"no hooks", Compose(), nil},
{"multiple errors", Compose(oopsHook, okHook, oopsHook), MultipleErrors([]error{oops, oops})},
{"single error", Compose(okHook, oopsHook, okHook), oops},
} {
t.Run(it.name, func(t *testing.T) {
t.Run("Before", func(t *testing.T) {
_, got := it.hooks.Before(context.Background(), "query")
if !reflect.DeepEqual(it.want, got) {
t.Errorf("unexpected error. want: %q, got: %q", it.want, got)
}
})
t.Run("After", func(t *testing.T) {
_, got := it.hooks.After(context.Background(), "query")
if !reflect.DeepEqual(it.want, got) {
t.Errorf("unexpected error. want: %q, got: %q", it.want, got)
}
})
t.Run("OnError", func(t *testing.T) {
cause := errors.New("crikey")
want := it.want
if want == nil {
want = cause
}
got := it.hooks.(OnErrorer).OnError(context.Background(), cause, "query")
if !reflect.DeepEqual(want, got) {
t.Errorf("unexpected error. want: %q, got: %q", want, got)
}
})
})
}
}

func TestWrapErrors(t *testing.T) {
var (
err1 = errors.New("oops")
err2 = errors.New("oops2")
)
for _, it := range []struct {
name string
def error
errors []error
want error
}{
{"no errors", err1, nil, err1},
{"single error", nil, []error{err1}, err1},
{"multiple errors", nil, []error{err1, err2}, MultipleErrors([]error{err1, err2})},
} {
t.Run(it.name, func(t *testing.T) {
if want, got := it.want, wrapErrors(it.def, it.errors); !reflect.DeepEqual(want, got) {
t.Errorf("unexpected wrapping. want: %q, got %q", want, got)
}
})
}
}
2 changes: 1 addition & 1 deletion sqlhooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (h *testHooks) After(ctx context.Context, query string, args ...interface{}
return h.after(ctx, query, args...)
}

func (h *testHooks) ErrHook(ctx context.Context, err error, query string, args ...interface{}) error {
func (h *testHooks) OnError(ctx context.Context, err error, query string, args ...interface{}) error {
return h.onError(ctx, err, query, args...)
}

Expand Down

0 comments on commit e01e589

Please sign in to comment.