diff --git a/function/function.go b/function/function.go index 609792f..516b09b 100644 --- a/function/function.go +++ b/function/function.go @@ -11,10 +11,12 @@ import ( // After creates a function that invokes func once it's called n or more times func After(n int, fn interface{}) func(args ...interface{}) []reflect.Value { + // Catch programming error while constructing the closure + MustBeFunction(fn) return func(args ...interface{}) []reflect.Value { n-- if n < 1 { - return invokeFunc(fn, args...) + return unsafeInvokeFunc(fn, args...) } return nil } @@ -22,11 +24,12 @@ func After(n int, fn interface{}) func(args ...interface{}) []reflect.Value { // Before creates a function that invokes func once it's called less than n times func Before(n int, fn interface{}) func(args ...interface{}) []reflect.Value { + // Catch programming error while constructing the closure + MustBeFunction(fn) var res []reflect.Value - return func(args ...interface{}) []reflect.Value { if n > 0 { - res = invokeFunc(fn, args...) + res = unsafeInvokeFunc(fn, args...) } if n <= 0 { fn = nil @@ -69,11 +72,12 @@ func Delay(delay time.Duration, fn interface{}, args ...interface{}) { // Schedule invoke function every duration time, util close the returned bool chan func Schedule(d time.Duration, fn interface{}, args ...interface{}) chan bool { + // Catch programming error while constructing the closure + MustBeFunction(fn) quit := make(chan bool) - go func() { for { - invokeFunc(fn, args...) + unsafeInvokeFunc(fn, args...) select { case <-time.After(d): case <-quit: diff --git a/function/function_util.go b/function/function_util.go index b809c56..dcc08a4 100644 --- a/function/function_util.go +++ b/function/function_util.go @@ -14,6 +14,15 @@ func invokeFunc(fn interface{}, args ...interface{}) []reflect.Value { return fv.Call(params) } +func unsafeInvokeFunc(fn interface{}, args ...interface{}) []reflect.Value { + fv := reflect.ValueOf(fn) + params := make([]reflect.Value, len(args)) + for i, item := range args { + params[i] = reflect.ValueOf(item) + } + return fv.Call(params) +} + func functionValue(function interface{}) reflect.Value { v := reflect.ValueOf(function) if v.Kind() != reflect.Func { @@ -21,3 +30,10 @@ func functionValue(function interface{}) reflect.Value { } return v } + +func MustBeFunction(function interface{}) { + v := reflect.ValueOf(function) + if v.Kind() != reflect.Func { + panic(fmt.Sprintf("Invalid function type, value of type %T", function)) + } +}