1
0
mirror of https://github.com/duke-git/lancet.git synced 2026-02-04 12:52:28 +08:00

function: catch earlier programming error (#8)

Place at first line of the function body a function type safe guard
check.
This commit is contained in:
donutloop
2021-12-30 13:22:00 +01:00
committed by GitHub
parent 0b0eb695e8
commit 613785b07c
2 changed files with 25 additions and 5 deletions

View File

@@ -11,10 +11,12 @@ import (
// After creates a function that invokes func once it's called n or more times
func After(n int, fn interface{}) func(args ...interface{}) []reflect.Value {
// Catch programming error while constructing the closure
MustBeFunction(fn)
return func(args ...interface{}) []reflect.Value {
n--
if n < 1 {
return invokeFunc(fn, args...)
return unsafeInvokeFunc(fn, args...)
}
return nil
}
@@ -22,11 +24,12 @@ func After(n int, fn interface{}) func(args ...interface{}) []reflect.Value {
// Before creates a function that invokes func once it's called less than n times
func Before(n int, fn interface{}) func(args ...interface{}) []reflect.Value {
// Catch programming error while constructing the closure
MustBeFunction(fn)
var res []reflect.Value
return func(args ...interface{}) []reflect.Value {
if n > 0 {
res = invokeFunc(fn, args...)
res = unsafeInvokeFunc(fn, args...)
}
if n <= 0 {
fn = nil
@@ -69,11 +72,12 @@ func Delay(delay time.Duration, fn interface{}, args ...interface{}) {
// Schedule invoke function every duration time, util close the returned bool chan
func Schedule(d time.Duration, fn interface{}, args ...interface{}) chan bool {
// Catch programming error while constructing the closure
MustBeFunction(fn)
quit := make(chan bool)
go func() {
for {
invokeFunc(fn, args...)
unsafeInvokeFunc(fn, args...)
select {
case <-time.After(d):
case <-quit:

View File

@@ -14,6 +14,15 @@ func invokeFunc(fn interface{}, args ...interface{}) []reflect.Value {
return fv.Call(params)
}
func unsafeInvokeFunc(fn interface{}, args ...interface{}) []reflect.Value {
fv := reflect.ValueOf(fn)
params := make([]reflect.Value, len(args))
for i, item := range args {
params[i] = reflect.ValueOf(item)
}
return fv.Call(params)
}
func functionValue(function interface{}) reflect.Value {
v := reflect.ValueOf(function)
if v.Kind() != reflect.Func {
@@ -21,3 +30,10 @@ func functionValue(function interface{}) reflect.Value {
}
return v
}
func MustBeFunction(function interface{}) {
v := reflect.ValueOf(function)
if v.Kind() != reflect.Func {
panic(fmt.Sprintf("Invalid function type, value of type %T", function))
}
}