-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ccmtaylor compose
- Loading branch information
Showing
4 changed files
with
199 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package sqlhooks | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
) | ||
|
||
// Compose allows for composing multiple Hooks into one. | ||
// It runs every callback on every hook in argument order, | ||
// even if previous hooks return an error. | ||
// If multiple hooks return errors, the error return value will be | ||
// MultipleErrors, which allows for introspecting the errors if necessary. | ||
func Compose(hooks ...Hooks) Hooks { | ||
return composed(hooks) | ||
} | ||
|
||
type composed []Hooks | ||
|
||
func (c composed) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { | ||
var errors []error | ||
for _, hook := range c { | ||
c, err := hook.Before(ctx, query, args...) | ||
if err != nil { | ||
errors = append(errors, err) | ||
} | ||
if c != nil { | ||
ctx = c | ||
} | ||
} | ||
return ctx, wrapErrors(nil, errors) | ||
} | ||
|
||
func (c composed) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { | ||
var errors []error | ||
for _, hook := range c { | ||
var err error | ||
c, err := hook.After(ctx, query, args...) | ||
if err != nil { | ||
errors = append(errors, err) | ||
} | ||
if c != nil { | ||
ctx = c | ||
} | ||
} | ||
return ctx, wrapErrors(nil, errors) | ||
} | ||
|
||
func (c composed) OnError(ctx context.Context, cause error, query string, args ...interface{}) error { | ||
var errors []error | ||
for _, hook := range c { | ||
if onErrorer, ok := hook.(OnErrorer); ok { | ||
if err := onErrorer.OnError(ctx, cause, query, args...); err != nil && err != cause { | ||
errors = append(errors, err) | ||
} | ||
} | ||
} | ||
return wrapErrors(cause, errors) | ||
} | ||
|
||
func wrapErrors(def error, errors []error) error { | ||
switch len(errors) { | ||
case 0: | ||
return def | ||
case 1: | ||
return errors[0] | ||
default: | ||
return MultipleErrors(errors) | ||
} | ||
} | ||
|
||
// MultipleErrors is an error that contains multiple errors. | ||
type MultipleErrors []error | ||
|
||
func (m MultipleErrors) Error() string { | ||
return fmt.Sprint("multiple errors:", []error(m)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// +build go1.13 | ||
|
||
package sqlhooks | ||
|
||
import "errors" | ||
|
||
// Is returns true if any of the wrapped errors is target according to errors.Is() | ||
func (m MultipleErrors) Is(target error) bool { | ||
for _, err := range m { | ||
if errors.Is(err, target) { | ||
return true | ||
} | ||
} | ||
return false | ||
} | ||
|
||
// Is tries to convert each wrapped error to target with errors.As() and returns true that succeeds. | ||
// If none of the errors are convertible, returns false. | ||
func (m MultipleErrors) As(target interface{}) bool { | ||
for _, err := range m { | ||
if errors.As(err, &target) { | ||
return true | ||
} | ||
} | ||
return false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
package sqlhooks | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"reflect" | ||
"testing" | ||
) | ||
|
||
var ( | ||
oops = errors.New("oops") | ||
oopsHook = &testHooks{ | ||
before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { | ||
return ctx, oops | ||
}, | ||
after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { | ||
return ctx, oops | ||
}, | ||
onError: func(ctx context.Context, err error, query string, args ...interface{}) error { | ||
return oops | ||
}, | ||
} | ||
okHook = &testHooks{ | ||
before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { | ||
return ctx, nil | ||
}, | ||
after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { | ||
return ctx, nil | ||
}, | ||
onError: func(ctx context.Context, err error, query string, args ...interface{}) error { | ||
return nil | ||
}, | ||
} | ||
) | ||
|
||
func TestCompose(t *testing.T) { | ||
for _, it := range []struct { | ||
name string | ||
hooks Hooks | ||
want error | ||
}{ | ||
{"happy case", Compose(okHook, okHook), nil}, | ||
{"no hooks", Compose(), nil}, | ||
{"multiple errors", Compose(oopsHook, okHook, oopsHook), MultipleErrors([]error{oops, oops})}, | ||
{"single error", Compose(okHook, oopsHook, okHook), oops}, | ||
} { | ||
t.Run(it.name, func(t *testing.T) { | ||
t.Run("Before", func(t *testing.T) { | ||
_, got := it.hooks.Before(context.Background(), "query") | ||
if !reflect.DeepEqual(it.want, got) { | ||
t.Errorf("unexpected error. want: %q, got: %q", it.want, got) | ||
} | ||
}) | ||
t.Run("After", func(t *testing.T) { | ||
_, got := it.hooks.After(context.Background(), "query") | ||
if !reflect.DeepEqual(it.want, got) { | ||
t.Errorf("unexpected error. want: %q, got: %q", it.want, got) | ||
} | ||
}) | ||
t.Run("OnError", func(t *testing.T) { | ||
cause := errors.New("crikey") | ||
want := it.want | ||
if want == nil { | ||
want = cause | ||
} | ||
got := it.hooks.(OnErrorer).OnError(context.Background(), cause, "query") | ||
if !reflect.DeepEqual(want, got) { | ||
t.Errorf("unexpected error. want: %q, got: %q", want, got) | ||
} | ||
}) | ||
}) | ||
} | ||
} | ||
|
||
func TestWrapErrors(t *testing.T) { | ||
var ( | ||
err1 = errors.New("oops") | ||
err2 = errors.New("oops2") | ||
) | ||
for _, it := range []struct { | ||
name string | ||
def error | ||
errors []error | ||
want error | ||
}{ | ||
{"no errors", err1, nil, err1}, | ||
{"single error", nil, []error{err1}, err1}, | ||
{"multiple errors", nil, []error{err1, err2}, MultipleErrors([]error{err1, err2})}, | ||
} { | ||
t.Run(it.name, func(t *testing.T) { | ||
if want, got := it.want, wrapErrors(it.def, it.errors); !reflect.DeepEqual(want, got) { | ||
t.Errorf("unexpected wrapping. want: %q, got %q", want, got) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters