Skip to content

Commit

Permalink
feat: Support singleton providers (#501)
Browse files Browse the repository at this point in the history
* feat: Support singleton providers

This change adds support for provider functions that are
not reinvoked even if requested by multiple other providers.
Instead, their value is cached and reused between invocations.

To make this possible, we change how bindings are stored:
instead of just a function reference, we now store a binding object
which records whether the binding is a singleton,
and records the resolved singleton value (if any).

Resolves #500

* refac(bindings): hide singleton status

Don't require callAnyFunction to be aware of
whether a binding is a singleton or not.
  • Loading branch information
abhinav authored Feb 17, 2025
1 parent 7f94c90 commit 3b9af5b
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 16 deletions.
78 changes: 69 additions & 9 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,59 @@ import (
"strings"
)

// binding is a single binding registered with Kong.
type binding struct {
// fn is a function that returns a value of the target type.
fn reflect.Value

// val is a value of the target type.
// Must be set if done and singleton are true.
val reflect.Value

// singleton indicates whether the binding is a singleton.
// If true, the binding will be resolved once and cached.
singleton bool

// done indicates whether a singleton binding has been resolved.
// If singleton is false, this field is ignored.
done bool
}

// newValueBinding builds a binding with an already resolved value.
func newValueBinding(v reflect.Value) *binding {
return &binding{val: v, done: true, singleton: true}
}

// newFunctionBinding builds a binding with a function
// that will return a value of the target type.
//
// The function signature must be func(...) (T, error) or func(...) T
// where parameters are recursively resolved.
func newFunctionBinding(f reflect.Value, singleton bool) *binding {
return &binding{fn: f, singleton: singleton}
}

// Get returns the pre-resolved value for the binding,
// or false if the binding is not resolved.
func (b *binding) Get() (v reflect.Value, ok bool) {
return b.val, b.done
}

// Set sets the value of the binding to the given value,
// marking it as resolved.
//
// If the binding is not a singleton, this method does nothing.
func (b *binding) Set(v reflect.Value) {
if b.singleton {
b.val = v
b.done = true
}
}

// A map of type to function that returns a value of that type.
//
// The function should have the signature func(...) (T, error). Arguments are recursively resolved.
type bindings map[reflect.Type]any
type bindings map[reflect.Type]*binding

func (b bindings) String() string {
out := []string{}
Expand All @@ -21,17 +70,18 @@ func (b bindings) String() string {

func (b bindings) add(values ...any) bindings {
for _, v := range values {
v := v
b[reflect.TypeOf(v)] = func() (any, error) { return v, nil }
val := reflect.ValueOf(v)
b[val.Type()] = newValueBinding(val)
}
return b
}

func (b bindings) addTo(impl, iface any) {
b[reflect.TypeOf(iface).Elem()] = func() (any, error) { return impl, nil }
val := reflect.ValueOf(impl)
b[reflect.TypeOf(iface).Elem()] = newValueBinding(val)
}

func (b bindings) addProvider(provider any) error {
func (b bindings) addProvider(provider any, singleton bool) error {
pv := reflect.ValueOf(provider)
t := pv.Type()
if t.Kind() != reflect.Func {
Expand All @@ -47,7 +97,7 @@ func (b bindings) addProvider(provider any) error {
}
}
rt := pv.Type().Out(0)
b[rt] = provider
b[rt] = newFunctionBinding(pv, singleton)
return nil
}

Expand Down Expand Up @@ -148,19 +198,29 @@ func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error)
t := f.Type()
for i := 0; i < t.NumIn(); i++ {
pt := t.In(i)
argf, ok := bindings[pt]
binding, ok := bindings[pt]
if !ok {
return nil, fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt)
}

// Don't need to call the function if the value is already resolved.
if val, ok := binding.Get(); ok {
in = append(in, val)
continue
}

// Recursively resolve binding functions.
argv, err := callAnyFunction(reflect.ValueOf(argf), bindings)
argv, err := callAnyFunction(binding.fn, bindings)
if err != nil {
return nil, fmt.Errorf("%s: %w", pt, err)
}
if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && ferrv.Type().Implements(callbackReturnSignature) && !ferrv.IsNil() {
return nil, ferrv.Interface().(error) //nolint:forcetypeassert
}
in = append(in, reflect.ValueOf(argv[0]))

val := reflect.ValueOf(argv[0])
binding.Set(val)
in = append(in, val)
}
outv := f.Call(in)
out = make([]any, len(outv))
Expand Down
17 changes: 13 additions & 4 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,19 @@ func (c *Context) BindTo(impl, iface any) {
// This is useful when the Run() function of different commands require different values that may
// not all be initialisable from the main() function.
//
// "provider" must be a function with the signature func(...) (T, error) or func(...) T, where
// ... will be recursively injected with bound values.
// "provider" must be a function with the signature func(...) (T, error) or func(...) T,
// where ... will be recursively injected with bound values.
func (c *Context) BindToProvider(provider any) error {
return c.bindings.addProvider(provider)
return c.bindings.addProvider(provider, false /* singleton */)
}

// BindSingletonProvider allows binding of provider functions.
// The provider will be called once and the result cached.
//
// "provider" must be a function with the signature func(...) (T, error) or func(...) T,
// where ... will be recursively injected with bound values.
func (c *Context) BindSingletonProvider(provider any) error {
return c.bindings.addProvider(provider, true /* singleton */)
}

// Value returns the value for a particular path element.
Expand Down Expand Up @@ -792,7 +801,7 @@ func (c *Context) RunNode(node *Node, binds ...any) (err error) {
methodt := t.Method(i)
if strings.HasPrefix(methodt.Name, "Provide") {
method := p.Method(i)
if err := methodBinds.addProvider(method.Interface()); err != nil {
if err := methodBinds.addProvider(method.Interface(), false /* singleton */); err != nil {
return fmt.Errorf("%s.%s: %w", t.Name(), methodt.Name, err)
}
}
Expand Down
24 changes: 21 additions & 3 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,33 @@ func BindTo(impl, iface any) Option {

// BindToProvider binds an injected value to a provider function.
//
// The provider function must have the signature:
// The provider function must have one of the following signatures:
//
// func(...) (T, error)
// func(...) T
//
// func() (any, error)
// Where arguments to the function are injected by Kong.
//
// This is useful when the Run() function of different commands require different values that may
// not all be initialisable from the main() function.
func BindToProvider(provider any) Option {
return OptionFunc(func(k *Kong) error {
return k.bindings.addProvider(provider)
return k.bindings.addProvider(provider, false /* singleton */)
})
}

// BindSingletonProvider binds an injected value to a provider function.
// The provider function must have the signature:
//
// func(...) (T, error)
// func(...) T
//
// Unlike [BindToProvider], the provider function will only be called
// at most once, and the result will be cached and reused
// across multiple recipients of the injected value.
func BindSingletonProvider(provider any) Option {
return OptionFunc(func(k *Kong) error {
return k.bindings.addProvider(provider, true /* singleton */)
})
}

Expand Down
37 changes: 37 additions & 0 deletions options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,43 @@ func TestBindToProvider(t *testing.T) {
assert.True(t, cli.Called)
}

func TestBindSingletonProvider(t *testing.T) {
type (
Connection struct{}
ClientA struct{ conn *Connection }
ClientB struct{ conn *Connection }
)

var numConnections int
newConnection := func() *Connection {
numConnections++
return &Connection{}
}

var cli struct{}
app, err := New(&cli,
BindSingletonProvider(newConnection),
BindToProvider(func(conn *Connection) *ClientA {
return &ClientA{conn: conn}
}),
BindToProvider(func(conn *Connection) *ClientB {
return &ClientB{conn: conn}
}),
)
assert.NoError(t, err)

ctx, err := app.Parse([]string{})
assert.NoError(t, err)

_, err = ctx.Call(func(a *ClientA, b *ClientB) {
assert.NotZero(t, a.conn)
assert.NotZero(t, b.conn)

assert.Equal(t, 1, numConnections, "expected newConnection to be called only once")
})
assert.NoError(t, err)
}

func TestFlagNamer(t *testing.T) {
var cli struct {
SomeFlag string
Expand Down

0 comments on commit 3b9af5b

Please sign in to comment.